# Custom
.vscode/
data/
-notes/
+/notes/
--- /dev/null
+login_failed = "Invalid login details."
+note_exists = "Cannot create note. A note with the same title already exists."
+note_not_found = "The specified note cannot be found."
+invalid_note_title = "The specified note title contains invalid characters."
+attachment_exists = (
+ "Cannot create attachment. An attachment with the same filename already "
+ "exists."
+)
+attachment_not_found = "The specified attachment cannot be found."
+invalid_attachment_filename = (
+ "The specified filename contains invalid characters."
+)
--- /dev/null
+from abc import ABC, abstractmethod
+
+from fastapi import UploadFile
+from fastapi.responses import FileResponse
+
+
+class BaseAttachments(ABC):
+ @abstractmethod
+ def create(self, file: UploadFile) -> None:
+ """Create a new attachment."""
+ pass
+
+ @abstractmethod
+ def get(self, filename: str) -> FileResponse:
+ """Get a specific attachment."""
+ pass
--- /dev/null
+from .file_system import FileSystemAttachments # noqa
--- /dev/null
+import os
+import shutil
+
+from fastapi import UploadFile
+from fastapi.responses import FileResponse
+
+from helpers import get_env, is_valid_filename
+
+from ..base import BaseAttachments
+
+
+class FileSystemAttachments(BaseAttachments):
+ def __init__(self):
+ self.base_path = get_env("FLATNOTES_PATH", mandatory=True)
+ if not os.path.exists(self.base_path):
+ raise NotADirectoryError(
+ f"'{self.base_path}' is not a valid directory."
+ )
+ self.storage_path = os.path.join(self.base_path, "attachments")
+ os.makedirs(self.storage_path, exist_ok=True)
+
+ def create(self, file: UploadFile) -> None:
+ """Create a new attachment."""
+ is_valid_filename(file.filename)
+ filepath = os.path.join(self.storage_path, file.filename)
+ if os.path.exists(filepath):
+ raise FileExistsError(f"'{file.filename}' already exists.")
+ with open(filepath, "wb") as f:
+ shutil.copyfileobj(file.file, f)
+
+ def get(self, filename: str) -> FileResponse:
+ """Get a specific attachment."""
+ is_valid_filename(filename)
+ filepath = os.path.join(self.storage_path, filename)
+ if not os.path.isfile(filepath):
+ raise FileNotFoundError(f"'{filename}' not found.")
+ return FileResponse(filepath)
+++ /dev/null
-from datetime import datetime, timedelta
-
-from fastapi import Depends, HTTPException, Request
-from fastapi.security import OAuth2PasswordBearer
-from jose import JWTError, jwt
-
-from config import AuthType, config
-
-JWT_ALGORITHM = "HS256"
-
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token", auto_error=False)
-
-
-def create_access_token(data: dict):
- to_encode = data.copy()
- expiry_datetime = datetime.utcnow() + timedelta(
- days=config.session_expiry_days
- )
- to_encode.update({"exp": expiry_datetime})
- encoded_jwt = jwt.encode(
- to_encode, config.session_key, algorithm=JWT_ALGORITHM
- )
- return encoded_jwt
-
-
-def validate_token(request: Request, token: str = Depends(oauth2_scheme)):
- # Skip authentication if auth_type is NONE
- if config.auth_type == AuthType.NONE:
- return
- # If no token is found in the header, check the cookies
- if token is None:
- token = request.cookies.get("token")
- # Validate the token
- try:
- if token is None:
- raise ValueError
- payload = jwt.decode(
- token, config.session_key, algorithms=[JWT_ALGORITHM]
- )
- username = payload.get("sub")
- if username is None or username.lower() != config.username.lower():
- raise ValueError
- return
- except (JWTError, ValueError):
- raise HTTPException(
- status_code=401,
- detail="Invalid authentication credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
-
-
-def no_auth():
- return
--- /dev/null
+from abc import ABC, abstractmethod
+from .models import Login, Token
+
+
+class BaseAuth(ABC):
+ @abstractmethod
+ def login(self, data: Login) -> Token:
+ """Login a user."""
+ pass
+
+ @abstractmethod
+ def authenticate(self, token: str) -> bool:
+ """Authenticate a user."""
+ pass
--- /dev/null
+from .local import LocalAuth # noqa
--- /dev/null
+import secrets
+from base64 import b32encode
+from datetime import datetime, timedelta
+
+import pyotp
+from fastapi import Depends, HTTPException, Request
+from fastapi.security import OAuth2PasswordBearer
+from jose import JWTError, jwt
+from qrcode import QRCode
+
+from global_config import GlobalConfig
+from helpers import get_env
+
+from ..base import BaseAuth
+from ..models import Login, Token
+
+global_config = GlobalConfig()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token", auto_error=False)
+
+
+class LocalAuth(BaseAuth):
+ JWT_ALGORITHM = "HS256"
+
+ def __init__(self) -> None:
+ self.username = get_env("FLATNOTES_USERNAME", mandatory=True).lower()
+ self.password = get_env("FLATNOTES_PASSWORD", mandatory=True)
+ self.secret_key = get_env("FLATNOTES_SECRET_KEY", mandatory=True)
+ self.session_expiry_days = get_env(
+ "FLATNOTES_SESSION_EXPIRY_DAYS", default=30, cast_int=True
+ )
+
+ # TOTP
+ if global_config.auth_type == "totp":
+ self.is_totp_enabled = True
+ self.totp_key = get_env("FLATNOTES_TOTP_KEY", mandatory=True)
+ self.totp_key = b32encode(self.totp_key.encode("utf-8"))
+ self.totp = pyotp.TOTP(self.totp_key)
+ self.last_used_totp = None
+ self._display_totp_enrolment()
+
+ def login(self, data: Login) -> Token:
+ # Check Username
+ username_correct = secrets.compare_digest(
+ self.username.lower(), data.username.lower()
+ )
+
+ # Check Password & TOTP
+ expected_password = self.password
+ if self.is_totp_enabled:
+ current_totp = self.totp.now()
+ expected_password += current_totp
+ password_correct = secrets.compare_digest(
+ expected_password, data.password
+ )
+
+ # Raise error if incorrect
+ if not (
+ username_correct
+ and password_correct
+ # Prevent TOTP from being reused
+ and (
+ self.is_totp_enabled is False
+ or current_totp != self.last_used_totp
+ )
+ ):
+ raise ValueError("Incorrect login credentials.")
+ if self.is_totp_enabled:
+ self.last_used_totp = current_totp
+
+ # Create Token
+ access_token = self._create_access_token(data={"sub": self.username})
+ return Token(access_token=access_token)
+
+ def authenticate(
+ self, request: Request, token: str = Depends(oauth2_scheme)
+ ):
+ # If no token is found in the header, check the cookies
+ if token is None:
+ token = request.cookies.get("token")
+ # Validate the token
+ try:
+ self._validate_token(token)
+ except (JWTError, ValueError):
+ raise HTTPException(
+ status_code=401,
+ detail="Invalid authentication credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ def _validate_token(self, token: str) -> bool:
+ if token is None:
+ raise ValueError
+ payload = jwt.decode(
+ token, self.secret_key, algorithms=[self.JWT_ALGORITHM]
+ )
+ username = payload.get("sub")
+ if username is None or username.lower() != self.username:
+ raise ValueError
+
+ def _create_access_token(self, data: dict):
+ to_encode = data.copy()
+ expiry_datetime = datetime.utcnow() + timedelta(
+ days=self.session_expiry_days
+ )
+ to_encode.update({"exp": expiry_datetime})
+ encoded_jwt = jwt.encode(
+ to_encode, self.secret_key, algorithm=self.JWT_ALGORITHM
+ )
+ return encoded_jwt
+
+ def _display_totp_enrolment(self):
+ uri = self.totp.provisioning_uri(
+ issuer_name="flatnotes", name=self.username
+ )
+ qr = QRCode()
+ qr.add_data(uri)
+ print(
+ "\nScan this QR code with your TOTP app of choice",
+ "e.g. Authy or Google Authenticator:",
+ )
+ qr.print_ascii()
+ print(
+ f"Or manually enter this key: {self.totp.secret.decode('utf-8')}\n"
+ )
--- /dev/null
+from pydantic import BaseModel, Field
+
+from helpers import CustomBaseModel
+
+
+class Login(CustomBaseModel):
+ username: str
+ password: str
+
+
+class Token(BaseModel):
+ # Note: OAuth requires keys to be snake_case so we use the standard
+ # BaseModel here
+ access_token: str
+ token_type: str = Field("bearer")
+++ /dev/null
-import os
-import sys
-from base64 import b32encode
-from enum import Enum
-
-from logger import logger
-
-
-class AuthType(str, Enum):
- NONE = "none"
- READ_ONLY = "read_only"
- PASSWORD = "password"
- TOTP = "totp"
-
-
-class Config:
- def __init__(self) -> None:
- self.data_path = self.get_data_path()
-
- self.auth_type = self.get_auth_type()
-
- self.username = self.get_username()
- self.password = self.get_password()
-
- self.session_key = self.get_session_key()
- self.session_expiry_days = self.get_session_expiry_days()
-
- self.totp_key = self.get_totp_key()
-
- @classmethod
- def get_env(cls, key, mandatory=False, default=None, cast_int=False):
- """Get an environment variable."""
- value = os.environ.get(key)
- if mandatory and not value:
- logger.error(f"Environment variable {key} must be set.")
- sys.exit(1)
- if not mandatory and not value:
- return default
- if cast_int:
- try:
- value = int(value)
- except (TypeError, ValueError):
- logger.error(f"Invalid value '{value}' for {key}.")
- sys.exit(1)
- return value
-
- def get_data_path(self):
- return self.get_env("FLATNOTES_PATH", mandatory=True)
-
- def get_auth_type(self):
- key = "FLATNOTES_AUTH_TYPE"
- auth_type = self.get_env(
- key, mandatory=False, default=AuthType.PASSWORD.value
- )
- try:
- auth_type = AuthType(auth_type.lower())
- except ValueError:
- logger.error(
- f"Invalid value '{auth_type}' for {key}. "
- + "Must be one of: "
- + ", ".join([auth_type.value for auth_type in AuthType])
- + "."
- )
- sys.exit(1)
- return auth_type
-
- def get_username(self):
- return self.get_env(
- "FLATNOTES_USERNAME",
- mandatory=self.auth_type
- not in [AuthType.NONE, AuthType.READ_ONLY],
- )
-
- def get_password(self):
- return self.get_env(
- "FLATNOTES_PASSWORD",
- mandatory=self.auth_type
- not in [AuthType.NONE, AuthType.READ_ONLY],
- )
-
- def get_session_key(self):
- return self.get_env(
- "FLATNOTES_SECRET_KEY",
- mandatory=self.auth_type
- not in [AuthType.NONE, AuthType.READ_ONLY],
- )
-
- def get_session_expiry_days(self):
- return self.get_env(
- "FLATNOTES_SESSION_EXPIRY_DAYS",
- mandatory=False,
- default=30,
- cast_int=True,
- )
-
- def get_totp_key(self):
- totp_key = self.get_env(
- "FLATNOTES_TOTP_KEY", mandatory=self.auth_type == AuthType.TOTP
- )
- if totp_key:
- totp_key = b32encode(totp_key.encode("utf-8"))
- return totp_key
-
-
-config = Config()
+++ /dev/null
-from fastapi.responses import JSONResponse
-
-filename_exists_response = JSONResponse(
- content={"message": "The specified filename already exists."},
- status_code=409,
-)
-
-invalid_filename_response = JSONResponse(
- content={"message": "The specified filename contains invalid characters."},
- status_code=400,
-)
-
-note_not_found_response = JSONResponse(
- content={"message": "The specified note cannot be found."}, status_code=404
-)
--- /dev/null
+import sys
+from enum import Enum
+
+from helpers import CustomBaseModel, get_env
+from logger import logger
+
+
+class GlobalConfig:
+ def __init__(self) -> None:
+ logger.debug("Loading global config...")
+ self.auth_type: AuthType = self._load_auth_type()
+
+ def load_auth(self):
+ if self.auth_type in (AuthType.NONE, AuthType.READ_ONLY):
+ return None
+ elif self.auth_type in (AuthType.PASSWORD, AuthType.TOTP):
+ from auth.local import LocalAuth
+
+ return LocalAuth()
+
+ def load_note_storage(self):
+ from notes.file_system import FileSystemNotes
+
+ return FileSystemNotes()
+
+ def load_attachment_storage(self):
+ from attachments.file_system import (
+ FileSystemAttachments,
+ )
+
+ return FileSystemAttachments()
+
+ def _load_auth_type(self):
+ key = "FLATNOTES_AUTH_TYPE"
+ auth_type = get_env(
+ key, mandatory=False, default=AuthType.PASSWORD.value
+ )
+ try:
+ auth_type = AuthType(auth_type.lower())
+ except ValueError:
+ logger.error(
+ f"Invalid value '{auth_type}' for {key}. "
+ + "Must be one of: "
+ + ", ".join([auth_type.value for auth_type in AuthType])
+ + "."
+ )
+ sys.exit(1)
+ return auth_type
+
+
+class AuthType(str, Enum):
+ NONE = "none"
+ READ_ONLY = "read_only"
+ PASSWORD = "password"
+ TOTP = "totp"
+
+
+class GlobalConfigResponseModel(CustomBaseModel):
+ auth_type: AuthType
import os
-import re
-import shutil
-from typing import List, Tuple
+import sys
+from pydantic import BaseModel
-def strip_ext(filename):
- return os.path.splitext(filename)[0]
+from logger import logger
def camel_case(snake_case_str: str) -> str:
return parts[0] + "".join(part.title() for part in parts[1:])
-def empty_dir(path):
- for item in os.listdir(path):
- item_path = os.path.join(path, item)
- if os.path.isfile(item_path):
- os.remove(item_path)
- elif os.path.isdir(item_path):
- shutil.rmtree(item_path)
-
-
-def re_extract(pattern, string) -> Tuple[str, List[str]]:
- """Similar to re.sub but returns a tuple of:
-
- - `string` with matches removed
- - list of matches"""
- matches = []
- text = re.sub(pattern, lambda tag: matches.append(tag.group()), string)
- return (text, matches)
-
-
-def is_valid_filename(filename):
- r"""Return False if the declared filename contains any of the following
- characters: <>:"/\|?*"""
+def is_valid_filename(value):
+ """Raise ValueError if the declared string contains any of the following
+ characters: <>:"/\\|?*"""
invalid_chars = r'<>:"/\|?*'
- return not any(invalid_char in filename for invalid_char in invalid_chars)
+ if any(invalid_char in value for invalid_char in invalid_chars):
+ raise ValueError(
+ "title cannot include any of the following characters: "
+ + invalid_chars
+ )
+ return value
+
+
+def strip_whitespace(value):
+ """Return the declared string with leading and trailing whitespace
+ removed."""
+ return value.strip()
+
+
+def get_env(key, mandatory=False, default=None, cast_int=False):
+ """Get an environment variable. If `mandatory` is True and environment
+ variable isn't set, exit the program"""
+ value = os.environ.get(key)
+ if mandatory and not value:
+ logger.error(f"Environment variable {key} must be set.")
+ sys.exit(1)
+ if not mandatory and not value:
+ return default
+ if cast_int:
+ try:
+ value = int(value)
+ except (TypeError, ValueError):
+ logger.error(f"Invalid value '{value}' for {key}.")
+ sys.exit(1)
+ return value
+
+
+class CustomBaseModel(BaseModel):
+ class Config:
+ alias_generator = camel_case
+ populate_by_name = True
+ from_attributes = True
-import os
-import secrets
-import shutil
-from typing import List, Literal, Union
+from typing import List, Literal
-import pyotp
from fastapi import Depends, FastAPI, HTTPException, UploadFile
-from fastapi.responses import FileResponse, HTMLResponse
+from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
-from qrcode import QRCode
-
-from auth import create_access_token, no_auth, validate_token
-from config import AuthType, config
-from error_responses import (
- filename_exists_response,
- invalid_filename_response,
- note_not_found_response,
-)
-from flatnotes import Flatnotes, InvalidTitleError, Note
-from helpers import is_valid_filename
-from models import (
- ConfigModel,
- LoginModel,
- NoteContentResponseModel,
- NotePatchModel,
- NotePostModel,
- NoteResponseModel,
- SearchResultModel,
- TokenModel,
-)
-
-ATTACHMENTS_DIR = os.path.join(config.data_path, "attachments")
-os.makedirs(ATTACHMENTS_DIR, exist_ok=True)
+import api_messages
+from attachments.base import BaseAttachments
+from auth.base import BaseAuth
+from auth.models import Login, Token
+from global_config import AuthType, GlobalConfig, GlobalConfigResponseModel
+from notes.base import BaseNotes
+from notes.models import Note, NoteCreate, NoteUpdate, SearchResult
+
+global_config = GlobalConfig()
+auth: BaseAuth = global_config.load_auth()
+note_storage: BaseNotes = global_config.load_note_storage()
+attachment_storage: BaseAttachments = global_config.load_attachment_storage()
+auth_deps = [Depends(auth.authenticate)] if auth else []
app = FastAPI()
-flatnotes = Flatnotes(config.data_path)
-
-totp = (
- pyotp.TOTP(config.totp_key) if config.auth_type == AuthType.TOTP else None
-)
-last_used_totp = None
-
-if config.auth_type in [AuthType.NONE, AuthType.READ_ONLY]:
- authenticate = no_auth
-else:
- authenticate = validate_token
-
-# Display TOTP QR code
-if config.auth_type == AuthType.TOTP:
- uri = totp.provisioning_uri(issuer_name="flatnotes", name=config.username)
- qr = QRCode()
- qr.add_data(uri)
- print(
- "\nScan this QR code with your TOTP app of choice",
- "e.g. Authy or Google Authenticator:",
- )
- qr.print_ascii()
- print(f"Or manually enter this key: {totp.secret.decode('utf-8')}\n")
-
-if config.auth_type not in [AuthType.NONE, AuthType.READ_ONLY]:
-
- @app.post("/api/token", response_model=TokenModel)
- def token(data: LoginModel):
- global last_used_totp
-
- username_correct = secrets.compare_digest(
- config.username.lower(), data.username.lower()
- )
-
- expected_password = config.password
- if config.auth_type == AuthType.TOTP:
- current_totp = totp.now()
- expected_password += current_totp
- password_correct = secrets.compare_digest(
- expected_password, data.password
- )
-
- if not (
- username_correct
- and password_correct
- # Prevent TOTP from being reused
- and (
- config.auth_type != AuthType.TOTP
- or current_totp != last_used_totp
- )
- ):
- raise HTTPException(
- status_code=400, detail="Incorrect login credentials."
- )
-
- access_token = create_access_token(data={"sub": config.username})
- if config.auth_type == AuthType.TOTP:
- last_used_totp = current_totp
- return TokenModel(access_token=access_token)
+# region UI
@app.get("/", include_in_schema=False)
@app.get("/login", include_in_schema=False)
@app.get("/search", include_in_schema=False)
return HTMLResponse(content=html)
-if config.auth_type != AuthType.READ_ONLY:
+# endregion
- @app.post(
- "/api/notes",
- dependencies=[Depends(authenticate)],
- response_model=NoteContentResponseModel,
- )
- def post_note(data: NotePostModel):
- """Create a new note."""
+
+# region Login
+if global_config.auth_type not in [AuthType.NONE, AuthType.READ_ONLY]:
+
+ @app.post("/api/token", response_model=Token)
+ def token(data: Login):
try:
- note = Note(flatnotes, data.title, new=True)
- note.content = data.content
- return NoteContentResponseModel.model_validate(note)
- except InvalidTitleError:
- return invalid_filename_response
- except FileExistsError:
- return filename_exists_response
+ return auth.login(data)
+ except ValueError:
+ raise HTTPException(
+ status_code=401, detail=api_messages.login_failed
+ )
+# endregion
+
+
+# region Notes
+# Get Note
@app.get(
"/api/notes/{title}",
- dependencies=[Depends(authenticate)],
- response_model=Union[NoteContentResponseModel, NoteResponseModel],
+ dependencies=auth_deps,
+ response_model=Note,
)
-def get_note(
- title: str,
- include_content: bool = True,
-):
+def get_note(title: str):
"""Get a specific note."""
try:
- note = Note(flatnotes, title)
- if include_content:
- return NoteContentResponseModel.model_validate(note)
- else:
- return NoteResponseModel.model_validate(note)
- except InvalidTitleError:
- return invalid_filename_response
+ return note_storage.get(title)
+ except ValueError:
+ raise HTTPException(
+ status_code=400, detail=api_messages.invalid_note_title
+ )
except FileNotFoundError:
- return note_not_found_response
+ raise HTTPException(404, api_messages.note_not_found)
-if config.auth_type != AuthType.READ_ONLY:
+if global_config.auth_type != AuthType.READ_ONLY:
+ # Create Note
+ @app.post(
+ "/api/notes",
+ dependencies=auth_deps,
+ response_model=Note,
+ )
+ def post_note(note: NoteCreate):
+ """Create a new note."""
+ try:
+ return note_storage.create(note)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=api_messages.invalid_note_title,
+ )
+ except FileExistsError:
+ raise HTTPException(
+ status_code=409, detail=api_messages.note_exists
+ )
+
+ # Update Note
@app.patch(
"/api/notes/{title}",
- dependencies=[Depends(authenticate)],
- response_model=NoteContentResponseModel,
+ dependencies=auth_deps,
+ response_model=Note,
)
- def patch_note(title: str, new_data: NotePatchModel):
+ def patch_note(title: str, data: NoteUpdate):
try:
- note = Note(flatnotes, title)
- if new_data.new_title is not None:
- note.title = new_data.new_title
- if new_data.new_content is not None:
- note.content = new_data.new_content
- return NoteContentResponseModel.model_validate(note)
- except InvalidTitleError:
- return invalid_filename_response
+ return note_storage.update(title, data)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=api_messages.invalid_note_title,
+ )
except FileExistsError:
- return filename_exists_response
+ raise HTTPException(
+ status_code=409, detail=api_messages.note_exists
+ )
except FileNotFoundError:
- return note_not_found_response
-
-
-if config.auth_type != AuthType.READ_ONLY:
+ raise HTTPException(404, api_messages.note_not_found)
+ # Delete Note
@app.delete(
"/api/notes/{title}",
- dependencies=[Depends(authenticate)],
+ dependencies=auth_deps,
response_model=None,
)
def delete_note(title: str):
try:
- note = Note(flatnotes, title)
- note.delete()
- except InvalidTitleError:
- return invalid_filename_response
+ note_storage.delete(title)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=api_messages.invalid_note_title,
+ )
except FileNotFoundError:
- return note_not_found_response
+ raise HTTPException(404, api_messages.note_not_found)
-@app.get(
- "/api/tags",
- dependencies=[Depends(authenticate)],
- response_model=List[str],
-)
-def get_tags():
- """Get a list of all indexed tags."""
- return flatnotes.get_tags()
+# endregion
+# region Search
@app.get(
"/api/search",
- dependencies=[Depends(authenticate)],
- response_model=List[SearchResultModel],
+ dependencies=auth_deps,
+ response_model=List[SearchResult],
)
def search(
term: str,
"""Perform a full text search on all notes."""
if sort == "lastModified":
sort = "last_modified"
- return [
- SearchResultModel.model_validate(note_hit)
- for note_hit in flatnotes.search(
- term, sort=sort, order=order, limit=limit
- )
- ]
+ return note_storage.search(term, sort=sort, order=order, limit=limit)
-@app.get("/api/config", response_model=ConfigModel)
+@app.get(
+ "/api/tags",
+ dependencies=auth_deps,
+ response_model=List[str],
+)
+def get_tags():
+ """Get a list of all indexed tags."""
+ return note_storage.get_tags()
+
+
+# endregion
+
+
+# region Config
+@app.get("/api/config", response_model=GlobalConfigResponseModel)
def get_config():
"""Retrieve server-side config required for the UI."""
- return ConfigModel.model_validate(config)
+ return GlobalConfigResponseModel(auth_type=global_config.auth_type)
-if config.auth_type != AuthType.READ_ONLY:
-
- @app.post(
- "/api/attachments",
- dependencies=[Depends(authenticate)],
- response_model=None,
- )
- def post_attachment(file: UploadFile):
- """Upload an attachment."""
- if not is_valid_filename(file.filename):
- return invalid_filename_response
- filepath = os.path.join(ATTACHMENTS_DIR, file.filename)
- if os.path.exists(filepath):
- return filename_exists_response
- with open(filepath, "wb") as f:
- shutil.copyfileobj(file.file, f)
+# endregion
+# region Attachments
+# Get Attachment
@app.get(
"/attachments/{filename}",
- dependencies=[Depends(authenticate)],
+ dependencies=auth_deps,
include_in_schema=False,
)
def get_attachment(filename: str):
"""Download an attachment."""
- if not is_valid_filename(filename):
- raise HTTPException(status_code=400, detail="Invalid filename.")
- filepath = os.path.join(ATTACHMENTS_DIR, filename)
- if not os.path.isfile(filepath):
- raise HTTPException(status_code=404, detail="File not found.")
- return FileResponse(filepath)
+ try:
+ return attachment_storage.get(filename)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=api_messages.invalid_attachment_filename,
+ )
+ except FileNotFoundError:
+ raise HTTPException(
+ status_code=404, detail=api_messages.attachment_not_found
+ )
+
+
+if global_config.auth_type != AuthType.READ_ONLY:
+
+ # Create Attachment
+ @app.post(
+ "/api/attachments",
+ dependencies=auth_deps,
+ response_model=None,
+ )
+ def post_attachment(file: UploadFile):
+ """Upload an attachment."""
+ try:
+ return attachment_storage.create(file)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=api_messages.invalid_attachment_filename,
+ )
+ except FileExistsError:
+ raise HTTPException(409, api_messages.attachment_exists)
+
+# endregion
app.mount("/", StaticFiles(directory="client/dist"), name="dist")
+++ /dev/null
-from typing import List, Optional
-
-from pydantic import BaseModel, Field
-
-from config import AuthType
-from helpers import camel_case
-
-
-class TokenModel(BaseModel):
- # Use of BaseModel instead of CustomBaseModel is intentional as OAuth
- # requires keys to be snake_case
- access_token: str
- token_type: str = Field("bearer")
-
-
-class CustomBaseModel(BaseModel):
- class Config:
- alias_generator = camel_case
- populate_by_name = True
- from_attributes = True
-
-
-class LoginModel(CustomBaseModel):
- username: str
- password: str
-
-
-class NotePostModel(CustomBaseModel):
- title: str
- content: Optional[str] = Field(None)
-
-
-class NoteResponseModel(CustomBaseModel):
- title: str
- last_modified: float
-
-
-class NoteContentResponseModel(NoteResponseModel):
- content: Optional[str] = Field(None)
-
-
-class NotePatchModel(CustomBaseModel):
- new_title: Optional[str] = Field(None)
- new_content: Optional[str] = Field(None)
-
-
-class SearchResultModel(CustomBaseModel):
- score: Optional[float] = Field(None)
- title: str
- last_modified: float
- title_highlights: Optional[str] = Field(None)
- content_highlights: Optional[str] = Field(None)
- tag_matches: Optional[List[str]] = Field(None)
-
-
-class ConfigModel(CustomBaseModel):
- auth_type: AuthType
--- /dev/null
+from abc import ABC, abstractmethod
+from typing import Literal
+
+from .models import Note, NoteCreate, NoteUpdate, SearchResult
+
+
+class BaseNotes(ABC):
+ @abstractmethod
+ def create(self, data: NoteCreate) -> Note:
+ """Create a new note."""
+ pass
+
+ @abstractmethod
+ def get(self, title: str) -> Note:
+ """Get a specific note."""
+ pass
+
+ @abstractmethod
+ def update(self, title: str, new_data: NoteUpdate) -> Note:
+ """Update a specific note."""
+ pass
+
+ @abstractmethod
+ def delete(self, title: str) -> None:
+ """Delete a specific note.""" ""
+ pass
+
+ @abstractmethod
+ def search(
+ self,
+ term: str,
+ sort: Literal["score", "title", "last_modified"] = "score",
+ order: Literal["asc", "desc"] = "desc",
+ limit: int = None,
+ ) -> list[SearchResult]:
+ """Search for notes."""
+ pass
+
+ @abstractmethod
+ def get_tags(self) -> list[str]:
+ """Get a list of all indexed tags."""
+ pass
--- /dev/null
+from .file_system import FileSystemNotes
import glob
import os
import re
+import shutil
from datetime import datetime
from typing import List, Literal, Set, Tuple
from whoosh.searching import Hit
from whoosh.support.charset import accent_map
-from helpers import empty_dir, is_valid_filename, re_extract, strip_ext
+from helpers import get_env, is_valid_filename
from logger import logger
+from ..base import BaseNotes
+from ..models import Note, NoteCreate, NoteUpdate, SearchResult
+
MARKDOWN_EXT = ".md"
INDEX_SCHEMA_VERSION = "4"
tags = KEYWORD(lowercase=True, field_boost=2.0)
-class InvalidTitleError(Exception):
- def __init__(self, message="The specified title is invalid"):
- self.message = message
- super().__init__(self.message)
-
-
-class Note:
- def __init__(
- self, flatnotes: "Flatnotes", title: str, new: bool = False
- ) -> None:
- self._flatnotes = flatnotes
- self._title = title.strip()
- if not is_valid_filename(self._title):
- raise InvalidTitleError
- exists = os.path.exists(self.filepath)
- if new and exists:
- raise FileExistsError
- if not new and not exists:
- raise FileNotFoundError
- if new:
- open(self.filepath, "w").close()
-
- @property
- def filepath(self):
- return os.path.join(self._flatnotes.dir, self.filename)
-
- @property
- def filename(self):
- return self._title + MARKDOWN_EXT
-
- @property
- def last_modified(self):
- return os.path.getmtime(self.filepath)
+class FileSystemNotes(BaseNotes):
+ TAGS_RE = re.compile(r"(?:(?<=^#)|(?<=\s#))\w+(?=\s|$)")
+ CODEBLOCK_RE = re.compile(r"`{1,3}.*?`{1,3}", re.DOTALL)
+ TAGS_WITH_HASH_RE = re.compile(r"(?:(?<=^)|(?<=\s))#\w+(?=\s|$)")
- # Editable Properties
- @property
- def title(self):
- return self._title
-
- @title.setter
- def title(self, new_title):
- new_title = new_title.strip()
- if not is_valid_filename(new_title):
- raise InvalidTitleError
- new_filepath = os.path.join(
- self._flatnotes.dir, new_title + MARKDOWN_EXT
+ def __init__(self):
+ self.storage_path = get_env("FLATNOTES_PATH", mandatory=True)
+ if not os.path.exists(self.storage_path):
+ raise NotADirectoryError(
+ f"'{self.storage_path}' is not a valid directory."
+ )
+ self.index = self._load_index()
+ self._sync_index()
+
+ def create(self, data: NoteCreate) -> Note:
+ """Create a new note."""
+ filepath = self._path_from_title(data.title)
+ self._write_file(filepath, data.content)
+ return Note(
+ title=data.title,
+ content=data.content,
+ last_modified=os.path.getmtime(filepath),
)
- os.rename(self.filepath, new_filepath)
- self._title = new_title
-
- @property
- def content(self):
- with open(self.filepath, "r", encoding="utf-8") as f:
- return f.read()
-
- @content.setter
- def content(self, new_content):
- if not os.path.exists(self.filepath):
- raise FileNotFoundError
- with open(self.filepath, "w", encoding="utf-8") as f:
- f.write(new_content)
-
- def delete(self):
- os.remove(self.filepath)
-
-
-class SearchResult(Note):
- def __init__(self, flatnotes: "Flatnotes", hit: Hit) -> None:
- super().__init__(flatnotes, strip_ext(hit["filename"]))
- self._matched_fields = self._get_matched_fields(hit.matched_terms())
- # If the search was ordered using a text field then hit.score is the
- # value of that field. This isn't useful so only set self._score if it
- # is a float.
- self._score = hit.score if type(hit.score) is float else None
-
- if "title" in self._matched_fields:
- hit.results.fragmenter = WholeFragmenter()
- self._title_highlights = hit.highlights("title", text=self.title)
- else:
- self._title_highlights = None
+ def get(self, title: str) -> Note:
+ """Get a specific note."""
+ is_valid_filename(title)
+ filepath = self._path_from_title(title)
+ content = self._read_file(filepath)
+ return Note(
+ title=title,
+ content=content,
+ last_modified=os.path.getmtime(filepath),
+ )
- if "content" in self._matched_fields:
- hit.results.fragmenter = ContextFragmenter()
- content_ex_tags, _ = Flatnotes.extract_tags(self.content)
- self._content_highlights = hit.highlights(
- "content",
- text=content_ex_tags,
- )
+ def update(self, title: str, data: NoteUpdate) -> Note:
+ """Update a specific note."""
+ is_valid_filename(title)
+ filepath = self._path_from_title(title)
+ if data.new_title is not None:
+ new_filepath = self._path_from_title(data.new_title)
+ os.rename(filepath, new_filepath)
+ title = data.new_title
+ filepath = new_filepath
+ if data.new_content is not None:
+ self._write_file(filepath, data.new_content, overwrite=True)
+ content = data.new_content
else:
- self._content_highlights = None
-
- self._tag_matches = (
- [field[1] for field in hit.matched_terms() if field[0] == "tags"]
- if "tags" in self._matched_fields
- else None
+ content = self._read_file(filepath)
+ return Note(
+ title=title,
+ content=content,
+ last_modified=os.path.getmtime(filepath),
)
- @property
- def score(self):
- return self._score
-
- @property
- def title_highlights(self):
- return self._title_highlights
+ def delete(self, title: str) -> None:
+ """Delete a specific note."""
+ is_valid_filename(title)
+ filepath = self._path_from_title(title)
+ os.remove(filepath)
- @property
- def content_highlights(self):
- return self._content_highlights
+ def search(
+ self,
+ term: str,
+ sort: Literal["score", "title", "last_modified"] = "score",
+ order: Literal["asc", "desc"] = "desc",
+ limit: int = None,
+ ) -> Tuple[SearchResult, ...]:
+ """Search the index for the given term."""
+ self._sync_index()
+ term = self._pre_process_search_term(term)
+ with self.index.searcher() as searcher:
+ # Parse Query
+ if term == "*":
+ query = Every()
+ else:
+ parser = MultifieldParser(
+ ["title", "content", "tags"], self.index.schema
+ )
+ parser.add_plugin(DateParserPlugin())
+ query = parser.parse(term)
- @property
- def tag_matches(self):
- return self._tag_matches
+ # Determine Sort By
+ # Note: For the 'sort' option, "score" is converted to None as
+ # that is the default for searches anyway and it's quicker for
+ # Whoosh if you specify None.
+ sort = sort if sort in ["title", "last_modified"] else None
- @staticmethod
- def _get_matched_fields(matched_terms):
- """Return a set of matched fields from a set of ('field', 'term') "
- "tuples generated by whoosh.searching.Hit.matched_terms()."""
- return set([matched_term[0] for matched_term in matched_terms])
+ # Determine Sort Direction
+ # Note: Confusingly, when sorting by 'score', reverse = True means
+ # asc so we have to flip the logic for that case!
+ reverse = order == "desc"
+ if sort is None:
+ reverse = not reverse
+ # Run Search
+ results = searcher.search(
+ query,
+ sortedby=sort,
+ reverse=reverse,
+ limit=limit,
+ terms=True,
+ )
+ return tuple(self._search_result_from_hit(hit) for hit in results)
-class Flatnotes(object):
- TAGS_RE = re.compile(r"(?:(?<=^#)|(?<=\s#))\w+(?=\s|$)")
- CODEBLOCK_RE = re.compile(r"`{1,3}.*?`{1,3}", re.DOTALL)
- TAGS_WITH_HASH_RE = re.compile(r"(?:(?<=^)|(?<=\s))#\w+(?=\s|$)")
+ def get_tags(self) -> list[str]:
+ """Return a list of all indexed tags."""
+ self._sync_index()
+ with self.index.reader() as reader:
+ tags = reader.field_terms("tags")
+ return [tag for tag in tags]
- def __init__(self, dir: str) -> None:
- if not os.path.exists(dir):
- raise NotADirectoryError(f"'{dir}' is not a valid directory.")
- self.dir = dir
+ @property
+ def _index_path(self):
+ return os.path.join(self.storage_path, ".flatnotes")
- self.index = self._load_index()
- self.update_index()
+ def _path_from_title(self, title: str) -> str:
+ return os.path.join(self.storage_path, title + MARKDOWN_EXT)
- @property
- def index_dir(self):
- return os.path.join(self.dir, ".flatnotes")
+ def _get_by_filename(self, filename: str) -> Note:
+ """Get a note by its filename."""
+ return self.get(self._strip_ext(filename))
def _load_index(self) -> Index:
"""Load the note index or create new if not exists."""
- index_dir_exists = os.path.exists(self.index_dir)
+ index_dir_exists = os.path.exists(self._index_path)
if index_dir_exists and whoosh.index.exists_in(
- self.index_dir, indexname=INDEX_SCHEMA_VERSION
+ self._index_path, indexname=INDEX_SCHEMA_VERSION
):
logger.info("Loading existing index")
return whoosh.index.open_dir(
- self.index_dir, indexname=INDEX_SCHEMA_VERSION
+ self._index_path, indexname=INDEX_SCHEMA_VERSION
)
else:
if index_dir_exists:
logger.info("Deleting outdated index")
- empty_dir(self.index_dir)
+ self._clear_dir(self._index_path)
else:
- os.mkdir(self.index_dir)
+ os.mkdir(self._index_path)
logger.info("Creating new index")
return whoosh.index.create_in(
- self.index_dir, IndexSchema, indexname=INDEX_SCHEMA_VERSION
+ self._index_path, IndexSchema, indexname=INDEX_SCHEMA_VERSION
)
@classmethod
- def extract_tags(cls, content) -> Tuple[str, Set[str]]:
+ def _extract_tags(cls, content) -> Tuple[str, Set[str]]:
"""Strip tags from the given content and return a tuple consisting of:
- The content without the tags.
- A set of tags converted to lowercase."""
content_ex_codeblock = re.sub(cls.CODEBLOCK_RE, "", content)
- _, tags = re_extract(cls.TAGS_RE, content_ex_codeblock)
- content_ex_tags, _ = re_extract(cls.TAGS_RE, content)
+ _, tags = cls._re_extract(cls.TAGS_RE, content_ex_codeblock)
+ content_ex_tags, _ = cls._re_extract(cls.TAGS_RE, content)
try:
tags = [tag.lower() for tag in tags]
return (content_ex_tags, set(tags))
"""Add a Note object to the index using the given writer. If the
filename already exists in the index an update will be performed
instead."""
- content_ex_tags, tag_set = self.extract_tags(note.content)
+ content_ex_tags, tag_set = self._extract_tags(note.content)
tag_string = " ".join(tag_set)
writer.update_document(
- filename=note.filename,
+ filename=note.title + MARKDOWN_EXT,
last_modified=datetime.fromtimestamp(note.last_modified),
title=note.title,
content=content_ex_tags,
tags=tag_string,
)
- def _get_notes(self) -> List[Note]:
- """Return a list containing a Note object for every file in the notes
- directory."""
+ def _list_all_note_filenames(self) -> List[str]:
+ """Return a list of all note filenames."""
return [
- Note(self, strip_ext(os.path.split(filepath)[1]))
+ os.path.split(filepath)[1]
for filepath in glob.glob(
- os.path.join(self.dir, "*" + MARKDOWN_EXT)
+ os.path.join(self.storage_path, "*" + MARKDOWN_EXT)
)
]
- def update_index(self, clean: bool = False) -> None:
+ def _sync_index(self, clean: bool = False) -> None:
"""Synchronize the index with the notes directory.
Specify clean=True to completely rebuild the index"""
indexed = set()
with self.index.searcher() as searcher:
for idx_note in searcher.all_stored_fields():
idx_filename = idx_note["filename"]
- idx_filepath = os.path.join(self.dir, idx_filename)
+ idx_filepath = os.path.join(self.storage_path, idx_filename)
# Delete missing
if not os.path.exists(idx_filepath):
writer.delete_by_term("filename", idx_filename)
):
logger.info(f"'{idx_filename}' updated")
self._add_note_to_index(
- writer, Note(self, strip_ext(idx_filename))
+ writer, self._get_by_filename(idx_filename)
)
indexed.add(idx_filename)
# Ignore already indexed
else:
indexed.add(idx_filename)
# Add new
- for note in self._get_notes():
- if note.filename not in indexed:
- self._add_note_to_index(writer, note)
- logger.info(f"'{note.filename}' added to index")
+ for filename in self._list_all_note_filenames():
+ if filename not in indexed:
+ self._add_note_to_index(
+ writer, self._get_by_filename(filename)
+ )
+ logger.info(f"'{filename}' added to index")
writer.commit()
- def get_tags(self):
- """Return a list of all indexed tags."""
- self.update_index()
- with self.index.reader() as reader:
- tags = reader.field_terms("tags")
- return [tag for tag in tags]
-
- def pre_process_search_term(self, term):
+ @classmethod
+ def _pre_process_search_term(cls, term):
term = term.strip()
# Replace "#tagname" with "tags:tagname"
term = re.sub(
- self.TAGS_WITH_HASH_RE,
+ cls.TAGS_WITH_HASH_RE,
lambda tag: "tags:" + tag.group(0)[1:],
term,
)
return term
- def search(
- self,
- term: str,
- sort: Literal["score", "title", "last_modified"] = "score",
- order: Literal["asc", "desc"] = "desc",
- limit: int = None,
- ) -> Tuple[SearchResult, ...]:
- """Search the index for the given term."""
- self.update_index()
- term = self.pre_process_search_term(term)
- with self.index.searcher() as searcher:
- # Parse Query
- if term == "*":
- query = Every()
- else:
- parser = MultifieldParser(
- ["title", "content", "tags"], self.index.schema
- )
- parser.add_plugin(DateParserPlugin())
- query = parser.parse(term)
+ @staticmethod
+ def _re_extract(pattern, string) -> Tuple[str, List[str]]:
+ """Similar to re.sub but returns a tuple of:
- # Determine Sort By
- # Note: For the 'sort' option, "score" is converted to None as
- # that is the default for searches anyway and it's quicker for
- # Whoosh if you specify None.
- sort = sort if sort in ["title", "last_modified"] else None
+ - `string` with matches removed
+ - list of matches"""
+ matches = []
+ text = re.sub(pattern, lambda tag: matches.append(tag.group()), string)
+ return (text, matches)
- # Determine Sort Direction
- # Note: Confusingly, when sorting by 'score', reverse = True means
- # asc so we have to flip the logic for that case!
- reverse = order == "desc"
- if sort is None:
- reverse = not reverse
+ @staticmethod
+ def _strip_ext(filename):
+ """Return the given filename without the extension."""
+ return os.path.splitext(filename)[0]
- # Run Search
- results = searcher.search(
- query,
- sortedby=sort,
- reverse=reverse,
- limit=limit,
- terms=True,
+ @staticmethod
+ def _clear_dir(path):
+ """Delete all contents of the given directory."""
+ for item in os.listdir(path):
+ item_path = os.path.join(path, item)
+ if os.path.isfile(item_path):
+ os.remove(item_path)
+ elif os.path.isdir(item_path):
+ shutil.rmtree(item_path)
+
+ def _search_result_from_hit(self, hit: Hit):
+ matched_fields = self._get_matched_fields(hit.matched_terms())
+
+ title = self._strip_ext(hit["filename"])
+ last_modified = hit["last_modified"].timestamp()
+
+ # If the search was ordered using a text field then hit.score is the
+ # value of that field. This isn't useful so only set self._score if it
+ # is a float.
+ score = hit.score if type(hit.score) is float else None
+
+ if "title" in matched_fields:
+ hit.results.fragmenter = WholeFragmenter()
+ title_highlights = hit.highlights("title", text=title)
+ else:
+ title_highlights = None
+
+ if "content" in matched_fields:
+ hit.results.fragmenter = ContextFragmenter()
+ content = self._read_file(self._path_from_title(title))
+ content_ex_tags, _ = FileSystemNotes._extract_tags(content)
+ content_highlights = hit.highlights(
+ "content",
+ text=content_ex_tags,
)
- return tuple(SearchResult(self, hit) for hit in results)
+ else:
+ content_highlights = None
+
+ tag_matches = (
+ [field[1] for field in hit.matched_terms() if field[0] == "tags"]
+ if "tags" in matched_fields
+ else None
+ )
+
+ return SearchResult(
+ title=title,
+ last_modified=last_modified,
+ score=score,
+ title_highlights=title_highlights,
+ content_highlights=content_highlights,
+ tag_matches=tag_matches,
+ )
+
+ @staticmethod
+ def _get_matched_fields(matched_terms):
+ """Return a set of matched fields from a set of ('field', 'term') "
+ "tuples generated by whoosh.searching.Hit.matched_terms()."""
+ return set([matched_term[0] for matched_term in matched_terms])
+
+ @staticmethod
+ def _read_file(filepath: str):
+ logger.debug(f"Reading from '{filepath}'")
+ with open(filepath, "r") as f:
+ content = f.read()
+ return content
+
+ @staticmethod
+ def _write_file(filepath: str, content: str, overwrite: bool = False):
+ logger.debug(f"Writing to '{filepath}'")
+ with open(filepath, "w" if overwrite else "x") as f:
+ f.write(content)
--- /dev/null
+from typing import List, Optional
+
+from pydantic import Field
+from pydantic.functional_validators import AfterValidator
+from typing_extensions import Annotated
+
+from helpers import CustomBaseModel, is_valid_filename, strip_whitespace
+
+
+class NoteBase(CustomBaseModel):
+ title: str
+
+
+class NoteCreate(CustomBaseModel):
+ title: Annotated[
+ str,
+ AfterValidator(strip_whitespace),
+ AfterValidator(is_valid_filename),
+ ]
+ content: Optional[str] = Field(None)
+
+
+class Note(CustomBaseModel):
+ title: str
+ content: Optional[str] = Field(None)
+ last_modified: float
+
+
+class NoteUpdate(CustomBaseModel):
+ new_title: Annotated[
+ Optional[str],
+ AfterValidator(strip_whitespace),
+ AfterValidator(is_valid_filename),
+ ] = Field(None)
+ new_content: Optional[str] = Field(None)
+
+
+class SearchResult(CustomBaseModel):
+ title: str
+ last_modified: float
+
+ score: Optional[float] = Field(None)
+ title_highlights: Optional[str] = Field(None)
+ content_highlights: Optional[str] = Field(None)
+ tag_matches: Optional[List[str]] = Field(None)