normalize_rule, path_join_robust, print_failure, print_success,
prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
- supports_color, strip_rule, update_readme_data, write_data,
- write_opening_header)
+ supports_color, strip_rule, update_all_sources, update_readme_data,
+ write_data, write_opening_header)
import updateHostsFile
import unittest
sys.stdout = StringIO()
- @mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
- def test_freshen_no_update(self, _, mock_update):
+ def test_freshen_no_update(self, _):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
dir_count = self.dir_count
- prompt_for_update(freshen=True, update_auto=False)
-
- mock_update.assert_not_called()
- mock_update.reset_mock()
+ update_sources = prompt_for_update(freshen=True, update_auto=False)
+ self.assertFalse(update_sources)
output = sys.stdout.getvalue()
expected = ("OK, we'll stick with "
contents = f.read()
self.assertEqual(contents, hosts_data)
- @mock.patch("updateHostsFile.update_all_sources", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
- def test_freshen_update(self, _, mock_update):
+ def test_freshen_update(self, _):
hosts_file = os.path.join(self.test_dir, "hosts")
hosts_data = "This data should not be overwritten"
dir_count = self.dir_count
for update_auto in (False, True):
- prompt_for_update(freshen=True, update_auto=update_auto)
-
- self.assert_called_once(mock_update)
- mock_update.reset_mock()
+ update_sources = prompt_for_update(freshen=True,
+ update_auto=update_auto)
+ self.assertTrue(update_sources)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
# End Exclusion Logic
+# Update Logic
+class TestUpdateAllSources(BaseStdout):
+
+ def setUp(self):
+ BaseStdout.setUp(self)
+
+ self.source_data_filename = "data.json"
+ self.host_filename = "hosts.txt"
+
+ @mock.patch(builtins() + ".open")
+ @mock.patch("updateHostsFile.recursive_glob", return_value=[])
+ def test_no_sources(self, _, mock_open):
+ update_all_sources(self.source_data_filename, self.host_filename)
+ mock_open.assert_not_called()
+
+ @mock.patch(builtins() + ".open", return_value=mock.Mock())
+ @mock.patch("json.load", return_value={"url": "example.com"})
+ @mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
+ @mock.patch("updateHostsFile.write_data", return_value=0)
+ @mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
+ def test_one_source(self, mock_get, mock_write, *_):
+ update_all_sources(self.source_data_filename, self.host_filename)
+ self.assert_called_once(mock_write)
+ self.assert_called_once(mock_get)
+
+ output = sys.stdout.getvalue()
+ expected = "Updating source from example.com"
+
+ self.assertIn(expected, output)
+
+ @mock.patch(builtins() + ".open", return_value=mock.Mock())
+ @mock.patch("json.load", return_value={"url": "example.com"})
+ @mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
+ @mock.patch("updateHostsFile.write_data", return_value=0)
+ @mock.patch("updateHostsFile.get_file_by_url",
+ return_value=Exception("fail"))
+ def test_source_fail(self, mock_get, mock_write, *_):
+ update_all_sources(self.source_data_filename, self.host_filename)
+ mock_write.assert_not_called()
+ self.assert_called_once(mock_get)
+
+ output = sys.stdout.getvalue()
+ expecteds = ["Updating source from example.com",
+ "Error in updating source: example.com"]
+ for expected in expecteds:
+ self.assertIn(expected, output)
+
+ @mock.patch(builtins() + ".open", return_value=mock.Mock())
+ @mock.patch("json.load", side_effect=[{"url": "example.com"},
+ {"url": "example2.com"}])
+ @mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
+ @mock.patch("updateHostsFile.write_data", return_value=0)
+ @mock.patch("updateHostsFile.get_file_by_url",
+ side_effect=[Exception("fail"), "file_data"])
+ def test_sources_fail_succeed(self, mock_get, mock_write, *_):
+ update_all_sources(self.source_data_filename, self.host_filename)
+ self.assert_called_once(mock_write)
+
+ get_calls = [mock.call("example.com"), mock.call("example2.com")]
+ mock_get.assert_has_calls(get_calls)
+
+ output = sys.stdout.getvalue()
+ expecteds = ["Updating source from example.com",
+ "Error in updating source: example.com",
+ "Updating source from example2.com"]
+ for expected in expecteds:
+ self.assertIn(expected, output)
+# End Update Logic
+
+
# File Logic
class TestNormalizeRule(BaseStdout):
auto = settings["auto"]
exclusion_regexes = settings["exclusionregexs"]
- prompt_for_update(freshen=settings["freshen"], update_auto=auto)
+ update_sources = prompt_for_update(freshen=settings["freshen"],
+ update_auto=auto)
+ if update_sources:
+ update_all_sources(settings["sourcedatafilename"],
+ settings["hostfilename"])
+
gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
if gather_exclusions:
if it is requested that data sources not be updated.
update_auto : bool
Whether or not to automatically update all data sources.
+
+ Returns
+ -------
+ update_sources : bool
+ Whether or not we should update data sources for exclusion files.
"""
# Create a hosts file if it doesn't exist.
prompt = "Do you want to update all data sources?"
if update_auto or query_yes_no(prompt):
- update_all_sources()
+ return True
elif not update_auto:
print("OK, we'll stick with what we've got locally.")
+ return False
+
def prompt_for_exclusions(skip_prompt):
"""
# Update Logic
-def update_all_sources():
+def update_all_sources(source_data_filename, host_filename):
"""
Update all host files, regardless of folder depth.
+
+ Parameters
+ ----------
+ source_data_filename : str
+ The name of the filename where information regarding updating
+ sources for a particular URL is stored. This filename is assumed
+ to be the same for all sources.
+ host_filename : str
+ The name of the file in which the updated source information
+ in stored for a particular URL. This filename is assumed to be
+ the same for all sources.
"""
- all_sources = recursive_glob("*", settings["sourcedatafilename"])
+ all_sources = recursive_glob("*", source_data_filename)
for source in all_sources:
update_file = open(source, "r")
hosts_file = open(path_join_robust(BASEDIR_PATH,
os.path.dirname(source),
- settings["hostfilename"]), "wb")
+ host_filename), "wb")
write_data(hosts_file, updated_file)
hosts_file.close()
except Exception: