Refactor out global settings usage in user prompt
authorgfyoung <redacted>
Sun, 9 Jul 2017 19:00:11 +0000 (12:00 -0700)
committergfyoung <redacted>
Thu, 13 Jul 2017 15:20:03 +0000 (08:20 -0700)
testUpdateHostsFile.py
updateHostsFile.py

index c1f2e07dee635d3d59ae0b7614ac373772b86cc1..9e8f5e30f13f6c15819475a5461db05ee3f598c0 100644 (file)
@@ -10,8 +10,10 @@ from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
                              get_file_by_url, is_valid_domain_format,
                              move_hosts_file_into_place, normalize_rule,
                              path_join_robust, print_failure, print_success,
-                             supports_color, query_yes_no, recursive_glob,
-                             remove_old_hosts_file, strip_rule,
+                             prompt_for_exclusions, prompt_for_move,
+                             prompt_for_flush_dns_cache, prompt_for_update,
+                             query_yes_no, recursive_glob,
+                             remove_old_hosts_file, supports_color, strip_rule,
                              update_readme_data, write_data,
                              write_opening_header)
 import updateHostsFile
@@ -33,7 +35,7 @@ else:
     import mock
 
 
-# Base Test Classes
+# Test Helper Objects
 class Base(unittest.TestCase):
 
     @staticmethod
@@ -66,7 +68,14 @@ class BaseMockDir(Base):
 
     def tearDown(self):
         shutil.rmtree(self.test_dir)
-# End Base Test Classes
+
+
+def builtins():
+    if PY3:
+        return "builtins"
+    else:
+        return "__builtin__"
+# End Test Helper Objects
 
 
 # Project Settings
@@ -107,6 +116,303 @@ class TestGetDefaults(Base):
 # End Project Settings
 
 
+# Prompt the User
+class TestPromptForUpdate(BaseStdout, BaseMockDir):
+
+    def setUp(self):
+        BaseStdout.setUp(self)
+        BaseMockDir.setUp(self)
+
+    def test_no_freshen_no_new_file(self):
+        hosts_file = os.path.join(self.test_dir, "hosts")
+        hosts_data = "This data should not be overwritten"
+
+        with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+            updateHostsFile.BASEDIR_PATH = self.test_dir
+
+            with open(hosts_file, "w") as f:
+                f.write(hosts_data)
+
+        for update_auto in (False, True):
+            dir_count = self.dir_count
+            prompt_for_update(freshen=False, update_auto=update_auto)
+
+            output = sys.stdout.getvalue()
+            self.assertEqual(output, "")
+
+            sys.stdout = StringIO()
+
+            self.assertEqual(self.dir_count, dir_count)
+
+            with open(hosts_file, "r") as f:
+                contents = f.read()
+                self.assertEqual(contents, hosts_data)
+
+    def test_no_freshen_new_file(self):
+        hosts_file = os.path.join(self.test_dir, "hosts")
+
+        with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+            updateHostsFile.BASEDIR_PATH = self.test_dir
+
+            dir_count = self.dir_count
+            prompt_for_update(freshen=False, update_auto=False)
+
+            output = sys.stdout.getvalue()
+            self.assertEqual(output, "")
+
+            sys.stdout = StringIO()
+
+            self.assertEqual(self.dir_count, dir_count + 1)
+
+            with open(hosts_file, "r") as f:
+                contents = f.read()
+                self.assertEqual(contents, "")
+
+    @mock.patch(builtins() + ".open")
+    def test_no_freshen_fail_new_file(self, mock_open):
+        for exc in (IOError, OSError):
+            mock_open.side_effect = exc("failed open")
+
+            with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+                updateHostsFile.BASEDIR_PATH = self.test_dir
+                prompt_for_update(freshen=False, update_auto=False)
+
+                output = sys.stdout.getvalue()
+                expected = ("ERROR: No 'hosts' file in the folder. "
+                            "Try creating one manually.")
+                self.assertIn(expected, output)
+
+                sys.stdout = StringIO()
+
+    @mock.patch("updateHostsFile.update_all_sources", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def test_freshen_no_update(self, _, mock_update):
+        hosts_file = os.path.join(self.test_dir, "hosts")
+        hosts_data = "This data should not be overwritten"
+
+        with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+            updateHostsFile.BASEDIR_PATH = self.test_dir
+
+            with open(hosts_file, "w") as f:
+                f.write(hosts_data)
+
+            dir_count = self.dir_count
+
+            prompt_for_update(freshen=True, update_auto=False)
+
+            mock_update.assert_not_called()
+            mock_update.reset_mock()
+
+            output = sys.stdout.getvalue()
+            expected = ("OK, we'll stick with "
+                        "what we've got locally.")
+            self.assertIn(expected, output)
+
+            sys.stdout = StringIO()
+
+            self.assertEqual(self.dir_count, dir_count)
+
+            with open(hosts_file, "r") as f:
+                contents = f.read()
+                self.assertEqual(contents, hosts_data)
+
+    @mock.patch("updateHostsFile.update_all_sources", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=True)
+    def test_freshen_update(self, _, mock_update):
+        hosts_file = os.path.join(self.test_dir, "hosts")
+        hosts_data = "This data should not be overwritten"
+
+        with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+            updateHostsFile.BASEDIR_PATH = self.test_dir
+
+            with open(hosts_file, "w") as f:
+                f.write(hosts_data)
+
+            dir_count = self.dir_count
+
+            for update_auto in (False, True):
+                prompt_for_update(freshen=True, update_auto=update_auto)
+
+                mock_update.assert_called_once()
+                mock_update.reset_mock()
+
+                output = sys.stdout.getvalue()
+                self.assertEqual(output, "")
+
+                sys.stdout = StringIO()
+
+                self.assertEqual(self.dir_count, dir_count)
+
+                with open(hosts_file, "r") as f:
+                    contents = f.read()
+                    self.assertEqual(contents, hosts_data)
+
+    def tearDown(self):
+        BaseStdout.tearDown(self)
+        BaseStdout.tearDown(self)
+
+
+class TestPromptForExclusions(BaseStdout):
+
+    @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testSkipPrompt(self, mock_query, mock_display):
+        prompt_for_exclusions(skip_prompt=True)
+
+        output = sys.stdout.getvalue()
+        self.assertEqual(output, "")
+
+        mock_query.assert_not_called()
+        mock_display.assert_not_called()
+
+    @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testNoSkipPromptNoDisplay(self, mock_query, mock_display):
+        prompt_for_exclusions(skip_prompt=False)
+
+        output = sys.stdout.getvalue()
+        expected = ("OK, we'll only exclude "
+                    "domains in the whitelist.")
+        self.assertIn(expected, output)
+
+        mock_query.assert_called_once()
+        mock_display.assert_not_called()
+
+    @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=True)
+    def testNoSkipPromptDisplay(self, mock_query, mock_display):
+        prompt_for_exclusions(skip_prompt=False)
+
+        output = sys.stdout.getvalue()
+        self.assertEqual(output, "")
+
+        mock_query.assert_called_once()
+        mock_display.assert_called_once()
+
+
+class TestPromptForFlushDnsCache(Base):
+
+    @mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testFlushCache(self, mock_query, mock_flush):
+        for prompt_flush in (False, True):
+            prompt_for_flush_dns_cache(flush_cache=True,
+                                       prompt_flush=prompt_flush)
+
+            mock_query.assert_not_called()
+            mock_flush.assert_called_once()
+
+            mock_query.reset_mock()
+            mock_flush.reset_mock()
+
+    @mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testNoFlushCacheNoPrompt(self, mock_query, mock_flush):
+        prompt_for_flush_dns_cache(flush_cache=False,
+                                   prompt_flush=False)
+
+        mock_query.assert_not_called()
+        mock_flush.assert_not_called()
+
+    @mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testNoFlushCachePromptNoFlush(self, mock_query, mock_flush):
+        prompt_for_flush_dns_cache(flush_cache=False,
+                                   prompt_flush=True)
+
+        mock_query.assert_called_once()
+        mock_flush.assert_not_called()
+
+    @mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=True)
+    def testNoFlushCachePromptFlush(self, mock_query, mock_flush):
+        prompt_for_flush_dns_cache(flush_cache=False,
+                                   prompt_flush=True)
+
+        mock_query.assert_called_once()
+        mock_flush.assert_called_once()
+
+
+class TestPromptForMove(Base):
+
+    def setUp(self):
+        Base.setUp(self)
+        self.final_file = "final.txt"
+
+    def prompt_for_move(self, **move_params):
+        return prompt_for_move(self.final_file, **move_params)
+
+    @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testSkipStaticHosts(self, mock_query, mock_move):
+        for replace in (False, True):
+            for auto in (False, True):
+                move_file = self.prompt_for_move(replace=replace, auto=auto,
+                                                 skipstatichosts=True)
+                self.assertFalse(move_file)
+
+                mock_query.assert_not_called()
+                mock_move.assert_not_called()
+
+                mock_query.reset_mock()
+                mock_move.reset_mock()
+
+    @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testReplaceNoSkipStaticHosts(self, mock_query, mock_move):
+        for auto in (False, True):
+            move_file = self.prompt_for_move(replace=True, auto=auto,
+                                             skipstatichosts=False)
+            self.assertTrue(move_file)
+
+            mock_query.assert_not_called()
+            mock_move.assert_called_once()
+
+            mock_query.reset_mock()
+            mock_move.reset_mock()
+
+    @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testAutoNoSkipStaticHosts(self, mock_query, mock_move):
+        for replace in (False, True):
+            move_file = self.prompt_for_move(replace=replace, auto=True,
+                                             skipstatichosts=True)
+            self.assertFalse(move_file)
+
+            mock_query.assert_not_called()
+            mock_move.assert_not_called()
+
+            mock_query.reset_mock()
+            mock_move.reset_mock()
+
+    @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=False)
+    def testPromptNoMove(self, mock_query, mock_move):
+        move_file = self.prompt_for_move(replace=False, auto=False,
+                                         skipstatichosts=False)
+        self.assertFalse(move_file)
+
+        mock_query.assert_called_once()
+        mock_move.assert_not_called()
+
+        mock_query.reset_mock()
+        mock_move.reset_mock()
+
+    @mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
+    @mock.patch("updateHostsFile.query_yes_no", return_value=True)
+    def testPromptMove(self, mock_query, mock_move):
+        move_file = self.prompt_for_move(replace=False, auto=False,
+                                         skipstatichosts=False)
+        self.assertTrue(move_file)
+
+        mock_query.assert_called_once()
+        mock_move.assert_called_once()
+
+        mock_query.reset_mock()
+        mock_move.reset_mock()
+# End Prompt the User
+
+
 # Exclusion Logic
 class TestGatherCustomExclusions(BaseStdout):
 
index 4d795e7ec99207618a0f2b842449fa0483cc2aa0..96c38706a7a747d38097992d34bd548484fde785 100644 (file)
@@ -146,8 +146,10 @@ def main():
     settings["extensions"] = sorted(list(
         set(options["extensions"]).intersection(settings["extensions"])))
 
-    prompt_for_update()
-    prompt_for_exclusions()
+    auto = settings["auto"]
+
+    prompt_for_update(freshen=settings["freshen"], update_auto=auto)
+    prompt_for_exclusions(skip_prompt=auto)
 
     merge_file = create_initial_file()
     remove_old_hosts_file(settings["backup"])
@@ -158,11 +160,12 @@ def main():
     final_file = remove_dups_and_excl(merge_file)
 
     number_of_rules = settings["numberofrules"]
+    skip_static_hosts = settings["skipstatichosts"]
 
     write_opening_header(final_file, extensions=extensions,
                          numberofrules=number_of_rules,
                          outputsubfolder=output_subfolder,
-                         skipstatichosts=settings["skipstatichosts"])
+                         skipstatichosts=skip_static_hosts)
     final_file.close()
 
     if settings["ziphosts"]:
@@ -183,63 +186,101 @@ def main():
                   "{:,}".format(number_of_rules) +
                   " unique entries.")
 
-    prompt_for_move(final_file)
+    move_file = prompt_for_move(final_file, auto=auto,
+                                replace=settings["replace"],
+                                skipstatichosts=skip_static_hosts)
+
+    # We only flush the DNS cache if we have
+    # moved a new hosts file into place.
+    if move_file:
+        prompt_for_flush_dns_cache(flush_cache=settings["flushdnscache"],
+                                   prompt_flush=not auto)
 
 
 # Prompt the User
-def prompt_for_update():
+def prompt_for_update(freshen, update_auto):
     """
     Prompt the user to update all hosts files.
+
+    If requested, the function will update all data sources after it
+    checks that a hosts file does indeed exist.
+
+    Parameters
+    ----------
+    freshen : bool
+        Whether data sources should be updated. This function will return
+        if it is requested that data sources not be updated.
+    update_auto : bool
+        Whether or not to automatically update all data sources.
     """
 
-    # Create hosts file if it doesn't exist.
-    if not os.path.isfile(path_join_robust(BASEDIR_PATH, "hosts")):
-        try:
-            open(path_join_robust(BASEDIR_PATH, "hosts"), "w+").close()
-        except Exception:
-            print_failure("ERROR: No 'hosts' file in the folder,"
-                          "try creating one manually")
+    # Create a hosts file if it doesn't exist.
+    hosts_file = path_join_robust(BASEDIR_PATH, "hosts")
 
-    if not settings["freshen"]:
+    if not os.path.isfile(hosts_file):
+        try:
+            open(hosts_file, "w+").close()
+        except (IOError, OSError):
+            # Starting in Python 3.3, IOError is aliased
+            # OSError. However, we have to catch both for
+            # Python 2.x failures.
+            print_failure("ERROR: No 'hosts' file in the folder. "
+                          "Try creating one manually.")
+
+    if not freshen:
         return
 
     prompt = "Do you want to update all data sources?"
-    if settings["auto"] or query_yes_no(prompt):
+
+    if update_auto or query_yes_no(prompt):
         update_all_sources()
-    elif not settings["auto"]:
+    elif not update_auto:
         print("OK, we'll stick with what we've got locally.")
 
 
-def prompt_for_exclusions():
+def prompt_for_exclusions(skip_prompt):
     """
     Prompt the user to exclude any custom domains from being blocked.
+
+    Parameters
+    ----------
+    skip_prompt : bool
+        Whether or not to skip prompting for custom domains to be excluded.
+        If true, the function returns immediately.
     """
 
     prompt = ("Do you want to exclude any domains?\n"
               "For example, hulu.com video streaming must be able to access "
               "its tracking and ad servers in order to play video.")
 
-    if not settings["auto"]:
+    if not skip_prompt:
         if query_yes_no(prompt):
             display_exclusion_options()
         else:
             print("OK, we'll only exclude domains in the whitelist.")
 
 
-def prompt_for_flush_dns_cache():
+def prompt_for_flush_dns_cache(flush_cache, prompt_flush):
     """
     Prompt the user to flush the DNS cache.
+
+    Parameters
+    ----------
+    flush_cache : bool
+        Whether to flush the DNS cache without prompting.
+    prompt_flush : bool
+        If `flush_cache` is False, whether we should prompt for flushing the
+        cache. Otherwise, the function returns immediately.
     """
 
-    if settings["flushdnscache"]:
+    if flush_cache:
         flush_dns_cache()
-
-    if not settings["auto"]:
+    elif prompt_flush:
         if query_yes_no("Attempt to flush the DNS cache?"):
             flush_dns_cache()
 
 
-def prompt_for_move(final_file):
+def prompt_for_move(final_file, **move_params):
     """
     Prompt the user to move the newly created hosts file to its designated
     location in the OS.
@@ -248,11 +289,25 @@ def prompt_for_move(final_file):
     ----------
     final_file : file
         The file object that contains the newly created hosts data.
+    move_params : kwargs
+        Dictionary providing additional parameters for moving the hosts file
+        into place. Currently, those fields are:
+
+        1) auto
+        2) replace
+        3) skipstatichosts
+
+    Returns
+    -------
+    move_file : bool
+        Whether or not the final hosts file was moved.
     """
 
-    if settings["replace"] and not settings["skipstatichosts"]:
+    skip_static_hosts = move_params["skipstatichosts"]
+
+    if move_params["replace"] and not skip_static_hosts:
         move_file = True
-    elif settings["auto"] or settings["skipstatichosts"]:
+    elif move_params["auto"] or skip_static_hosts:
         move_file = False
     else:
         prompt = ("Do you want to replace your existing hosts file " +
@@ -261,7 +316,8 @@ def prompt_for_move(final_file):
 
     if move_file:
         move_hosts_file_into_place(final_file)
-        prompt_for_flush_dns_cache()
+
+    return move_file
 # End Prompt the User
 
 
git clone https://git.99rst.org/PROJECT