Refactor & minor bug fix
authorxliang <redacted>
Sat, 15 Mar 2025 04:21:40 +0000 (21:21 -0700)
committerkalken <redacted>
Tue, 18 Mar 2025 10:47:28 +0000 (11:47 +0100)
wg-mullvad.py

index 88fa8758238be5326c852eaf3f3d3b31d3f31b98..84eefdc4b3fa9280d4c3c6168d774df8c4775ed5 100755 (executable)
@@ -1,19 +1,21 @@
 #!/usr/bin/env python3
 
-import urllib.request
-import configparser
 import argparse
-import pathlib
-import json
-import sys
-import ipaddress
 import base64
+import configparser
+import functools
 import gzip
-from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
+import ipaddress
+import json
+import pathlib
+import sys
+import urllib.request
+
 from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
 
 
-_version = '1.1.1'
+_version = '1.1.2'
 
 
 def generate_publickey(privatekey: str) -> str:
@@ -24,41 +26,159 @@ def generate_publickey(privatekey: str) -> str:
         encoding=serialization.Encoding.Raw,
         format=serialization.PublicFormat.Raw,
     )
-    wgpublickey = base64.b64encode(public_key_bytes).decode('utf-8')
-    return wgpublickey
+    return base64.b64encode(public_key_bytes).decode('utf-8')
 
 
 def generate_privatekey() -> str:
-    privatekey = X25519PrivateKey.generate()
-    private_key_bytes = privatekey.private_bytes(
+    private_key = X25519PrivateKey.generate()
+    private_key_bytes = private_key.private_bytes(
         encoding=serialization.Encoding.Raw,
         format=serialization.PrivateFormat.Raw,
         encryption_algorithm=serialization.NoEncryption(),
     )
-    wgprivatekey = base64.b64encode(private_key_bytes).decode('utf-8')
-    return wgprivatekey
+    return base64.b64encode(private_key_bytes).decode('utf-8')
+
+
+class MullvadApi:
+    HOST = 'https://api.mullvad.net'
+
+    def __init__(self, account_number):
+        self.account_number = account_number
+
+    def new_device(self, public_key, hijack_dns):
+        body = {
+            'pubkey': public_key,
+            'hijack_dns': hijack_dns,
+        }
+
+        return self._api(f'{MullvadApi.HOST}/accounts/v1/devices', body)
+
+    def list_devices(self):
+        return self._api(f'{MullvadApi.HOST}/accounts/v1/devices')
+
+    @functools.cached_property
+    def web_token(self):
+        body = {
+            'account_number': self.account_number,
+        }
+        req = urllib.request.Request(f'{MullvadApi.HOST}/auth/v1/webtoken')
+        req.add_header('Content-Type', 'application/json')
+        with urllib.request.urlopen(req, json.dumps(body).encode()) as response:
+            data = json.load(response)
+        return data['access_token']
+
+    def _api(self, url, body=None):
+        req = urllib.request.Request(url)
+        req.add_header('Authorization', f'Bearer {self.web_token}')
+        req.add_header('Accept-Encoding', 'gzip')
+
+        if body:
+            req.add_header('Content-Type', 'application/json')
+
+        with urllib.request.urlopen(req, data=json.dumps(body).encode() if body else None) as response:
+            return self.get_response(response)
+
+    @functools.cache
+    @staticmethod
+    def wireguard_relays():
+        req = urllib.request.Request(f'{MullvadApi.HOST}/public/relays/wireguard/v2/')
+        req.add_header('Accept-Encoding', 'gzip')
+        with urllib.request.urlopen(req) as response:
+            data = MullvadApi.get_response(response)
+        return data['wireguard']
+
+    @functools.cache
+    @staticmethod
+    def multihop_info():
+        req = urllib.request.Request(f'{MullvadApi.HOST}/www/relays/all')
+        req.add_header('Accept-Encoding', 'gzip')
+        with urllib.request.urlopen(req) as response:
+            data = MullvadApi.get_response(response)
+        return [i for i in data if i['type'] == 'wireguard']
+
+    @staticmethod
+    def get_response(response):
+        if response.headers.get('Content-Encoding') == 'gzip':
+            return json.loads(gzip.decompress(response.read()))
+        else:
+            return json.load(response)
+
+
+class MullvadConfig:
+    def __init__(self, output_dir, wg_dns, wg_relay_port, wg_relay_ipv6):
+        self.output_dir = output_dir
+        self.wg_dns = wg_dns
+        self.wg_relay_port = wg_relay_port
+        self.wg_relay_ipv6 = wg_relay_ipv6
+
+    def create_wg_configs(self, relays, device, privatekey, multihop_server) -> None:
+        wg = MullvadApi.wireguard_relays()
+        output_dir = pathlib.Path(self.output_dir).expanduser()
+        output_dir.mkdir(exist_ok=True, parents=True)
+        config = configparser.ConfigParser()
+        config.add_section('Interface')
+        config.set('Interface', '#device', device['name'])
+        config.set('Interface', 'privateKey', privatekey)
+        config.set('Interface', 'address', ','.join([device['ipv4_address'], device['ipv6_address']]))
+        if self.wg_dns:
+            config.set('Interface', 'dns', ','.join([str(x) for x in self.wg_dns]))
+        else:
+            config.set('Interface', 'dns', ','.join([wg['ipv4_gateway'], wg['ipv6_gateway']]))
+        config.add_section('Peer')
+
+        print(f'Creating files in: {output_dir}')
+        for relay in relays:
+            self.create_wg_config(config, relay, multihop_server)
+
+    def create_wg_config(self, config, relay, multihop_server=None) -> None:
+        output_dir = pathlib.Path(self.output_dir).expanduser()
+        hostname = relay['hostname']
+        if multihop_server:
+            server_name = multihop_server['hostname']
+            file_path = pathlib.Path.joinpath(output_dir, f'{hostname}-via-{server_name}.conf')
+        else:
+            file_path = pathlib.Path.joinpath(output_dir, f'{hostname}.conf')
+
+        file_path.touch(mode=0o600, exist_ok=True)
+
+        if multihop_server:
+            remote_server = multihop_server
+            remote_port = relay['multihop_port']
+        else:
+            remote_server = relay
+            remote_port = self.wg_relay_port
+
+        if self.wg_relay_ipv6:
+            wg_relay_address = remote_server['ipv6_addr_in']
+        else:
+            wg_relay_address = remote_server['ipv4_addr_in']
+
+        with file_path.open('w') as _file:
+            config.set('Peer', '#owned', str(relay['owned']))
+            config.set('Peer', '#provider', relay['provider'])
+            config.set('Peer', 'publickey', relay['pubkey'])
+            config.set('Peer', 'allowedips', '0.0.0.0/0,::/0')
+            config.set('Peer', 'endpoint', f'{wg_relay_address}:{remote_port}')
+            config.write(_file)
 
 
 class Mullvad:
     def __init__(self, args):
-        self._account_number = args.account_number
-        self._output_dir = args.output_dir
+        self.mullvad_api = MullvadApi(args.account_number)
+        self.mullvad_config = MullvadConfig(args.output_dir, args.wg_dns, args.wg_relay_port, args.wg_relay_ipv6)
+
         self._settings_file = args.settings_file
-        self._wg_relay_port = args.wg_relay_port
-        self._wg_relay_ipv6 = args.wg_relay_ipv6
-        self._wg_dns = args.wg_dns
         self._wg_hijack_dns = args.wg_hijack_dns
         self._wg_multihop_server = args.wg_multihop_server
-        self._webtoken = None
         self._filter = args.filter
 
         self._config = configparser.ConfigParser()
         self._settings_file = pathlib.Path(self._settings_file).expanduser()
 
     def run(self):
-        multihop_server = False
+        multihop_server = None
         if self._wg_multihop_server:
-            multihop_servers = self.filter(self.get_multihop_info(), self._wg_multihop_server)
+            multihop_servers = self.filter(MullvadApi.multihop_info(), self._wg_multihop_server)
             if len(multihop_servers) == 1:
                 multihop_server = multihop_servers[0]
             elif len(multihop_servers) >= 1:
@@ -70,16 +190,21 @@ class Mullvad:
                 print(f'No multihop-server matching hostname: {self._wg_multihop_server}')
                 sys.exit(1)
 
+        relays = self.filter(MullvadApi.multihop_info(), self._filter)
+        if not relays:
+            print(f'No relays matching filter: {self._filter}')
+            sys.exit(1)
+
         if self._settings_file.is_file():
-            privatekey = self.get_privatekey()
+            private_key = self.get_privatekey()
         else:
-            privatekey = generate_privatekey()
-            self.save_privatekey(privatekey)
+            private_key = generate_privatekey()
+            self.save_privatekey(private_key)
 
-        publickey = generate_publickey(privatekey)
-        device = self.get_device(publickey) or self.create_device(publickey)
+        public_key = generate_publickey(private_key)
+        device = self.get_device(public_key) or self.create_device(public_key)
         if device:
-            self.create_wg_configs(device, privatekey, multihop_server)
+            self.mullvad_config.create_wg_configs(relays, device, private_key, multihop_server)
 
     def get_privatekey(self) -> str:
         print(f'Reading settings from: {self._settings_file}')
@@ -101,166 +226,47 @@ class Mullvad:
             self._config.write(_file)
         return True
 
-    def get_webtoken(self) -> str:
-        if not self._webtoken:
-            self.generate_webtoken()
-        return self._webtoken
-
-    def generate_webtoken(self) -> str:
-        body = {
-            'account_number': self._account_number,
-        }
-        req = urllib.request.Request('https://api.mullvad.net/auth/v1/webtoken')
-        req.add_header('Content-Type', 'application/json')
-        with urllib.request.urlopen(req, json.dumps(body).encode()) as response:
-            data = json.load(response)
-        self._webtoken = data['access_token']
-
-    def api(self, url, body=None):
-        webtoken = self.get_webtoken()
-        req = urllib.request.Request(url)
-        req.add_header('Authorization', f'Bearer {webtoken}')
-        req.add_header('Accept-Encoding', 'gzip')
-
-        if body:
-            req.add_header('Content-Type', 'application/json')
-
-        with urllib.request.urlopen(req, data=json.dumps(body).encode() if body else None) as response:
-            return self.get_response(response)
-
-    def get_response(self, response):
-        if response.headers.get('Content-Encoding') == 'gzip':
-            data = json.loads(gzip.decompress(response.read()))
-        else:
-            data = json.load(response)
-        return data
-
     def get_device(self, publickey):
         print(f'Trying to find device: {publickey}')
         try:
-            for device in self.api('https://api.mullvad.net/accounts/v1/devices'):
+            for device in self.mullvad_api.list_devices():
                 if publickey == device['pubkey']:
-                    _name = device['name']
-                    _pubkey = device['pubkey']
-                    print(f'Device found: ({_name}) {_pubkey}')
+                    name = device['name']
+                    pubkey = device['pubkey']
+                    print(f'Device found: ({name}) {pubkey}')
                     return device
             print(f'Device is not registered: {publickey}')
             return None
         except urllib.error.HTTPError as e:
-            error_message = json.load(e)
-            _code = error_message.get('code')
-            _message = error_message.get('detail')
-            if _message:
-                print(_message)
-            if _code == 'INVALID_ACCOUNT':
-                print(f'Invalid account: {self._account_number}')
-            sys.exit(1)
+            self.handle_mullvad_api_error(e)
 
     def create_device(self, publickey):
         print(f'Trying to create device: {publickey}')
-        body = {
-            'pubkey': publickey,
-            'hijack_dns': self._wg_hijack_dns,
-        }
         try:
-            response = self.api('https://api.mullvad.net/accounts/v1/devices', body)
+            response = self.mullvad_api.new_device(publickey, self._wg_hijack_dns)
             print(f'Device created: ({response["name"]}) {response["pubkey"]}')
             return response
         except urllib.error.HTTPError as e:
-            error_message = json.load(e)
-            _code = error_message.get('code')
-            _message = error_message.get('detail')
-            if _message:
-                print(_message)
-            if _code == 'PUBKEY_IN_USE':
-                print(f'Error: Private key settings exits in {self._settings_file} but device has been removed')
-                print('Solution 1: Wait for grace period to pass before using this key (5 min)')
-                print('Solution 2: Remove setting file if you want to create a new device')
-            sys.exit(1)
-
-    def get_wireguard_info(self):
-        try:
-            req = urllib.request.Request('https://api.mullvad.net/public/relays/wireguard/v2/')
-            req.add_header('Accept-Encoding', 'gzip')
-            with urllib.request.urlopen(req) as response:
-                data = self.get_response(response)
-            return data['wireguard']
-        except urllib.error.HTTPError as e:
-            error_message = self.get_response(e)
-            print(error_message)
-            sys.exit(1)
-
-    def get_multihop_info(self):
-        try:
-            req = urllib.request.Request('https://api.mullvad.net/www/relays/all')
-            req.add_header('Accept-Encoding', 'gzip')
-            with urllib.request.urlopen(req) as response:
-                data = self.get_response(response)
-            return [i for i in data if i['type'] == 'wireguard']
-        except urllib.error.HTTPError as e:
-            error_message = self.get_response(e)
-            print(error_message)
-            sys.exit(1)
+            self.handle_mullvad_api_error(e)
 
-    def filter(self, data: list, _filter) -> list:
-        if _filter:
-            return [d for d in data if d['hostname'].startswith(_filter)]
+    def filter(self, data: list, prefix_filter) -> list:
+        if prefix_filter:
+            return [d for d in data if d['hostname'].startswith(prefix_filter)]
         return data
 
-    def create_wg_configs(self, device, privatekey, multihop_server) -> None:
-        wg = self.get_wireguard_info()
-        output_dir = pathlib.Path(self._output_dir).expanduser()
-        output_dir.mkdir(exist_ok=True, parents=True)
-        config = configparser.ConfigParser()
-        config.add_section('Interface')
-        config.set('Interface', '#device', device['name'])
-        config.set('Interface', 'privateKey', privatekey)
-        config.set('Interface', 'address',  ','.join([device['ipv4_address'], device['ipv6_address']]))
-        if self._wg_dns:
-            config.set('Interface', 'dns', ','.join([str(x) for x in self._wg_dns]))
-        else:
-            config.set('Interface', 'dns', ','.join([wg['ipv4_gateway'], wg['ipv6_gateway']]))
-        config.add_section('Peer')
-
-        relays = self.filter(self.get_multihop_info(), self._filter)
-        if not relays:
-            print(f'No relays matching filter: {self._filter}')
-            return
-
-        print(f'Creating files in: {output_dir}')
-        for relay in relays:
-            self.create_wg_config(config, relay, multihop_server)
-
-    def create_wg_config(self, config, relay, multihop_server=None) -> None:
-        output_dir = pathlib.Path(self._output_dir).expanduser()
-        _hostname = relay['hostname']
-        if multihop_server:
-            _servername = multihop_server['hostname']
-            _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}-via-{_servername}.conf')
-        else:
-            _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}.conf')
-
-        _filepath.touch(mode=0o600, exist_ok=True)
-
-        if multihop_server:
-            remote_server = multihop_server
-            remote_port = relay['multihop_port']
-        else:
-            remote_server = relay
-            remote_port = self._wg_relay_port
-
-        if self._wg_relay_ipv6:
-            wg_relay_address = remote_server['ipv6_addr_in']
-        else:
-            wg_relay_address = remote_server['ipv4_addr_in']
-
-        with _filepath.open('w') as _file:
-            config.set('Peer', '#owned', str(relay['owned']))
-            config.set('Peer', '#provider', relay['provider'])
-            config.set('Peer', 'publickey', relay['pubkey'])
-            config.set('Peer', 'allowedips', '0.0.0.0/0,::/0')
-            config.set('Peer', 'endpoint', f'{wg_relay_address}:{remote_port}')
-            config.write(_file)
+    def handle_mullvad_api_error(self, err):
+        error_message = MullvadApi.get_response(err)
+        error_code = error_message.get('code')
+        detail_message = error_message.get('detail')
+        if detail_message:
+            print(detail_message)
+        if error_code == 'PUBKEY_IN_USE':
+            print(f'Private key settings exits in {self._settings_file} but device has been removed')
+            print('Solution 1: Wait for grace period to pass before using this key (5 min)')
+            print('Solution 2: Remove setting file if you want to create a new device')
+        elif error_code == 'INVALID_ACCOUNT':
+            print(f'Invalid account number: {self.mullvad_api.account_number}')
+        sys.exit(1)
 
 
 def validate_account(value: str) -> str:
git clone https://git.99rst.org/PROJECT