From: Adam Dullage Date: Sun, 4 Feb 2024 16:04:49 +0000 (+0000) Subject: Refactor X-Git-Url: http://git.99rst.org/?a=commitdiff_plain;h=21ac3a1b43fdf2baee12f6aed13d379de1dbe2aa;p=flatnotes.git Refactor --- diff --git a/.gitignore b/.gitignore index b1e8c8b..d993309 100644 --- a/.gitignore +++ b/.gitignore @@ -247,4 +247,4 @@ dist # Custom .vscode/ data/ -notes/ +/notes/ diff --git a/server/api_messages.py b/server/api_messages.py new file mode 100644 index 0000000..bf1e300 --- /dev/null +++ b/server/api_messages.py @@ -0,0 +1,12 @@ +login_failed = "Invalid login details." +note_exists = "Cannot create note. A note with the same title already exists." +note_not_found = "The specified note cannot be found." +invalid_note_title = "The specified note title contains invalid characters." +attachment_exists = ( + "Cannot create attachment. An attachment with the same filename already " + "exists." +) +attachment_not_found = "The specified attachment cannot be found." +invalid_attachment_filename = ( + "The specified filename contains invalid characters." +) diff --git a/server/attachments/base.py b/server/attachments/base.py new file mode 100644 index 0000000..fd4864f --- /dev/null +++ b/server/attachments/base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + +from fastapi import UploadFile +from fastapi.responses import FileResponse + + +class BaseAttachments(ABC): + @abstractmethod + def create(self, file: UploadFile) -> None: + """Create a new attachment.""" + pass + + @abstractmethod + def get(self, filename: str) -> FileResponse: + """Get a specific attachment.""" + pass diff --git a/server/attachments/file_system/__init__.py b/server/attachments/file_system/__init__.py new file mode 100644 index 0000000..502cf2c --- /dev/null +++ b/server/attachments/file_system/__init__.py @@ -0,0 +1 @@ +from .file_system import FileSystemAttachments # noqa diff --git a/server/attachments/file_system/file_system.py b/server/attachments/file_system/file_system.py new file mode 100644 index 0000000..c57c99c --- /dev/null +++ b/server/attachments/file_system/file_system.py @@ -0,0 +1,37 @@ +import os +import shutil + +from fastapi import UploadFile +from fastapi.responses import FileResponse + +from helpers import get_env, is_valid_filename + +from ..base import BaseAttachments + + +class FileSystemAttachments(BaseAttachments): + def __init__(self): + self.base_path = get_env("FLATNOTES_PATH", mandatory=True) + if not os.path.exists(self.base_path): + raise NotADirectoryError( + f"'{self.base_path}' is not a valid directory." + ) + self.storage_path = os.path.join(self.base_path, "attachments") + os.makedirs(self.storage_path, exist_ok=True) + + def create(self, file: UploadFile) -> None: + """Create a new attachment.""" + is_valid_filename(file.filename) + filepath = os.path.join(self.storage_path, file.filename) + if os.path.exists(filepath): + raise FileExistsError(f"'{file.filename}' already exists.") + with open(filepath, "wb") as f: + shutil.copyfileobj(file.file, f) + + def get(self, filename: str) -> FileResponse: + """Get a specific attachment.""" + is_valid_filename(filename) + filepath = os.path.join(self.storage_path, filename) + if not os.path.isfile(filepath): + raise FileNotFoundError(f"'{filename}' not found.") + return FileResponse(filepath) diff --git a/server/auth.py b/server/auth.py deleted file mode 100644 index 2234e4e..0000000 --- a/server/auth.py +++ /dev/null @@ -1,53 +0,0 @@ -from datetime import datetime, timedelta - -from fastapi import Depends, HTTPException, Request -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt - -from config import AuthType, config - -JWT_ALGORITHM = "HS256" - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token", auto_error=False) - - -def create_access_token(data: dict): - to_encode = data.copy() - expiry_datetime = datetime.utcnow() + timedelta( - days=config.session_expiry_days - ) - to_encode.update({"exp": expiry_datetime}) - encoded_jwt = jwt.encode( - to_encode, config.session_key, algorithm=JWT_ALGORITHM - ) - return encoded_jwt - - -def validate_token(request: Request, token: str = Depends(oauth2_scheme)): - # Skip authentication if auth_type is NONE - if config.auth_type == AuthType.NONE: - return - # If no token is found in the header, check the cookies - if token is None: - token = request.cookies.get("token") - # Validate the token - try: - if token is None: - raise ValueError - payload = jwt.decode( - token, config.session_key, algorithms=[JWT_ALGORITHM] - ) - username = payload.get("sub") - if username is None or username.lower() != config.username.lower(): - raise ValueError - return - except (JWTError, ValueError): - raise HTTPException( - status_code=401, - detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - - -def no_auth(): - return diff --git a/server/auth/base.py b/server/auth/base.py new file mode 100644 index 0000000..1f52b59 --- /dev/null +++ b/server/auth/base.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from .models import Login, Token + + +class BaseAuth(ABC): + @abstractmethod + def login(self, data: Login) -> Token: + """Login a user.""" + pass + + @abstractmethod + def authenticate(self, token: str) -> bool: + """Authenticate a user.""" + pass diff --git a/server/auth/local/__init__.py b/server/auth/local/__init__.py new file mode 100644 index 0000000..08c1d4b --- /dev/null +++ b/server/auth/local/__init__.py @@ -0,0 +1 @@ +from .local import LocalAuth # noqa diff --git a/server/auth/local/local.py b/server/auth/local/local.py new file mode 100644 index 0000000..6c2eba4 --- /dev/null +++ b/server/auth/local/local.py @@ -0,0 +1,124 @@ +import secrets +from base64 import b32encode +from datetime import datetime, timedelta + +import pyotp +from fastapi import Depends, HTTPException, Request +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from qrcode import QRCode + +from global_config import GlobalConfig +from helpers import get_env + +from ..base import BaseAuth +from ..models import Login, Token + +global_config = GlobalConfig() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token", auto_error=False) + + +class LocalAuth(BaseAuth): + JWT_ALGORITHM = "HS256" + + def __init__(self) -> None: + self.username = get_env("FLATNOTES_USERNAME", mandatory=True).lower() + self.password = get_env("FLATNOTES_PASSWORD", mandatory=True) + self.secret_key = get_env("FLATNOTES_SECRET_KEY", mandatory=True) + self.session_expiry_days = get_env( + "FLATNOTES_SESSION_EXPIRY_DAYS", default=30, cast_int=True + ) + + # TOTP + if global_config.auth_type == "totp": + self.is_totp_enabled = True + self.totp_key = get_env("FLATNOTES_TOTP_KEY", mandatory=True) + self.totp_key = b32encode(self.totp_key.encode("utf-8")) + self.totp = pyotp.TOTP(self.totp_key) + self.last_used_totp = None + self._display_totp_enrolment() + + def login(self, data: Login) -> Token: + # Check Username + username_correct = secrets.compare_digest( + self.username.lower(), data.username.lower() + ) + + # Check Password & TOTP + expected_password = self.password + if self.is_totp_enabled: + current_totp = self.totp.now() + expected_password += current_totp + password_correct = secrets.compare_digest( + expected_password, data.password + ) + + # Raise error if incorrect + if not ( + username_correct + and password_correct + # Prevent TOTP from being reused + and ( + self.is_totp_enabled is False + or current_totp != self.last_used_totp + ) + ): + raise ValueError("Incorrect login credentials.") + if self.is_totp_enabled: + self.last_used_totp = current_totp + + # Create Token + access_token = self._create_access_token(data={"sub": self.username}) + return Token(access_token=access_token) + + def authenticate( + self, request: Request, token: str = Depends(oauth2_scheme) + ): + # If no token is found in the header, check the cookies + if token is None: + token = request.cookies.get("token") + # Validate the token + try: + self._validate_token(token) + except (JWTError, ValueError): + raise HTTPException( + status_code=401, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + def _validate_token(self, token: str) -> bool: + if token is None: + raise ValueError + payload = jwt.decode( + token, self.secret_key, algorithms=[self.JWT_ALGORITHM] + ) + username = payload.get("sub") + if username is None or username.lower() != self.username: + raise ValueError + + def _create_access_token(self, data: dict): + to_encode = data.copy() + expiry_datetime = datetime.utcnow() + timedelta( + days=self.session_expiry_days + ) + to_encode.update({"exp": expiry_datetime}) + encoded_jwt = jwt.encode( + to_encode, self.secret_key, algorithm=self.JWT_ALGORITHM + ) + return encoded_jwt + + def _display_totp_enrolment(self): + uri = self.totp.provisioning_uri( + issuer_name="flatnotes", name=self.username + ) + qr = QRCode() + qr.add_data(uri) + print( + "\nScan this QR code with your TOTP app of choice", + "e.g. Authy or Google Authenticator:", + ) + qr.print_ascii() + print( + f"Or manually enter this key: {self.totp.secret.decode('utf-8')}\n" + ) diff --git a/server/auth/models.py b/server/auth/models.py new file mode 100644 index 0000000..79fc69c --- /dev/null +++ b/server/auth/models.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel, Field + +from helpers import CustomBaseModel + + +class Login(CustomBaseModel): + username: str + password: str + + +class Token(BaseModel): + # Note: OAuth requires keys to be snake_case so we use the standard + # BaseModel here + access_token: str + token_type: str = Field("bearer") diff --git a/server/config.py b/server/config.py deleted file mode 100644 index 4955205..0000000 --- a/server/config.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import sys -from base64 import b32encode -from enum import Enum - -from logger import logger - - -class AuthType(str, Enum): - NONE = "none" - READ_ONLY = "read_only" - PASSWORD = "password" - TOTP = "totp" - - -class Config: - def __init__(self) -> None: - self.data_path = self.get_data_path() - - self.auth_type = self.get_auth_type() - - self.username = self.get_username() - self.password = self.get_password() - - self.session_key = self.get_session_key() - self.session_expiry_days = self.get_session_expiry_days() - - self.totp_key = self.get_totp_key() - - @classmethod - def get_env(cls, key, mandatory=False, default=None, cast_int=False): - """Get an environment variable.""" - value = os.environ.get(key) - if mandatory and not value: - logger.error(f"Environment variable {key} must be set.") - sys.exit(1) - if not mandatory and not value: - return default - if cast_int: - try: - value = int(value) - except (TypeError, ValueError): - logger.error(f"Invalid value '{value}' for {key}.") - sys.exit(1) - return value - - def get_data_path(self): - return self.get_env("FLATNOTES_PATH", mandatory=True) - - def get_auth_type(self): - key = "FLATNOTES_AUTH_TYPE" - auth_type = self.get_env( - key, mandatory=False, default=AuthType.PASSWORD.value - ) - try: - auth_type = AuthType(auth_type.lower()) - except ValueError: - logger.error( - f"Invalid value '{auth_type}' for {key}. " - + "Must be one of: " - + ", ".join([auth_type.value for auth_type in AuthType]) - + "." - ) - sys.exit(1) - return auth_type - - def get_username(self): - return self.get_env( - "FLATNOTES_USERNAME", - mandatory=self.auth_type - not in [AuthType.NONE, AuthType.READ_ONLY], - ) - - def get_password(self): - return self.get_env( - "FLATNOTES_PASSWORD", - mandatory=self.auth_type - not in [AuthType.NONE, AuthType.READ_ONLY], - ) - - def get_session_key(self): - return self.get_env( - "FLATNOTES_SECRET_KEY", - mandatory=self.auth_type - not in [AuthType.NONE, AuthType.READ_ONLY], - ) - - def get_session_expiry_days(self): - return self.get_env( - "FLATNOTES_SESSION_EXPIRY_DAYS", - mandatory=False, - default=30, - cast_int=True, - ) - - def get_totp_key(self): - totp_key = self.get_env( - "FLATNOTES_TOTP_KEY", mandatory=self.auth_type == AuthType.TOTP - ) - if totp_key: - totp_key = b32encode(totp_key.encode("utf-8")) - return totp_key - - -config = Config() diff --git a/server/error_responses.py b/server/error_responses.py deleted file mode 100644 index 2e86a57..0000000 --- a/server/error_responses.py +++ /dev/null @@ -1,15 +0,0 @@ -from fastapi.responses import JSONResponse - -filename_exists_response = JSONResponse( - content={"message": "The specified filename already exists."}, - status_code=409, -) - -invalid_filename_response = JSONResponse( - content={"message": "The specified filename contains invalid characters."}, - status_code=400, -) - -note_not_found_response = JSONResponse( - content={"message": "The specified note cannot be found."}, status_code=404 -) diff --git a/server/global_config.py b/server/global_config.py new file mode 100644 index 0000000..fa2429b --- /dev/null +++ b/server/global_config.py @@ -0,0 +1,59 @@ +import sys +from enum import Enum + +from helpers import CustomBaseModel, get_env +from logger import logger + + +class GlobalConfig: + def __init__(self) -> None: + logger.debug("Loading global config...") + self.auth_type: AuthType = self._load_auth_type() + + def load_auth(self): + if self.auth_type in (AuthType.NONE, AuthType.READ_ONLY): + return None + elif self.auth_type in (AuthType.PASSWORD, AuthType.TOTP): + from auth.local import LocalAuth + + return LocalAuth() + + def load_note_storage(self): + from notes.file_system import FileSystemNotes + + return FileSystemNotes() + + def load_attachment_storage(self): + from attachments.file_system import ( + FileSystemAttachments, + ) + + return FileSystemAttachments() + + def _load_auth_type(self): + key = "FLATNOTES_AUTH_TYPE" + auth_type = get_env( + key, mandatory=False, default=AuthType.PASSWORD.value + ) + try: + auth_type = AuthType(auth_type.lower()) + except ValueError: + logger.error( + f"Invalid value '{auth_type}' for {key}. " + + "Must be one of: " + + ", ".join([auth_type.value for auth_type in AuthType]) + + "." + ) + sys.exit(1) + return auth_type + + +class AuthType(str, Enum): + NONE = "none" + READ_ONLY = "read_only" + PASSWORD = "password" + TOTP = "totp" + + +class GlobalConfigResponseModel(CustomBaseModel): + auth_type: AuthType diff --git a/server/helpers.py b/server/helpers.py index 06f18fb..645f739 100644 --- a/server/helpers.py +++ b/server/helpers.py @@ -1,11 +1,9 @@ import os -import re -import shutil -from typing import List, Tuple +import sys +from pydantic import BaseModel -def strip_ext(filename): - return os.path.splitext(filename)[0] +from logger import logger def camel_case(snake_case_str: str) -> str: @@ -14,27 +12,44 @@ def camel_case(snake_case_str: str) -> str: return parts[0] + "".join(part.title() for part in parts[1:]) -def empty_dir(path): - for item in os.listdir(path): - item_path = os.path.join(path, item) - if os.path.isfile(item_path): - os.remove(item_path) - elif os.path.isdir(item_path): - shutil.rmtree(item_path) - - -def re_extract(pattern, string) -> Tuple[str, List[str]]: - """Similar to re.sub but returns a tuple of: - - - `string` with matches removed - - list of matches""" - matches = [] - text = re.sub(pattern, lambda tag: matches.append(tag.group()), string) - return (text, matches) - - -def is_valid_filename(filename): - r"""Return False if the declared filename contains any of the following - characters: <>:"/\|?*""" +def is_valid_filename(value): + """Raise ValueError if the declared string contains any of the following + characters: <>:"/\\|?*""" invalid_chars = r'<>:"/\|?*' - return not any(invalid_char in filename for invalid_char in invalid_chars) + if any(invalid_char in value for invalid_char in invalid_chars): + raise ValueError( + "title cannot include any of the following characters: " + + invalid_chars + ) + return value + + +def strip_whitespace(value): + """Return the declared string with leading and trailing whitespace + removed.""" + return value.strip() + + +def get_env(key, mandatory=False, default=None, cast_int=False): + """Get an environment variable. If `mandatory` is True and environment + variable isn't set, exit the program""" + value = os.environ.get(key) + if mandatory and not value: + logger.error(f"Environment variable {key} must be set.") + sys.exit(1) + if not mandatory and not value: + return default + if cast_int: + try: + value = int(value) + except (TypeError, ValueError): + logger.error(f"Invalid value '{value}' for {key}.") + sys.exit(1) + return value + + +class CustomBaseModel(BaseModel): + class Config: + alias_generator = camel_case + populate_by_name = True + from_attributes = True diff --git a/server/main.py b/server/main.py index 7e93c3e..99a8eff 100644 --- a/server/main.py +++ b/server/main.py @@ -1,99 +1,26 @@ -import os -import secrets -import shutil -from typing import List, Literal, Union +from typing import List, Literal -import pyotp from fastapi import Depends, FastAPI, HTTPException, UploadFile -from fastapi.responses import FileResponse, HTMLResponse +from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles -from qrcode import QRCode - -from auth import create_access_token, no_auth, validate_token -from config import AuthType, config -from error_responses import ( - filename_exists_response, - invalid_filename_response, - note_not_found_response, -) -from flatnotes import Flatnotes, InvalidTitleError, Note -from helpers import is_valid_filename -from models import ( - ConfigModel, - LoginModel, - NoteContentResponseModel, - NotePatchModel, - NotePostModel, - NoteResponseModel, - SearchResultModel, - TokenModel, -) - -ATTACHMENTS_DIR = os.path.join(config.data_path, "attachments") -os.makedirs(ATTACHMENTS_DIR, exist_ok=True) +import api_messages +from attachments.base import BaseAttachments +from auth.base import BaseAuth +from auth.models import Login, Token +from global_config import AuthType, GlobalConfig, GlobalConfigResponseModel +from notes.base import BaseNotes +from notes.models import Note, NoteCreate, NoteUpdate, SearchResult + +global_config = GlobalConfig() +auth: BaseAuth = global_config.load_auth() +note_storage: BaseNotes = global_config.load_note_storage() +attachment_storage: BaseAttachments = global_config.load_attachment_storage() +auth_deps = [Depends(auth.authenticate)] if auth else [] app = FastAPI() -flatnotes = Flatnotes(config.data_path) - -totp = ( - pyotp.TOTP(config.totp_key) if config.auth_type == AuthType.TOTP else None -) -last_used_totp = None - -if config.auth_type in [AuthType.NONE, AuthType.READ_ONLY]: - authenticate = no_auth -else: - authenticate = validate_token - -# Display TOTP QR code -if config.auth_type == AuthType.TOTP: - uri = totp.provisioning_uri(issuer_name="flatnotes", name=config.username) - qr = QRCode() - qr.add_data(uri) - print( - "\nScan this QR code with your TOTP app of choice", - "e.g. Authy or Google Authenticator:", - ) - qr.print_ascii() - print(f"Or manually enter this key: {totp.secret.decode('utf-8')}\n") - -if config.auth_type not in [AuthType.NONE, AuthType.READ_ONLY]: - - @app.post("/api/token", response_model=TokenModel) - def token(data: LoginModel): - global last_used_totp - - username_correct = secrets.compare_digest( - config.username.lower(), data.username.lower() - ) - - expected_password = config.password - if config.auth_type == AuthType.TOTP: - current_totp = totp.now() - expected_password += current_totp - password_correct = secrets.compare_digest( - expected_password, data.password - ) - - if not ( - username_correct - and password_correct - # Prevent TOTP from being reused - and ( - config.auth_type != AuthType.TOTP - or current_totp != last_used_totp - ) - ): - raise HTTPException( - status_code=400, detail="Incorrect login credentials." - ) - - access_token = create_access_token(data={"sub": config.username}) - if config.auth_type == AuthType.TOTP: - last_used_totp = current_totp - return TokenModel(access_token=access_token) +# region UI @app.get("/", include_in_schema=False) @app.get("/login", include_in_schema=False) @app.get("/search", include_in_schema=False) @@ -105,101 +32,113 @@ def root(title: str = ""): return HTMLResponse(content=html) -if config.auth_type != AuthType.READ_ONLY: +# endregion - @app.post( - "/api/notes", - dependencies=[Depends(authenticate)], - response_model=NoteContentResponseModel, - ) - def post_note(data: NotePostModel): - """Create a new note.""" + +# region Login +if global_config.auth_type not in [AuthType.NONE, AuthType.READ_ONLY]: + + @app.post("/api/token", response_model=Token) + def token(data: Login): try: - note = Note(flatnotes, data.title, new=True) - note.content = data.content - return NoteContentResponseModel.model_validate(note) - except InvalidTitleError: - return invalid_filename_response - except FileExistsError: - return filename_exists_response + return auth.login(data) + except ValueError: + raise HTTPException( + status_code=401, detail=api_messages.login_failed + ) +# endregion + + +# region Notes +# Get Note @app.get( "/api/notes/{title}", - dependencies=[Depends(authenticate)], - response_model=Union[NoteContentResponseModel, NoteResponseModel], + dependencies=auth_deps, + response_model=Note, ) -def get_note( - title: str, - include_content: bool = True, -): +def get_note(title: str): """Get a specific note.""" try: - note = Note(flatnotes, title) - if include_content: - return NoteContentResponseModel.model_validate(note) - else: - return NoteResponseModel.model_validate(note) - except InvalidTitleError: - return invalid_filename_response + return note_storage.get(title) + except ValueError: + raise HTTPException( + status_code=400, detail=api_messages.invalid_note_title + ) except FileNotFoundError: - return note_not_found_response + raise HTTPException(404, api_messages.note_not_found) -if config.auth_type != AuthType.READ_ONLY: +if global_config.auth_type != AuthType.READ_ONLY: + # Create Note + @app.post( + "/api/notes", + dependencies=auth_deps, + response_model=Note, + ) + def post_note(note: NoteCreate): + """Create a new note.""" + try: + return note_storage.create(note) + except ValueError: + raise HTTPException( + status_code=400, + detail=api_messages.invalid_note_title, + ) + except FileExistsError: + raise HTTPException( + status_code=409, detail=api_messages.note_exists + ) + + # Update Note @app.patch( "/api/notes/{title}", - dependencies=[Depends(authenticate)], - response_model=NoteContentResponseModel, + dependencies=auth_deps, + response_model=Note, ) - def patch_note(title: str, new_data: NotePatchModel): + def patch_note(title: str, data: NoteUpdate): try: - note = Note(flatnotes, title) - if new_data.new_title is not None: - note.title = new_data.new_title - if new_data.new_content is not None: - note.content = new_data.new_content - return NoteContentResponseModel.model_validate(note) - except InvalidTitleError: - return invalid_filename_response + return note_storage.update(title, data) + except ValueError: + raise HTTPException( + status_code=400, + detail=api_messages.invalid_note_title, + ) except FileExistsError: - return filename_exists_response + raise HTTPException( + status_code=409, detail=api_messages.note_exists + ) except FileNotFoundError: - return note_not_found_response - - -if config.auth_type != AuthType.READ_ONLY: + raise HTTPException(404, api_messages.note_not_found) + # Delete Note @app.delete( "/api/notes/{title}", - dependencies=[Depends(authenticate)], + dependencies=auth_deps, response_model=None, ) def delete_note(title: str): try: - note = Note(flatnotes, title) - note.delete() - except InvalidTitleError: - return invalid_filename_response + note_storage.delete(title) + except ValueError: + raise HTTPException( + status_code=400, + detail=api_messages.invalid_note_title, + ) except FileNotFoundError: - return note_not_found_response + raise HTTPException(404, api_messages.note_not_found) -@app.get( - "/api/tags", - dependencies=[Depends(authenticate)], - response_model=List[str], -) -def get_tags(): - """Get a list of all indexed tags.""" - return flatnotes.get_tags() +# endregion +# region Search @app.get( "/api/search", - dependencies=[Depends(authenticate)], - response_model=List[SearchResultModel], + dependencies=auth_deps, + response_model=List[SearchResult], ) def search( term: str, @@ -210,51 +149,75 @@ def search( """Perform a full text search on all notes.""" if sort == "lastModified": sort = "last_modified" - return [ - SearchResultModel.model_validate(note_hit) - for note_hit in flatnotes.search( - term, sort=sort, order=order, limit=limit - ) - ] + return note_storage.search(term, sort=sort, order=order, limit=limit) -@app.get("/api/config", response_model=ConfigModel) +@app.get( + "/api/tags", + dependencies=auth_deps, + response_model=List[str], +) +def get_tags(): + """Get a list of all indexed tags.""" + return note_storage.get_tags() + + +# endregion + + +# region Config +@app.get("/api/config", response_model=GlobalConfigResponseModel) def get_config(): """Retrieve server-side config required for the UI.""" - return ConfigModel.model_validate(config) + return GlobalConfigResponseModel(auth_type=global_config.auth_type) -if config.auth_type != AuthType.READ_ONLY: - - @app.post( - "/api/attachments", - dependencies=[Depends(authenticate)], - response_model=None, - ) - def post_attachment(file: UploadFile): - """Upload an attachment.""" - if not is_valid_filename(file.filename): - return invalid_filename_response - filepath = os.path.join(ATTACHMENTS_DIR, file.filename) - if os.path.exists(filepath): - return filename_exists_response - with open(filepath, "wb") as f: - shutil.copyfileobj(file.file, f) +# endregion +# region Attachments +# Get Attachment @app.get( "/attachments/{filename}", - dependencies=[Depends(authenticate)], + dependencies=auth_deps, include_in_schema=False, ) def get_attachment(filename: str): """Download an attachment.""" - if not is_valid_filename(filename): - raise HTTPException(status_code=400, detail="Invalid filename.") - filepath = os.path.join(ATTACHMENTS_DIR, filename) - if not os.path.isfile(filepath): - raise HTTPException(status_code=404, detail="File not found.") - return FileResponse(filepath) + try: + return attachment_storage.get(filename) + except ValueError: + raise HTTPException( + status_code=400, + detail=api_messages.invalid_attachment_filename, + ) + except FileNotFoundError: + raise HTTPException( + status_code=404, detail=api_messages.attachment_not_found + ) + + +if global_config.auth_type != AuthType.READ_ONLY: + + # Create Attachment + @app.post( + "/api/attachments", + dependencies=auth_deps, + response_model=None, + ) + def post_attachment(file: UploadFile): + """Upload an attachment.""" + try: + return attachment_storage.create(file) + except ValueError: + raise HTTPException( + status_code=400, + detail=api_messages.invalid_attachment_filename, + ) + except FileExistsError: + raise HTTPException(409, api_messages.attachment_exists) + +# endregion app.mount("/", StaticFiles(directory="client/dist"), name="dist") diff --git a/server/models.py b/server/models.py deleted file mode 100644 index b87d8bd..0000000 --- a/server/models.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import List, Optional - -from pydantic import BaseModel, Field - -from config import AuthType -from helpers import camel_case - - -class TokenModel(BaseModel): - # Use of BaseModel instead of CustomBaseModel is intentional as OAuth - # requires keys to be snake_case - access_token: str - token_type: str = Field("bearer") - - -class CustomBaseModel(BaseModel): - class Config: - alias_generator = camel_case - populate_by_name = True - from_attributes = True - - -class LoginModel(CustomBaseModel): - username: str - password: str - - -class NotePostModel(CustomBaseModel): - title: str - content: Optional[str] = Field(None) - - -class NoteResponseModel(CustomBaseModel): - title: str - last_modified: float - - -class NoteContentResponseModel(NoteResponseModel): - content: Optional[str] = Field(None) - - -class NotePatchModel(CustomBaseModel): - new_title: Optional[str] = Field(None) - new_content: Optional[str] = Field(None) - - -class SearchResultModel(CustomBaseModel): - score: Optional[float] = Field(None) - title: str - last_modified: float - title_highlights: Optional[str] = Field(None) - content_highlights: Optional[str] = Field(None) - tag_matches: Optional[List[str]] = Field(None) - - -class ConfigModel(CustomBaseModel): - auth_type: AuthType diff --git a/server/notes/base.py b/server/notes/base.py new file mode 100644 index 0000000..ad97398 --- /dev/null +++ b/server/notes/base.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Literal + +from .models import Note, NoteCreate, NoteUpdate, SearchResult + + +class BaseNotes(ABC): + @abstractmethod + def create(self, data: NoteCreate) -> Note: + """Create a new note.""" + pass + + @abstractmethod + def get(self, title: str) -> Note: + """Get a specific note.""" + pass + + @abstractmethod + def update(self, title: str, new_data: NoteUpdate) -> Note: + """Update a specific note.""" + pass + + @abstractmethod + def delete(self, title: str) -> None: + """Delete a specific note.""" "" + pass + + @abstractmethod + def search( + self, + term: str, + sort: Literal["score", "title", "last_modified"] = "score", + order: Literal["asc", "desc"] = "desc", + limit: int = None, + ) -> list[SearchResult]: + """Search for notes.""" + pass + + @abstractmethod + def get_tags(self) -> list[str]: + """Get a list of all indexed tags.""" + pass diff --git a/server/notes/file_system/__init__.py b/server/notes/file_system/__init__.py new file mode 100644 index 0000000..2815b67 --- /dev/null +++ b/server/notes/file_system/__init__.py @@ -0,0 +1 @@ +from .file_system import FileSystemNotes diff --git a/server/flatnotes.py b/server/notes/file_system/file_system.py similarity index 51% rename from server/flatnotes.py rename to server/notes/file_system/file_system.py index 526ad32..0533120 100644 --- a/server/flatnotes.py +++ b/server/notes/file_system/file_system.py @@ -1,6 +1,7 @@ import glob import os import re +import shutil from datetime import datetime from typing import List, Literal, Set, Tuple @@ -16,9 +17,12 @@ from whoosh.query import Every from whoosh.searching import Hit from whoosh.support.charset import accent_map -from helpers import empty_dir, is_valid_filename, re_extract, strip_ext +from helpers import get_env, is_valid_filename from logger import logger +from ..base import BaseNotes +from ..models import Note, NoteCreate, NoteUpdate, SearchResult + MARKDOWN_EXT = ".md" INDEX_SCHEMA_VERSION = "4" @@ -35,174 +39,159 @@ class IndexSchema(SchemaClass): tags = KEYWORD(lowercase=True, field_boost=2.0) -class InvalidTitleError(Exception): - def __init__(self, message="The specified title is invalid"): - self.message = message - super().__init__(self.message) - - -class Note: - def __init__( - self, flatnotes: "Flatnotes", title: str, new: bool = False - ) -> None: - self._flatnotes = flatnotes - self._title = title.strip() - if not is_valid_filename(self._title): - raise InvalidTitleError - exists = os.path.exists(self.filepath) - if new and exists: - raise FileExistsError - if not new and not exists: - raise FileNotFoundError - if new: - open(self.filepath, "w").close() - - @property - def filepath(self): - return os.path.join(self._flatnotes.dir, self.filename) - - @property - def filename(self): - return self._title + MARKDOWN_EXT - - @property - def last_modified(self): - return os.path.getmtime(self.filepath) +class FileSystemNotes(BaseNotes): + TAGS_RE = re.compile(r"(?:(?<=^#)|(?<=\s#))\w+(?=\s|$)") + CODEBLOCK_RE = re.compile(r"`{1,3}.*?`{1,3}", re.DOTALL) + TAGS_WITH_HASH_RE = re.compile(r"(?:(?<=^)|(?<=\s))#\w+(?=\s|$)") - # Editable Properties - @property - def title(self): - return self._title - - @title.setter - def title(self, new_title): - new_title = new_title.strip() - if not is_valid_filename(new_title): - raise InvalidTitleError - new_filepath = os.path.join( - self._flatnotes.dir, new_title + MARKDOWN_EXT + def __init__(self): + self.storage_path = get_env("FLATNOTES_PATH", mandatory=True) + if not os.path.exists(self.storage_path): + raise NotADirectoryError( + f"'{self.storage_path}' is not a valid directory." + ) + self.index = self._load_index() + self._sync_index() + + def create(self, data: NoteCreate) -> Note: + """Create a new note.""" + filepath = self._path_from_title(data.title) + self._write_file(filepath, data.content) + return Note( + title=data.title, + content=data.content, + last_modified=os.path.getmtime(filepath), ) - os.rename(self.filepath, new_filepath) - self._title = new_title - - @property - def content(self): - with open(self.filepath, "r", encoding="utf-8") as f: - return f.read() - - @content.setter - def content(self, new_content): - if not os.path.exists(self.filepath): - raise FileNotFoundError - with open(self.filepath, "w", encoding="utf-8") as f: - f.write(new_content) - - def delete(self): - os.remove(self.filepath) - - -class SearchResult(Note): - def __init__(self, flatnotes: "Flatnotes", hit: Hit) -> None: - super().__init__(flatnotes, strip_ext(hit["filename"])) - self._matched_fields = self._get_matched_fields(hit.matched_terms()) - # If the search was ordered using a text field then hit.score is the - # value of that field. This isn't useful so only set self._score if it - # is a float. - self._score = hit.score if type(hit.score) is float else None - - if "title" in self._matched_fields: - hit.results.fragmenter = WholeFragmenter() - self._title_highlights = hit.highlights("title", text=self.title) - else: - self._title_highlights = None + def get(self, title: str) -> Note: + """Get a specific note.""" + is_valid_filename(title) + filepath = self._path_from_title(title) + content = self._read_file(filepath) + return Note( + title=title, + content=content, + last_modified=os.path.getmtime(filepath), + ) - if "content" in self._matched_fields: - hit.results.fragmenter = ContextFragmenter() - content_ex_tags, _ = Flatnotes.extract_tags(self.content) - self._content_highlights = hit.highlights( - "content", - text=content_ex_tags, - ) + def update(self, title: str, data: NoteUpdate) -> Note: + """Update a specific note.""" + is_valid_filename(title) + filepath = self._path_from_title(title) + if data.new_title is not None: + new_filepath = self._path_from_title(data.new_title) + os.rename(filepath, new_filepath) + title = data.new_title + filepath = new_filepath + if data.new_content is not None: + self._write_file(filepath, data.new_content, overwrite=True) + content = data.new_content else: - self._content_highlights = None - - self._tag_matches = ( - [field[1] for field in hit.matched_terms() if field[0] == "tags"] - if "tags" in self._matched_fields - else None + content = self._read_file(filepath) + return Note( + title=title, + content=content, + last_modified=os.path.getmtime(filepath), ) - @property - def score(self): - return self._score - - @property - def title_highlights(self): - return self._title_highlights + def delete(self, title: str) -> None: + """Delete a specific note.""" + is_valid_filename(title) + filepath = self._path_from_title(title) + os.remove(filepath) - @property - def content_highlights(self): - return self._content_highlights + def search( + self, + term: str, + sort: Literal["score", "title", "last_modified"] = "score", + order: Literal["asc", "desc"] = "desc", + limit: int = None, + ) -> Tuple[SearchResult, ...]: + """Search the index for the given term.""" + self._sync_index() + term = self._pre_process_search_term(term) + with self.index.searcher() as searcher: + # Parse Query + if term == "*": + query = Every() + else: + parser = MultifieldParser( + ["title", "content", "tags"], self.index.schema + ) + parser.add_plugin(DateParserPlugin()) + query = parser.parse(term) - @property - def tag_matches(self): - return self._tag_matches + # Determine Sort By + # Note: For the 'sort' option, "score" is converted to None as + # that is the default for searches anyway and it's quicker for + # Whoosh if you specify None. + sort = sort if sort in ["title", "last_modified"] else None - @staticmethod - def _get_matched_fields(matched_terms): - """Return a set of matched fields from a set of ('field', 'term') " - "tuples generated by whoosh.searching.Hit.matched_terms().""" - return set([matched_term[0] for matched_term in matched_terms]) + # Determine Sort Direction + # Note: Confusingly, when sorting by 'score', reverse = True means + # asc so we have to flip the logic for that case! + reverse = order == "desc" + if sort is None: + reverse = not reverse + # Run Search + results = searcher.search( + query, + sortedby=sort, + reverse=reverse, + limit=limit, + terms=True, + ) + return tuple(self._search_result_from_hit(hit) for hit in results) -class Flatnotes(object): - TAGS_RE = re.compile(r"(?:(?<=^#)|(?<=\s#))\w+(?=\s|$)") - CODEBLOCK_RE = re.compile(r"`{1,3}.*?`{1,3}", re.DOTALL) - TAGS_WITH_HASH_RE = re.compile(r"(?:(?<=^)|(?<=\s))#\w+(?=\s|$)") + def get_tags(self) -> list[str]: + """Return a list of all indexed tags.""" + self._sync_index() + with self.index.reader() as reader: + tags = reader.field_terms("tags") + return [tag for tag in tags] - def __init__(self, dir: str) -> None: - if not os.path.exists(dir): - raise NotADirectoryError(f"'{dir}' is not a valid directory.") - self.dir = dir + @property + def _index_path(self): + return os.path.join(self.storage_path, ".flatnotes") - self.index = self._load_index() - self.update_index() + def _path_from_title(self, title: str) -> str: + return os.path.join(self.storage_path, title + MARKDOWN_EXT) - @property - def index_dir(self): - return os.path.join(self.dir, ".flatnotes") + def _get_by_filename(self, filename: str) -> Note: + """Get a note by its filename.""" + return self.get(self._strip_ext(filename)) def _load_index(self) -> Index: """Load the note index or create new if not exists.""" - index_dir_exists = os.path.exists(self.index_dir) + index_dir_exists = os.path.exists(self._index_path) if index_dir_exists and whoosh.index.exists_in( - self.index_dir, indexname=INDEX_SCHEMA_VERSION + self._index_path, indexname=INDEX_SCHEMA_VERSION ): logger.info("Loading existing index") return whoosh.index.open_dir( - self.index_dir, indexname=INDEX_SCHEMA_VERSION + self._index_path, indexname=INDEX_SCHEMA_VERSION ) else: if index_dir_exists: logger.info("Deleting outdated index") - empty_dir(self.index_dir) + self._clear_dir(self._index_path) else: - os.mkdir(self.index_dir) + os.mkdir(self._index_path) logger.info("Creating new index") return whoosh.index.create_in( - self.index_dir, IndexSchema, indexname=INDEX_SCHEMA_VERSION + self._index_path, IndexSchema, indexname=INDEX_SCHEMA_VERSION ) @classmethod - def extract_tags(cls, content) -> Tuple[str, Set[str]]: + def _extract_tags(cls, content) -> Tuple[str, Set[str]]: """Strip tags from the given content and return a tuple consisting of: - The content without the tags. - A set of tags converted to lowercase.""" content_ex_codeblock = re.sub(cls.CODEBLOCK_RE, "", content) - _, tags = re_extract(cls.TAGS_RE, content_ex_codeblock) - content_ex_tags, _ = re_extract(cls.TAGS_RE, content) + _, tags = cls._re_extract(cls.TAGS_RE, content_ex_codeblock) + content_ex_tags, _ = cls._re_extract(cls.TAGS_RE, content) try: tags = [tag.lower() for tag in tags] return (content_ex_tags, set(tags)) @@ -215,27 +204,26 @@ class Flatnotes(object): """Add a Note object to the index using the given writer. If the filename already exists in the index an update will be performed instead.""" - content_ex_tags, tag_set = self.extract_tags(note.content) + content_ex_tags, tag_set = self._extract_tags(note.content) tag_string = " ".join(tag_set) writer.update_document( - filename=note.filename, + filename=note.title + MARKDOWN_EXT, last_modified=datetime.fromtimestamp(note.last_modified), title=note.title, content=content_ex_tags, tags=tag_string, ) - def _get_notes(self) -> List[Note]: - """Return a list containing a Note object for every file in the notes - directory.""" + def _list_all_note_filenames(self) -> List[str]: + """Return a list of all note filenames.""" return [ - Note(self, strip_ext(os.path.split(filepath)[1])) + os.path.split(filepath)[1] for filepath in glob.glob( - os.path.join(self.dir, "*" + MARKDOWN_EXT) + os.path.join(self.storage_path, "*" + MARKDOWN_EXT) ) ] - def update_index(self, clean: bool = False) -> None: + def _sync_index(self, clean: bool = False) -> None: """Synchronize the index with the notes directory. Specify clean=True to completely rebuild the index""" indexed = set() @@ -245,7 +233,7 @@ class Flatnotes(object): with self.index.searcher() as searcher: for idx_note in searcher.all_stored_fields(): idx_filename = idx_note["filename"] - idx_filepath = os.path.join(self.dir, idx_filename) + idx_filepath = os.path.join(self.storage_path, idx_filename) # Delete missing if not os.path.exists(idx_filepath): writer.delete_by_term("filename", idx_filename) @@ -257,76 +245,115 @@ class Flatnotes(object): ): logger.info(f"'{idx_filename}' updated") self._add_note_to_index( - writer, Note(self, strip_ext(idx_filename)) + writer, self._get_by_filename(idx_filename) ) indexed.add(idx_filename) # Ignore already indexed else: indexed.add(idx_filename) # Add new - for note in self._get_notes(): - if note.filename not in indexed: - self._add_note_to_index(writer, note) - logger.info(f"'{note.filename}' added to index") + for filename in self._list_all_note_filenames(): + if filename not in indexed: + self._add_note_to_index( + writer, self._get_by_filename(filename) + ) + logger.info(f"'{filename}' added to index") writer.commit() - def get_tags(self): - """Return a list of all indexed tags.""" - self.update_index() - with self.index.reader() as reader: - tags = reader.field_terms("tags") - return [tag for tag in tags] - - def pre_process_search_term(self, term): + @classmethod + def _pre_process_search_term(cls, term): term = term.strip() # Replace "#tagname" with "tags:tagname" term = re.sub( - self.TAGS_WITH_HASH_RE, + cls.TAGS_WITH_HASH_RE, lambda tag: "tags:" + tag.group(0)[1:], term, ) return term - def search( - self, - term: str, - sort: Literal["score", "title", "last_modified"] = "score", - order: Literal["asc", "desc"] = "desc", - limit: int = None, - ) -> Tuple[SearchResult, ...]: - """Search the index for the given term.""" - self.update_index() - term = self.pre_process_search_term(term) - with self.index.searcher() as searcher: - # Parse Query - if term == "*": - query = Every() - else: - parser = MultifieldParser( - ["title", "content", "tags"], self.index.schema - ) - parser.add_plugin(DateParserPlugin()) - query = parser.parse(term) + @staticmethod + def _re_extract(pattern, string) -> Tuple[str, List[str]]: + """Similar to re.sub but returns a tuple of: - # Determine Sort By - # Note: For the 'sort' option, "score" is converted to None as - # that is the default for searches anyway and it's quicker for - # Whoosh if you specify None. - sort = sort if sort in ["title", "last_modified"] else None + - `string` with matches removed + - list of matches""" + matches = [] + text = re.sub(pattern, lambda tag: matches.append(tag.group()), string) + return (text, matches) - # Determine Sort Direction - # Note: Confusingly, when sorting by 'score', reverse = True means - # asc so we have to flip the logic for that case! - reverse = order == "desc" - if sort is None: - reverse = not reverse + @staticmethod + def _strip_ext(filename): + """Return the given filename without the extension.""" + return os.path.splitext(filename)[0] - # Run Search - results = searcher.search( - query, - sortedby=sort, - reverse=reverse, - limit=limit, - terms=True, + @staticmethod + def _clear_dir(path): + """Delete all contents of the given directory.""" + for item in os.listdir(path): + item_path = os.path.join(path, item) + if os.path.isfile(item_path): + os.remove(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + + def _search_result_from_hit(self, hit: Hit): + matched_fields = self._get_matched_fields(hit.matched_terms()) + + title = self._strip_ext(hit["filename"]) + last_modified = hit["last_modified"].timestamp() + + # If the search was ordered using a text field then hit.score is the + # value of that field. This isn't useful so only set self._score if it + # is a float. + score = hit.score if type(hit.score) is float else None + + if "title" in matched_fields: + hit.results.fragmenter = WholeFragmenter() + title_highlights = hit.highlights("title", text=title) + else: + title_highlights = None + + if "content" in matched_fields: + hit.results.fragmenter = ContextFragmenter() + content = self._read_file(self._path_from_title(title)) + content_ex_tags, _ = FileSystemNotes._extract_tags(content) + content_highlights = hit.highlights( + "content", + text=content_ex_tags, ) - return tuple(SearchResult(self, hit) for hit in results) + else: + content_highlights = None + + tag_matches = ( + [field[1] for field in hit.matched_terms() if field[0] == "tags"] + if "tags" in matched_fields + else None + ) + + return SearchResult( + title=title, + last_modified=last_modified, + score=score, + title_highlights=title_highlights, + content_highlights=content_highlights, + tag_matches=tag_matches, + ) + + @staticmethod + def _get_matched_fields(matched_terms): + """Return a set of matched fields from a set of ('field', 'term') " + "tuples generated by whoosh.searching.Hit.matched_terms().""" + return set([matched_term[0] for matched_term in matched_terms]) + + @staticmethod + def _read_file(filepath: str): + logger.debug(f"Reading from '{filepath}'") + with open(filepath, "r") as f: + content = f.read() + return content + + @staticmethod + def _write_file(filepath: str, content: str, overwrite: bool = False): + logger.debug(f"Writing to '{filepath}'") + with open(filepath, "w" if overwrite else "x") as f: + f.write(content) diff --git a/server/notes/models.py b/server/notes/models.py new file mode 100644 index 0000000..ba16708 --- /dev/null +++ b/server/notes/models.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from pydantic import Field +from pydantic.functional_validators import AfterValidator +from typing_extensions import Annotated + +from helpers import CustomBaseModel, is_valid_filename, strip_whitespace + + +class NoteBase(CustomBaseModel): + title: str + + +class NoteCreate(CustomBaseModel): + title: Annotated[ + str, + AfterValidator(strip_whitespace), + AfterValidator(is_valid_filename), + ] + content: Optional[str] = Field(None) + + +class Note(CustomBaseModel): + title: str + content: Optional[str] = Field(None) + last_modified: float + + +class NoteUpdate(CustomBaseModel): + new_title: Annotated[ + Optional[str], + AfterValidator(strip_whitespace), + AfterValidator(is_valid_filename), + ] = Field(None) + new_content: Optional[str] = Field(None) + + +class SearchResult(CustomBaseModel): + title: str + last_modified: float + + score: Optional[float] = Field(None) + title_highlights: Optional[str] = Field(None) + content_highlights: Optional[str] = Field(None) + tag_matches: Optional[List[str]] = Field(None)