Refactor out source data updating
authorgfyoung <redacted>
Sun, 13 Aug 2017 11:22:42 +0000 (04:22 -0700)
committergfyoung <redacted>
Sun, 20 Aug 2017 18:52:28 +0000 (11:52 -0700)
testUpdateHostsFile.py
updateHostsFile.py

index 2c6514c42a59ba74b63b9c80be6c6e8d579ff837..836e949c43858e5ec9a02665fedfdf59e8f86144 100644 (file)
@@ -13,7 +13,7 @@ from updateHostsFile import (
     prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
     prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
     supports_color, strip_rule, update_all_sources, update_readme_data,
-    write_data, write_opening_header)
+    update_sources_data, write_data, write_opening_header)
 
 import updateHostsFile
 import unittest
@@ -542,6 +542,80 @@ class TestMatchesExclusions(Base):
 
 
 # Update Logic
+class TestUpdateSourcesData(Base):
+
+    def setUp(self):
+        Base.setUp(self)
+
+        self.data_path = "data"
+        self.extensions_path = "extensions"
+        self.source_data_filename = "update.json"
+
+        self.update_kwargs = dict(datapath=self.data_path,
+                                  extensionspath=self.extensions_path,
+                                  sourcedatafilename=self.source_data_filename)
+
+    def update_sources_data(self, sources_data, extensions):
+        return update_sources_data(sources_data[:], extensions=extensions,
+                                   **self.update_kwargs)
+
+    @mock.patch("updateHostsFile.recursive_glob", return_value=[])
+    @mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
+    @mock.patch(builtins() + ".open", return_value=mock.Mock())
+    def test_no_update(self, mock_open, mock_join_robust, _):
+        extensions = []
+        sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
+
+        new_sources_data = self.update_sources_data(sources_data, extensions)
+        self.assertEqual(new_sources_data, sources_data)
+        mock_join_robust.assert_not_called()
+        mock_open.assert_not_called()
+
+        extensions = [".json", ".txt"]
+        new_sources_data = self.update_sources_data(sources_data, extensions)
+
+        self.assertEqual(new_sources_data, sources_data)
+        join_calls = [mock.call(self.extensions_path, ".json"),
+                      mock.call(self.extensions_path, ".txt")]
+        mock_join_robust.assert_has_calls(join_calls)
+        mock_open.assert_not_called()
+
+    @mock.patch("updateHostsFile.recursive_glob",
+                side_effect=[[], ["update1.txt", "update2.txt"]])
+    @mock.patch("json.load", return_value={"mock_source": "mock_source.ext"})
+    @mock.patch(builtins() + ".open", return_value=mock.Mock())
+    @mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
+    def test_update_only_extensions(self, mock_join_robust, *_):
+        extensions = [".json"]
+        sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
+        new_sources_data = self.update_sources_data(sources_data, extensions)
+
+        expected = sources_data + [{"mock_source": "mock_source.ext"}] * 2
+        self.assertEqual(new_sources_data, expected)
+        self.assert_called_once(mock_join_robust)
+
+    @mock.patch("updateHostsFile.recursive_glob",
+                side_effect=[["update1.txt", "update2.txt"],
+                             ["update3.txt", "update4.txt"]])
+    @mock.patch("json.load", side_effect=[{"mock_source": "mock_source.txt"},
+                                          {"mock_source": "mock_source2.txt"},
+                                          {"mock_source": "mock_source3.txt"},
+                                          {"mock_source": "mock_source4.txt"}])
+    @mock.patch(builtins() + ".open", return_value=mock.Mock())
+    @mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
+    def test_update_both_pathways(self, mock_join_robust, *_):
+        extensions = [".json"]
+        sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
+        new_sources_data = self.update_sources_data(sources_data, extensions)
+
+        expected = sources_data + [{"mock_source": "mock_source.txt"},
+                                   {"mock_source": "mock_source2.txt"},
+                                   {"mock_source": "mock_source3.txt"},
+                                   {"mock_source": "mock_source4.txt"}]
+        self.assertEqual(new_sources_data, expected)
+        self.assert_called_once(mock_join_robust)
+
+
 class TestUpdateAllSources(BaseStdout):
 
     def setUp(self):
index c52dee202d9a1656ba9ba79ba280e5f786154310..d9aa12090eb7c16a2b539d7a32b3fdf97791942a 100644 (file)
@@ -129,25 +129,27 @@ def main():
     settings = get_defaults()
     settings.update(options)
 
-    settings["sources"] = list_dir_no_hidden(settings["datapath"])
-    settings["extensionsources"] = list_dir_no_hidden(
-        settings["extensionspath"])
+    data_path = settings["datapath"]
+    extensions_path = settings["extensionspath"]
+
+    settings["sources"] = list_dir_no_hidden(data_path)
+    settings["extensionsources"] = list_dir_no_hidden(extensions_path)
 
     # All our extensions folders...
     settings["extensions"] = [os.path.basename(item) for item in
-                              list_dir_no_hidden(settings["extensionspath"])]
+                              list_dir_no_hidden(extensions_path)]
     # ... intersected with the extensions passed-in as arguments, then sorted.
     settings["extensions"] = sorted(list(
         set(options["extensions"]).intersection(settings["extensions"])))
 
     auto = settings["auto"]
     exclusion_regexes = settings["exclusionregexs"]
+    source_data_filename = settings["sourcedatafilename"]
 
     update_sources = prompt_for_update(freshen=settings["freshen"],
                                        update_auto=auto)
     if update_sources:
-        update_all_sources(settings["sourcedatafilename"],
-                           settings["hostfilename"])
+        update_all_sources(source_data_filename, settings["hostfilename"])
 
     gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
 
@@ -159,15 +161,19 @@ def main():
             exclusion_pattern=exclusion_pattern,
             exclusion_regexes=exclusion_regexes)
 
-    merge_file = create_initial_file()
-    remove_old_hosts_file(settings["backup"])
-
     extensions = settings["extensions"]
-    output_subfolder = settings["outputsubfolder"]
+    sources_data = update_sources_data(settings["sourcesdata"],
+                                       datapath=data_path,
+                                       extensions=extensions,
+                                       extensionspath=extensions_path,
+                                       sourcedatafilename=source_data_filename)
 
+    merge_file = create_initial_file()
+    remove_old_hosts_file(settings["backup"])
     final_file = remove_dups_and_excl(merge_file, exclusion_regexes)
 
     number_of_rules = settings["numberofrules"]
+    output_subfolder = settings["outputsubfolder"]
     skip_static_hosts = settings["skipstatichosts"]
 
     write_opening_header(final_file, extensions=extensions,
@@ -180,7 +186,7 @@ def main():
                        extensions=extensions,
                        numberofrules=number_of_rules,
                        outputsubfolder=output_subfolder,
-                       sourcesdata=settings["sourcesdata"])
+                       sourcesdata=sources_data)
 
     print_success("Success! The hosts file has been saved in folder " +
                   output_subfolder + "\nIt contains " +
@@ -477,6 +483,52 @@ def matches_exclusions(stripped_rule, exclusion_regexes):
 
 
 # Update Logic
+def update_sources_data(sources_data, **sources_params):
+    """
+    Update the sources data and information for each source.
+
+    Parameters
+    ----------
+    sources_data : list
+        The list of sources data that we are to update.
+    sources_params : kwargs
+        Dictionary providing additional parameters for updating the
+        sources data. Currently, those fields are:
+
+        1) datapath
+        2) extensions
+        3) extensionspath
+        4) sourcedatafilename
+
+    Returns
+    -------
+    update_sources_data : list
+        The original source data list with new source data appended.
+    """
+
+    source_data_filename = sources_params["sourcedatafilename"]
+
+    for source in recursive_glob(sources_params["datapath"],
+                                 source_data_filename):
+        update_file = open(source, "r")
+        update_data = json.load(update_file)
+        sources_data.append(update_data)
+        update_file.close()
+
+    for source in sources_params["extensions"]:
+        source_dir = path_join_robust(
+            sources_params["extensionspath"], source)
+        for update_file_path in recursive_glob(source_dir,
+                                               source_data_filename):
+            update_file = open(update_file_path, "r")
+            update_data = json.load(update_file)
+
+            sources_data.append(update_data)
+            update_file.close()
+
+    return sources_data
+
+
 def update_all_sources(source_data_filename, host_filename):
     """
     Update all host files, regardless of folder depth.
@@ -534,13 +586,6 @@ def create_initial_file():
         with open(source, "r") as curFile:
             write_data(merge_file, curFile.read())
 
-    for source in recursive_glob(settings["datapath"],
-                                 settings["sourcedatafilename"]):
-        update_file = open(source, "r")
-        update_data = json.load(update_file)
-        settings["sourcesdata"].append(update_data)
-        update_file.close()
-
     # spin the sources for extensions to the base file
     for source in settings["extensions"]:
         for filename in recursive_glob(path_join_robust(
@@ -548,15 +593,6 @@ def create_initial_file():
             with open(filename, "r") as curFile:
                 write_data(merge_file, curFile.read())
 
-        for update_file_path in recursive_glob(path_join_robust(
-                settings["extensionspath"], source),
-                settings["sourcedatafilename"]):
-            update_file = open(update_file_path, "r")
-            update_data = json.load(update_file)
-
-            settings["sourcesdata"].append(update_data)
-            update_file.close()
-
     if os.path.isfile(settings["blacklistfile"]):
         with open(settings["blacklistfile"], "r") as curFile:
             write_data(merge_file, curFile.read())
git clone https://git.99rst.org/PROJECT