From: gfyoung Date: Sun, 13 Aug 2017 11:22:42 +0000 (-0700) Subject: Refactor out source data updating X-Git-Url: http://git.99rst.org/?a=commitdiff_plain;h=24ab22e139027ea6b7afe56933a61af1324eec68;p=stevenblack-hosts.git Refactor out source data updating --- diff --git a/testUpdateHostsFile.py b/testUpdateHostsFile.py index 2c6514c42..836e949c4 100644 --- a/testUpdateHostsFile.py +++ b/testUpdateHostsFile.py @@ -13,7 +13,7 @@ from updateHostsFile import ( prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache, prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file, supports_color, strip_rule, update_all_sources, update_readme_data, - write_data, write_opening_header) + update_sources_data, write_data, write_opening_header) import updateHostsFile import unittest @@ -542,6 +542,80 @@ class TestMatchesExclusions(Base): # Update Logic +class TestUpdateSourcesData(Base): + + def setUp(self): + Base.setUp(self) + + self.data_path = "data" + self.extensions_path = "extensions" + self.source_data_filename = "update.json" + + self.update_kwargs = dict(datapath=self.data_path, + extensionspath=self.extensions_path, + sourcedatafilename=self.source_data_filename) + + def update_sources_data(self, sources_data, extensions): + return update_sources_data(sources_data[:], extensions=extensions, + **self.update_kwargs) + + @mock.patch("updateHostsFile.recursive_glob", return_value=[]) + @mock.patch("updateHostsFile.path_join_robust", return_value="dirpath") + @mock.patch(builtins() + ".open", return_value=mock.Mock()) + def test_no_update(self, mock_open, mock_join_robust, _): + extensions = [] + sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}] + + new_sources_data = self.update_sources_data(sources_data, extensions) + self.assertEqual(new_sources_data, sources_data) + mock_join_robust.assert_not_called() + mock_open.assert_not_called() + + extensions = [".json", ".txt"] + new_sources_data = self.update_sources_data(sources_data, extensions) + + self.assertEqual(new_sources_data, sources_data) + join_calls = [mock.call(self.extensions_path, ".json"), + mock.call(self.extensions_path, ".txt")] + mock_join_robust.assert_has_calls(join_calls) + mock_open.assert_not_called() + + @mock.patch("updateHostsFile.recursive_glob", + side_effect=[[], ["update1.txt", "update2.txt"]]) + @mock.patch("json.load", return_value={"mock_source": "mock_source.ext"}) + @mock.patch(builtins() + ".open", return_value=mock.Mock()) + @mock.patch("updateHostsFile.path_join_robust", return_value="dirpath") + def test_update_only_extensions(self, mock_join_robust, *_): + extensions = [".json"] + sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}] + new_sources_data = self.update_sources_data(sources_data, extensions) + + expected = sources_data + [{"mock_source": "mock_source.ext"}] * 2 + self.assertEqual(new_sources_data, expected) + self.assert_called_once(mock_join_robust) + + @mock.patch("updateHostsFile.recursive_glob", + side_effect=[["update1.txt", "update2.txt"], + ["update3.txt", "update4.txt"]]) + @mock.patch("json.load", side_effect=[{"mock_source": "mock_source.txt"}, + {"mock_source": "mock_source2.txt"}, + {"mock_source": "mock_source3.txt"}, + {"mock_source": "mock_source4.txt"}]) + @mock.patch(builtins() + ".open", return_value=mock.Mock()) + @mock.patch("updateHostsFile.path_join_robust", return_value="dirpath") + def test_update_both_pathways(self, mock_join_robust, *_): + extensions = [".json"] + sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}] + new_sources_data = self.update_sources_data(sources_data, extensions) + + expected = sources_data + [{"mock_source": "mock_source.txt"}, + {"mock_source": "mock_source2.txt"}, + {"mock_source": "mock_source3.txt"}, + {"mock_source": "mock_source4.txt"}] + self.assertEqual(new_sources_data, expected) + self.assert_called_once(mock_join_robust) + + class TestUpdateAllSources(BaseStdout): def setUp(self): diff --git a/updateHostsFile.py b/updateHostsFile.py index c52dee202..d9aa12090 100644 --- a/updateHostsFile.py +++ b/updateHostsFile.py @@ -129,25 +129,27 @@ def main(): settings = get_defaults() settings.update(options) - settings["sources"] = list_dir_no_hidden(settings["datapath"]) - settings["extensionsources"] = list_dir_no_hidden( - settings["extensionspath"]) + data_path = settings["datapath"] + extensions_path = settings["extensionspath"] + + settings["sources"] = list_dir_no_hidden(data_path) + settings["extensionsources"] = list_dir_no_hidden(extensions_path) # All our extensions folders... settings["extensions"] = [os.path.basename(item) for item in - list_dir_no_hidden(settings["extensionspath"])] + list_dir_no_hidden(extensions_path)] # ... intersected with the extensions passed-in as arguments, then sorted. settings["extensions"] = sorted(list( set(options["extensions"]).intersection(settings["extensions"]))) auto = settings["auto"] exclusion_regexes = settings["exclusionregexs"] + source_data_filename = settings["sourcedatafilename"] update_sources = prompt_for_update(freshen=settings["freshen"], update_auto=auto) if update_sources: - update_all_sources(settings["sourcedatafilename"], - settings["hostfilename"]) + update_all_sources(source_data_filename, settings["hostfilename"]) gather_exclusions = prompt_for_exclusions(skip_prompt=auto) @@ -159,15 +161,19 @@ def main(): exclusion_pattern=exclusion_pattern, exclusion_regexes=exclusion_regexes) - merge_file = create_initial_file() - remove_old_hosts_file(settings["backup"]) - extensions = settings["extensions"] - output_subfolder = settings["outputsubfolder"] + sources_data = update_sources_data(settings["sourcesdata"], + datapath=data_path, + extensions=extensions, + extensionspath=extensions_path, + sourcedatafilename=source_data_filename) + merge_file = create_initial_file() + remove_old_hosts_file(settings["backup"]) final_file = remove_dups_and_excl(merge_file, exclusion_regexes) number_of_rules = settings["numberofrules"] + output_subfolder = settings["outputsubfolder"] skip_static_hosts = settings["skipstatichosts"] write_opening_header(final_file, extensions=extensions, @@ -180,7 +186,7 @@ def main(): extensions=extensions, numberofrules=number_of_rules, outputsubfolder=output_subfolder, - sourcesdata=settings["sourcesdata"]) + sourcesdata=sources_data) print_success("Success! The hosts file has been saved in folder " + output_subfolder + "\nIt contains " + @@ -477,6 +483,52 @@ def matches_exclusions(stripped_rule, exclusion_regexes): # Update Logic +def update_sources_data(sources_data, **sources_params): + """ + Update the sources data and information for each source. + + Parameters + ---------- + sources_data : list + The list of sources data that we are to update. + sources_params : kwargs + Dictionary providing additional parameters for updating the + sources data. Currently, those fields are: + + 1) datapath + 2) extensions + 3) extensionspath + 4) sourcedatafilename + + Returns + ------- + update_sources_data : list + The original source data list with new source data appended. + """ + + source_data_filename = sources_params["sourcedatafilename"] + + for source in recursive_glob(sources_params["datapath"], + source_data_filename): + update_file = open(source, "r") + update_data = json.load(update_file) + sources_data.append(update_data) + update_file.close() + + for source in sources_params["extensions"]: + source_dir = path_join_robust( + sources_params["extensionspath"], source) + for update_file_path in recursive_glob(source_dir, + source_data_filename): + update_file = open(update_file_path, "r") + update_data = json.load(update_file) + + sources_data.append(update_data) + update_file.close() + + return sources_data + + def update_all_sources(source_data_filename, host_filename): """ Update all host files, regardless of folder depth. @@ -534,13 +586,6 @@ def create_initial_file(): with open(source, "r") as curFile: write_data(merge_file, curFile.read()) - for source in recursive_glob(settings["datapath"], - settings["sourcedatafilename"]): - update_file = open(source, "r") - update_data = json.load(update_file) - settings["sourcesdata"].append(update_data) - update_file.close() - # spin the sources for extensions to the base file for source in settings["extensions"]: for filename in recursive_glob(path_join_robust( @@ -548,15 +593,6 @@ def create_initial_file(): with open(filename, "r") as curFile: write_data(merge_file, curFile.read()) - for update_file_path in recursive_glob(path_join_robust( - settings["extensionspath"], source), - settings["sourcedatafilename"]): - update_file = open(update_file_path, "r") - update_data = json.load(update_file) - - settings["sourcesdata"].append(update_data) - update_file.close() - if os.path.isfile(settings["blacklistfile"]): with open(settings["blacklistfile"], "r") as curFile: write_data(merge_file, curFile.read())