#
# Python script for testing updateHostFiles.py
-from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
- gather_custom_exclusions, get_defaults,
- get_file_by_url, is_valid_domain_format,
- move_hosts_file_into_place, normalize_rule,
- path_join_robust, print_failure, print_success,
- prompt_for_exclusions, prompt_for_move,
- prompt_for_flush_dns_cache, prompt_for_update,
- query_yes_no, recursive_glob,
- remove_old_hosts_file, supports_color, strip_rule,
- update_readme_data, write_data,
- write_opening_header)
+from updateHostsFile import (
+ Colors, PY3, colorize, display_exclusion_options, exclude_domain,
+ flush_dns_cache, gather_custom_exclusions, get_defaults, get_file_by_url,
+ is_valid_domain_format, matches_exclusions, move_hosts_file_into_place,
+ normalize_rule, path_join_robust, print_failure, print_success,
+ prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
+ prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
+ supports_color, strip_rule, update_readme_data, write_data,
+ write_opening_header)
+
import updateHostsFile
import unittest
import tempfile
import json
import sys
import os
+import re
if PY3:
from io import BytesIO, StringIO
class TestPromptForExclusions(BaseStdout):
- @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
- def testSkipPrompt(self, mock_query, mock_display):
- prompt_for_exclusions(skip_prompt=True)
+ def testSkipPrompt(self, mock_query):
+ gather_exclusions = prompt_for_exclusions(skip_prompt=True)
+ self.assertFalse(gather_exclusions)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
mock_query.assert_not_called()
- mock_display.assert_not_called()
- @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
- def testNoSkipPromptNoDisplay(self, mock_query, mock_display):
- prompt_for_exclusions(skip_prompt=False)
+ def testNoSkipPromptNoDisplay(self, mock_query):
+ gather_exclusions = prompt_for_exclusions(skip_prompt=False)
+ self.assertFalse(gather_exclusions)
output = sys.stdout.getvalue()
expected = ("OK, we'll only exclude "
self.assertIn(expected, output)
self.assert_called_once(mock_query)
- mock_display.assert_not_called()
- @mock.patch("updateHostsFile.display_exclusion_options", return_value=0)
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
- def testNoSkipPromptDisplay(self, mock_query, mock_display):
- prompt_for_exclusions(skip_prompt=False)
+ def testNoSkipPromptDisplay(self, mock_query):
+ gather_exclusions = prompt_for_exclusions(skip_prompt=False)
+ self.assertTrue(gather_exclusions)
output = sys.stdout.getvalue()
self.assertEqual(output, "")
self.assert_called_once(mock_query)
- self.assert_called_once(mock_display)
class TestPromptForFlushDnsCache(Base):
# Exclusion Logic
+class TestDisplayExclusionsOptions(Base):
+
+ @mock.patch("updateHostsFile.query_yes_no", return_value=0)
+ @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+ @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+ def test_no_exclusions(self, mock_gather, mock_exclude, _):
+ common_exclusions = []
+ display_exclusion_options(common_exclusions, "foo", [])
+
+ mock_gather.assert_not_called()
+ mock_exclude.assert_not_called()
+
+ @mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 1, 0])
+ @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+ @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+ def test_only_common_exclusions(self, mock_gather, mock_exclude, _):
+ common_exclusions = ["foo", "bar"]
+ display_exclusion_options(common_exclusions, "foo", [])
+
+ mock_gather.assert_not_called()
+
+ exclude_calls = [mock.call("foo", "foo", []),
+ mock.call("bar", "foo", None)]
+ mock_exclude.assert_has_calls(exclude_calls)
+
+ @mock.patch("updateHostsFile.query_yes_no", side_effect=[0, 0, 1])
+ @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+ @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+ def test_gather_exclusions(self, mock_gather, mock_exclude, _):
+ common_exclusions = ["foo", "bar"]
+ display_exclusion_options(common_exclusions, "foo", [])
+
+ mock_exclude.assert_not_called()
+ self.assert_called_once(mock_gather)
+
+ @mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 0, 1])
+ @mock.patch("updateHostsFile.exclude_domain", return_value=None)
+ @mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
+ def test_mixture_gather_exclusions(self, mock_gather, mock_exclude, _):
+ common_exclusions = ["foo", "bar"]
+ display_exclusion_options(common_exclusions, "foo", [])
+
+ mock_exclude.assert_called_once_with("foo", "foo", [])
+ self.assert_called_once(mock_gather)
+
+
class TestGatherCustomExclusions(BaseStdout):
# Can only test in the invalid domain case
@mock.patch("updateHostsFile.raw_input", side_effect=["foo", "no"])
@mock.patch("updateHostsFile.is_valid_domain_format", return_value=False)
def test_basic(self, *_):
- gather_custom_exclusions()
+ gather_custom_exclusions("foo", [])
expected = "Do you have more domains you want to enter? [Y/n]"
output = sys.stdout.getvalue()
"bar", "no"])
@mock.patch("updateHostsFile.is_valid_domain_format", return_value=False)
def test_multiple(self, *_):
- gather_custom_exclusions()
+ gather_custom_exclusions("foo", [])
expected = ("Do you have more domains you want to enter? [Y/n] "
"Do you have more domains you want to enter? [Y/n]")
output = sys.stdout.getvalue()
self.assertIn(expected, output)
+
+
+class TestExcludeDomain(Base):
+
+ def test_invalid_exclude_domain(self):
+ exclusion_regexes = []
+ exclusion_pattern = "*.com"
+
+ for domain in ["google.com", "hulu.com", "adaway.org"]:
+ self.assertRaises(re.error, exclude_domain, domain,
+ exclusion_pattern, exclusion_regexes)
+
+ self.assertListEqual(exclusion_regexes, [])
+
+ def test_valid_exclude_domain(self):
+ exp_count = 0
+ expected_regexes = []
+ exclusion_regexes = []
+ exclusion_pattern = "[a-z]\."
+
+ for domain in ["google.com", "hulu.com", "adaway.org"]:
+ self.assertEqual(len(exclusion_regexes), exp_count)
+
+ exclusion_regexes = exclude_domain(domain, exclusion_pattern,
+ exclusion_regexes)
+ expected_regex = re.compile(exclusion_pattern + domain)
+
+ expected_regexes.append(expected_regex)
+ exp_count += 1
+
+ self.assertEqual(len(exclusion_regexes), exp_count)
+ self.assertListEqual(exclusion_regexes, expected_regexes)
+
+
+class TestMatchesExclusions(Base):
+
+ def test_no_match_empty_list(self):
+ exclusion_regexes = []
+
+ for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com",
+ "9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]:
+ self.assertFalse(matches_exclusions(domain, exclusion_regexes))
+
+ def test_no_match_list(self):
+ exclusion_regexes = [".*\.org", ".*\.edu"]
+ exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
+
+ for domain in ["1.2.3.4 localhost", "5.6.7.8 hulu.com",
+ "9.1.2.3 yahoo.com", "4.5.6.7 cloudfront.net"]:
+ self.assertFalse(matches_exclusions(domain, exclusion_regexes))
+
+ def test_match_list(self):
+ exclusion_regexes = [".*\.com", ".*\.org", ".*\.edu"]
+ exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
+
+ for domain in ["5.6.7.8 hulu.com", "9.1.2.3 yahoo.com",
+ "4.5.6.7 adaway.org", "8.9.1.2 education.edu"]:
+ self.assertTrue(matches_exclusions(domain, exclusion_regexes))
# End Exclusion Logic
set(options["extensions"]).intersection(settings["extensions"])))
auto = settings["auto"]
+ exclusion_regexes = settings["exclusionregexs"]
prompt_for_update(freshen=settings["freshen"], update_auto=auto)
- prompt_for_exclusions(skip_prompt=auto)
+ gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
+
+ if gather_exclusions:
+ common_exclusions = settings["commonexclusions"]
+ exclusion_pattern = settings["exclusionpattern"]
+ exclusion_regexes = display_exclusion_options(
+ common_exclusions=common_exclusions,
+ exclusion_pattern=exclusion_pattern,
+ exclusion_regexes=exclusion_regexes)
merge_file = create_initial_file()
remove_old_hosts_file(settings["backup"])
extensions = settings["extensions"]
output_subfolder = settings["outputsubfolder"]
- final_file = remove_dups_and_excl(merge_file)
+ final_file = remove_dups_and_excl(merge_file, exclusion_regexes)
number_of_rules = settings["numberofrules"]
skip_static_hosts = settings["skipstatichosts"]
skip_prompt : bool
Whether or not to skip prompting for custom domains to be excluded.
If true, the function returns immediately.
+
+ Returns
+ -------
+ gather_exclusions : bool
+ Whether or not we should proceed to prompt the user to exclude any
+ custom domains beyond those in the whitelist.
"""
prompt = ("Do you want to exclude any domains?\n"
if not skip_prompt:
if query_yes_no(prompt):
- display_exclusion_options()
+ return True
else:
print("OK, we'll only exclude domains in the whitelist.")
+ return False
+
def prompt_for_flush_dns_cache(flush_cache, prompt_flush):
"""
# Exclusion logic
-def display_exclusion_options():
+def display_exclusion_options(common_exclusions, exclusion_pattern,
+ exclusion_regexes):
"""
Display the exclusion options to the user.
This function checks whether a user wants to exclude particular domains,
and if so, excludes them.
+
+ Parameters
+ ----------
+ common_exclusions : list
+ A list of common domains that are excluded from being blocked. One
+ example is Hulu. This setting is set directly in the script and cannot
+ be overwritten by the user.
+ exclusion_pattern : str
+ The exclusion pattern with which to create the domain regex.
+ exclusion_regexes : list
+ The list of regex patterns used to exclude domains.
+
+ Returns
+ -------
+ aug_exclusion_regexes : list
+ The original list of regex patterns potentially with additional
+ patterns from domains that user chooses to exclude.
"""
- for exclusion_option in settings["commonexclusions"]:
+ for exclusion_option in common_exclusions:
prompt = "Do you want to exclude the domain " + exclusion_option + " ?"
if query_yes_no(prompt):
- exclude_domain(exclusion_option)
+ exclusion_regexes = exclude_domain(exclusion_option,
+ exclusion_pattern,
+ exclusion_regexes)
else:
continue
if query_yes_no("Do you want to exclude any other domains?"):
- gather_custom_exclusions()
+ exclusion_regexes = gather_custom_exclusions(exclusion_pattern,
+ exclusion_regexes)
+
+ return exclusion_regexes
-def gather_custom_exclusions():
+def gather_custom_exclusions(exclusion_pattern, exclusion_regexes):
"""
Gather custom exclusions from the user.
+
+ Parameters
+ ----------
+ exclusion_pattern : str
+ The exclusion pattern with which to create the domain regex.
+ exclusion_regexes : list
+ The list of regex patterns used to exclude domains.
+
+ Returns
+ -------
+ aug_exclusion_regexes : list
+ The original list of regex patterns potentially with additional
+ patterns from domains that user chooses to exclude.
"""
# We continue running this while-loop until the user
user_domain = raw_input(domain_prompt)
if is_valid_domain_format(user_domain):
- exclude_domain(user_domain)
+ exclusion_regexes = exclude_domain(user_domain, exclusion_pattern,
+ exclusion_regexes)
continue_prompt = "Do you have more domains you want to enter?"
if not query_yes_no(continue_prompt):
- return
+ break
+
+ return exclusion_regexes
-def exclude_domain(domain):
+def exclude_domain(domain, exclusion_pattern, exclusion_regexes):
"""
Exclude a domain from being blocked.
+ This create the domain regex by which to exclude this domain and appends
+ it a list of already-existing exclusion regexes.
+
Parameters
----------
domain : str
The filename or regex pattern to exclude.
+ exclusion_pattern : str
+ The exclusion pattern with which to create the domain regex.
+ exclusion_regexes : list
+ The list of regex patterns used to exclude domains.
+
+ Returns
+ -------
+ aug_exclusion_regexes : list
+ The original list of regex patterns with one additional pattern from
+ the `domain` input.
"""
- settings["exclusionregexs"].append(re.compile(
- settings["exclusionpattern"] + domain))
+ exclusion_regex = re.compile(exclusion_pattern + domain)
+ exclusion_regexes.append(exclusion_regex)
+ return exclusion_regexes
-def matches_exclusions(stripped_rule):
+
+def matches_exclusions(stripped_rule, exclusion_regexes):
"""
Check whether a rule matches an exclusion rule we already provided.
----------
stripped_rule : str
The rule that we are checking.
+ exclusion_regexes : list
+ The list of regex patterns used to exclude domains.
Returns
-------
"""
stripped_domain = stripped_rule.split()[1]
- for exclusionRegex in settings["exclusionregexs"]:
+
+ for exclusionRegex in exclusion_regexes:
if exclusionRegex.search(stripped_domain):
return True
+
return False
# End Exclusion Logic
return merge_file
-def remove_dups_and_excl(merge_file):
+def remove_dups_and_excl(merge_file, exclusion_regexes):
"""
Remove duplicates and remove hosts that we are excluding.
----------
merge_file : file
The file object that contains the hostnames that we are pruning.
+ exclusion_regexes : list
+ The list of regex patterns used to exclude domains.
"""
number_of_rules = settings["numberofrules"]
continue
stripped_rule = strip_rule(line) # strip comments
- if not stripped_rule or matches_exclusions(stripped_rule):
+ if not stripped_rule or matches_exclusions(stripped_rule,
+ exclusion_regexes):
continue
# Normalize rule