Refactor
authorAdam Dullage <redacted>
Sun, 4 Feb 2024 16:04:49 +0000 (16:04 +0000)
committerAdam Dullage <redacted>
Sun, 4 Feb 2024 16:04:49 +0000 (16:04 +0000)
20 files changed:
.gitignore
server/api_messages.py [new file with mode: 0644]
server/attachments/base.py [new file with mode: 0644]
server/attachments/file_system/__init__.py [new file with mode: 0644]
server/attachments/file_system/file_system.py [new file with mode: 0644]
server/auth.py [deleted file]
server/auth/base.py [new file with mode: 0644]
server/auth/local/__init__.py [new file with mode: 0644]
server/auth/local/local.py [new file with mode: 0644]
server/auth/models.py [new file with mode: 0644]
server/config.py [deleted file]
server/error_responses.py [deleted file]
server/global_config.py [new file with mode: 0644]
server/helpers.py
server/main.py
server/models.py [deleted file]
server/notes/base.py [new file with mode: 0644]
server/notes/file_system/__init__.py [new file with mode: 0644]
server/notes/file_system/file_system.py [moved from server/flatnotes.py with 51% similarity]
server/notes/models.py [new file with mode: 0644]

index b1e8c8bd1fba6d9bc3f22875455f45c7fee96cb5..d993309c27463ffffd5b9d2736260686c7fb54af 100644 (file)
@@ -247,4 +247,4 @@ dist
 # Custom
 .vscode/
 data/
-notes/
+/notes/
diff --git a/server/api_messages.py b/server/api_messages.py
new file mode 100644 (file)
index 0000000..bf1e300
--- /dev/null
@@ -0,0 +1,12 @@
+login_failed = "Invalid login details."
+note_exists = "Cannot create note. A note with the same title already exists."
+note_not_found = "The specified note cannot be found."
+invalid_note_title = "The specified note title contains invalid characters."
+attachment_exists = (
+    "Cannot create attachment. An attachment with the same filename already "
+    "exists."
+)
+attachment_not_found = "The specified attachment cannot be found."
+invalid_attachment_filename = (
+    "The specified filename contains invalid characters."
+)
diff --git a/server/attachments/base.py b/server/attachments/base.py
new file mode 100644 (file)
index 0000000..fd4864f
--- /dev/null
@@ -0,0 +1,16 @@
+from abc import ABC, abstractmethod
+
+from fastapi import UploadFile
+from fastapi.responses import FileResponse
+
+
+class BaseAttachments(ABC):
+    @abstractmethod
+    def create(self, file: UploadFile) -> None:
+        """Create a new attachment."""
+        pass
+
+    @abstractmethod
+    def get(self, filename: str) -> FileResponse:
+        """Get a specific attachment."""
+        pass
diff --git a/server/attachments/file_system/__init__.py b/server/attachments/file_system/__init__.py
new file mode 100644 (file)
index 0000000..502cf2c
--- /dev/null
@@ -0,0 +1 @@
+from .file_system import FileSystemAttachments  # noqa
diff --git a/server/attachments/file_system/file_system.py b/server/attachments/file_system/file_system.py
new file mode 100644 (file)
index 0000000..c57c99c
--- /dev/null
@@ -0,0 +1,37 @@
+import os
+import shutil
+
+from fastapi import UploadFile
+from fastapi.responses import FileResponse
+
+from helpers import get_env, is_valid_filename
+
+from ..base import BaseAttachments
+
+
+class FileSystemAttachments(BaseAttachments):
+    def __init__(self):
+        self.base_path = get_env("FLATNOTES_PATH", mandatory=True)
+        if not os.path.exists(self.base_path):
+            raise NotADirectoryError(
+                f"'{self.base_path}' is not a valid directory."
+            )
+        self.storage_path = os.path.join(self.base_path, "attachments")
+        os.makedirs(self.storage_path, exist_ok=True)
+
+    def create(self, file: UploadFile) -> None:
+        """Create a new attachment."""
+        is_valid_filename(file.filename)
+        filepath = os.path.join(self.storage_path, file.filename)
+        if os.path.exists(filepath):
+            raise FileExistsError(f"'{file.filename}' already exists.")
+        with open(filepath, "wb") as f:
+            shutil.copyfileobj(file.file, f)
+
+    def get(self, filename: str) -> FileResponse:
+        """Get a specific attachment."""
+        is_valid_filename(filename)
+        filepath = os.path.join(self.storage_path, filename)
+        if not os.path.isfile(filepath):
+            raise FileNotFoundError(f"'{filename}' not found.")
+        return FileResponse(filepath)
diff --git a/server/auth.py b/server/auth.py
deleted file mode 100644 (file)
index 2234e4e..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-from datetime import datetime, timedelta
-
-from fastapi import Depends, HTTPException, Request
-from fastapi.security import OAuth2PasswordBearer
-from jose import JWTError, jwt
-
-from config import AuthType, config
-
-JWT_ALGORITHM = "HS256"
-
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token", auto_error=False)
-
-
-def create_access_token(data: dict):
-    to_encode = data.copy()
-    expiry_datetime = datetime.utcnow() + timedelta(
-        days=config.session_expiry_days
-    )
-    to_encode.update({"exp": expiry_datetime})
-    encoded_jwt = jwt.encode(
-        to_encode, config.session_key, algorithm=JWT_ALGORITHM
-    )
-    return encoded_jwt
-
-
-def validate_token(request: Request, token: str = Depends(oauth2_scheme)):
-    # Skip authentication if auth_type is NONE
-    if config.auth_type == AuthType.NONE:
-        return
-    # If no token is found in the header, check the cookies
-    if token is None:
-        token = request.cookies.get("token")
-    # Validate the token
-    try:
-        if token is None:
-            raise ValueError
-        payload = jwt.decode(
-            token, config.session_key, algorithms=[JWT_ALGORITHM]
-        )
-        username = payload.get("sub")
-        if username is None or username.lower() != config.username.lower():
-            raise ValueError
-        return
-    except (JWTError, ValueError):
-        raise HTTPException(
-            status_code=401,
-            detail="Invalid authentication credentials",
-            headers={"WWW-Authenticate": "Bearer"},
-        )
-
-
-def no_auth():
-    return
diff --git a/server/auth/base.py b/server/auth/base.py
new file mode 100644 (file)
index 0000000..1f52b59
--- /dev/null
@@ -0,0 +1,14 @@
+from abc import ABC, abstractmethod
+from .models import Login, Token
+
+
+class BaseAuth(ABC):
+    @abstractmethod
+    def login(self, data: Login) -> Token:
+        """Login a user."""
+        pass
+
+    @abstractmethod
+    def authenticate(self, token: str) -> bool:
+        """Authenticate a user."""
+        pass
diff --git a/server/auth/local/__init__.py b/server/auth/local/__init__.py
new file mode 100644 (file)
index 0000000..08c1d4b
--- /dev/null
@@ -0,0 +1 @@
+from .local import LocalAuth  # noqa
diff --git a/server/auth/local/local.py b/server/auth/local/local.py
new file mode 100644 (file)
index 0000000..6c2eba4
--- /dev/null
@@ -0,0 +1,124 @@
+import secrets
+from base64 import b32encode
+from datetime import datetime, timedelta
+
+import pyotp
+from fastapi import Depends, HTTPException, Request
+from fastapi.security import OAuth2PasswordBearer
+from jose import JWTError, jwt
+from qrcode import QRCode
+
+from global_config import GlobalConfig
+from helpers import get_env
+
+from ..base import BaseAuth
+from ..models import Login, Token
+
+global_config = GlobalConfig()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token", auto_error=False)
+
+
+class LocalAuth(BaseAuth):
+    JWT_ALGORITHM = "HS256"
+
+    def __init__(self) -> None:
+        self.username = get_env("FLATNOTES_USERNAME", mandatory=True).lower()
+        self.password = get_env("FLATNOTES_PASSWORD", mandatory=True)
+        self.secret_key = get_env("FLATNOTES_SECRET_KEY", mandatory=True)
+        self.session_expiry_days = get_env(
+            "FLATNOTES_SESSION_EXPIRY_DAYS", default=30, cast_int=True
+        )
+
+        # TOTP
+        if global_config.auth_type == "totp":
+            self.is_totp_enabled = True
+            self.totp_key = get_env("FLATNOTES_TOTP_KEY", mandatory=True)
+            self.totp_key = b32encode(self.totp_key.encode("utf-8"))
+            self.totp = pyotp.TOTP(self.totp_key)
+            self.last_used_totp = None
+            self._display_totp_enrolment()
+
+    def login(self, data: Login) -> Token:
+        # Check Username
+        username_correct = secrets.compare_digest(
+            self.username.lower(), data.username.lower()
+        )
+
+        # Check Password & TOTP
+        expected_password = self.password
+        if self.is_totp_enabled:
+            current_totp = self.totp.now()
+            expected_password += current_totp
+        password_correct = secrets.compare_digest(
+            expected_password, data.password
+        )
+
+        # Raise error if incorrect
+        if not (
+            username_correct
+            and password_correct
+            # Prevent TOTP from being reused
+            and (
+                self.is_totp_enabled is False
+                or current_totp != self.last_used_totp
+            )
+        ):
+            raise ValueError("Incorrect login credentials.")
+        if self.is_totp_enabled:
+            self.last_used_totp = current_totp
+
+        # Create Token
+        access_token = self._create_access_token(data={"sub": self.username})
+        return Token(access_token=access_token)
+
+    def authenticate(
+        self, request: Request, token: str = Depends(oauth2_scheme)
+    ):
+        # If no token is found in the header, check the cookies
+        if token is None:
+            token = request.cookies.get("token")
+        # Validate the token
+        try:
+            self._validate_token(token)
+        except (JWTError, ValueError):
+            raise HTTPException(
+                status_code=401,
+                detail="Invalid authentication credentials",
+                headers={"WWW-Authenticate": "Bearer"},
+            )
+
+    def _validate_token(self, token: str) -> bool:
+        if token is None:
+            raise ValueError
+        payload = jwt.decode(
+            token, self.secret_key, algorithms=[self.JWT_ALGORITHM]
+        )
+        username = payload.get("sub")
+        if username is None or username.lower() != self.username:
+            raise ValueError
+
+    def _create_access_token(self, data: dict):
+        to_encode = data.copy()
+        expiry_datetime = datetime.utcnow() + timedelta(
+            days=self.session_expiry_days
+        )
+        to_encode.update({"exp": expiry_datetime})
+        encoded_jwt = jwt.encode(
+            to_encode, self.secret_key, algorithm=self.JWT_ALGORITHM
+        )
+        return encoded_jwt
+
+    def _display_totp_enrolment(self):
+        uri = self.totp.provisioning_uri(
+            issuer_name="flatnotes", name=self.username
+        )
+        qr = QRCode()
+        qr.add_data(uri)
+        print(
+            "\nScan this QR code with your TOTP app of choice",
+            "e.g. Authy or Google Authenticator:",
+        )
+        qr.print_ascii()
+        print(
+            f"Or manually enter this key: {self.totp.secret.decode('utf-8')}\n"
+        )
diff --git a/server/auth/models.py b/server/auth/models.py
new file mode 100644 (file)
index 0000000..79fc69c
--- /dev/null
@@ -0,0 +1,15 @@
+from pydantic import BaseModel, Field
+
+from helpers import CustomBaseModel
+
+
+class Login(CustomBaseModel):
+    username: str
+    password: str
+
+
+class Token(BaseModel):
+    # Note: OAuth requires keys to be snake_case so we use the standard
+    # BaseModel here
+    access_token: str
+    token_type: str = Field("bearer")
diff --git a/server/config.py b/server/config.py
deleted file mode 100644 (file)
index 4955205..0000000
+++ /dev/null
@@ -1,105 +0,0 @@
-import os
-import sys
-from base64 import b32encode
-from enum import Enum
-
-from logger import logger
-
-
-class AuthType(str, Enum):
-    NONE = "none"
-    READ_ONLY = "read_only"
-    PASSWORD = "password"
-    TOTP = "totp"
-
-
-class Config:
-    def __init__(self) -> None:
-        self.data_path = self.get_data_path()
-
-        self.auth_type = self.get_auth_type()
-
-        self.username = self.get_username()
-        self.password = self.get_password()
-
-        self.session_key = self.get_session_key()
-        self.session_expiry_days = self.get_session_expiry_days()
-
-        self.totp_key = self.get_totp_key()
-
-    @classmethod
-    def get_env(cls, key, mandatory=False, default=None, cast_int=False):
-        """Get an environment variable."""
-        value = os.environ.get(key)
-        if mandatory and not value:
-            logger.error(f"Environment variable {key} must be set.")
-            sys.exit(1)
-        if not mandatory and not value:
-            return default
-        if cast_int:
-            try:
-                value = int(value)
-            except (TypeError, ValueError):
-                logger.error(f"Invalid value '{value}' for {key}.")
-                sys.exit(1)
-        return value
-
-    def get_data_path(self):
-        return self.get_env("FLATNOTES_PATH", mandatory=True)
-
-    def get_auth_type(self):
-        key = "FLATNOTES_AUTH_TYPE"
-        auth_type = self.get_env(
-            key, mandatory=False, default=AuthType.PASSWORD.value
-        )
-        try:
-            auth_type = AuthType(auth_type.lower())
-        except ValueError:
-            logger.error(
-                f"Invalid value '{auth_type}' for {key}. "
-                + "Must be one of: "
-                + ", ".join([auth_type.value for auth_type in AuthType])
-                + "."
-            )
-            sys.exit(1)
-        return auth_type
-
-    def get_username(self):
-        return self.get_env(
-            "FLATNOTES_USERNAME",
-            mandatory=self.auth_type
-            not in [AuthType.NONE, AuthType.READ_ONLY],
-        )
-
-    def get_password(self):
-        return self.get_env(
-            "FLATNOTES_PASSWORD",
-            mandatory=self.auth_type
-            not in [AuthType.NONE, AuthType.READ_ONLY],
-        )
-
-    def get_session_key(self):
-        return self.get_env(
-            "FLATNOTES_SECRET_KEY",
-            mandatory=self.auth_type
-            not in [AuthType.NONE, AuthType.READ_ONLY],
-        )
-
-    def get_session_expiry_days(self):
-        return self.get_env(
-            "FLATNOTES_SESSION_EXPIRY_DAYS",
-            mandatory=False,
-            default=30,
-            cast_int=True,
-        )
-
-    def get_totp_key(self):
-        totp_key = self.get_env(
-            "FLATNOTES_TOTP_KEY", mandatory=self.auth_type == AuthType.TOTP
-        )
-        if totp_key:
-            totp_key = b32encode(totp_key.encode("utf-8"))
-        return totp_key
-
-
-config = Config()
diff --git a/server/error_responses.py b/server/error_responses.py
deleted file mode 100644 (file)
index 2e86a57..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-from fastapi.responses import JSONResponse
-
-filename_exists_response = JSONResponse(
-    content={"message": "The specified filename already exists."},
-    status_code=409,
-)
-
-invalid_filename_response = JSONResponse(
-    content={"message": "The specified filename contains invalid characters."},
-    status_code=400,
-)
-
-note_not_found_response = JSONResponse(
-    content={"message": "The specified note cannot be found."}, status_code=404
-)
diff --git a/server/global_config.py b/server/global_config.py
new file mode 100644 (file)
index 0000000..fa2429b
--- /dev/null
@@ -0,0 +1,59 @@
+import sys
+from enum import Enum
+
+from helpers import CustomBaseModel, get_env
+from logger import logger
+
+
+class GlobalConfig:
+    def __init__(self) -> None:
+        logger.debug("Loading global config...")
+        self.auth_type: AuthType = self._load_auth_type()
+
+    def load_auth(self):
+        if self.auth_type in (AuthType.NONE, AuthType.READ_ONLY):
+            return None
+        elif self.auth_type in (AuthType.PASSWORD, AuthType.TOTP):
+            from auth.local import LocalAuth
+
+            return LocalAuth()
+
+    def load_note_storage(self):
+        from notes.file_system import FileSystemNotes
+
+        return FileSystemNotes()
+
+    def load_attachment_storage(self):
+        from attachments.file_system import (
+            FileSystemAttachments,
+        )
+
+        return FileSystemAttachments()
+
+    def _load_auth_type(self):
+        key = "FLATNOTES_AUTH_TYPE"
+        auth_type = get_env(
+            key, mandatory=False, default=AuthType.PASSWORD.value
+        )
+        try:
+            auth_type = AuthType(auth_type.lower())
+        except ValueError:
+            logger.error(
+                f"Invalid value '{auth_type}' for {key}. "
+                + "Must be one of: "
+                + ", ".join([auth_type.value for auth_type in AuthType])
+                + "."
+            )
+            sys.exit(1)
+        return auth_type
+
+
+class AuthType(str, Enum):
+    NONE = "none"
+    READ_ONLY = "read_only"
+    PASSWORD = "password"
+    TOTP = "totp"
+
+
+class GlobalConfigResponseModel(CustomBaseModel):
+    auth_type: AuthType
index 06f18fb98a4e8fbfd759e38dc06dda7a2aa8841e..645f7390ca1f7479cc7155cf79dc7a9fd921b75c 100644 (file)
@@ -1,11 +1,9 @@
 import os
-import re
-import shutil
-from typing import List, Tuple
+import sys
 
+from pydantic import BaseModel
 
-def strip_ext(filename):
-    return os.path.splitext(filename)[0]
+from logger import logger
 
 
 def camel_case(snake_case_str: str) -> str:
@@ -14,27 +12,44 @@ def camel_case(snake_case_str: str) -> str:
     return parts[0] + "".join(part.title() for part in parts[1:])
 
 
-def empty_dir(path):
-    for item in os.listdir(path):
-        item_path = os.path.join(path, item)
-        if os.path.isfile(item_path):
-            os.remove(item_path)
-        elif os.path.isdir(item_path):
-            shutil.rmtree(item_path)
-
-
-def re_extract(pattern, string) -> Tuple[str, List[str]]:
-    """Similar to re.sub but returns a tuple of:
-
-    - `string` with matches removed
-    - list of matches"""
-    matches = []
-    text = re.sub(pattern, lambda tag: matches.append(tag.group()), string)
-    return (text, matches)
-
-
-def is_valid_filename(filename):
-    r"""Return False if the declared filename contains any of the following
-    characters: <>:"/\|?*"""
+def is_valid_filename(value):
+    """Raise ValueError if the declared string contains any of the following
+    characters: <>:"/\\|?*"""
     invalid_chars = r'<>:"/\|?*'
-    return not any(invalid_char in filename for invalid_char in invalid_chars)
+    if any(invalid_char in value for invalid_char in invalid_chars):
+        raise ValueError(
+            "title cannot include any of the following characters: "
+            + invalid_chars
+        )
+    return value
+
+
+def strip_whitespace(value):
+    """Return the declared string with leading and trailing whitespace
+    removed."""
+    return value.strip()
+
+
+def get_env(key, mandatory=False, default=None, cast_int=False):
+    """Get an environment variable. If `mandatory` is True and environment
+    variable isn't set, exit the program"""
+    value = os.environ.get(key)
+    if mandatory and not value:
+        logger.error(f"Environment variable {key} must be set.")
+        sys.exit(1)
+    if not mandatory and not value:
+        return default
+    if cast_int:
+        try:
+            value = int(value)
+        except (TypeError, ValueError):
+            logger.error(f"Invalid value '{value}' for {key}.")
+            sys.exit(1)
+    return value
+
+
+class CustomBaseModel(BaseModel):
+    class Config:
+        alias_generator = camel_case
+        populate_by_name = True
+        from_attributes = True
index 7e93c3eded323b29dfc33fc344aa232a603e17a7..99a8eff8c701824e1c6d45446ecb0db247120fb2 100644 (file)
@@ -1,99 +1,26 @@
-import os
-import secrets
-import shutil
-from typing import List, Literal, Union
+from typing import List, Literal
 
-import pyotp
 from fastapi import Depends, FastAPI, HTTPException, UploadFile
-from fastapi.responses import FileResponse, HTMLResponse
+from fastapi.responses import HTMLResponse
 from fastapi.staticfiles import StaticFiles
-from qrcode import QRCode
-
-from auth import create_access_token, no_auth, validate_token
-from config import AuthType, config
-from error_responses import (
-    filename_exists_response,
-    invalid_filename_response,
-    note_not_found_response,
-)
-from flatnotes import Flatnotes, InvalidTitleError, Note
-from helpers import is_valid_filename
-from models import (
-    ConfigModel,
-    LoginModel,
-    NoteContentResponseModel,
-    NotePatchModel,
-    NotePostModel,
-    NoteResponseModel,
-    SearchResultModel,
-    TokenModel,
-)
-
-ATTACHMENTS_DIR = os.path.join(config.data_path, "attachments")
-os.makedirs(ATTACHMENTS_DIR, exist_ok=True)
 
+import api_messages
+from attachments.base import BaseAttachments
+from auth.base import BaseAuth
+from auth.models import Login, Token
+from global_config import AuthType, GlobalConfig, GlobalConfigResponseModel
+from notes.base import BaseNotes
+from notes.models import Note, NoteCreate, NoteUpdate, SearchResult
+
+global_config = GlobalConfig()
+auth: BaseAuth = global_config.load_auth()
+note_storage: BaseNotes = global_config.load_note_storage()
+attachment_storage: BaseAttachments = global_config.load_attachment_storage()
+auth_deps = [Depends(auth.authenticate)] if auth else []
 app = FastAPI()
-flatnotes = Flatnotes(config.data_path)
-
-totp = (
-    pyotp.TOTP(config.totp_key) if config.auth_type == AuthType.TOTP else None
-)
-last_used_totp = None
-
-if config.auth_type in [AuthType.NONE, AuthType.READ_ONLY]:
-    authenticate = no_auth
-else:
-    authenticate = validate_token
-
-# Display TOTP QR code
-if config.auth_type == AuthType.TOTP:
-    uri = totp.provisioning_uri(issuer_name="flatnotes", name=config.username)
-    qr = QRCode()
-    qr.add_data(uri)
-    print(
-        "\nScan this QR code with your TOTP app of choice",
-        "e.g. Authy or Google Authenticator:",
-    )
-    qr.print_ascii()
-    print(f"Or manually enter this key: {totp.secret.decode('utf-8')}\n")
-
-if config.auth_type not in [AuthType.NONE, AuthType.READ_ONLY]:
-
-    @app.post("/api/token", response_model=TokenModel)
-    def token(data: LoginModel):
-        global last_used_totp
-
-        username_correct = secrets.compare_digest(
-            config.username.lower(), data.username.lower()
-        )
-
-        expected_password = config.password
-        if config.auth_type == AuthType.TOTP:
-            current_totp = totp.now()
-            expected_password += current_totp
-        password_correct = secrets.compare_digest(
-            expected_password, data.password
-        )
-
-        if not (
-            username_correct
-            and password_correct
-            # Prevent TOTP from being reused
-            and (
-                config.auth_type != AuthType.TOTP
-                or current_totp != last_used_totp
-            )
-        ):
-            raise HTTPException(
-                status_code=400, detail="Incorrect login credentials."
-            )
-
-        access_token = create_access_token(data={"sub": config.username})
-        if config.auth_type == AuthType.TOTP:
-            last_used_totp = current_totp
-        return TokenModel(access_token=access_token)
 
 
+# region UI
 @app.get("/", include_in_schema=False)
 @app.get("/login", include_in_schema=False)
 @app.get("/search", include_in_schema=False)
@@ -105,101 +32,113 @@ def root(title: str = ""):
     return HTMLResponse(content=html)
 
 
-if config.auth_type != AuthType.READ_ONLY:
+# endregion
 
-    @app.post(
-        "/api/notes",
-        dependencies=[Depends(authenticate)],
-        response_model=NoteContentResponseModel,
-    )
-    def post_note(data: NotePostModel):
-        """Create a new note."""
+
+# region Login
+if global_config.auth_type not in [AuthType.NONE, AuthType.READ_ONLY]:
+
+    @app.post("/api/token", response_model=Token)
+    def token(data: Login):
         try:
-            note = Note(flatnotes, data.title, new=True)
-            note.content = data.content
-            return NoteContentResponseModel.model_validate(note)
-        except InvalidTitleError:
-            return invalid_filename_response
-        except FileExistsError:
-            return filename_exists_response
+            return auth.login(data)
+        except ValueError:
+            raise HTTPException(
+                status_code=401, detail=api_messages.login_failed
+            )
 
 
+# endregion
+
+
+# region Notes
+# Get Note
 @app.get(
     "/api/notes/{title}",
-    dependencies=[Depends(authenticate)],
-    response_model=Union[NoteContentResponseModel, NoteResponseModel],
+    dependencies=auth_deps,
+    response_model=Note,
 )
-def get_note(
-    title: str,
-    include_content: bool = True,
-):
+def get_note(title: str):
     """Get a specific note."""
     try:
-        note = Note(flatnotes, title)
-        if include_content:
-            return NoteContentResponseModel.model_validate(note)
-        else:
-            return NoteResponseModel.model_validate(note)
-    except InvalidTitleError:
-        return invalid_filename_response
+        return note_storage.get(title)
+    except ValueError:
+        raise HTTPException(
+            status_code=400, detail=api_messages.invalid_note_title
+        )
     except FileNotFoundError:
-        return note_not_found_response
+        raise HTTPException(404, api_messages.note_not_found)
 
 
-if config.auth_type != AuthType.READ_ONLY:
+if global_config.auth_type != AuthType.READ_ONLY:
 
+    # Create Note
+    @app.post(
+        "/api/notes",
+        dependencies=auth_deps,
+        response_model=Note,
+    )
+    def post_note(note: NoteCreate):
+        """Create a new note."""
+        try:
+            return note_storage.create(note)
+        except ValueError:
+            raise HTTPException(
+                status_code=400,
+                detail=api_messages.invalid_note_title,
+            )
+        except FileExistsError:
+            raise HTTPException(
+                status_code=409, detail=api_messages.note_exists
+            )
+
+    # Update Note
     @app.patch(
         "/api/notes/{title}",
-        dependencies=[Depends(authenticate)],
-        response_model=NoteContentResponseModel,
+        dependencies=auth_deps,
+        response_model=Note,
     )
-    def patch_note(title: str, new_data: NotePatchModel):
+    def patch_note(title: str, data: NoteUpdate):
         try:
-            note = Note(flatnotes, title)
-            if new_data.new_title is not None:
-                note.title = new_data.new_title
-            if new_data.new_content is not None:
-                note.content = new_data.new_content
-            return NoteContentResponseModel.model_validate(note)
-        except InvalidTitleError:
-            return invalid_filename_response
+            return note_storage.update(title, data)
+        except ValueError:
+            raise HTTPException(
+                status_code=400,
+                detail=api_messages.invalid_note_title,
+            )
         except FileExistsError:
-            return filename_exists_response
+            raise HTTPException(
+                status_code=409, detail=api_messages.note_exists
+            )
         except FileNotFoundError:
-            return note_not_found_response
-
-
-if config.auth_type != AuthType.READ_ONLY:
+            raise HTTPException(404, api_messages.note_not_found)
 
+    # Delete Note
     @app.delete(
         "/api/notes/{title}",
-        dependencies=[Depends(authenticate)],
+        dependencies=auth_deps,
         response_model=None,
     )
     def delete_note(title: str):
         try:
-            note = Note(flatnotes, title)
-            note.delete()
-        except InvalidTitleError:
-            return invalid_filename_response
+            note_storage.delete(title)
+        except ValueError:
+            raise HTTPException(
+                status_code=400,
+                detail=api_messages.invalid_note_title,
+            )
         except FileNotFoundError:
-            return note_not_found_response
+            raise HTTPException(404, api_messages.note_not_found)
 
 
-@app.get(
-    "/api/tags",
-    dependencies=[Depends(authenticate)],
-    response_model=List[str],
-)
-def get_tags():
-    """Get a list of all indexed tags."""
-    return flatnotes.get_tags()
+# endregion
 
 
+# region Search
 @app.get(
     "/api/search",
-    dependencies=[Depends(authenticate)],
-    response_model=List[SearchResultModel],
+    dependencies=auth_deps,
+    response_model=List[SearchResult],
 )
 def search(
     term: str,
@@ -210,51 +149,75 @@ def search(
     """Perform a full text search on all notes."""
     if sort == "lastModified":
         sort = "last_modified"
-    return [
-        SearchResultModel.model_validate(note_hit)
-        for note_hit in flatnotes.search(
-            term, sort=sort, order=order, limit=limit
-        )
-    ]
+    return note_storage.search(term, sort=sort, order=order, limit=limit)
 
 
-@app.get("/api/config", response_model=ConfigModel)
+@app.get(
+    "/api/tags",
+    dependencies=auth_deps,
+    response_model=List[str],
+)
+def get_tags():
+    """Get a list of all indexed tags."""
+    return note_storage.get_tags()
+
+
+# endregion
+
+
+# region Config
+@app.get("/api/config", response_model=GlobalConfigResponseModel)
 def get_config():
     """Retrieve server-side config required for the UI."""
-    return ConfigModel.model_validate(config)
+    return GlobalConfigResponseModel(auth_type=global_config.auth_type)
 
 
-if config.auth_type != AuthType.READ_ONLY:
-
-    @app.post(
-        "/api/attachments",
-        dependencies=[Depends(authenticate)],
-        response_model=None,
-    )
-    def post_attachment(file: UploadFile):
-        """Upload an attachment."""
-        if not is_valid_filename(file.filename):
-            return invalid_filename_response
-        filepath = os.path.join(ATTACHMENTS_DIR, file.filename)
-        if os.path.exists(filepath):
-            return filename_exists_response
-        with open(filepath, "wb") as f:
-            shutil.copyfileobj(file.file, f)
+# endregion
 
 
+# region Attachments
+# Get Attachment
 @app.get(
     "/attachments/{filename}",
-    dependencies=[Depends(authenticate)],
+    dependencies=auth_deps,
     include_in_schema=False,
 )
 def get_attachment(filename: str):
     """Download an attachment."""
-    if not is_valid_filename(filename):
-        raise HTTPException(status_code=400, detail="Invalid filename.")
-    filepath = os.path.join(ATTACHMENTS_DIR, filename)
-    if not os.path.isfile(filepath):
-        raise HTTPException(status_code=404, detail="File not found.")
-    return FileResponse(filepath)
+    try:
+        return attachment_storage.get(filename)
+    except ValueError:
+        raise HTTPException(
+            status_code=400,
+            detail=api_messages.invalid_attachment_filename,
+        )
+    except FileNotFoundError:
+        raise HTTPException(
+            status_code=404, detail=api_messages.attachment_not_found
+        )
+
+
+if global_config.auth_type != AuthType.READ_ONLY:
+
+    # Create Attachment
+    @app.post(
+        "/api/attachments",
+        dependencies=auth_deps,
+        response_model=None,
+    )
+    def post_attachment(file: UploadFile):
+        """Upload an attachment."""
+        try:
+            return attachment_storage.create(file)
+        except ValueError:
+            raise HTTPException(
+                status_code=400,
+                detail=api_messages.invalid_attachment_filename,
+            )
+        except FileExistsError:
+            raise HTTPException(409, api_messages.attachment_exists)
+
 
+# endregion
 
 app.mount("/", StaticFiles(directory="client/dist"), name="dist")
diff --git a/server/models.py b/server/models.py
deleted file mode 100644 (file)
index b87d8bd..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import List, Optional
-
-from pydantic import BaseModel, Field
-
-from config import AuthType
-from helpers import camel_case
-
-
-class TokenModel(BaseModel):
-    # Use of BaseModel instead of CustomBaseModel is intentional as OAuth
-    # requires keys to be snake_case
-    access_token: str
-    token_type: str = Field("bearer")
-
-
-class CustomBaseModel(BaseModel):
-    class Config:
-        alias_generator = camel_case
-        populate_by_name = True
-        from_attributes = True
-
-
-class LoginModel(CustomBaseModel):
-    username: str
-    password: str
-
-
-class NotePostModel(CustomBaseModel):
-    title: str
-    content: Optional[str] = Field(None)
-
-
-class NoteResponseModel(CustomBaseModel):
-    title: str
-    last_modified: float
-
-
-class NoteContentResponseModel(NoteResponseModel):
-    content: Optional[str] = Field(None)
-
-
-class NotePatchModel(CustomBaseModel):
-    new_title: Optional[str] = Field(None)
-    new_content: Optional[str] = Field(None)
-
-
-class SearchResultModel(CustomBaseModel):
-    score: Optional[float] = Field(None)
-    title: str
-    last_modified: float
-    title_highlights: Optional[str] = Field(None)
-    content_highlights: Optional[str] = Field(None)
-    tag_matches: Optional[List[str]] = Field(None)
-
-
-class ConfigModel(CustomBaseModel):
-    auth_type: AuthType
diff --git a/server/notes/base.py b/server/notes/base.py
new file mode 100644 (file)
index 0000000..ad97398
--- /dev/null
@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+from typing import Literal
+
+from .models import Note, NoteCreate, NoteUpdate, SearchResult
+
+
+class BaseNotes(ABC):
+    @abstractmethod
+    def create(self, data: NoteCreate) -> Note:
+        """Create a new note."""
+        pass
+
+    @abstractmethod
+    def get(self, title: str) -> Note:
+        """Get a specific note."""
+        pass
+
+    @abstractmethod
+    def update(self, title: str, new_data: NoteUpdate) -> Note:
+        """Update a specific note."""
+        pass
+
+    @abstractmethod
+    def delete(self, title: str) -> None:
+        """Delete a specific note.""" ""
+        pass
+
+    @abstractmethod
+    def search(
+        self,
+        term: str,
+        sort: Literal["score", "title", "last_modified"] = "score",
+        order: Literal["asc", "desc"] = "desc",
+        limit: int = None,
+    ) -> list[SearchResult]:
+        """Search for notes."""
+        pass
+
+    @abstractmethod
+    def get_tags(self) -> list[str]:
+        """Get a list of all indexed tags."""
+        pass
diff --git a/server/notes/file_system/__init__.py b/server/notes/file_system/__init__.py
new file mode 100644 (file)
index 0000000..2815b67
--- /dev/null
@@ -0,0 +1 @@
+from .file_system import FileSystemNotes
similarity index 51%
rename from server/flatnotes.py
rename to server/notes/file_system/file_system.py
index 526ad325b775d2c27856407c1bb8a56476980159..0533120ef1d09f1bb475fc03a75133f14932c47a 100644 (file)
@@ -1,6 +1,7 @@
 import glob
 import os
 import re
+import shutil
 from datetime import datetime
 from typing import List, Literal, Set, Tuple
 
@@ -16,9 +17,12 @@ from whoosh.query import Every
 from whoosh.searching import Hit
 from whoosh.support.charset import accent_map
 
-from helpers import empty_dir, is_valid_filename, re_extract, strip_ext
+from helpers import get_env, is_valid_filename
 from logger import logger
 
+from ..base import BaseNotes
+from ..models import Note, NoteCreate, NoteUpdate, SearchResult
+
 MARKDOWN_EXT = ".md"
 INDEX_SCHEMA_VERSION = "4"
 
@@ -35,174 +39,159 @@ class IndexSchema(SchemaClass):
     tags = KEYWORD(lowercase=True, field_boost=2.0)
 
 
-class InvalidTitleError(Exception):
-    def __init__(self, message="The specified title is invalid"):
-        self.message = message
-        super().__init__(self.message)
-
-
-class Note:
-    def __init__(
-        self, flatnotes: "Flatnotes", title: str, new: bool = False
-    ) -> None:
-        self._flatnotes = flatnotes
-        self._title = title.strip()
-        if not is_valid_filename(self._title):
-            raise InvalidTitleError
-        exists = os.path.exists(self.filepath)
-        if new and exists:
-            raise FileExistsError
-        if not new and not exists:
-            raise FileNotFoundError
-        if new:
-            open(self.filepath, "w").close()
-
-    @property
-    def filepath(self):
-        return os.path.join(self._flatnotes.dir, self.filename)
-
-    @property
-    def filename(self):
-        return self._title + MARKDOWN_EXT
-
-    @property
-    def last_modified(self):
-        return os.path.getmtime(self.filepath)
+class FileSystemNotes(BaseNotes):
+    TAGS_RE = re.compile(r"(?:(?<=^#)|(?<=\s#))\w+(?=\s|$)")
+    CODEBLOCK_RE = re.compile(r"`{1,3}.*?`{1,3}", re.DOTALL)
+    TAGS_WITH_HASH_RE = re.compile(r"(?:(?<=^)|(?<=\s))#\w+(?=\s|$)")
 
-    # Editable Properties
-    @property
-    def title(self):
-        return self._title
-
-    @title.setter
-    def title(self, new_title):
-        new_title = new_title.strip()
-        if not is_valid_filename(new_title):
-            raise InvalidTitleError
-        new_filepath = os.path.join(
-            self._flatnotes.dir, new_title + MARKDOWN_EXT
+    def __init__(self):
+        self.storage_path = get_env("FLATNOTES_PATH", mandatory=True)
+        if not os.path.exists(self.storage_path):
+            raise NotADirectoryError(
+                f"'{self.storage_path}' is not a valid directory."
+            )
+        self.index = self._load_index()
+        self._sync_index()
+
+    def create(self, data: NoteCreate) -> Note:
+        """Create a new note."""
+        filepath = self._path_from_title(data.title)
+        self._write_file(filepath, data.content)
+        return Note(
+            title=data.title,
+            content=data.content,
+            last_modified=os.path.getmtime(filepath),
         )
-        os.rename(self.filepath, new_filepath)
-        self._title = new_title
-
-    @property
-    def content(self):
-        with open(self.filepath, "r", encoding="utf-8") as f:
-            return f.read()
-
-    @content.setter
-    def content(self, new_content):
-        if not os.path.exists(self.filepath):
-            raise FileNotFoundError
-        with open(self.filepath, "w", encoding="utf-8") as f:
-            f.write(new_content)
-
-    def delete(self):
-        os.remove(self.filepath)
-
-
-class SearchResult(Note):
-    def __init__(self, flatnotes: "Flatnotes", hit: Hit) -> None:
-        super().__init__(flatnotes, strip_ext(hit["filename"]))
 
-        self._matched_fields = self._get_matched_fields(hit.matched_terms())
-        # If the search was ordered using a text field then hit.score is the
-        # value of that field. This isn't useful so only set self._score if it
-        # is a float.
-        self._score = hit.score if type(hit.score) is float else None
-
-        if "title" in self._matched_fields:
-            hit.results.fragmenter = WholeFragmenter()
-            self._title_highlights = hit.highlights("title", text=self.title)
-        else:
-            self._title_highlights = None
+    def get(self, title: str) -> Note:
+        """Get a specific note."""
+        is_valid_filename(title)
+        filepath = self._path_from_title(title)
+        content = self._read_file(filepath)
+        return Note(
+            title=title,
+            content=content,
+            last_modified=os.path.getmtime(filepath),
+        )
 
-        if "content" in self._matched_fields:
-            hit.results.fragmenter = ContextFragmenter()
-            content_ex_tags, _ = Flatnotes.extract_tags(self.content)
-            self._content_highlights = hit.highlights(
-                "content",
-                text=content_ex_tags,
-            )
+    def update(self, title: str, data: NoteUpdate) -> Note:
+        """Update a specific note."""
+        is_valid_filename(title)
+        filepath = self._path_from_title(title)
+        if data.new_title is not None:
+            new_filepath = self._path_from_title(data.new_title)
+            os.rename(filepath, new_filepath)
+            title = data.new_title
+            filepath = new_filepath
+        if data.new_content is not None:
+            self._write_file(filepath, data.new_content, overwrite=True)
+            content = data.new_content
         else:
-            self._content_highlights = None
-
-        self._tag_matches = (
-            [field[1] for field in hit.matched_terms() if field[0] == "tags"]
-            if "tags" in self._matched_fields
-            else None
+            content = self._read_file(filepath)
+        return Note(
+            title=title,
+            content=content,
+            last_modified=os.path.getmtime(filepath),
         )
 
-    @property
-    def score(self):
-        return self._score
-
-    @property
-    def title_highlights(self):
-        return self._title_highlights
+    def delete(self, title: str) -> None:
+        """Delete a specific note."""
+        is_valid_filename(title)
+        filepath = self._path_from_title(title)
+        os.remove(filepath)
 
-    @property
-    def content_highlights(self):
-        return self._content_highlights
+    def search(
+        self,
+        term: str,
+        sort: Literal["score", "title", "last_modified"] = "score",
+        order: Literal["asc", "desc"] = "desc",
+        limit: int = None,
+    ) -> Tuple[SearchResult, ...]:
+        """Search the index for the given term."""
+        self._sync_index()
+        term = self._pre_process_search_term(term)
+        with self.index.searcher() as searcher:
+            # Parse Query
+            if term == "*":
+                query = Every()
+            else:
+                parser = MultifieldParser(
+                    ["title", "content", "tags"], self.index.schema
+                )
+                parser.add_plugin(DateParserPlugin())
+                query = parser.parse(term)
 
-    @property
-    def tag_matches(self):
-        return self._tag_matches
+            # Determine Sort By
+            # Note: For the 'sort' option, "score" is converted to None as
+            # that is the default for searches anyway and it's quicker for
+            # Whoosh if you specify None.
+            sort = sort if sort in ["title", "last_modified"] else None
 
-    @staticmethod
-    def _get_matched_fields(matched_terms):
-        """Return a set of matched fields from a set of ('field', 'term') "
-        "tuples generated by whoosh.searching.Hit.matched_terms()."""
-        return set([matched_term[0] for matched_term in matched_terms])
+            # Determine Sort Direction
+            # Note: Confusingly, when sorting by 'score', reverse = True means
+            # asc so we have to flip the logic for that case!
+            reverse = order == "desc"
+            if sort is None:
+                reverse = not reverse
 
+            # Run Search
+            results = searcher.search(
+                query,
+                sortedby=sort,
+                reverse=reverse,
+                limit=limit,
+                terms=True,
+            )
+            return tuple(self._search_result_from_hit(hit) for hit in results)
 
-class Flatnotes(object):
-    TAGS_RE = re.compile(r"(?:(?<=^#)|(?<=\s#))\w+(?=\s|$)")
-    CODEBLOCK_RE = re.compile(r"`{1,3}.*?`{1,3}", re.DOTALL)
-    TAGS_WITH_HASH_RE = re.compile(r"(?:(?<=^)|(?<=\s))#\w+(?=\s|$)")
+    def get_tags(self) -> list[str]:
+        """Return a list of all indexed tags."""
+        self._sync_index()
+        with self.index.reader() as reader:
+            tags = reader.field_terms("tags")
+            return [tag for tag in tags]
 
-    def __init__(self, dir: str) -> None:
-        if not os.path.exists(dir):
-            raise NotADirectoryError(f"'{dir}' is not a valid directory.")
-        self.dir = dir
+    @property
+    def _index_path(self):
+        return os.path.join(self.storage_path, ".flatnotes")
 
-        self.index = self._load_index()
-        self.update_index()
+    def _path_from_title(self, title: str) -> str:
+        return os.path.join(self.storage_path, title + MARKDOWN_EXT)
 
-    @property
-    def index_dir(self):
-        return os.path.join(self.dir, ".flatnotes")
+    def _get_by_filename(self, filename: str) -> Note:
+        """Get a note by its filename."""
+        return self.get(self._strip_ext(filename))
 
     def _load_index(self) -> Index:
         """Load the note index or create new if not exists."""
-        index_dir_exists = os.path.exists(self.index_dir)
+        index_dir_exists = os.path.exists(self._index_path)
         if index_dir_exists and whoosh.index.exists_in(
-            self.index_dir, indexname=INDEX_SCHEMA_VERSION
+            self._index_path, indexname=INDEX_SCHEMA_VERSION
         ):
             logger.info("Loading existing index")
             return whoosh.index.open_dir(
-                self.index_dir, indexname=INDEX_SCHEMA_VERSION
+                self._index_path, indexname=INDEX_SCHEMA_VERSION
             )
         else:
             if index_dir_exists:
                 logger.info("Deleting outdated index")
-                empty_dir(self.index_dir)
+                self._clear_dir(self._index_path)
             else:
-                os.mkdir(self.index_dir)
+                os.mkdir(self._index_path)
             logger.info("Creating new index")
             return whoosh.index.create_in(
-                self.index_dir, IndexSchema, indexname=INDEX_SCHEMA_VERSION
+                self._index_path, IndexSchema, indexname=INDEX_SCHEMA_VERSION
             )
 
     @classmethod
-    def extract_tags(cls, content) -> Tuple[str, Set[str]]:
+    def _extract_tags(cls, content) -> Tuple[str, Set[str]]:
         """Strip tags from the given content and return a tuple consisting of:
 
         - The content without the tags.
         - A set of tags converted to lowercase."""
         content_ex_codeblock = re.sub(cls.CODEBLOCK_RE, "", content)
-        _, tags = re_extract(cls.TAGS_RE, content_ex_codeblock)
-        content_ex_tags, _ = re_extract(cls.TAGS_RE, content)
+        _, tags = cls._re_extract(cls.TAGS_RE, content_ex_codeblock)
+        content_ex_tags, _ = cls._re_extract(cls.TAGS_RE, content)
         try:
             tags = [tag.lower() for tag in tags]
             return (content_ex_tags, set(tags))
@@ -215,27 +204,26 @@ class Flatnotes(object):
         """Add a Note object to the index using the given writer. If the
         filename already exists in the index an update will be performed
         instead."""
-        content_ex_tags, tag_set = self.extract_tags(note.content)
+        content_ex_tags, tag_set = self._extract_tags(note.content)
         tag_string = " ".join(tag_set)
         writer.update_document(
-            filename=note.filename,
+            filename=note.title + MARKDOWN_EXT,
             last_modified=datetime.fromtimestamp(note.last_modified),
             title=note.title,
             content=content_ex_tags,
             tags=tag_string,
         )
 
-    def _get_notes(self) -> List[Note]:
-        """Return a list containing a Note object for every file in the notes
-        directory."""
+    def _list_all_note_filenames(self) -> List[str]:
+        """Return a list of all note filenames."""
         return [
-            Note(self, strip_ext(os.path.split(filepath)[1]))
+            os.path.split(filepath)[1]
             for filepath in glob.glob(
-                os.path.join(self.dir, "*" + MARKDOWN_EXT)
+                os.path.join(self.storage_path, "*" + MARKDOWN_EXT)
             )
         ]
 
-    def update_index(self, clean: bool = False) -> None:
+    def _sync_index(self, clean: bool = False) -> None:
         """Synchronize the index with the notes directory.
         Specify clean=True to completely rebuild the index"""
         indexed = set()
@@ -245,7 +233,7 @@ class Flatnotes(object):
         with self.index.searcher() as searcher:
             for idx_note in searcher.all_stored_fields():
                 idx_filename = idx_note["filename"]
-                idx_filepath = os.path.join(self.dir, idx_filename)
+                idx_filepath = os.path.join(self.storage_path, idx_filename)
                 # Delete missing
                 if not os.path.exists(idx_filepath):
                     writer.delete_by_term("filename", idx_filename)
@@ -257,76 +245,115 @@ class Flatnotes(object):
                 ):
                     logger.info(f"'{idx_filename}' updated")
                     self._add_note_to_index(
-                        writer, Note(self, strip_ext(idx_filename))
+                        writer, self._get_by_filename(idx_filename)
                     )
                     indexed.add(idx_filename)
                 # Ignore already indexed
                 else:
                     indexed.add(idx_filename)
         # Add new
-        for note in self._get_notes():
-            if note.filename not in indexed:
-                self._add_note_to_index(writer, note)
-                logger.info(f"'{note.filename}' added to index")
+        for filename in self._list_all_note_filenames():
+            if filename not in indexed:
+                self._add_note_to_index(
+                    writer, self._get_by_filename(filename)
+                )
+                logger.info(f"'{filename}' added to index")
         writer.commit()
 
-    def get_tags(self):
-        """Return a list of all indexed tags."""
-        self.update_index()
-        with self.index.reader() as reader:
-            tags = reader.field_terms("tags")
-            return [tag for tag in tags]
-
-    def pre_process_search_term(self, term):
+    @classmethod
+    def _pre_process_search_term(cls, term):
         term = term.strip()
         # Replace "#tagname" with "tags:tagname"
         term = re.sub(
-            self.TAGS_WITH_HASH_RE,
+            cls.TAGS_WITH_HASH_RE,
             lambda tag: "tags:" + tag.group(0)[1:],
             term,
         )
         return term
 
-    def search(
-        self,
-        term: str,
-        sort: Literal["score", "title", "last_modified"] = "score",
-        order: Literal["asc", "desc"] = "desc",
-        limit: int = None,
-    ) -> Tuple[SearchResult, ...]:
-        """Search the index for the given term."""
-        self.update_index()
-        term = self.pre_process_search_term(term)
-        with self.index.searcher() as searcher:
-            # Parse Query
-            if term == "*":
-                query = Every()
-            else:
-                parser = MultifieldParser(
-                    ["title", "content", "tags"], self.index.schema
-                )
-                parser.add_plugin(DateParserPlugin())
-                query = parser.parse(term)
+    @staticmethod
+    def _re_extract(pattern, string) -> Tuple[str, List[str]]:
+        """Similar to re.sub but returns a tuple of:
 
-            # Determine Sort By
-            # Note: For the 'sort' option, "score" is converted to None as
-            # that is the default for searches anyway and it's quicker for
-            # Whoosh if you specify None.
-            sort = sort if sort in ["title", "last_modified"] else None
+        - `string` with matches removed
+        - list of matches"""
+        matches = []
+        text = re.sub(pattern, lambda tag: matches.append(tag.group()), string)
+        return (text, matches)
 
-            # Determine Sort Direction
-            # Note: Confusingly, when sorting by 'score', reverse = True means
-            # asc so we have to flip the logic for that case!
-            reverse = order == "desc"
-            if sort is None:
-                reverse = not reverse
+    @staticmethod
+    def _strip_ext(filename):
+        """Return the given filename without the extension."""
+        return os.path.splitext(filename)[0]
 
-            # Run Search
-            results = searcher.search(
-                query,
-                sortedby=sort,
-                reverse=reverse,
-                limit=limit,
-                terms=True,
+    @staticmethod
+    def _clear_dir(path):
+        """Delete all contents of the given directory."""
+        for item in os.listdir(path):
+            item_path = os.path.join(path, item)
+            if os.path.isfile(item_path):
+                os.remove(item_path)
+            elif os.path.isdir(item_path):
+                shutil.rmtree(item_path)
+
+    def _search_result_from_hit(self, hit: Hit):
+        matched_fields = self._get_matched_fields(hit.matched_terms())
+
+        title = self._strip_ext(hit["filename"])
+        last_modified = hit["last_modified"].timestamp()
+
+        # If the search was ordered using a text field then hit.score is the
+        # value of that field. This isn't useful so only set self._score if it
+        # is a float.
+        score = hit.score if type(hit.score) is float else None
+
+        if "title" in matched_fields:
+            hit.results.fragmenter = WholeFragmenter()
+            title_highlights = hit.highlights("title", text=title)
+        else:
+            title_highlights = None
+
+        if "content" in matched_fields:
+            hit.results.fragmenter = ContextFragmenter()
+            content = self._read_file(self._path_from_title(title))
+            content_ex_tags, _ = FileSystemNotes._extract_tags(content)
+            content_highlights = hit.highlights(
+                "content",
+                text=content_ex_tags,
             )
-            return tuple(SearchResult(self, hit) for hit in results)
+        else:
+            content_highlights = None
+
+        tag_matches = (
+            [field[1] for field in hit.matched_terms() if field[0] == "tags"]
+            if "tags" in matched_fields
+            else None
+        )
+
+        return SearchResult(
+            title=title,
+            last_modified=last_modified,
+            score=score,
+            title_highlights=title_highlights,
+            content_highlights=content_highlights,
+            tag_matches=tag_matches,
+        )
+
+    @staticmethod
+    def _get_matched_fields(matched_terms):
+        """Return a set of matched fields from a set of ('field', 'term') "
+        "tuples generated by whoosh.searching.Hit.matched_terms()."""
+        return set([matched_term[0] for matched_term in matched_terms])
+
+    @staticmethod
+    def _read_file(filepath: str):
+        logger.debug(f"Reading from '{filepath}'")
+        with open(filepath, "r") as f:
+            content = f.read()
+        return content
+
+    @staticmethod
+    def _write_file(filepath: str, content: str, overwrite: bool = False):
+        logger.debug(f"Writing to '{filepath}'")
+        with open(filepath, "w" if overwrite else "x") as f:
+            f.write(content)
diff --git a/server/notes/models.py b/server/notes/models.py
new file mode 100644 (file)
index 0000000..ba16708
--- /dev/null
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from pydantic import Field
+from pydantic.functional_validators import AfterValidator
+from typing_extensions import Annotated
+
+from helpers import CustomBaseModel, is_valid_filename, strip_whitespace
+
+
+class NoteBase(CustomBaseModel):
+    title: str
+
+
+class NoteCreate(CustomBaseModel):
+    title: Annotated[
+        str,
+        AfterValidator(strip_whitespace),
+        AfterValidator(is_valid_filename),
+    ]
+    content: Optional[str] = Field(None)
+
+
+class Note(CustomBaseModel):
+    title: str
+    content: Optional[str] = Field(None)
+    last_modified: float
+
+
+class NoteUpdate(CustomBaseModel):
+    new_title: Annotated[
+        Optional[str],
+        AfterValidator(strip_whitespace),
+        AfterValidator(is_valid_filename),
+    ] = Field(None)
+    new_content: Optional[str] = Field(None)
+
+
+class SearchResult(CustomBaseModel):
+    title: str
+    last_modified: float
+
+    score: Optional[float] = Field(None)
+    title_highlights: Optional[str] = Field(None)
+    content_highlights: Optional[str] = Field(None)
+    tag_matches: Optional[List[str]] = Field(None)
git clone https://git.99rst.org/PROJECT