From: gfyoung Date: Fri, 23 Jun 2017 07:25:40 +0000 (-0700) Subject: Add unittests for updateHostsFile.py X-Git-Url: http://git.99rst.org/?a=commitdiff_plain;h=dce24af399cf2ddc81a8343b490c84fec7cd9d88;p=stevenblack-hosts.git Add unittests for updateHostsFile.py --- diff --git a/.gitignore b/.gitignore index e6993e4ce..ec766b793 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ readmeData.json hosts hosts-* /web.config +/__pycache__ diff --git a/.travis.yml b/.travis.yml index 37732ad70..379356929 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,8 @@ python: - "3.6" install: - pip install flake8 + pip install flake8 mock nose script: - flake8 + - nosetests + - flake8 diff --git a/readme_template.md b/readme_template.md index 656cb66e7..2bc9425e2 100644 --- a/readme_template.md +++ b/readme_template.md @@ -8,6 +8,20 @@ Like this: git clone --depth 5 https://github.com/StevenBlack/hosts.git +To run unit tests, in the top level directory, just run: + + python testUpdateHostsFile.py + +You can also install `nose` with `pip` and then just run: + + nosetests + +**Note** if you are using Python 2, you must first install the `mock` library: + + pip install mock + +Afterwards, you can follow the instructions above. + # Unified hosts file @EXTENSIONS_HEADER@ This repository consolidates several reputable `hosts` files, and merges them diff --git a/testUpdateHostsFile.py b/testUpdateHostsFile.py new file mode 100644 index 000000000..e875737d8 --- /dev/null +++ b/testUpdateHostsFile.py @@ -0,0 +1,790 @@ +#!/usr/bin/env python + +# Script by gfyoung +# https://github.com/gfyoung +# +# Python script for testing updateHostFiles.py + +from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache, + gather_custom_exclusions, get_defaults, + get_file_by_url, is_valid_domain_format, + move_hosts_file_into_place, normalize_rule, + path_join_robust, print_failure, print_success, + supports_color, query_yes_no, recursive_glob, + strip_rule, write_data) +import updateHostsFile +import unittest +import locale +import sys +import os + +if PY3: + from io import BytesIO, StringIO + import unittest.mock as mock + unicode = str +else: + from StringIO import StringIO + BytesIO = StringIO + import mock + + +# Base Test Classes +class Base(unittest.TestCase): + + @staticmethod + def mock_property(name): + return mock.patch(name, new_callable=mock.PropertyMock) + + @property + def sep(self): + return "\\" if sys.platform == "win32" else "/" + + +class BaseStdout(Base): + + def setUp(self): + sys.stdout = StringIO() + + def tearDown(self): + sys.stdout.close() + sys.stdout = sys.__stdout__ +# End Base Test Classes + + +# Project Settings +class TestGetDefaults(Base): + + def test_get_defaults(self): + with self.mock_property("updateHostsFile.BASEDIR_PATH"): + updateHostsFile.BASEDIR_PATH = "foo" + actual = get_defaults() + expected = {"numberofrules": 0, + "datapath": "foo" + self.sep + "data", + "freshen": True, + "replace": False, + "backup": False, + "skipstatichosts": False, + "keepdomaincomments": False, + "extensionspath": "foo" + self.sep + "extensions", + "extensions": [], + "outputsubfolder": "", + "hostfilename": "hosts", + "targetip": "0.0.0.0", + "ziphosts": False, + "sourcedatafilename": "update.json", + "sourcesdata": [], + "readmefilename": "readme.md", + "readmetemplate": ("foo" + self.sep + + "readme_template.md"), + "readmedata": {}, + "readmedatafilename": ("foo" + self.sep + + "readmeData.json"), + "exclusionpattern": "([a-zA-Z\d-]+\.){0,}", + "exclusionregexs": [], + "exclusions": [], + "commonexclusions": ["hulu.com"], + "blacklistfile": "foo" + self.sep + "blacklist", + "whitelistfile": "foo" + self.sep + "whitelist"} + self.assertDictEqual(actual, expected) +# End Project Settings + + +# Exclusion Logic +class TestGatherCustomExclusions(BaseStdout): + + # Can only test in the invalid domain case + # because of the settings global variable. + @mock.patch("updateHostsFile.raw_input", side_effect=["foo", "no"]) + @mock.patch("updateHostsFile.is_valid_domain_format", return_value=False) + def test_basic(self, *_): + gather_custom_exclusions() + + expected = "Do you have more domains you want to enter? [Y/n]" + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("updateHostsFile.raw_input", side_effect=["foo", "yes", + "bar", "no"]) + @mock.patch("updateHostsFile.is_valid_domain_format", return_value=False) + def test_multiple(self, *_): + gather_custom_exclusions() + + expected = ("Do you have more domains you want to enter? [Y/n] " + "Do you have more domains you want to enter? [Y/n]") + output = sys.stdout.getvalue() + self.assertIn(expected, output) +# End Exclusion Logic + + +# File Logic +class TestNormalizeRule(BaseStdout): + + # Can only test non-matches because they don't + # interact with the settings global variable. + def test_no_match(self): + for rule in ["foo", "128.0.0.1", "bar.com/usa", "0.0.0 google", + "0.1.2.3.4 foo/bar", "twitter.com"]: + self.assertEqual(normalize_rule(rule), (None, None)) + + output = sys.stdout.getvalue() + sys.stdout = StringIO() + + expected = "==>" + rule + "<==" + self.assertIn(expected, output) + + +class TestStripRule(Base): + + def test_strip_empty(self): + for line in ["0.0.0.0", "domain.com", "foo"]: + output = strip_rule(line) + self.assertEqual(output, "") + + def test_strip_exactly_two(self): + for line in ["0.0.0.0 twitter.com", "127.0.0.1 facebook.com", + "8.8.8.8 google.com", "1.2.3.4 foo.bar.edu"]: + output = strip_rule(line) + self.assertEqual(output, line) + + def test_strip_more_than_two(self): + for line in ["0.0.0.0 twitter.com", "127.0.0.1 facebook.com", + "8.8.8.8 google.com", "1.2.3.4 foo.bar.edu"]: + output = strip_rule(line + " # comments here galore") + self.assertEqual(output, line) + + +class TestMoveHostsFile(BaseStdout): + + @mock.patch("os.path.abspath", side_effect=lambda f: f) + def test_move_hosts_no_name(self, _): + with self.mock_property("os.name"): + os.name = "foo" + + mock_file = mock.Mock(name="foo") + move_hosts_file_into_place(mock_file) + + expected = "" + output = sys.stdout.getvalue() + + self.assertEqual(output, expected) + + @mock.patch("os.path.abspath", side_effect=lambda f: f) + def test_move_hosts_windows(self, _): + with self.mock_property("os.name"): + os.name = "nt" + + mock_file = mock.Mock(name="foo") + move_hosts_file_into_place(mock_file) + + expected = ("Automatically moving the hosts " + "file in place is not yet supported.\n" + "Please move the generated file to " + "%SystemRoot%\system32\drivers\etc\hosts") + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("os.path.abspath", side_effect=lambda f: f) + @mock.patch("subprocess.call", return_value=0) + def test_move_hosts_posix(self, *_): + with self.mock_property("os.name"): + os.name = "posix" + + mock_file = mock.Mock(name="foo") + move_hosts_file_into_place(mock_file) + + expected = ("Moving the file requires administrative " + "privileges. You might need to enter your password.") + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("os.path.abspath", side_effect=lambda f: f) + @mock.patch("subprocess.call", return_value=1) + def test_move_hosts_posix_fail(self, *_): + with self.mock_property("os.name"): + os.name = "posix" + + mock_file = mock.Mock(name="foo") + move_hosts_file_into_place(mock_file) + + expected = "Moving the file failed." + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + +class TestFlushDnsCache(BaseStdout): + + @mock.patch("subprocess.call", return_value=0) + def test_flush_darwin(self, _): + with self.mock_property("platform.system") as obj: + obj.return_value = "Darwin" + flush_dns_cache() + + expected = ("Flushing the DNS cache to utilize new hosts " + "file...\nFlushing the DNS cache requires " + "administrative privileges. You might need to " + "enter your password.") + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("subprocess.call", return_value=1) + def test_flush_darwin_fail(self, _): + with self.mock_property("platform.system") as obj: + obj.return_value = "Darwin" + flush_dns_cache() + + expected = "Flushing the DNS cache failed." + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + def test_flush_windows(self): + with self.mock_property("platform.system") as obj: + obj.return_value = "win32" + + with self.mock_property("os.name"): + os.name = "nt" + flush_dns_cache() + + expected = ("Automatically flushing the DNS cache is " + "not yet supported.\nPlease copy and paste " + "the command 'ipconfig /flushdns' in " + "administrator command prompt after running " + "this script.") + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("os.path.isfile", return_value=False) + def test_flush_no_tool(self, _): + with self.mock_property("platform.system") as obj: + obj.return_value = "Linux" + + with self.mock_property("os.name"): + os.name = "posix" + flush_dns_cache() + + expected = "Unable to determine DNS management tool." + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("os.path.isfile", side_effect=[True] + [False] * 10) + @mock.patch("subprocess.call", return_value=0) + def test_flush_posix(self, *_): + with self.mock_property("platform.system") as obj: + obj.return_value = "Linux" + + with self.mock_property("os.name"): + os.name = "posix" + flush_dns_cache() + + expected = ("Flushing the DNS cache by " + "restarting nscd succeeded") + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("os.path.isfile", side_effect=[True] + [False] * 10) + @mock.patch("subprocess.call", return_value=1) + def test_flush_posix_fail(self, *_): + with self.mock_property("platform.system") as obj: + obj.return_value = "Linux" + + with self.mock_property("os.name"): + os.name = "posix" + flush_dns_cache() + + expected = ("Flushing the DNS cache by " + "restarting nscd failed") + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + @mock.patch("os.path.isfile", side_effect=[True, False, + True] + [False] * 10) + @mock.patch("subprocess.call", side_effect=[1, 0]) + def test_flush_posix_fail_then_succeed(self, *_): + with self.mock_property("platform.system") as obj: + obj.return_value = "Linux" + + with self.mock_property("os.name"): + os.name = "posix" + flush_dns_cache() + + output = sys.stdout.getvalue() + for expected in [("Flushing the DNS cache by " + "restarting nscd failed"), + ("Flushing the DNS cache by restarting " + "NetworkManager.service succeeded")]: + self.assertIn(expected, output) +# End File Logic + + +# Helper Functions +def mock_url_open(url): + """ + Mock of `urlopen` that returns the url in a `BytesIO` stream. + + Parameters + ---------- + url : str + The URL associated with the file to open. + + Returns + ------- + bytes_stream : BytesIO + The `url` input wrapped in a `BytesIO` stream. + """ + + return BytesIO(url) + + +def mock_url_open_fail(_): + """ + Mock of `urlopen` that fails with an Exception. + """ + + raise Exception() + + +def mock_url_open_read_fail(_): + """ + Mock of `urlopen` that returns an object that fails on `read`. + + Returns + ------- + file_mock : mock.Mock + A mock of a file object that fails when reading. + """ + + def fail_read(): + raise Exception() + + m = mock.Mock() + + m.read = fail_read + return m + + +def mock_url_open_decode_fail(_): + """ + Mock of `urlopen` that returns an object that fails on during decoding + the output of `urlopen`. + + Returns + ------- + file_mock : mock.Mock + A mock of a file object that fails when decoding the output. + """ + + def fail_decode(_): + raise Exception() + + def read(): + s = mock.Mock() + s.decode = fail_decode + + return s + + m = mock.Mock() + m.read = read + return m + + +class GetFileByUrl(BaseStdout): + + @mock.patch("updateHostsFile.urlopen", + side_effect=mock_url_open) + def test_read_url(self, _): + url = b"www.google.com" + + expected = "www.google.com" + actual = get_file_by_url(url) + + self.assertEqual(actual, expected) + + @mock.patch("updateHostsFile.urlopen", + side_effect=mock_url_open_fail) + def test_read_url_fail(self, _): + url = b"www.google.com" + self.assertIsNone(get_file_by_url(url)) + + expected = "Problem getting file:" + output = sys.stdout.getvalue() + + self.assertIn(expected, output) + + @mock.patch("updateHostsFile.urlopen", + side_effect=mock_url_open_read_fail) + def test_read_url_read_fail(self, _): + url = b"www.google.com" + self.assertIsNone(get_file_by_url(url)) + + expected = "Problem getting file:" + output = sys.stdout.getvalue() + + self.assertIn(expected, output) + + @mock.patch("updateHostsFile.urlopen", + side_effect=mock_url_open_decode_fail) + def test_read_url_decode_fail(self, _): + url = b"www.google.com" + self.assertIsNone(get_file_by_url(url)) + + expected = "Problem getting file:" + output = sys.stdout.getvalue() + + self.assertIn(expected, output) + + +class TestWriteData(Base): + + def test_write_basic(self): + f = BytesIO() + + data = "foo" + write_data(f, data) + + expected = b"foo" + actual = f.getvalue() + + self.assertEqual(actual, expected) + + def test_write_unicode(self): + f = BytesIO() + + data = u"foo" + write_data(f, data) + + expected = b"foo" + actual = f.getvalue() + + self.assertEqual(actual, expected) + + +class TestQueryYesOrNo(BaseStdout): + + def test_invalid_default(self): + for invalid_default in ["foo", "bar", "baz", 1, 2, 3]: + self.assertRaises(ValueError, query_yes_no, "?", invalid_default) + + @mock.patch("updateHostsFile.raw_input", side_effect=["yes"] * 3) + def test_valid_default(self, _): + for valid_default, expected in [(None, "[y/n]"), ("yes", "[Y/n]"), + ("no", "[y/N]")]: + self.assertTrue(query_yes_no("?", valid_default)) + + output = sys.stdout.getvalue() + sys.stdout = StringIO() + + self.assertIn(expected, output) + + @mock.patch("updateHostsFile.raw_input", side_effect=([""] * 2)) + def test_use_valid_default(self, _): + for valid_default in ["yes", "no"]: + expected = (valid_default == "yes") + actual = query_yes_no("?", valid_default) + + self.assertEqual(actual, expected) + + @mock.patch("updateHostsFile.raw_input", side_effect=["no", "NO", "N", + "n", "No", "nO"]) + def test_valid_no(self, _): + self.assertFalse(query_yes_no("?", None)) + + @mock.patch("updateHostsFile.raw_input", side_effect=["yes", "YES", "Y", + "yeS", "y", "YeS", + "yES", "YEs"]) + def test_valid_yes(self, _): + self.assertTrue(query_yes_no("?", None)) + + @mock.patch("updateHostsFile.raw_input", side_effect=["foo", "yes", + "foo", "no"]) + def test_invalid_then_valid(self, _): + expected = "Please respond with 'yes' or 'no'" + + # The first time, we respond "yes" + self.assertTrue(query_yes_no("?", None)) + + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + sys.stdout = StringIO() + + # The second time, we respond "no" + self.assertFalse(query_yes_no("?", None)) + + output = sys.stdout.getvalue() + self.assertIn(expected, output) + + +class TestIsValidDomainFormat(BaseStdout): + + def test_empty_domain(self): + self.assertFalse(is_valid_domain_format("")) + + output = sys.stdout.getvalue() + expected = "You didn't enter a domain. Try again." + + self.assertTrue(expected in output) + + def test_invalid_domain(self): + expected = ("Do not include www.domain.com or " + "http(s)://domain.com. Try again.") + + for invalid_domain in ["www.subdomain.domain", "https://github.com", + "http://www.google.com"]: + self.assertFalse(is_valid_domain_format(invalid_domain)) + + output = sys.stdout.getvalue() + sys.stdout = StringIO() + + self.assertIn(expected, output) + + def test_valid_domain(self): + for valid_domain in ["github.com", "travis.org", "twitter.com"]: + self.assertTrue(is_valid_domain_format(valid_domain)) + + output = sys.stdout.getvalue() + sys.stdout = StringIO() + + self.assertEqual(output, "") + + +def mock_walk(stem): + """ + Mock method for `os.walk`. + + Please refer to the documentation of `os.walk` for information about + the provided parameters. + """ + + files = ["foo.txt", "bar.bat", "baz.py", "foo/foo.c", "foo/bar.doc", + "foo/baz/foo.py", "bar/foo/baz.c", "bar/bar/foo.bat"] + + if stem == ".": + stem = "" + + matches = [] + + for f in files: + if not stem or f.startswith(stem + "/"): + matches.append(("", "_", [f])) + + return matches + + +class TestRecursiveGlob(Base): + + @staticmethod + def sorted_recursive_glob(stem, file_pattern): + actual = recursive_glob(stem, file_pattern) + actual.sort() + + return actual + + @mock.patch("os.walk", side_effect=mock_walk) + def test_all_match(self, _): + with self.mock_property("sys.version_info"): + sys.version_info = (2, 6) + + expected = ["bar.bat", "bar/bar/foo.bat", + "bar/foo/baz.c", "baz.py", + "foo.txt", "foo/bar.doc", + "foo/baz/foo.py", "foo/foo.c"] + actual = self.sorted_recursive_glob("*", "*") + self.assertListEqual(actual, expected) + + expected = ["bar/bar/foo.bat", "bar/foo/baz.c"] + actual = self.sorted_recursive_glob("bar", "*") + self.assertListEqual(actual, expected) + + expected = ["foo/bar.doc", "foo/baz/foo.py", "foo/foo.c"] + actual = self.sorted_recursive_glob("foo", "*") + self.assertListEqual(actual, expected) + + @mock.patch("os.walk", side_effect=mock_walk) + def test_file_ending(self, _): + with self.mock_property("sys.version_info"): + sys.version_info = (2, 6) + + expected = ["foo/baz/foo.py"] + actual = self.sorted_recursive_glob("foo", "*.py") + self.assertListEqual(actual, expected) + + expected = ["bar/foo/baz.c", "foo/foo.c"] + actual = self.sorted_recursive_glob("*", "*.c") + self.assertListEqual(actual, expected) + + expected = [] + actual = self.sorted_recursive_glob("*", ".xlsx") + self.assertListEqual(actual, expected) + + +def mock_path_join(*_): + """ + Mock method for `os.path.join`. + + Please refer to the documentation of `os.path.join` for information about + the provided parameters. + """ + + raise UnicodeDecodeError("foo", b"", 1, 5, "foo") + + +class TestPathJoinRobust(Base): + + def test_basic(self): + expected = "path1" + actual = path_join_robust("path1") + self.assertEqual(actual, expected) + + actual = path_join_robust(u"path1") + self.assertEqual(actual, expected) + + def test_join(self): + for i in range(1, 4): + paths = ["pathNew"] * i + expected = "path1" + (self.sep + "pathNew") * i + actual = path_join_robust("path1", *paths) + + self.assertEqual(actual, expected) + + def test_join_unicode(self): + for i in range(1, 4): + paths = [u"pathNew"] * i + expected = "path1" + (self.sep + "pathNew") * i + actual = path_join_robust("path1", *paths) + + self.assertEqual(actual, expected) + + @mock.patch("os.path.join", side_effect=mock_path_join) + def test_join_error(self, _): + self.assertRaises(locale.Error, path_join_robust, "path") + + +# Colors +class TestSupportsColor(BaseStdout): + + def test_posix(self): + with self.mock_property("sys.platform"): + sys.platform = "Linux" + + with self.mock_property("sys.stdout.isatty") as obj: + obj.return_value = True + self.assertTrue(supports_color()) + + def test_pocket_pc(self): + with self.mock_property("sys.platform"): + sys.platform = "Pocket PC" + self.assertFalse(supports_color()) + + def test_windows_no_ansicon(self): + with self.mock_property("sys.platform"): + sys.platform = "win32" + + with self.mock_property("os.environ"): + os.environ = [] + + self.assertFalse(supports_color()) + + def test_windows_ansicon(self): + with self.mock_property("sys.platform"): + sys.platform = "win32" + + with self.mock_property("os.environ"): + os.environ = ["ANSICON"] + + with self.mock_property("sys.stdout.isatty") as obj: + obj.return_value = True + self.assertTrue(supports_color()) + + def test_no_isatty_attribute(self): + with self.mock_property("sys.platform"): + sys.platform = "Linux" + + with self.mock_property("sys.stdout"): + sys.stdout = list() + self.assertFalse(supports_color()) + + def test_no_isatty(self): + with self.mock_property("sys.platform"): + sys.platform = "Linux" + + with self.mock_property("sys.stdout.isatty") as obj: + obj.return_value = False + self.assertFalse(supports_color()) + + +class TestColorize(Base): + + def setUp(self): + self.text = "house" + self.colors = ["red", "orange", "yellow", + "green", "blue", "purple"] + + @mock.patch("updateHostsFile.supports_color", return_value=False) + def test_colorize_no_support(self, _): + for color in self.colors: + expected = self.text + actual = colorize(self.text, color) + + self.assertEqual(actual, expected) + + @mock.patch("updateHostsFile.supports_color", return_value=True) + def test_colorize_support(self, _): + for color in self.colors: + expected = color + self.text + Colors.ENDC + actual = colorize(self.text, color) + + self.assertEqual(actual, expected) + + +class TestPrintSuccess(BaseStdout): + + def setUp(self): + super(TestPrintSuccess, self).setUp() + self.text = "house" + + @mock.patch("updateHostsFile.supports_color", return_value=False) + def test_print_success_no_support(self, _): + print_success(self.text) + + expected = self.text + "\n" + actual = sys.stdout.getvalue() + + self.assertEqual(actual, expected) + + @mock.patch("updateHostsFile.supports_color", return_value=True) + def test_print_success_support(self, _): + print_success(self.text) + + expected = Colors.SUCCESS + self.text + Colors.ENDC + "\n" + actual = sys.stdout.getvalue() + + self.assertEqual(actual, expected) + + +class TestPrintFailure(BaseStdout): + + def setUp(self): + super(TestPrintFailure, self).setUp() + self.text = "house" + + @mock.patch("updateHostsFile.supports_color", return_value=False) + def test_print_failure_no_support(self, _): + print_failure(self.text) + + expected = self.text + "\n" + actual = sys.stdout.getvalue() + + self.assertEqual(actual, expected) + + @mock.patch("updateHostsFile.supports_color", return_value=True) + def test_print_failure_support(self, _): + print_failure(self.text) + + expected = Colors.FAIL + self.text + Colors.ENDC + "\n" + actual = sys.stdout.getvalue() + + self.assertEqual(actual, expected) +# End Helper Functions + + +if __name__ == "__main__": + unittest.main() diff --git a/updateHostsFile.py b/updateHostsFile.py index 59c9b6d61..2fdbcd37e 100644 --- a/updateHostsFile.py +++ b/updateHostsFile.py @@ -33,7 +33,7 @@ if PY3: raw_input = input else: # Python 2 from urllib2 import urlopen - + raw_input = raw_input # noqa # Syntactic sugar for "sudo" command in UNIX / Linux SUDO = "/usr/bin/sudo" @@ -185,7 +185,7 @@ def prompt_for_update(): if not os.path.isfile(path_join_robust(BASEDIR_PATH, "hosts")): try: open(path_join_robust(BASEDIR_PATH, "hosts"), "w+").close() - except: + except Exception: print_failure("ERROR: No 'hosts' file in the folder," "try creating one manually") @@ -363,7 +363,7 @@ def update_all_sources(): settings["hostfilename"]), "wb") write_data(hosts_file, updated_file) hosts_file.close() - except: + except Exception: print("Error in updating source: ", update_url) # End Update Logic @@ -763,7 +763,7 @@ def get_file_by_url(url): try: f = urlopen(url) return f.read().decode("UTF-8") - except: + except Exception: print("Problem getting file: ", url)