Refactor out global settings usage in exclusions
authorgfyoung <redacted>
Fri, 21 Jul 2017 08:19:59 +0000 (01:19 -0700)
committergfyoung <redacted>
Mon, 7 Aug 2017 15:43:47 +0000 (08:43 -0700)
.travis.yml
testUpdateHostsFile.py
updateHostsFile.py

index 3f4c0778a1f994530184201a5c6d08d4a93e7053..6b9036e8694a48c5a0f6b641b6bacb4eb15dc117 100644 (file)
@@ -19,6 +19,7 @@ os:
 
 env:
   - PYTHON_VERSION="2.7"
+  - PYTHON_VERSION="3.5"
   - PYTHON_VERSION="3.6"
 
 before_install:
index ff313318643e73d917ab7203f289ff9df9956be9..139bd89555e8532b957b388f4f7b6987f3cedb9e 100644 (file)
@@ -5,17 +5,16 @@
 #
 # Python script for testing updateHostFiles.py
 
-from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
-                             gather_custom_exclusions, get_defaults,
-                             get_file_by_url, is_valid_domain_format,
-                             move_hosts_file_into_place, normalize_rule,
-                             path_join_robust, print_failure, print_success,
-                             prompt_for_exclusions, prompt_for_move,
-                             prompt_for_flush_dns_cache, prompt_for_update,
-                             query_yes_no, recursive_glob,
-                             remove_old_hosts_file, supports_color, strip_rule,
-                             update_readme_data, write_data,
-                             write_opening_header)
+from updateHostsFile import (
+    Colors, PY3, colorize, display_exclusion_options, exclude_domain,
+    flush_dns_cache, gather_custom_exclusions, get_defaults, get_file_by_url,
+    is_valid_domain_format, matches_exclusions, move_hosts_file_into_place,
+    normalize_rule, path_join_robust, print_failure, print_success,
+    prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
+    prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
+    supports_color, strip_rule, update_readme_data, write_data,
+    write_opening_header)
+
 import updateHostsFile
 import unittest
 import tempfile
@@ -24,6 +23,7 @@ import shutil
 import json
 import sys
 import os
+import re
 
 if PY3:
     from io import BytesIO, StringIO
@@ -260,21 +260,20 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
 
 class TestPromptForExclusions(BaseStdout):
 
-    @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
     @mock.patch("updateHostsFile.query_yes_no", return_value=False)
-    def testSkipPrompt(self, mock_query, mock_display):
-        prompt_for_exclusions(skip_prompt=True)
+    def testSkipPrompt(self, mock_query):
+        gather_exclusions = prompt_for_exclusions(skip_prompt=True)
+        self.assertFalse(gather_exclusions)
 
         output = sys.stdout.getvalue()
         self.assertEqual(output, "")
 
         mock_query.assert_not_called()
-        mock_display.assert_not_called()
 
-    @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
     @mock.patch("updateHostsFile.query_yes_no", return_value=False)
-    def testNoSkipPromptNoDisplay(self, mock_query, mock_display):
-        prompt_for_exclusions(skip_prompt=False)
+    def testNoSkipPromptNoDisplay(self, mock_query):
+        gather_exclusions = prompt_for_exclusions(skip_prompt=False)
+        self.assertFalse(gather_exclusions)
 
         output = sys.stdout.getvalue()
         expected = ("OK, we'll only exclude "
@@ -282,18 +281,16 @@ class TestPromptForExclusions(BaseStdout):
         self.assertIn(expected, output)
 
         self.assert_called_once(mock_query)
-        mock_display.assert_not_called()
 
-    @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
     @mock.patch("updateHostsFile.query_yes_no", return_value=True)
-    def testNoSkipPromptDisplay(self, mock_query, mock_display):
-        prompt_for_exclusions(skip_prompt=False)
+    def testNoSkipPromptDisplay(self, mock_query):
+        gather_exclusions = prompt_for_exclusions(skip_prompt=False)
+        self.assertTrue(gather_exclusions)
 
         output = sys.stdout.getvalue()
         self.assertEqual(output, "")
 
         self.assert_called_once(mock_query)
-        self.assert_called_once(mock_display)
 
 
 class TestPromptForFlushDnsCache(Base):
@@ -420,6 +417,52 @@ class TestPromptForMove(Base):
 
 
 # Exclusion Logic
+class TestDisplayExclusionsOptions(Base):
+
+    @mock.patch("updateHostsFile.query_yes_no", return_value=0)
+    @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+    @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+    def test_no_exclusions(self, mock_gather, mock_exclude, _):
+        common_exclusions = []
+        display_exclusion_options(common_exclusions, "foo", [])
+
+        mock_gather.assert_not_called()
+        mock_exclude.assert_not_called()
+
+    @mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 1, 0])
+    @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+    @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+    def test_only_common_exclusions(self, mock_gather, mock_exclude, _):
+        common_exclusions = ["foo", "bar"]
+        display_exclusion_options(common_exclusions, "foo", [])
+
+        mock_gather.assert_not_called()
+
+        exclude_calls = [mock.call("foo", "foo", []),
+                         mock.call("bar", "foo", None)]
+        mock_exclude.assert_has_calls(exclude_calls)
+
+    @mock.patch("updateHostsFile.query_yes_no", side_effect=[0, 0, 1])
+    @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+    @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+    def test_gather_exclusions(self, mock_gather, mock_exclude, _):
+        common_exclusions = ["foo", "bar"]
+        display_exclusion_options(common_exclusions, "foo", [])
+
+        mock_exclude.assert_not_called()
+        self.assert_called_once(mock_gather)
+
+    @mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 0, 1])
+    @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+    @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+    def test_mixture_gather_exclusions(self, mock_gather, mock_exclude, _):
+        common_exclusions = ["foo", "bar"]
+        display_exclusion_options(common_exclusions, "foo", [])
+
+        mock_exclude.assert_called_once_with("foo", "foo", [])
+        self.assert_called_once(mock_gather)
+
+
 class TestGatherCustomExclusions(BaseStdout):
 
     # Can only test in the invalid domain case
@@ -427,7 +470,7 @@ class TestGatherCustomExclusions(BaseStdout):
     @mock.patch("updateHostsFile.raw_input", side_effect=["foo", "no"])
     @mock.patch("updateHostsFile.is_valid_domain_format", return_value=False)
     def test_basic(self, *_):
-        gather_custom_exclusions()
+        gather_custom_exclusions("foo", [])
 
         expected = "Do you have more domains you want to enter? [Y/n]"
         output = sys.stdout.getvalue()
@@ -437,12 +480,70 @@ class TestGatherCustomExclusions(BaseStdout):
                                                           "bar", "no"])
     @mock.patch("updateHostsFile.is_valid_domain_format", return_value=False)
     def test_multiple(self, *_):
-        gather_custom_exclusions()
+        gather_custom_exclusions("foo", [])
 
         expected = ("Do you have more domains you want to enter? [Y/n] "
                     "Do you have more domains you want to enter? [Y/n]")
         output = sys.stdout.getvalue()
         self.assertIn(expected, output)
+
+
+class TestExcludeDomain(Base):
+
+    def test_invalid_exclude_domain(self):
+        exclusion_regexes = []
+        exclusion_pattern = "*.com"
+
+        for domain in ["google.com", "hulu.com", "adaway.org"]:
+            self.assertRaises(re.error, exclude_domain, domain,
+                              exclusion_pattern, exclusion_regexes)
+
+        self.assertListEqual(exclusion_regexes, [])
+
+    def test_valid_exclude_domain(self):
+        exp_count = 0
+        expected_regexes = []
+        exclusion_regexes = []
+        exclusion_pattern = "[a-z]\."
+
+        for domain in ["google.com", "hulu.com", "adaway.org"]:
+            self.assertEqual(len(exclusion_regexes), exp_count)
+
+            exclusion_regexes = exclude_domain(domain, exclusion_pattern,
+                                               exclusion_regexes)
+            expected_regex = re.compile(exclusion_pattern + domain)
+
+            expected_regexes.append(expected_regex)
+            exp_count += 1
+
+        self.assertEqual(len(exclusion_regexes), exp_count)
+        self.assertListEqual(exclusion_regexes, expected_regexes)
+
+
+class TestMatchesExclusions(Base):
+
+    def test_no_match_empty_list(self):
+        exclusion_regexes = []
+
+        for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com",
+                       "9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]:
+            self.assertFalse(matches_exclusions(domain, exclusion_regexes))
+
+    def test_no_match_list(self):
+        exclusion_regexes = [".*\.org", ".*\.edu"]
+        exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
+
+        for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com",
+                       "9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]:
+            self.assertFalse(matches_exclusions(domain, exclusion_regexes))
+
+    def test_match_list(self):
+        exclusion_regexes = [".*\.com", ".*\.org", ".*\.edu"]
+        exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
+
+        for domain in ["5.6.7.8 hulu.com", "9.1.2.3 yahoo.com",
+                       "4.5.6.7 adaway.org", "8.9.1.2 education.edu"]:
+            self.assertTrue(matches_exclusions(domain, exclusion_regexes))
 # End Exclusion Logic
 
 
index 96c38706a7a747d38097992d34bd548484fde785..dbcfe7696ffa41d1af1f50b3da940c6b9201d385 100644 (file)
@@ -147,9 +147,18 @@ def main():
         set(options["extensions"]).intersection(settings["extensions"])))
 
     auto = settings["auto"]
+    exclusion_regexes = settings["exclusionregexs"]
 
     prompt_for_update(freshen=settings["freshen"], update_auto=auto)
-    prompt_for_exclusions(skip_prompt=auto)
+    gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
+
+    if gather_exclusions:
+        common_exclusions = settings["commonexclusions"]
+        exclusion_pattern = settings["exclusionpattern"]
+        exclusion_regexes = display_exclusion_options(
+            common_exclusions=common_exclusions,
+            exclusion_pattern=exclusion_pattern,
+            exclusion_regexes=exclusion_regexes)
 
     merge_file = create_initial_file()
     remove_old_hosts_file(settings["backup"])
@@ -157,7 +166,7 @@ def main():
     extensions = settings["extensions"]
     output_subfolder = settings["outputsubfolder"]
 
-    final_file = remove_dups_and_excl(merge_file)
+    final_file = remove_dups_and_excl(merge_file, exclusion_regexes)
 
     number_of_rules = settings["numberofrules"]
     skip_static_hosts = settings["skipstatichosts"]
@@ -247,6 +256,12 @@ def prompt_for_exclusions(skip_prompt):
     skip_prompt : bool
         Whether or not to skip prompting for custom domains to be excluded.
         If true, the function returns immediately.
+
+    Returns
+    -------
+    gather_exclusions : bool
+        Whether or not we should proceed to prompt the user to exclude any
+        custom domains beyond those in the whitelist.
     """
 
     prompt = ("Do you want to exclude any domains?\n"
@@ -255,10 +270,12 @@ def prompt_for_exclusions(skip_prompt):
 
     if not skip_prompt:
         if query_yes_no(prompt):
-            display_exclusion_options()
+            return True
         else:
             print("OK, we'll only exclude domains in the whitelist.")
 
+    return False
+
 
 def prompt_for_flush_dns_cache(flush_cache, prompt_flush):
     """
@@ -322,29 +339,65 @@ def prompt_for_move(final_file, **move_params):
 
 
 # Exclusion logic
-def display_exclusion_options():
+def display_exclusion_options(common_exclusions, exclusion_pattern,
+                              exclusion_regexes):
     """
     Display the exclusion options to the user.
 
     This function checks whether a user wants to exclude particular domains,
     and if so, excludes them.
+
+    Parameters
+    ----------
+    common_exclusions : list
+        A list of common domains that are excluded from being blocked. One
+        example is Hulu. This setting is set directly in the script and cannot
+        be overwritten by the user.
+    exclusion_pattern : str
+        The exclusion pattern with which to create the domain regex.
+    exclusion_regexes : list
+        The list of regex patterns used to exclude domains.
+
+    Returns
+    -------
+    aug_exclusion_regexes : list
+        The original list of regex patterns potentially with additional
+        patterns from domains that user chooses to exclude.
     """
 
-    for exclusion_option in settings["commonexclusions"]:
+    for exclusion_option in common_exclusions:
         prompt = "Do you want to exclude the domain " + exclusion_option + " ?"
 
         if query_yes_no(prompt):
-            exclude_domain(exclusion_option)
+            exclusion_regexes = exclude_domain(exclusion_option,
+                                               exclusion_pattern,
+                                               exclusion_regexes)
         else:
             continue
 
     if query_yes_no("Do you want to exclude any other domains?"):
-        gather_custom_exclusions()
+        exclusion_regexes = gather_custom_exclusions(exclusion_pattern,
+                                                     exclusion_regexes)
+
+    return exclusion_regexes
 
 
-def gather_custom_exclusions():
+def gather_custom_exclusions(exclusion_pattern, exclusion_regexes):
     """
     Gather custom exclusions from the user.
+
+    Parameters
+    ----------
+    exclusion_pattern : str
+        The exclusion pattern with which to create the domain regex.
+    exclusion_regexes : list
+        The list of regex patterns used to exclude domains.
+
+    Returns
+    -------
+    aug_exclusion_regexes : list
+        The original list of regex patterns potentially with additional
+        patterns from domains that user chooses to exclude.
     """
 
     # We continue running this while-loop until the user
@@ -355,28 +408,46 @@ def gather_custom_exclusions():
         user_domain = raw_input(domain_prompt)
 
         if is_valid_domain_format(user_domain):
-            exclude_domain(user_domain)
+            exclusion_regexes = exclude_domain(user_domain, exclusion_pattern,
+                                               exclusion_regexes)
 
         continue_prompt = "Do you have more domains you want to enter?"
         if not query_yes_no(continue_prompt):
-            return
+            break
+
+    return exclusion_regexes
 
 
-def exclude_domain(domain):
+def exclude_domain(domain, exclusion_pattern, exclusion_regexes):
     """
     Exclude a domain from being blocked.
 
+    This create the domain regex by which to exclude this domain and appends
+    it a list of already-existing exclusion regexes.
+
     Parameters
     ----------
     domain : str
         The filename or regex pattern to exclude.
+    exclusion_pattern : str
+        The exclusion pattern with which to create the domain regex.
+    exclusion_regexes : list
+        The list of regex patterns used to exclude domains.
+
+    Returns
+    -------
+    aug_exclusion_regexes : list
+        The original list of regex patterns with one additional pattern from
+        the `domain` input.
     """
 
-    settings["exclusionregexs"].append(re.compile(
-        settings["exclusionpattern"] + domain))
+    exclusion_regex = re.compile(exclusion_pattern + domain)
+    exclusion_regexes.append(exclusion_regex)
 
+    return exclusion_regexes
 
-def matches_exclusions(stripped_rule):
+
+def matches_exclusions(stripped_rule, exclusion_regexes):
     """
     Check whether a rule matches an exclusion rule we already provided.
 
@@ -387,6 +458,8 @@ def matches_exclusions(stripped_rule):
     ----------
     stripped_rule : str
         The rule that we are checking.
+    exclusion_regexes : list
+        The list of regex patterns used to exclude domains.
 
     Returns
     -------
@@ -395,9 +468,11 @@ def matches_exclusions(stripped_rule):
     """
 
     stripped_domain = stripped_rule.split()[1]
-    for exclusionRegex in settings["exclusionregexs"]:
+
+    for exclusionRegex in exclusion_regexes:
         if exclusionRegex.search(stripped_domain):
             return True
+
     return False
 # End Exclusion Logic
 
@@ -479,7 +554,7 @@ def create_initial_file():
     return merge_file
 
 
-def remove_dups_and_excl(merge_file):
+def remove_dups_and_excl(merge_file, exclusion_regexes):
     """
     Remove duplicates and remove hosts that we are excluding.
 
@@ -490,6 +565,8 @@ def remove_dups_and_excl(merge_file):
     ----------
     merge_file : file
         The file object that contains the hostnames that we are pruning.
+    exclusion_regexes : list
+        The list of regex patterns used to exclude domains.
     """
 
     number_of_rules = settings["numberofrules"]
@@ -532,7 +609,8 @@ def remove_dups_and_excl(merge_file):
             continue
 
         stripped_rule = strip_rule(line)  # strip comments
-        if not stripped_rule or matches_exclusions(stripped_rule):
+        if not stripped_rule or matches_exclusions(stripped_rule,
+                                                   exclusion_regexes):
             continue
 
         # Normalize rule
git clone https://git.99rst.org/PROJECT