General refactor/improvements
authorxliang <redacted>
Wed, 5 Mar 2025 07:33:23 +0000 (23:33 -0800)
committerkalken <redacted>
Thu, 6 Mar 2025 13:15:34 +0000 (14:15 +0100)
wg-mullvad.py

index a8917732254014d6ef3bf2588d4facedebc768ef..c7e04a82160a84a585001944ac0fa090cc706e60 100755 (executable)
@@ -16,24 +16,24 @@ from cryptography.hazmat.primitives import serialization
 _version = '1.1'
 
 
-def generate_publickey(privatekey):
+def generate_publickey(privatekey: str) -> str:
     private_key_bytes = base64.b64decode(privatekey)
     private_key = X25519PrivateKey.from_private_bytes(private_key_bytes)
     public_key = private_key.public_key()
     public_key_bytes = public_key.public_bytes(
         encoding=serialization.Encoding.Raw,
-        format=serialization.PublicFormat.Raw
+        format=serialization.PublicFormat.Raw,
     )
     wgpublickey = base64.b64encode(public_key_bytes).decode('utf-8')
     return wgpublickey
 
 
-def generate_privatekey():
+def generate_privatekey() -> str:
     privatekey = X25519PrivateKey.generate()
     private_key_bytes = privatekey.private_bytes(
         encoding=serialization.Encoding.Raw,
         format=serialization.PrivateFormat.Raw,
-        encryption_algorithm=serialization.NoEncryption()
+        encryption_algorithm=serialization.NoEncryption(),
     )
     wgprivatekey = base64.b64encode(private_key_bytes).decode('utf-8')
     return wgprivatekey
@@ -81,7 +81,7 @@ class Mullvad:
         if device:
             self.create_wg_configs(device, privatekey, multihop_server)
 
-    def get_privatekey(self):
+    def get_privatekey(self) -> str:
         print(f'Reading settings from: {self._settings_file}')
         self._config.read(self._settings_file)
         try:
@@ -91,7 +91,7 @@ class Mullvad:
             print('Solution: add it or remove the file completely to generate a new device')
             sys.exit(1)
 
-    def save_privatekey(self, privatekey):
+    def save_privatekey(self, privatekey) -> bool:
         self._settings_file.parent.mkdir(parents=True, exist_ok=True)
         self._settings_file.touch(mode=0o600, exist_ok=True)
         with self._settings_file.open('w') as _file:
@@ -101,18 +101,19 @@ class Mullvad:
             self._config.write(_file)
         return True
 
-    def get_webtoken(self):
+    def get_webtoken(self) -> str:
         if not self._webtoken:
             self.generate_webtoken()
         return self._webtoken
 
-    def generate_webtoken(self):
-        body = {}
-        body['account_number'] = self._account_number
+    def generate_webtoken(self) -> str:
+        body = {
+            'account_number': self._account_number,
+        }
         req = urllib.request.Request('https://api.mullvad.net/auth/v1/webtoken')
         req.add_header('Content-Type', 'application/json')
-        response = urllib.request.urlopen(req, json.dumps(body).encode()).read()
-        data = json.loads(response)
+        with urllib.request.urlopen(req, json.dumps(body).encode()) as response:
+            data = json.load(response)
         self._webtoken = data['access_token']
 
     def api(self, url, body=None):
@@ -123,17 +124,15 @@ class Mullvad:
 
         if body:
             req.add_header('Content-Type', 'application/json')
-            response = urllib.request.urlopen(req, json.dumps(body).encode())
-        else:
-            response = urllib.request.urlopen(req)
 
-        return self.get_response(response)
+        with urllib.request.urlopen(req, data=json.dumps(body).encode() if body else None) as response:
+            return self.get_response(response)
 
     def get_response(self, response):
         if response.headers.get('Content-Encoding') == 'gzip':
             data = json.loads(gzip.decompress(response.read()))
         else:
-            data = json.loads(response.read())
+            data = json.load(response)
         return data
 
     def get_device(self, publickey):
@@ -148,7 +147,7 @@ class Mullvad:
             print(f'Device is not registered: {publickey}')
             return None
         except urllib.error.HTTPError as e:
-            error_message = json.loads(e.read())
+            error_message = json.load(e)
             _code = error_message.get('code')
             _message = error_message.get('detail')
             if _message:
@@ -159,15 +158,16 @@ class Mullvad:
 
     def create_device(self, publickey):
         print(f'Trying to create device: {publickey}')
-        body = {}
-        body['pubkey'] = publickey
-        body['hijack_dns'] = self._wg_hijack_dns
+        body = {
+            'pubkey': publickey,
+            'hijack_dns': self._wg_hijack_dns,
+        }
         try:
             response = self.api('https://api.mullvad.net/accounts/v1/devices', body)
-            print(f'Device created: ({response['name']}) {response['pubkey']}')
+            print(f'Device created: ({response["name"]}) {response["pubkey"]}')
             return response
         except urllib.error.HTTPError as e:
-            error_message = json.loads(e.read())
+            error_message = json.load(e)
             _code = error_message.get('code')
             _message = error_message.get('detail')
             if _message:
@@ -182,8 +182,8 @@ class Mullvad:
         try:
             req = urllib.request.Request('https://api.mullvad.net/public/relays/wireguard/v2/')
             req.add_header('Accept-Encoding', 'gzip')
-            response = urllib.request.urlopen(req)
-            data = self.get_response(response)
+            with urllib.request.urlopen(req) as response:
+                data = self.get_response(response)
             return data['wireguard']
         except urllib.error.HTTPError as e:
             error_message = self.get_response(e)
@@ -194,20 +194,20 @@ class Mullvad:
         try:
             req = urllib.request.Request('https://api.mullvad.net/www/relays/all')
             req.add_header('Accept-Encoding', 'gzip')
-            response = urllib.request.urlopen(req)
-            data = self.get_response(response)
+            with urllib.request.urlopen(req) as response:
+                data = self.get_response(response)
             return [i for i in data if i['type'] == 'wireguard']
         except urllib.error.HTTPError as e:
             error_message = self.get_response(e)
             print(error_message)
             sys.exit(1)
 
-    def filter(self, data, _filter):
+    def filter(self, data: list, _filter) -> list:
         if _filter:
-            return [d for d in data if _filter in d['hostname']]
+            return [d for d in data if d['hostname'].startswith(_filter)]
         return data
 
-    def create_wg_configs(self, device, privatekey, multihop_server):
+    def create_wg_configs(self, device, privatekey, multihop_server) -> None:
         wg = self.get_wireguard_info()
         output_dir = pathlib.Path(self._output_dir).expanduser()
         output_dir.mkdir(exist_ok=True, parents=True)
@@ -222,65 +222,61 @@ class Mullvad:
             config.set('Interface', 'dns', ','.join([wg['ipv4_gateway'], wg['ipv6_gateway']]))
         config.add_section('Peer')
 
+        relays = self.filter(self.get_multihop_info(), self._filter)
+        if not relays:
+            print(f'No relays matching filter: {self._filter}')
+            return
+
+        print(f'Creating files in: {output_dir}')
+        for relay in relays:
+            self.create_wg_config(config, relay, multihop_server)
+
+    def create_wg_config(self, config, relay, multihop_server=None) -> None:
+        output_dir = pathlib.Path(self._output_dir).expanduser()
+        _hostname = relay['hostname']
         if multihop_server:
-            _servername = multihop_server["hostname"]
-            relays = self.filter(self.get_multihop_info(), self._filter)
-            if relays:
-                print(f'Creating files in: {output_dir}')
-                for relay in relays:
-                    _hostname = relay['hostname']
-                    _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}-via-{_servername}.conf')
-                    _filepath.touch(mode=0o600, exist_ok=True)
-                    with _filepath.open('w') as _file:
-                        config.set('Peer', '#owned', str(relay['owned']))
-                        config.set('Peer', '#provider', relay['provider'])
-                        config.set('Peer', 'publickey', relay['pubkey'])
-                        config.set('Peer', 'allowedips', '0.0.0.0/0,::/0')
-                        if self._wg_relay_ipv6:
-                            wg_relay_address = multihop_server['ipv6_addr_in']
-                        else:
-                            wg_relay_address = multihop_server['ipv4_addr_in']
-                        config.set('Peer', 'endpoint', f'{wg_relay_address}:{relay["multihop_port"]}')
-                        config.write(_file)
-            else:
-                print(f'No relays matching filter: {self._filter}')
+            _servername = multihop_server['hostname']
+            _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}-via-{_servername}.conf')
         else:
-            relays = self.filter(wg['relays'], self._filter)
-            if relays:
-                print(f'Creating files in: {output_dir}')
-                for relay in relays:
-                    _hostname = relay['hostname']
-                    _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}.conf')
-                    _filepath.touch(mode=0o600, exist_ok=True)
-                    with _filepath.open('w') as _file:
-                        config.set('Peer', '#owned', str(relay['owned']))
-                        config.set('Peer', '#provider', relay['provider'])
-                        config.set('Peer', 'publickey', relay['public_key'])
-                        config.set('Peer', 'allowedips', '0.0.0.0/0,::/0')
-                        if self._wg_relay_ipv6:
-                            wg_relay_address = relay['ipv6_addr_in']
-                        else:
-                            wg_relay_address = relay['ipv4_addr_in']
-                        config.set('Peer', 'endpoint', f'{wg_relay_address}:{self._wg_relay_port}')
-                        config.write(_file)
-            else:
-                print(f'No relays matching filter: {self._filter}')
+            _filepath = pathlib.Path.joinpath(output_dir, f'{_hostname}.conf')
+
+        _filepath.touch(mode=0o600, exist_ok=True)
+
+        if multihop_server:
+            remote_server = multihop_server
+            remote_port = relay['multihop_port']
+        else:
+            remote_server = relay
+            remote_port = self._wg_relay_port
+
+        if self._wg_relay_ipv6:
+            wg_relay_address = remote_server['ipv6_addr_in']
+        else:
+            wg_relay_address = remote_server['ipv4_addr_in']
+
+        with _filepath.open('w') as _file:
+            config.set('Peer', '#owned', str(relay['owned']))
+            config.set('Peer', '#provider', relay['provider'])
+            config.set('Peer', 'publickey', relay['pubkey'])
+            config.set('Peer', 'allowedips', '0.0.0.0/0,::/0')
+            config.set('Peer', 'endpoint', f'{wg_relay_address}:{remote_port}')
+            config.write(_file)
 
 
-def validate_account(value):
+def validate_account(value: str) -> str:
     if not value.isdigit():
-        raise argparse.ArgumentTypeError("The string must contain only numbers.")
+        raise argparse.ArgumentTypeError('The string must contain only numbers.')
     return value
 
 
-def validate_port(value):
+def validate_port(value: str) -> int:
     try:
         port = int(value)
-    except ValueError:
-        raise argparse.ArgumentTypeError(f"Port must be an integer, but got '{value}'.")
+    except ValueError as e:
+        raise argparse.ArgumentTypeError(f'Port must be an integer, but got \'{value}\'.') from e
 
     if port < 1 or port > 65535:
-        raise argparse.ArgumentTypeError("Port number must be between 1 and 65535.")
+        raise argparse.ArgumentTypeError('Port number must be between 1 and 65535.')
     return port
 
 
@@ -302,14 +298,14 @@ def main():
         default='~/.config/mullvad/wg0', help='directory to write settings')
     parser.add_argument(
         '--wg-relay-port', dest='wg_relay_port', action='store', type=validate_port,
-        default=51820, help='use custom port for relays in wireguard configs')
+        default=51820, help='use custom port for relays in WireGuard configs')
     parser.add_argument(
         '--dns', dest='wg_dns', action='store', nargs='+', type=ipaddress.ip_address,
-        help='use custom dns server in wireguard configs')
+        help='use custom dns server in WireGuard configs')
     parser.add_argument(
         '--hijack-dns', dest='wg_hijack_dns', help='activate hijack dns when creating device', action='store_true')
     parser.add_argument(
-        '--ipv6', dest='wg_relay_ipv6', help='use ipv6 address for relays in wireguard configs', action='store_true')
+        '--ipv6', dest='wg_relay_ipv6', help='use ipv6 address for relays in WireGuard configs', action='store_true')
     parser.add_argument(
         '--multihop-server', dest='wg_multihop_server', action='store', default=None, help='use multihop server')
     parser.add_argument(
git clone https://git.99rst.org/PROJECT