Refactor out global settings usage in update logic
authorgfyoung <redacted>
Tue, 8 Aug 2017 04:18:35 +0000 (21:18 -0700)
committergfyoung <redacted>
Wed, 9 Aug 2017 15:13:22 +0000 (08:13 -0700)
testUpdateHostsFile.py
updateHostsFile.py

index 139bd89555e8532b957b388f4f7b6987f3cedb9e..444c71b355c59a1674e8ba1904a16a5604bd4924 100644 (file)
@@ -12,8 +12,8 @@ from updateHostsFile import (
     normalize_rule, path_join_robust, print_failure, print_success,
     prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
     prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
-    supports_color, strip_rule, update_readme_data, write_data,
-    write_opening_header)
+    supports_color, strip_rule, update_all_sources, update_readme_data,
+    write_data, write_opening_header)
 
 import updateHostsFile
 import unittest
@@ -190,9 +190,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
 
                 sys.stdout = StringIO()
 
-    @mock.patch("updateHostsFile.update_all_sources", return_value=0)
     @mock.patch("updateHostsFile.query_yes_no", return_value=False)
-    def test_freshen_no_update(self, _, mock_update):
+    def test_freshen_no_update(self, _):
         hosts_file = os.path.join(self.test_dir, "hosts")
         hosts_data = "This data should not be overwritten"
 
@@ -204,10 +203,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
 
             dir_count = self.dir_count
 
-            prompt_for_update(freshen=True, update_auto=False)
-
-            mock_update.assert_not_called()
-            mock_update.reset_mock()
+            update_sources = prompt_for_update(freshen=True, update_auto=False)
+            self.assertFalse(update_sources)
 
             output = sys.stdout.getvalue()
             expected = ("OK, we'll stick with "
@@ -222,9 +219,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
                 contents = f.read()
                 self.assertEqual(contents, hosts_data)
 
-    @mock.patch("updateHostsFile.update_all_sources", return_value=0)
     @mock.patch("updateHostsFile.query_yes_no", return_value=True)
-    def test_freshen_update(self, _, mock_update):
+    def test_freshen_update(self, _):
         hosts_file = os.path.join(self.test_dir, "hosts")
         hosts_data = "This data should not be overwritten"
 
@@ -237,10 +233,9 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
             dir_count = self.dir_count
 
             for update_auto in (False, True):
-                prompt_for_update(freshen=True, update_auto=update_auto)
-
-                self.assert_called_once(mock_update)
-                mock_update.reset_mock()
+                update_sources = prompt_for_update(freshen=True,
+                                                   update_auto=update_auto)
+                self.assertTrue(update_sources)
 
                 output = sys.stdout.getvalue()
                 self.assertEqual(output, "")
@@ -547,6 +542,76 @@ class TestMatchesExclusions(Base):
 # End Exclusion Logic
 
 
+# Update Logic
+class TestUpdateAllSources(BaseStdout):
+
+    def setUp(self):
+        BaseStdout.setUp(self)
+
+        self.source_data_filename = "data.json"
+        self.host_filename = "hosts.txt"
+
+    @mock.patch(builtins() + ".open")
+    @mock.patch("updateHostsFile.recursive_glob", return_value=[])
+    def test_no_sources(self, _, mock_open):
+        update_all_sources(self.source_data_filename, self.host_filename)
+        mock_open.assert_not_called()
+
+    @mock.patch(builtins() + ".open", return_value=mock.Mock())
+    @mock.patch("json.load", return_value={"url": "example.com"})
+    @mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
+    @mock.patch("updateHostsFile.write_data", return_value=0)
+    @mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
+    def test_one_source(self, mock_get, mock_write, *_):
+        update_all_sources(self.source_data_filename, self.host_filename)
+        self.assert_called_once(mock_write)
+        self.assert_called_once(mock_get)
+
+        output = sys.stdout.getvalue()
+        expected = "Updating source  from example.com"
+
+        self.assertIn(expected, output)
+
+    @mock.patch(builtins() + ".open", return_value=mock.Mock())
+    @mock.patch("json.load", return_value={"url": "example.com"})
+    @mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
+    @mock.patch("updateHostsFile.write_data", return_value=0)
+    @mock.patch("updateHostsFile.get_file_by_url",
+                return_value=Exception("fail"))
+    def test_source_fail(self, mock_get, mock_write, *_):
+        update_all_sources(self.source_data_filename, self.host_filename)
+        mock_write.assert_not_called()
+        self.assert_called_once(mock_get)
+
+        output = sys.stdout.getvalue()
+        expecteds = ["Updating source  from example.com",
+                     "Error in updating source:  example.com"]
+        for expected in expecteds:
+            self.assertIn(expected, output)
+
+    @mock.patch(builtins() + ".open", return_value=mock.Mock())
+    @mock.patch("json.load", side_effect=[{"url": "example.com"},
+                                          {"url": "example2.com"}])
+    @mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
+    @mock.patch("updateHostsFile.write_data", return_value=0)
+    @mock.patch("updateHostsFile.get_file_by_url",
+                side_effect=[Exception("fail"), "file_data"])
+    def test_sources_fail_succeed(self, mock_get, mock_write, *_):
+        update_all_sources(self.source_data_filename, self.host_filename)
+        self.assert_called_once(mock_write)
+
+        get_calls = [mock.call("example.com"), mock.call("example2.com")]
+        mock_get.assert_has_calls(get_calls)
+
+        output = sys.stdout.getvalue()
+        expecteds = ["Updating source  from example.com",
+                     "Error in updating source:  example.com",
+                     "Updating source  from example2.com"]
+        for expected in expecteds:
+            self.assertIn(expected, output)
+# End Update Logic
+
+
 # File Logic
 class TestNormalizeRule(BaseStdout):
 
index dbcfe7696ffa41d1af1f50b3da940c6b9201d385..7f429ce8694305daf13b18f64832b6971329ab23 100644 (file)
@@ -149,7 +149,12 @@ def main():
     auto = settings["auto"]
     exclusion_regexes = settings["exclusionregexs"]
 
-    prompt_for_update(freshen=settings["freshen"], update_auto=auto)
+    update_sources = prompt_for_update(freshen=settings["freshen"],
+                                       update_auto=auto)
+    if update_sources:
+        update_all_sources(settings["sourcedatafilename"],
+                           settings["hostfilename"])
+
     gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
 
     if gather_exclusions:
@@ -221,6 +226,11 @@ def prompt_for_update(freshen, update_auto):
         if it is requested that data sources not be updated.
     update_auto : bool
         Whether or not to automatically update all data sources.
+
+    Returns
+    -------
+    update_sources : bool
+        Whether or not we should update data sources for exclusion files.
     """
 
     # Create a hosts file if it doesn't exist.
@@ -242,10 +252,12 @@ def prompt_for_update(freshen, update_auto):
     prompt = "Do you want to update all data sources?"
 
     if update_auto or query_yes_no(prompt):
-        update_all_sources()
+        return True
     elif not update_auto:
         print("OK, we'll stick with what we've got locally.")
 
+    return False
+
 
 def prompt_for_exclusions(skip_prompt):
     """
@@ -478,12 +490,23 @@ def matches_exclusions(stripped_rule, exclusion_regexes):
 
 
 # Update Logic
-def update_all_sources():
+def update_all_sources(source_data_filename, host_filename):
     """
     Update all host files, regardless of folder depth.
+
+    Parameters
+    ----------
+    source_data_filename : str
+        The name of the filename where information regarding updating
+        sources for a particular URL is stored. This filename is assumed
+        to be the same for all sources.
+    host_filename : str
+        The name of the file in which the updated source information
+        in stored for a particular URL. This filename is assumed to be
+        the same for all sources.
     """
 
-    all_sources = recursive_glob("*", settings["sourcedatafilename"])
+    all_sources = recursive_glob("*", source_data_filename)
 
     for source in all_sources:
         update_file = open(source, "r")
@@ -502,7 +525,7 @@ def update_all_sources():
 
             hosts_file = open(path_join_robust(BASEDIR_PATH,
                                                os.path.dirname(source),
-                                               settings["hostfilename"]), "wb")
+                                               host_filename), "wb")
             write_data(hosts_file, updated_file)
             hosts_file.close()
         except Exception:
git clone https://git.99rst.org/PROJECT