move_hosts_file_into_place, normalize_rule,
path_join_robust, print_failure, print_success,
supports_color, query_yes_no, recursive_glob,
- strip_rule, write_data)
+ remove_old_hosts_file, strip_rule,
+ update_readme_data, write_data,
+ write_opening_header)
import updateHostsFile
import unittest
+import tempfile
import locale
+import shutil
+import json
import sys
import os
def tearDown(self):
sys.stdout.close()
sys.stdout = sys.__stdout__
+
+
+class BaseMockDir(Base):
+
+ @property
+ def dir_count(self):
+ return len(os.listdir(self.test_dir))
+
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.test_dir)
# End Base Test Classes
# File Logic
class TestNormalizeRule(BaseStdout):
- # Can only test non-matches because they don't
- # interact with the settings global variable.
def test_no_match(self):
+ kwargs = dict(target_ip="0.0.0.0", keep_domain_comments=False)
+
for rule in ["foo", "128.0.0.1", "bar.com/usa", "0.0.0 google",
"0.1.2.3.4 foo/bar", "twitter.com"]:
- self.assertEqual(normalize_rule(rule), (None, None))
+ self.assertEqual(normalize_rule(rule, **kwargs), (None, None))
output = sys.stdout.getvalue()
sys.stdout = StringIO()
expected = "==>" + rule + "<=="
self.assertIn(expected, output)
+ def test_no_comments(self):
+ for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
+ rule = "127.0.0.1 google foo"
+ expected = ("google", str(target_ip) + " google\n")
+
+ actual = normalize_rule(rule, target_ip=target_ip,
+ keep_domain_comments=False)
+ self.assertEqual(actual, expected)
+
+ # Nothing gets printed if there's a match.
+ output = sys.stdout.getvalue()
+ self.assertEqual(output, "")
+
+ sys.stdout = StringIO()
+
+ def test_with_comments(self):
+ for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
+ for comment in ("foo", "bar", "baz"):
+ rule = "127.0.0.1 google " + comment
+ expected = ("google", (str(target_ip) + " google # " +
+ comment + "\n"))
+
+ actual = normalize_rule(rule, target_ip=target_ip,
+ keep_domain_comments=True)
+ self.assertEqual(actual, expected)
+
+ # Nothing gets printed if there's a match.
+ output = sys.stdout.getvalue()
+ self.assertEqual(output, "")
+
+ sys.stdout = StringIO()
+
class TestStripRule(Base):
self.assertEqual(output, line)
+class TestWriteOpeningHeader(BaseMockDir):
+
+ def setUp(self):
+ super(TestWriteOpeningHeader, self).setUp()
+ self.final_file = BytesIO()
+
+ def test_missing_keyword(self):
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules=5, skipstatichosts=False)
+
+ for k in kwargs.keys():
+ bad_kwargs = kwargs.copy()
+ bad_kwargs.pop(k)
+
+ self.assertRaises(KeyError, write_opening_header,
+ self.final_file, **bad_kwargs)
+
+ def test_basic(self):
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules=5, skipstatichosts=True)
+ write_opening_header(self.final_file, **kwargs)
+
+ contents = self.final_file.getvalue()
+ contents = contents.decode("UTF-8")
+
+ # Expected contents.
+ for expected in (
+ "# This hosts file is a merged collection",
+ "# with a dash of crowd sourcing via Github",
+ "# Number of unique domains: {count}".format(
+ count=kwargs["numberofrules"]),
+ "Fetch the latest version of this file:",
+ "Project home page: https://github.com/StevenBlack/hosts",
+ ):
+ self.assertIn(expected, contents)
+
+ # Expected non-contents.
+ for expected in (
+ "# Extensions added to this file:",
+ "127.0.0.1 localhost",
+ "127.0.0.1 local",
+ "127.0.0.53",
+ "127.0.1.1",
+ ):
+ self.assertNotIn(expected, contents)
+
+ def test_basic_include_static_hosts(self):
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules=5, skipstatichosts=False)
+ with self.mock_property("platform.system") as obj:
+ obj.return_value = "Windows"
+ write_opening_header(self.final_file, **kwargs)
+
+ contents = self.final_file.getvalue()
+ contents = contents.decode("UTF-8")
+
+ # Expected contents.
+ for expected in (
+ "127.0.0.1 local",
+ "127.0.0.1 localhost",
+ "# This hosts file is a merged collection",
+ "# with a dash of crowd sourcing via Github",
+ "# Number of unique domains: {count}".format(
+ count=kwargs["numberofrules"]),
+ "Fetch the latest version of this file:",
+ "Project home page: https://github.com/StevenBlack/hosts",
+ ):
+ self.assertIn(expected, contents)
+
+ # Expected non-contents.
+ for expected in (
+ "# Extensions added to this file:",
+ "127.0.0.53",
+ "127.0.1.1",
+ ):
+ self.assertNotIn(expected, contents)
+
+ def test_basic_include_static_hosts_linux(self):
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules=5, skipstatichosts=False)
+ with self.mock_property("platform.system") as system:
+ system.return_value = "Linux"
+
+ with self.mock_property("socket.gethostname") as hostname:
+ hostname.return_value = "steven-hosts"
+ write_opening_header(self.final_file, **kwargs)
+
+ contents = self.final_file.getvalue()
+ contents = contents.decode("UTF-8")
+
+ # Expected contents.
+ for expected in (
+ "127.0.1.1",
+ "127.0.0.53",
+ "steven-hosts",
+ "127.0.0.1 local",
+ "127.0.0.1 localhost",
+ "# This hosts file is a merged collection",
+ "# with a dash of crowd sourcing via Github",
+ "# Number of unique domains: {count}".format(
+ count=kwargs["numberofrules"]),
+ "Fetch the latest version of this file:",
+ "Project home page: https://github.com/StevenBlack/hosts",
+ ):
+ self.assertIn(expected, contents)
+
+ # Expected non-contents.
+ expected = "# Extensions added to this file:"
+ self.assertNotIn(expected, contents)
+
+ def test_extensions(self):
+ kwargs = dict(extensions=["epsilon", "gamma", "mu", "phi"],
+ outputsubfolder="", numberofrules=5,
+ skipstatichosts=True)
+ write_opening_header(self.final_file, **kwargs)
+
+ contents = self.final_file.getvalue()
+ contents = contents.decode("UTF-8")
+
+ # Expected contents.
+ for expected in (
+ ", ".join(kwargs["extensions"]),
+ "# Extensions added to this file:",
+ "# This hosts file is a merged collection",
+ "# with a dash of crowd sourcing via Github",
+ "# Number of unique domains: {count}".format(
+ count=kwargs["numberofrules"]),
+ "Fetch the latest version of this file:",
+ "Project home page: https://github.com/StevenBlack/hosts",
+ ):
+ self.assertIn(expected, contents)
+
+ # Expected non-contents.
+ for expected in (
+ "127.0.0.1 localhost",
+ "127.0.0.1 local",
+ "127.0.0.53",
+ "127.0.1.1",
+ ):
+ self.assertNotIn(expected, contents)
+
+ def test_no_preamble(self):
+ # We should not even attempt to read this, as it is a directory.
+ hosts_dir = os.path.join(self.test_dir, "myhosts")
+ os.mkdir(hosts_dir)
+
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules=5, skipstatichosts=True)
+
+ with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+ updateHostsFile.BASEDIR_PATH = self.test_dir
+ write_opening_header(self.final_file, **kwargs)
+
+ contents = self.final_file.getvalue()
+ contents = contents.decode("UTF-8")
+
+ # Expected contents.
+ for expected in (
+ "# This hosts file is a merged collection",
+ "# with a dash of crowd sourcing via Github",
+ "# Number of unique domains: {count}".format(
+ count=kwargs["numberofrules"]),
+ "Fetch the latest version of this file:",
+ "Project home page: https://github.com/StevenBlack/hosts",
+ ):
+ self.assertIn(expected, contents)
+
+ # Expected non-contents.
+ for expected in (
+ "# Extensions added to this file:",
+ "127.0.0.1 localhost",
+ "127.0.0.1 local",
+ "127.0.0.53",
+ "127.0.1.1",
+ ):
+ self.assertNotIn(expected, contents)
+
+ def test_preamble(self):
+ hosts_file = os.path.join(self.test_dir, "myhosts")
+ with open(hosts_file, "w") as f:
+ f.write("peter-piper-picked-a-pepper")
+
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules=5, skipstatichosts=True)
+
+ with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+ updateHostsFile.BASEDIR_PATH = self.test_dir
+ write_opening_header(self.final_file, **kwargs)
+
+ contents = self.final_file.getvalue()
+ contents = contents.decode("UTF-8")
+
+ # Expected contents.
+ for expected in (
+ "peter-piper-picked-a-pepper",
+ "# This hosts file is a merged collection",
+ "# with a dash of crowd sourcing via Github",
+ "# Number of unique domains: {count}".format(
+ count=kwargs["numberofrules"]),
+ "Fetch the latest version of this file:",
+ "Project home page: https://github.com/StevenBlack/hosts",
+ ):
+ self.assertIn(expected, contents)
+
+ # Expected non-contents.
+ for expected in (
+ "# Extensions added to this file:",
+ "127.0.0.1 localhost",
+ "127.0.0.1 local",
+ "127.0.0.53",
+ "127.0.1.1",
+ ):
+ self.assertNotIn(expected, contents)
+
+ def tearDown(self):
+ super(TestWriteOpeningHeader, self).tearDown()
+ self.final_file.close()
+
+
+class TestUpdateReadmeData(BaseMockDir):
+
+ def setUp(self):
+ super(TestUpdateReadmeData, self).setUp()
+ self.readme_file = os.path.join(self.test_dir, "readmeData.json")
+
+ def test_missing_keyword(self):
+ kwargs = dict(extensions="", outputsubfolder="",
+ numberofrules="", sourcesdata="")
+
+ for k in kwargs.keys():
+ bad_kwargs = kwargs.copy()
+ bad_kwargs.pop(k)
+
+ self.assertRaises(KeyError, update_readme_data,
+ self.readme_file, **bad_kwargs)
+
+ def test_add_fields(self):
+ with open(self.readme_file, "w") as f:
+ json.dump({"foo": "bar"}, f)
+
+ kwargs = dict(extensions=None, outputsubfolder="foo",
+ numberofrules=5, sourcesdata="hosts")
+ update_readme_data(self.readme_file, **kwargs)
+
+ expected = {
+ "base": {
+ "location": "foo" + self.sep,
+ "sourcesdata": "hosts",
+ "entries": 5,
+ },
+ "foo": "bar"
+ }
+
+ with open(self.readme_file, "r") as f:
+ actual = json.load(f)
+ self.assertEqual(actual, expected)
+
+ def test_modify_fields(self):
+ with open(self.readme_file, "w") as f:
+ json.dump({"base": "soprano"}, f)
+
+ kwargs = dict(extensions=None, outputsubfolder="foo",
+ numberofrules=5, sourcesdata="hosts")
+ update_readme_data(self.readme_file, **kwargs)
+
+ expected = {
+ "base": {
+ "location": "foo" + self.sep,
+ "sourcesdata": "hosts",
+ "entries": 5,
+ }
+ }
+
+ with open(self.readme_file, "r") as f:
+ actual = json.load(f)
+ self.assertEqual(actual, expected)
+
+ def test_set_extensions(self):
+ with open(self.readme_file, "w") as f:
+ json.dump({}, f)
+
+ kwargs = dict(extensions=["com", "org"], outputsubfolder="foo",
+ numberofrules=5, sourcesdata="hosts")
+ update_readme_data(self.readme_file, **kwargs)
+
+ expected = {
+ "com-org": {
+ "location": "foo" + self.sep,
+ "sourcesdata": "hosts",
+ "entries": 5,
+ }
+ }
+
+ with open(self.readme_file, "r") as f:
+ actual = json.load(f)
+ self.assertEqual(actual, expected)
+
+
class TestMoveHostsFile(BaseStdout):
@mock.patch("os.path.abspath", side_effect=lambda f: f)
("Flushing the DNS cache by restarting "
"NetworkManager.service succeeded")]:
self.assertIn(expected, output)
+
+
+def mock_path_join_robust(*args):
+ # We want to hard-code the backup hosts filename
+ # instead of parametrizing based on current time.
+ if len(args) == 2 and args[1].startswith("hosts-"):
+ return os.path.join(args[0], "hosts-new")
+ else:
+ return os.path.join(*args)
+
+
+class TestRemoveOldHostsFile(BaseMockDir):
+
+ def setUp(self):
+ super(TestRemoveOldHostsFile, self).setUp()
+ self.hosts_file = os.path.join(self.test_dir, "hosts")
+
+ def test_remove_hosts_file(self):
+ old_dir_count = self.dir_count
+
+ with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+ updateHostsFile.BASEDIR_PATH = self.test_dir
+ remove_old_hosts_file(backup=False)
+
+ new_dir_count = old_dir_count + 1
+ self.assertEqual(self.dir_count, new_dir_count)
+
+ with open(self.hosts_file, "r") as f:
+ contents = f.read()
+ self.assertEqual(contents, "")
+
+ def test_remove_hosts_file_exists(self):
+ with open(self.hosts_file, "w") as f:
+ f.write("foo")
+
+ old_dir_count = self.dir_count
+
+ with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+ updateHostsFile.BASEDIR_PATH = self.test_dir
+ remove_old_hosts_file(backup=False)
+
+ new_dir_count = old_dir_count
+ self.assertEqual(self.dir_count, new_dir_count)
+
+ with open(self.hosts_file, "r") as f:
+ contents = f.read()
+ self.assertEqual(contents, "")
+
+ @mock.patch("updateHostsFile.path_join_robust",
+ side_effect=mock_path_join_robust)
+ def test_remove_hosts_file_backup(self, _):
+ with open(self.hosts_file, "w") as f:
+ f.write("foo")
+
+ old_dir_count = self.dir_count
+
+ with self.mock_property("updateHostsFile.BASEDIR_PATH"):
+ updateHostsFile.BASEDIR_PATH = self.test_dir
+ remove_old_hosts_file(backup=True)
+
+ new_dir_count = old_dir_count + 1
+ self.assertEqual(self.dir_count, new_dir_count)
+
+ with open(self.hosts_file, "r") as f:
+ contents = f.read()
+ self.assertEqual(contents, "")
+
+ new_hosts_file = self.hosts_file + "-new"
+
+ with open(new_hosts_file, "r") as f:
+ contents = f.read()
+ self.assertEqual(contents, "foo")
# End File Logic
settings["extensions"] = sorted(list(
set(options["extensions"]).intersection(settings["extensions"])))
- with open(settings["readmedatafilename"], "r") as f:
- settings["readmedata"] = json.load(f)
-
prompt_for_update()
prompt_for_exclusions()
merge_file = create_initial_file()
- remove_old_hosts_file()
+ remove_old_hosts_file(settings["backup"])
+
+ extensions = settings["extensions"]
+ number_of_rules = settings["numberofrules"]
+ output_subfolder = settings["outputsubfolder"]
final_file = remove_dups_and_excl(merge_file)
- write_opening_header(final_file)
+ write_opening_header(final_file, extensions=extensions,
+ numberofrules=number_of_rules,
+ outputsubfolder=output_subfolder,
+ skipstatichosts=settings["skipstatichosts"])
final_file.close()
if settings["ziphosts"]:
- zf = zipfile.ZipFile(path_join_robust(settings["outputsubfolder"],
+ zf = zipfile.ZipFile(path_join_robust(output_subfolder,
"hosts.zip"), mode='w')
- zf.write(path_join_robust(settings["outputsubfolder"], "hosts"),
+ zf.write(path_join_robust(output_subfolder, "hosts"),
compress_type=zipfile.ZIP_DEFLATED, arcname='hosts')
zf.close()
- update_readme_data()
+ update_readme_data(settings["readmedatafilename"],
+ extensions=extensions,
+ numberofrules=number_of_rules,
+ outputsubfolder=output_subfolder,
+ sourcesdata=settings["sourcesdata"])
+
print_success("Success! The hosts file has been saved in folder " +
- settings["outputsubfolder"] + "\nIt contains " +
- "{:,}".format(settings["numberofrules"]) +
+ output_subfolder + "\nIt contains " +
+ "{:,}".format(number_of_rules) +
" unique entries.")
prompt_for_move(final_file)
if settings["auto"] or query_yes_no(prompt):
update_all_sources()
elif not settings["auto"]:
- print("OK, we'll stick with what we've got locally.")
+ print("OK, we'll stick with what we've got locally.")
def prompt_for_exclusions():
continue
# Normalize rule
- hostname, normalized_rule = normalize_rule(stripped_rule)
+ hostname, normalized_rule = normalize_rule(
+ stripped_rule, target_ip=settings["targetip"],
+ keep_domain_comments=settings["keepdomaincomments"])
+
for exclude in exclusions:
if exclude in line:
write_line = False
return final_file
-def normalize_rule(rule):
+def normalize_rule(rule, target_ip, keep_domain_comments):
"""
Standardize and format the rule string provided.
----------
rule : str
The rule whose spelling and spacing we are standardizing.
+ target_ip : str
+ The target IP address for the rule.
+ keep_domain_comments : bool
+ Whether or not to keep comments regarding these domains in
+ the normalized rule.
Returns
-------
- normalized_rule : str
- The rule string with spelling and spacing reformatted.
+ normalized_rule : tuple
+ A tuple of the hostname and the rule string with spelling
+ and spacing reformatted.
"""
- result = re.search(r'^\s*(\d{1,3}\.){3}\d{1,3}\s+([\w\.-]+[a-zA-Z])(.*)',
- rule)
+ regex = r'^\s*(\d{1,3}\.){3}\d{1,3}\s+([\w\.-]+[a-zA-Z])(.*)'
+ result = re.search(regex, rule)
+
if result:
hostname, suffix = result.group(2, 3)
- # Explicitly lowercase and trim the hostname
+ # Explicitly lowercase and trim the hostname.
hostname = hostname.lower().strip()
- if suffix and settings["keepdomaincomments"]:
- # add suffix as comment only, not as a separate host
- return hostname, "%s %s #%s\n" % (settings["targetip"],
- hostname, suffix)
- else:
- return hostname, "%s %s\n" % (settings["targetip"], hostname)
+ rule = "%s %s" % (target_ip, hostname)
+
+ if suffix and keep_domain_comments:
+ rule += " #%s" % suffix
+
+ return hostname, rule + "\n"
+
print("==>%s<==" % rule)
return None, None
return split_line[0] + " " + split_line[1]
-def write_opening_header(final_file):
+def write_opening_header(final_file, **header_params):
"""
Write the header information into the newly-created hosts file.
----------
final_file : file
The file object that points to the newly-created hosts file.
+ header_params : kwargs
+ Dictionary providing additional parameters for populating the header
+ information. Currently, those fields are:
+
+ 1) extensions
+ 2) numberofrules
+ 3) outputsubfolder
+ 4) skipstatichosts
"""
- final_file.seek(0) # reset file pointer
- file_contents = final_file.read() # save content
- final_file.seek(0) # write at the top
+ final_file.seek(0) # Reset file pointer.
+ file_contents = final_file.read() # Save content.
+
+ final_file.seek(0) # Write at the top.
write_data(final_file, "# This hosts file is a merged collection "
"of hosts from reputable sources,\n")
write_data(final_file, "# with a dash of crowd sourcing via Github\n#\n")
write_data(final_file, "# Date: " + time.strftime(
"%B %d %Y", time.gmtime()) + "\n")
- if settings["extensions"]:
+
+ if header_params["extensions"]:
write_data(final_file, "# Extensions added to this file: " + ", ".join(
- settings["extensions"]) + "\n")
- write_data(final_file, "# Number of unique domains: " + "{:,}\n#\n".format(
- settings["numberofrules"]))
+ header_params["extensions"]) + "\n")
+
+ write_data(final_file, ("# Number of unique domains: " +
+ "{:,}\n#\n".format(header_params[
+ "numberofrules"])))
write_data(final_file, "# Fetch the latest version of this file: "
"https://raw.githubusercontent.com/"
"StevenBlack/hosts/master/" +
- path_join_robust(settings["outputsubfolder"], "") + "hosts\n")
+ path_join_robust(header_params["outputsubfolder"],
+ "") + "hosts\n")
write_data(final_file, "# Project home page: https://github.com/"
"StevenBlack/hosts\n#\n")
write_data(final_file, "# ==============================="
"================================\n")
write_data(final_file, "\n")
- if not settings["skipstatichosts"]:
+ if not header_params["skipstatichosts"]:
write_data(final_file, "127.0.0.1 localhost\n")
write_data(final_file, "127.0.0.1 localhost.localdomain\n")
write_data(final_file, "127.0.0.1 local\n")
write_data(final_file, "::1 localhost\n")
write_data(final_file, "fe80::1%lo0 localhost\n")
write_data(final_file, "0.0.0.0 0.0.0.0\n")
+
if platform.system() == "Linux":
write_data(final_file, "127.0.1.1 " + socket.gethostname() + "\n")
write_data(final_file, "127.0.0.53 " + socket.gethostname() + "\n")
+
write_data(final_file, "\n")
preamble = path_join_robust(BASEDIR_PATH, "myhosts")
+
if os.path.isfile(preamble):
with open(preamble, "r") as f:
write_data(final_file, f.read())
final_file.write(file_contents)
-def update_readme_data():
+def update_readme_data(readme_file, **readme_updates):
"""
Update the host and website information provided in the README JSON data.
+
+ Parameters
+ ----------
+ readme_file : str
+ The name of the README file to update.
+ readme_updates : kwargs
+ Dictionary providing additional JSON fields to update before
+ saving the data. Currently, those fields are:
+
+ 1) extensions
+ 2) sourcesdata
+ 3) numberofrules
+ 4) outputsubfolder
"""
extensions_key = "base"
- if settings["extensions"]:
- extensions_key = "-".join(settings["extensions"])
+ extensions = readme_updates["extensions"]
+
+ if extensions:
+ extensions_key = "-".join(extensions)
- generation_data = {"location": path_join_robust(
- settings["outputsubfolder"], ""),
- "entries": settings["numberofrules"],
- "sourcesdata": settings["sourcesdata"]}
- settings["readmedata"][extensions_key] = generation_data
- with open(settings["readmedatafilename"], "w") as f:
- json.dump(settings["readmedata"], f)
+ output_folder = readme_updates["outputsubfolder"]
+ generation_data = {"location": path_join_robust(output_folder, ""),
+ "entries": readme_updates["numberofrules"],
+ "sourcesdata": readme_updates["sourcesdata"]}
+
+ with open(readme_file, "r") as f:
+ readme_data = json.load(f)
+ readme_data[extensions_key] = generation_data
+
+ with open(readme_file, "w") as f:
+ json.dump(readme_data, f)
def move_hosts_file_into_place(final_file):
print_failure("Unable to determine DNS management tool.")
-def remove_old_hosts_file():
+def remove_old_hosts_file(backup):
"""
Remove the old hosts file.
This is a hotfix because merging with an already existing hosts file leads
to artifacts and duplicates.
+
+ Parameters
+ ----------
+ backup : boolean, default False
+ Whether or not to backup the existing hosts file.
"""
old_file_path = path_join_robust(BASEDIR_PATH, "hosts")
- # create if already removed, so remove wont raise an error
+
+ # Create if already removed, so remove won't raise an error.
open(old_file_path, "a").close()
- if settings["backup"]:
+ if backup:
backup_file_path = path_join_robust(BASEDIR_PATH, "hosts-{}".format(
time.strftime("%Y-%m-%d-%H-%M-%S")))