From: xliang Date: Wed, 5 Mar 2025 07:33:23 +0000 (-0800) Subject: General refactor/improvements X-Git-Url: http://git.99rst.org/?a=commitdiff_plain;h=c3a999329c1b718cfbc02ac07af42795ad6de629;p=mullvad-wg-tools.git General refactor/improvements --- diff --git a/wg-mullvad.py b/wg-mullvad.py index a891773..c7e04a8 100755 --- a/wg-mullvad.py +++ b/wg-mullvad.py @@ -16,24 +16,24 @@ from cryptography.hazmat.primitives import serialization _version = '1.1' -def generate_publickey(privatekey): +def generate_publickey(privatekey: str) -> str: private_key_bytes = base64.b64decode(privatekey) private_key = X25519PrivateKey.from_private_bytes(private_key_bytes) public_key = private_key.public_key() public_key_bytes = public_key.public_bytes( encoding=serialization.Encoding.Raw, - format=serialization.PublicFormat.Raw + format=serialization.PublicFormat.Raw, ) wgpublickey = base64.b64encode(public_key_bytes).decode('utf-8') return wgpublickey -def generate_privatekey(): +def generate_privatekey() -> str: privatekey = X25519PrivateKey.generate() private_key_bytes = privatekey.private_bytes( encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) wgprivatekey = base64.b64encode(private_key_bytes).decode('utf-8') return wgprivatekey @@ -81,7 +81,7 @@ class Mullvad: if device: self.create_wg_configs(device, privatekey, multihop_server) - def get_privatekey(self): + def get_privatekey(self) -> str: print(f'Reading settings from: {self._settings_file}') self._config.read(self._settings_file) try: @@ -91,7 +91,7 @@ class Mullvad: print('Solution: add it or remove the file completely to generate a new device') sys.exit(1) - def save_privatekey(self, privatekey): + def save_privatekey(self, privatekey) -> bool: self._settings_file.parent.mkdir(parents=True, exist_ok=True) self._settings_file.touch(mode=0o600, exist_ok=True) with self._settings_file.open('w') as _file: @@ -101,18 +101,19 @@ class Mullvad: self._config.write(_file) return True - def get_webtoken(self): + def get_webtoken(self) -> str: if not self._webtoken: self.generate_webtoken() return self._webtoken - def generate_webtoken(self): - body = {} - body['account_number'] = self._account_number + def generate_webtoken(self) -> str: + body = { + 'account_number': self._account_number, + } req = urllib.request.Request('https://api.mullvad.net/auth/v1/webtoken') req.add_header('Content-Type', 'application/json') - response = urllib.request.urlopen(req, json.dumps(body).encode()).read() - data = json.loads(response) + with urllib.request.urlopen(req, json.dumps(body).encode()) as response: + data = json.load(response) self._webtoken = data['access_token'] def api(self, url, body=None): @@ -123,17 +124,15 @@ class Mullvad: if body: req.add_header('Content-Type', 'application/json') - response = urllib.request.urlopen(req, json.dumps(body).encode()) - else: - response = urllib.request.urlopen(req) - return self.get_response(response) + with urllib.request.urlopen(req, data=json.dumps(body).encode() if body else None) as response: + return self.get_response(response) def get_response(self, response): if response.headers.get('Content-Encoding') == 'gzip': data = json.loads(gzip.decompress(response.read())) else: - data = json.loads(response.read()) + data = json.load(response) return data def get_device(self, publickey): @@ -148,7 +147,7 @@ class Mullvad: print(f'Device is not registered: {publickey}') return None except urllib.error.HTTPError as e: - error_message = json.loads(e.read()) + error_message = json.load(e) _code = error_message.get('code') _message = error_message.get('detail') if _message: @@ -159,15 +158,16 @@ class Mullvad: def create_device(self, publickey): print(f'Trying to create device: {publickey}') - body = {} - body['pubkey'] = publickey - body['hijack_dns'] = self._wg_hijack_dns + body = { + 'pubkey': publickey, + 'hijack_dns': self._wg_hijack_dns, + } try: response = self.api('https://api.mullvad.net/accounts/v1/devices', body) - print(f'Device created: ({response['name']}) {response['pubkey']}') + print(f'Device created: ({response["name"]}) {response["pubkey"]}') return response except urllib.error.HTTPError as e: - error_message = json.loads(e.read()) + error_message = json.load(e) _code = error_message.get('code') _message = error_message.get('detail') if _message: @@ -182,8 +182,8 @@ class Mullvad: try: req = urllib.request.Request('https://api.mullvad.net/public/relays/wireguard/v2/') req.add_header('Accept-Encoding', 'gzip') - response = urllib.request.urlopen(req) - data = self.get_response(response) + with urllib.request.urlopen(req) as response: + data = self.get_response(response) return data['wireguard'] except urllib.error.HTTPError as e: error_message = self.get_response(e) @@ -194,20 +194,20 @@ class Mullvad: try: req = urllib.request.Request('https://api.mullvad.net/www/relays/all') req.add_header('Accept-Encoding', 'gzip') - response = urllib.request.urlopen(req) - data = self.get_response(response) + with urllib.request.urlopen(req) as response: + data = self.get_response(response) return [i for i in data if i['type'] == 'wireguard'] except urllib.error.HTTPError as e: error_message = self.get_response(e) print(error_message) sys.exit(1) - def filter(self, data, _filter): + def filter(self, data: list, _filter) -> list: if _filter: - return [d for d in data if _filter in d['hostname']] + return [d for d in data if d['hostname'].startswith(_filter)] return data - def create_wg_configs(self, device, privatekey, multihop_server): + def create_wg_configs(self, device, privatekey, multihop_server) -> None: wg = self.get_wireguard_info() output_dir = pathlib.Path(self._output_dir).expanduser() output_dir.mkdir(exist_ok=True, parents=True) @@ -222,65 +222,61 @@ class Mullvad: config.set('Interface', 'dns', ','.join([wg['ipv4_gateway'], wg['ipv6_gateway']])) config.add_section('Peer') + relays = self.filter(self.get_multihop_info(), self._filter) + if not relays: + print(f'No relays matching filter: {self._filter}') + return + + print(f'Creating files in: {output_dir}') + for relay in relays: + self.create_wg_config(config, relay, multihop_server) + + def create_wg_config(self, config, relay, multihop_server=None) -> None: + output_dir = pathlib.Path(self._output_dir).expanduser() + _hostname = relay['hostname'] if multihop_server: - _servername = multihop_server["hostname"] - relays = self.filter(self.get_multihop_info(), self._filter) - if relays: - print(f'Creating files in: {output_dir}') - for relay in relays: - _hostname = relay['hostname'] - _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}-via-{_servername}.conf') - _filepath.touch(mode=0o600, exist_ok=True) - with _filepath.open('w') as _file: - config.set('Peer', '#owned', str(relay['owned'])) - config.set('Peer', '#provider', relay['provider']) - config.set('Peer', 'publickey', relay['pubkey']) - config.set('Peer', 'allowedips', '0.0.0.0/0,::/0') - if self._wg_relay_ipv6: - wg_relay_address = multihop_server['ipv6_addr_in'] - else: - wg_relay_address = multihop_server['ipv4_addr_in'] - config.set('Peer', 'endpoint', f'{wg_relay_address}:{relay["multihop_port"]}') - config.write(_file) - else: - print(f'No relays matching filter: {self._filter}') + _servername = multihop_server['hostname'] + _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}-via-{_servername}.conf') else: - relays = self.filter(wg['relays'], self._filter) - if relays: - print(f'Creating files in: {output_dir}') - for relay in relays: - _hostname = relay['hostname'] - _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}.conf') - _filepath.touch(mode=0o600, exist_ok=True) - with _filepath.open('w') as _file: - config.set('Peer', '#owned', str(relay['owned'])) - config.set('Peer', '#provider', relay['provider']) - config.set('Peer', 'publickey', relay['public_key']) - config.set('Peer', 'allowedips', '0.0.0.0/0,::/0') - if self._wg_relay_ipv6: - wg_relay_address = relay['ipv6_addr_in'] - else: - wg_relay_address = relay['ipv4_addr_in'] - config.set('Peer', 'endpoint', f'{wg_relay_address}:{self._wg_relay_port}') - config.write(_file) - else: - print(f'No relays matching filter: {self._filter}') + _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}.conf') + + _filepath.touch(mode=0o600, exist_ok=True) + + if multihop_server: + remote_server = multihop_server + remote_port = relay['multihop_port'] + else: + remote_server = relay + remote_port = self._wg_relay_port + + if self._wg_relay_ipv6: + wg_relay_address = remote_server['ipv6_addr_in'] + else: + wg_relay_address = remote_server['ipv4_addr_in'] + + with _filepath.open('w') as _file: + config.set('Peer', '#owned', str(relay['owned'])) + config.set('Peer', '#provider', relay['provider']) + config.set('Peer', 'publickey', relay['pubkey']) + config.set('Peer', 'allowedips', '0.0.0.0/0,::/0') + config.set('Peer', 'endpoint', f'{wg_relay_address}:{remote_port}') + config.write(_file) -def validate_account(value): +def validate_account(value: str) -> str: if not value.isdigit(): - raise argparse.ArgumentTypeError("The string must contain only numbers.") + raise argparse.ArgumentTypeError('The string must contain only numbers.') return value -def validate_port(value): +def validate_port(value: str) -> int: try: port = int(value) - except ValueError: - raise argparse.ArgumentTypeError(f"Port must be an integer, but got '{value}'.") + except ValueError as e: + raise argparse.ArgumentTypeError(f'Port must be an integer, but got \'{value}\'.') from e if port < 1 or port > 65535: - raise argparse.ArgumentTypeError("Port number must be between 1 and 65535.") + raise argparse.ArgumentTypeError('Port number must be between 1 and 65535.') return port @@ -302,14 +298,14 @@ def main(): default='~/.config/mullvad/wg0', help='directory to write settings') parser.add_argument( '--wg-relay-port', dest='wg_relay_port', action='store', type=validate_port, - default=51820, help='use custom port for relays in wireguard configs') + default=51820, help='use custom port for relays in WireGuard configs') parser.add_argument( '--dns', dest='wg_dns', action='store', nargs='+', type=ipaddress.ip_address, - help='use custom dns server in wireguard configs') + help='use custom dns server in WireGuard configs') parser.add_argument( '--hijack-dns', dest='wg_hijack_dns', help='activate hijack dns when creating device', action='store_true') parser.add_argument( - '--ipv6', dest='wg_relay_ipv6', help='use ipv6 address for relays in wireguard configs', action='store_true') + '--ipv6', dest='wg_relay_ipv6', help='use ipv6 address for relays in WireGuard configs', action='store_true') parser.add_argument( '--multihop-server', dest='wg_multihop_server', action='store', default=None, help='use multihop server') parser.add_argument(