From: Stefan Gasser Date: Mon, 19 Jan 2026 17:30:22 +0000 (+0100) Subject: Refactor for multi-provider architecture (#49) X-Git-Url: http://git.99rst.org/?a=commitdiff_plain;h=879fd9da06fb80149cb6961dbabe7755f948e84b;p=sgasser-llm-shield.git Refactor for multi-provider architecture (#49) Reorganizes the codebase to support multiple LLM providers with a clean, extensible architecture. This is a preparatory refactor that improves code organization without adding new provider support. Architecture changes: - Move masking utilities to src/masking/ (conflict-resolver, placeholders, context) - Add provider-specific directories: src/providers/openai/ - Create shared provider utilities: src/providers/errors.ts, src/routes/utils.ts - Extract OpenAI-specific code to src/masking/extractors/openai.ts - Add service layer: src/services/pii.ts, src/services/secrets.ts - Move stream transformer to provider directory New patterns: - Provider-agnostic text extraction with TextExtractor interface - Shared error handling with ProviderError class - Centralized timeout constants in src/constants/timeouts.ts - Unified logging helpers in src/routes/utils.ts Removed: - src/services/decision.ts (logic moved to service layer) - src/providers/openai-client.ts (replaced by src/providers/openai/client.ts) All 219 tests pass. --- diff --git a/CLAUDE.md b/CLAUDE.md index 68bedcb..74d7d60 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -18,22 +18,47 @@ OpenAI-compatible proxy with two privacy modes: route to local LLM or mask PII f src/ ├── index.ts # Hono server entry ├── config.ts # YAML config + Zod validation +├── constants/ # Shared constants +│ ├── languages.ts # Supported languages +│ └── timeouts.ts # HTTP timeout values ├── routes/ -│ ├── proxy.ts # /openai/v1/* (chat completions + wildcard proxy) +│ ├── openai.ts # /openai/v1/* (chat completions + wildcard proxy) │ ├── dashboard.tsx # Dashboard routes + API │ ├── health.ts # GET /health -│ └── info.ts # GET /info -├── views/ -│ └── dashboard/ -│ └── page.tsx # Dashboard UI -└── services/ - ├── decision.ts # Route/mask logic - ├── pii-detector.ts # Presidio client - ├── llm-client.ts # OpenAI/Ollama client - ├── masking.ts # PII mask/unmask - ├── stream-transformer.ts # SSE unmask for streaming - ├── language-detector.ts # Auto language detection - └── logger.ts # SQLite logging +│ ├── info.ts # GET /info +│ └── utils.ts # Shared route utilities +├── providers/ +│ ├── errors.ts # Shared provider errors +│ ├── local.ts # Local LLM client (Ollama/OpenAI-compatible) +│ └── openai/ +│ ├── client.ts # OpenAI API client +│ ├── stream-transformer.ts # SSE unmasking for streaming +│ └── types.ts # OpenAI request/response types +├── masking/ +│ ├── service.ts # Masking orchestration +│ ├── context.ts # Masking context management +│ ├── placeholders.ts # Placeholder generation +│ ├── conflict-resolver.ts # Overlapping entity resolution +│ ├── types.ts # Shared masking types +│ └── extractors/ +│ └── openai.ts # OpenAI text extraction/insertion +├── pii/ +│ ├── detect.ts # Presidio client +│ └── mask.ts # PII masking logic +├── secrets/ +│ ├── detect.ts # Secret detection +│ ├── mask.ts # Secret masking +│ └── patterns/ # Secret pattern definitions +├── services/ +│ ├── pii.ts # PII detection service +│ ├── secrets.ts # Secrets processing service +│ ├── language-detector.ts # Auto language detection +│ └── logger.ts # SQLite logging +├── utils/ +│ └── content.ts # Content utilities +└── views/ + └── dashboard/ + └── page.tsx # Dashboard UI ``` Tests are colocated (`*.test.ts`). diff --git a/config.example.yaml b/config.example.yaml index ec1b4e3..6a660d1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -12,8 +12,8 @@ server: port: 3000 host: "0.0.0.0" -# Providers - OpenAI-compatible API endpoints -# Can be cloud (OpenAI, Azure) or self-hosted (vLLM, LiteLLM proxy, etc.) +# Providers - API endpoints +# Can be cloud (OpenAI, Anthropic, Azure) or self-hosted (vLLM, LiteLLM proxy, etc.) providers: # OpenAI-compatible endpoint (required) # The proxy forwards your client's Authorization header @@ -21,6 +21,12 @@ providers: base_url: https://api.openai.com/v1 # api_key: ${OPENAI_API_KEY} # Optional fallback if client doesn't send auth header + # Anthropic endpoint (optional) + # Enable to use /anthropic/v1/messages endpoint + anthropic: + base_url: https://api.anthropic.com + # api_key: ${ANTHROPIC_API_KEY} # Optional fallback if client doesn't send auth header + # Local provider - only used when mode: route # Supports: ollama (native), openai (for vLLM, LocalAI, LM Studio, etc.) local: diff --git a/src/config.ts b/src/config.ts index 461c188..6b338c7 100644 --- a/src/config.ts +++ b/src/config.ts @@ -38,6 +38,7 @@ const LanguagesSchema = z .default(["en"]); const PIIDetectionSchema = z.object({ + enabled: z.boolean().default(true), presidio_url: z.string().url(), languages: LanguagesSchema, fallback_language: LanguageEnum.default("en"), @@ -103,7 +104,7 @@ const ConfigSchema = z .object({ mode: z.enum(["route", "mask"]).default("route"), server: ServerSchema.default({}), - // Providers - OpenAI-compatible endpoints + // Providers providers: z.object({ openai: OpenAIProviderSchema.default({}), }), diff --git a/src/constants/timeouts.ts b/src/constants/timeouts.ts new file mode 100644 index 0000000..e4dd8d0 --- /dev/null +++ b/src/constants/timeouts.ts @@ -0,0 +1,2 @@ +export const REQUEST_TIMEOUT_MS = 120_000; +export const HEALTH_CHECK_TIMEOUT_MS = 5_000; diff --git a/src/index.ts b/src/index.ts index 88c2f69..3eb1c97 100644 --- a/src/index.ts +++ b/src/index.ts @@ -177,7 +177,7 @@ Provider: ╚═══════════════════════════════════════════════════════════╝ Server: http://${host}:${port} -API: http://${host}:${port}/openai/v1/chat/completions +OpenAI API: http://${host}:${port}/openai/v1/chat/completions Health: http://${host}:${port}/health Info: http://${host}:${port}/info Dashboard: http://${host}:${port}/dashboard diff --git a/src/utils/conflict-resolver.test.ts b/src/masking/conflict-resolver.test.ts similarity index 100% rename from src/utils/conflict-resolver.test.ts rename to src/masking/conflict-resolver.test.ts diff --git a/src/utils/conflict-resolver.ts b/src/masking/conflict-resolver.ts similarity index 100% rename from src/utils/conflict-resolver.ts rename to src/masking/conflict-resolver.ts diff --git a/src/utils/message-transform.test.ts b/src/masking/context.test.ts similarity index 70% rename from src/utils/message-transform.test.ts rename to src/masking/context.test.ts index 72c4359..1871e11 100644 --- a/src/utils/message-transform.test.ts +++ b/src/masking/context.test.ts @@ -1,5 +1,4 @@ import { describe, expect, test } from "bun:test"; -import type { ChatMessage } from "../providers/openai-client"; import type { Span } from "./conflict-resolver"; import { createPlaceholderContext, @@ -8,9 +7,7 @@ import { processStreamChunk, replaceWithPlaceholders, restorePlaceholders, - restoreResponsePlaceholders, - transformMessagesPerPart, -} from "./message-transform"; +} from "./context"; /** * Simple placeholder format for testing: [[TYPE_N]] @@ -260,162 +257,6 @@ describe("replace -> restore roundtrip", () => { }); }); -describe("transformMessagesPerPart", () => { - test("transforms string content", () => { - const messages: ChatMessage[] = [{ role: "user", content: "Hello world" }]; - const perPartData = [[[{ marker: true }]]]; - - const result = transformMessagesPerPart( - messages, - perPartData, - (text, data) => (data.length > 0 ? text.toUpperCase() : text), - {}, - ); - - expect(result[0].content).toBe("HELLO WORLD"); - }); - - test("skips messages without data", () => { - const messages: ChatMessage[] = [ - { role: "user", content: "Keep this" }, - { role: "assistant", content: "And this" }, - ]; - const perPartData = [[[]], [[]]]; - - const result = transformMessagesPerPart( - messages, - perPartData, - (text) => text.toUpperCase(), - {}, - ); - - expect(result[0].content).toBe("Keep this"); - expect(result[1].content).toBe("And this"); - }); - - test("transforms array content (multimodal)", () => { - const messages: ChatMessage[] = [ - { - role: "user", - content: [ - { type: "text", text: "Hello" }, - { type: "image_url", image_url: { url: "https://example.com/img.jpg" } }, - ], - }, - ]; - const perPartData = [[[{ marker: true }], []]]; - - const result = transformMessagesPerPart( - messages, - perPartData, - (text, data) => (data.length > 0 ? text.toUpperCase() : text), - {}, - ); - - const content = result[0].content as Array<{ type: string; text?: string }>; - expect(content[0].text).toBe("HELLO"); - expect(content[1].type).toBe("image_url"); - }); - - test("preserves message roles", () => { - const messages: ChatMessage[] = [ - { role: "system", content: "sys" }, - { role: "user", content: "usr" }, - { role: "assistant", content: "ast" }, - ]; - const perPartData = [[[]], [[]], [[]]]; - - const result = transformMessagesPerPart(messages, perPartData, (t) => t, {}); - - expect(result[0].role).toBe("system"); - expect(result[1].role).toBe("user"); - expect(result[2].role).toBe("assistant"); - }); - - test("passes context to transform function", () => { - const messages: ChatMessage[] = [{ role: "user", content: "test" }]; - const perPartData = [[[{ id: 1 }]]]; - const ctx = { prefix: ">> " }; - - const result = transformMessagesPerPart( - messages, - perPartData, - (text, _data, context: { prefix: string }) => context.prefix + text, - ctx, - ); - - expect(result[0].content).toBe(">> test"); - }); -}); - -describe("restoreResponsePlaceholders", () => { - test("restores placeholders in response choices", () => { - const ctx = createPlaceholderContext(); - ctx.mapping["[[X_1]]"] = "secret"; - - const response = { - id: "test", - choices: [{ message: { content: "Value: [[X_1]]" } }], - }; - - const result = restoreResponsePlaceholders(response, ctx); - expect(result.choices[0].message.content).toBe("Value: secret"); - }); - - test("handles multiple choices", () => { - const ctx = createPlaceholderContext(); - ctx.mapping["[[X_1]]"] = "val"; - - const response = { - id: "test", - choices: [{ message: { content: "A: [[X_1]]" } }, { message: { content: "B: [[X_1]]" } }], - }; - - const result = restoreResponsePlaceholders(response, ctx); - expect(result.choices[0].message.content).toBe("A: val"); - expect(result.choices[1].message.content).toBe("B: val"); - }); - - test("preserves response structure", () => { - const ctx = createPlaceholderContext(); - const response = { - id: "resp-123", - model: "test-model", - choices: [{ message: { content: "text" } }], - usage: { tokens: 10 }, - }; - - const result = restoreResponsePlaceholders(response, ctx); - expect(result.id).toBe("resp-123"); - expect(result.model).toBe("test-model"); - expect(result.usage).toEqual({ tokens: 10 }); - }); - - test("applies formatValue function", () => { - const ctx = createPlaceholderContext(); - ctx.mapping["[[X_1]]"] = "secret"; - - const response = { - id: "test", - choices: [{ message: { content: "[[X_1]]" } }], - }; - - const result = restoreResponsePlaceholders(response, ctx, (v) => `<${v}>`); - expect(result.choices[0].message.content).toBe(""); - }); - - test("handles non-string content", () => { - const ctx = createPlaceholderContext(); - const response = { - id: "test", - choices: [{ message: { content: null } }], - }; - - const result = restoreResponsePlaceholders(response, ctx); - expect(result.choices[0].message.content).toBe(null); - }); -}); - describe("processStreamChunk", () => { test("processes complete text without placeholders", () => { const ctx = createPlaceholderContext(); diff --git a/src/utils/message-transform.ts b/src/masking/context.ts similarity index 66% rename from src/utils/message-transform.ts rename to src/masking/context.ts index 408b4c1..63f8b02 100644 --- a/src/utils/message-transform.ts +++ b/src/masking/context.ts @@ -1,18 +1,9 @@ /** - * Generic utilities for per-part message transformations - * - * Both PII masking and secrets masking need to: - * 1. Iterate over messages and their content parts - * 2. Apply transformations based on per-part detection data - * 3. Handle string vs array content uniformly - * - * This module provides shared infrastructure to avoid duplication. + * Placeholder context and text transformation utilities */ -import type { ChatMessage } from "../providers/openai-client"; +import { findPartialPlaceholderStart } from "../masking/placeholders"; import type { Span } from "./conflict-resolver"; -import type { ContentPart } from "./content"; -import { findPartialPlaceholderStart } from "./placeholders"; /** * Generic context for placeholder-based transformations @@ -64,54 +55,6 @@ export function incrementAndGenerate( return format(type, count); } -/** - * Transforms messages using per-part data - * - * Generic function that handles the common pattern of: - * - Iterating over messages - * - Handling string vs array content - * - Applying a transform function per text part - * - * @param messages - Chat messages to transform - * @param perPartData - Per-message, per-part data: data[msgIdx][partIdx] - * @param transform - Function to transform text using the part data - * @param context - Shared context passed to all transform calls - */ -export function transformMessagesPerPart( - messages: ChatMessage[], - perPartData: TData[][][], - transform: (text: string, data: TData[], context: TContext) => string, - context: TContext, -): ChatMessage[] { - return messages.map((msg, msgIdx) => { - const partData = perPartData[msgIdx] || []; - - // String content → data is in partData[0] - if (typeof msg.content === "string") { - const data = partData[0] || []; - if (data.length === 0) return msg; - const transformed = transform(msg.content, data, context); - return { ...msg, content: transformed }; - } - - // Array content (multimodal) → data is per-part - if (Array.isArray(msg.content)) { - const transformedContent = msg.content.map((part: ContentPart, partIdx: number) => { - const data = partData[partIdx] || []; - if (part.type === "text" && typeof part.text === "string" && data.length > 0) { - const transformed = transform(part.text, data, context); - return { ...part, text: transformed }; - } - return part; - }); - return { ...msg, content: transformedContent }; - } - - // Null/undefined content - return msg; - }); -} - /** * Restores placeholders in text with original values * @@ -141,31 +84,6 @@ export function restorePlaceholders( return result; } -/** - * Restores placeholders in a chat completion response - * - * @param response - The response object with choices - * @param context - Context with placeholder mappings - * @param formatValue - Optional function to format restored values - */ -export function restoreResponsePlaceholders< - T extends { choices: Array<{ message: { content: unknown } }> }, ->(response: T, context: PlaceholderContext, formatValue?: (original: string) => string): T { - return { - ...response, - choices: response.choices.map((choice) => ({ - ...choice, - message: { - ...choice.message, - content: - typeof choice.message.content === "string" - ? restorePlaceholders(choice.message.content, context, formatValue) - : choice.message.content, - }, - })), - } as T; -} - /** * Replaces items in text with placeholders * diff --git a/src/masking/extractors/openai.test.ts b/src/masking/extractors/openai.test.ts new file mode 100644 index 0000000..5640a46 --- /dev/null +++ b/src/masking/extractors/openai.test.ts @@ -0,0 +1,295 @@ +import { describe, expect, test } from "bun:test"; +import type { PlaceholderContext } from "../../masking/context"; +import type { OpenAIMessage, OpenAIRequest, OpenAIResponse } from "../../providers/openai/types"; +import { openaiExtractor } from "./openai"; + +/** Helper to create a minimal request from messages */ +function createRequest(messages: OpenAIMessage[]): OpenAIRequest { + return { model: "gpt-4", messages }; +} + +describe("OpenAI Text Extractor", () => { + describe("extractTexts", () => { + test("extracts text from string content", () => { + const request = createRequest([ + { role: "system", content: "You are helpful" }, + { role: "user", content: "Hello world" }, + ]); + + const spans = openaiExtractor.extractTexts(request); + + expect(spans).toHaveLength(2); + expect(spans[0]).toEqual({ + text: "You are helpful", + path: "messages[0].content", + messageIndex: 0, + partIndex: 0, + }); + expect(spans[1]).toEqual({ + text: "Hello world", + path: "messages[1].content", + messageIndex: 1, + partIndex: 0, + }); + }); + + test("extracts text from multimodal array content", () => { + const request = createRequest([ + { + role: "user", + content: [ + { type: "text", text: "Describe this image:" }, + { type: "image_url", image_url: { url: "https://example.com/img.jpg" } }, + { type: "text", text: "Be detailed" }, + ], + }, + ]); + + const spans = openaiExtractor.extractTexts(request); + + expect(spans).toHaveLength(2); + expect(spans[0]).toEqual({ + text: "Describe this image:", + path: "messages[0].content[0].text", + messageIndex: 0, + partIndex: 0, + }); + expect(spans[1]).toEqual({ + text: "Be detailed", + path: "messages[0].content[2].text", + messageIndex: 0, + partIndex: 2, + }); + }); + + test("handles mixed string and array content", () => { + const request = createRequest([ + { role: "system", content: "System prompt" }, + { + role: "user", + content: [{ type: "text", text: "User message with image" }], + }, + { role: "assistant", content: "Assistant response" }, + ]); + + const spans = openaiExtractor.extractTexts(request); + + expect(spans).toHaveLength(3); + expect(spans[0].messageIndex).toBe(0); + expect(spans[1].messageIndex).toBe(1); + expect(spans[2].messageIndex).toBe(2); + }); + + test("skips null/undefined content", () => { + const request = createRequest([ + { role: "user", content: "Hello" }, + { role: "assistant", content: null as unknown as string }, + ]); + + const spans = openaiExtractor.extractTexts(request); + + expect(spans).toHaveLength(1); + expect(spans[0].text).toBe("Hello"); + }); + }); + + describe("applyMasked", () => { + test("applies masked text to string content", () => { + const request = createRequest([{ role: "user", content: "My email is john@example.com" }]); + + const maskedSpans = [ + { + path: "messages[0].content", + maskedText: "My email is [[EMAIL_ADDRESS_1]]", + messageIndex: 0, + partIndex: 0, + }, + ]; + + const result = openaiExtractor.applyMasked(request, maskedSpans); + + expect(result.messages[0].content).toBe("My email is [[EMAIL_ADDRESS_1]]"); + }); + + test("applies masked text to multimodal content", () => { + const request = createRequest([ + { + role: "user", + content: [ + { type: "text", text: "Contact: john@example.com" }, + { type: "image_url", image_url: { url: "https://example.com/img.jpg" } }, + { type: "text", text: "Phone: 555-1234" }, + ], + }, + ]); + + const maskedSpans = [ + { + path: "messages[0].content[0].text", + maskedText: "Contact: [[EMAIL_ADDRESS_1]]", + messageIndex: 0, + partIndex: 0, + }, + { + path: "messages[0].content[2].text", + maskedText: "Phone: [[PHONE_NUMBER_1]]", + messageIndex: 0, + partIndex: 2, + }, + ]; + + const result = openaiExtractor.applyMasked(request, maskedSpans); + const content = result.messages[0].content as Array<{ type: string; text?: string }>; + + expect(content[0].text).toBe("Contact: [[EMAIL_ADDRESS_1]]"); + expect(content[1].type).toBe("image_url"); // Unchanged + expect(content[2].text).toBe("Phone: [[PHONE_NUMBER_1]]"); + }); + + test("preserves messages without masked spans", () => { + const request = createRequest([ + { role: "system", content: "You are helpful" }, + { role: "user", content: "My email is john@example.com" }, + ]); + + const maskedSpans = [ + { + path: "messages[1].content", + maskedText: "My email is [[EMAIL_ADDRESS_1]]", + messageIndex: 1, + partIndex: 0, + }, + ]; + + const result = openaiExtractor.applyMasked(request, maskedSpans); + + expect(result.messages[0].content).toBe("You are helpful"); // Unchanged + expect(result.messages[1].content).toBe("My email is [[EMAIL_ADDRESS_1]]"); + }); + }); + + describe("unmaskResponse", () => { + test("unmasks placeholders in response content", () => { + const response: OpenAIResponse = { + id: "test-id", + object: "chat.completion", + created: 123456, + model: "gpt-4", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Hello [[PERSON_1]], your email is [[EMAIL_ADDRESS_1]]", + }, + finish_reason: "stop", + }, + ], + }; + + const context: PlaceholderContext = { + mapping: { + "[[PERSON_1]]": "John", + "[[EMAIL_ADDRESS_1]]": "john@example.com", + }, + reverseMapping: { + John: "[[PERSON_1]]", + "john@example.com": "[[EMAIL_ADDRESS_1]]", + }, + counters: { PERSON: 1, EMAIL_ADDRESS: 1 }, + }; + + const result = openaiExtractor.unmaskResponse(response, context); + + expect(result.choices[0].message.content).toBe("Hello John, your email is john@example.com"); + }); + + test("applies formatValue function when provided", () => { + const response: OpenAIResponse = { + id: "test-id", + object: "chat.completion", + created: 123456, + model: "gpt-4", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Hello [[PERSON_1]]" }, + finish_reason: "stop", + }, + ], + }; + + const context: PlaceholderContext = { + mapping: { "[[PERSON_1]]": "John" }, + reverseMapping: { John: "[[PERSON_1]]" }, + counters: { PERSON: 1 }, + }; + + const result = openaiExtractor.unmaskResponse( + response, + context, + (val) => `[protected]${val}`, + ); + + expect(result.choices[0].message.content).toBe("Hello [protected]John"); + }); + + test("handles multiple choices", () => { + const response: OpenAIResponse = { + id: "test-id", + object: "chat.completion", + created: 123456, + model: "gpt-4", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Option A: [[PERSON_1]]" }, + finish_reason: "stop", + }, + { + index: 1, + message: { role: "assistant", content: "Option B: [[PERSON_1]]" }, + finish_reason: "stop", + }, + ], + }; + + const context: PlaceholderContext = { + mapping: { "[[PERSON_1]]": "John" }, + reverseMapping: { John: "[[PERSON_1]]" }, + counters: { PERSON: 1 }, + }; + + const result = openaiExtractor.unmaskResponse(response, context); + + expect(result.choices[0].message.content).toBe("Option A: John"); + expect(result.choices[1].message.content).toBe("Option B: John"); + }); + + test("preserves non-string content", () => { + const response: OpenAIResponse = { + id: "test-id", + object: "chat.completion", + created: 123456, + model: "gpt-4", + choices: [ + { + index: 0, + message: { role: "assistant", content: null as unknown as string }, + finish_reason: "stop", + }, + ], + }; + + const context: PlaceholderContext = { + mapping: {}, + reverseMapping: {}, + counters: {}, + }; + + const result = openaiExtractor.unmaskResponse(response, context); + + expect(result.choices[0].message.content).toBeNull(); + }); + }); +}); diff --git a/src/masking/extractors/openai.ts b/src/masking/extractors/openai.ts new file mode 100644 index 0000000..4cca57e --- /dev/null +++ b/src/masking/extractors/openai.ts @@ -0,0 +1,111 @@ +/** + * OpenAI request extractor for format-agnostic masking + * + * Extracts text content from OpenAI-format requests and responses, + * enabling the core masking service to work without knowledge of + * the specific request structure. + * + * For OpenAI, system prompts are regular messages with role "system", + * so no special handling is needed. + */ + +import { type PlaceholderContext, restorePlaceholders } from "../../masking/context"; +import type { OpenAIRequest, OpenAIResponse } from "../../providers/openai/types"; +import type { OpenAIContentPart } from "../../utils/content"; +import type { MaskedSpan, RequestExtractor, TextSpan } from "../types"; + +/** + * OpenAI request extractor + * + * Handles both string content and multimodal array content. + * System prompts are just messages with role "system". + */ +export const openaiExtractor: RequestExtractor = { + extractTexts(request: OpenAIRequest): TextSpan[] { + const spans: TextSpan[] = []; + + for (let msgIdx = 0; msgIdx < request.messages.length; msgIdx++) { + const msg = request.messages[msgIdx]; + + if (typeof msg.content === "string") { + spans.push({ + text: msg.content, + path: `messages[${msgIdx}].content`, + messageIndex: msgIdx, + partIndex: 0, + }); + continue; + } + + if (Array.isArray(msg.content)) { + for (let partIdx = 0; partIdx < msg.content.length; partIdx++) { + const part = msg.content[partIdx] as OpenAIContentPart; + if (part.type === "text" && typeof part.text === "string") { + spans.push({ + text: part.text, + path: `messages[${msgIdx}].content[${partIdx}].text`, + messageIndex: msgIdx, + partIndex: partIdx, + }); + } + } + } + } + + return spans; + }, + + applyMasked(request: OpenAIRequest, maskedSpans: MaskedSpan[]): OpenAIRequest { + const lookup = new Map(); + for (const span of maskedSpans) { + lookup.set(`${span.messageIndex}:${span.partIndex}`, span.maskedText); + } + + const maskedMessages = request.messages.map((msg, msgIdx) => { + if (typeof msg.content === "string") { + const key = `${msgIdx}:0`; + const masked = lookup.get(key); + if (masked !== undefined) { + return { ...msg, content: masked }; + } + return msg; + } + + if (Array.isArray(msg.content)) { + const transformedContent = msg.content.map((part: OpenAIContentPart, partIdx: number) => { + const key = `${msgIdx}:${partIdx}`; + const masked = lookup.get(key); + if (part.type === "text" && masked !== undefined) { + return { ...part, text: masked }; + } + return part; + }); + return { ...msg, content: transformedContent }; + } + + return msg; + }); + + return { ...request, messages: maskedMessages }; + }, + + unmaskResponse( + response: OpenAIResponse, + context: PlaceholderContext, + formatValue?: (original: string) => string, + ): OpenAIResponse { + return { + ...response, + choices: response.choices.map((choice) => ({ + ...choice, + message: { + ...choice.message, + content: + typeof choice.message.content === "string" + ? restorePlaceholders(choice.message.content, context, formatValue) + : choice.message.content, + }, + })), + }; + }, +}; diff --git a/src/utils/placeholders.test.ts b/src/masking/placeholders.test.ts similarity index 100% rename from src/utils/placeholders.test.ts rename to src/masking/placeholders.test.ts diff --git a/src/utils/placeholders.ts b/src/masking/placeholders.ts similarity index 93% rename from src/utils/placeholders.ts rename to src/masking/placeholders.ts index 708c84b..149c37d 100644 --- a/src/utils/placeholders.ts +++ b/src/masking/placeholders.ts @@ -1,6 +1,5 @@ /** - * Placeholder constants for PII masking and secrets masking - * Single source of truth for all placeholder-related logic + * Placeholder constants and utilities */ export const PLACEHOLDER_DELIMITERS = { diff --git a/src/masking/service.ts b/src/masking/service.ts new file mode 100644 index 0000000..3ee35a8 --- /dev/null +++ b/src/masking/service.ts @@ -0,0 +1,148 @@ +/** + * Core masking service + * + * Provides masking operations that work on text spans. Handles: + * - Replacing sensitive data with placeholders + * - Storing mappings for later unmasking + * - Processing streaming chunks with buffering + */ + +import type { Span } from "../masking/conflict-resolver"; +import { + createPlaceholderContext, + flushBuffer, + type PlaceholderContext, + processStreamChunk, + replaceWithPlaceholders, + restorePlaceholders, +} from "../masking/context"; +import type { MaskedSpan, TextSpan } from "./types"; + +export type { PlaceholderContext } from "../masking/context"; + +/** + * Result of masking text spans + */ +export interface MaskSpansResult { + /** Masked text spans ready to apply back to messages */ + maskedSpans: MaskedSpan[]; + /** Context for unmasking (maps placeholders to original values) */ + context: PlaceholderContext; +} + +/** + * Masks text spans using per-span entity data + * + * This is the core masking operation that: + * 1. Takes extracted text spans + * 2. Applies entity-based replacement for each span + * 3. Returns masked spans ready to be applied back to messages + * + * @param spans - Text spans extracted from messages + * @param perSpanData - Per-span entity/location data: perSpanData[spanIndex] = items + * @param getType - Function to get type string from an item + * @param generatePlaceholder - Function to generate placeholder for a type + * @param resolveConflicts - Function to resolve overlapping items + * @param context - Optional existing context (for combining PII + secrets masking) + */ +export function maskSpans( + spans: TextSpan[], + perSpanData: T[][], + getType: (item: T) => string, + generatePlaceholder: (type: string, context: PlaceholderContext) => string, + resolveConflicts: (items: T[]) => T[], + context?: PlaceholderContext, +): MaskSpansResult { + const ctx = context || createPlaceholderContext(); + const maskedSpans: MaskedSpan[] = []; + + for (let i = 0; i < spans.length; i++) { + const span = spans[i]; + const items = perSpanData[i] || []; + + if (items.length === 0) { + // No items to mask, but still include the span for completeness + maskedSpans.push({ + path: span.path, + maskedText: span.text, + messageIndex: span.messageIndex, + partIndex: span.partIndex, + }); + continue; + } + + const maskedText = replaceWithPlaceholders( + span.text, + items, + ctx, + getType, + generatePlaceholder, + resolveConflicts, + ); + + maskedSpans.push({ + path: span.path, + maskedText, + messageIndex: span.messageIndex, + partIndex: span.partIndex, + }); + } + + return { maskedSpans, context: ctx }; +} + +/** + * Creates a new masking context + */ +export function createMaskingContext(): PlaceholderContext { + return createPlaceholderContext(); +} + +/** + * Unmasks text by replacing placeholders with original values + * + * @param text - Text containing placeholders + * @param context - Masking context with mappings + * @param formatValue - Optional function to format restored values + */ +export function unmask( + text: string, + context: PlaceholderContext, + formatValue?: (original: string) => string, +): string { + return restorePlaceholders(text, context, formatValue); +} + +/** + * Processes a stream chunk, buffering partial placeholders + * + * @param buffer - Previous buffer content + * @param newChunk - New chunk to process + * @param context - Placeholder context + * @param formatValue - Optional function to format restored values + */ +export function unmaskStreamChunk( + buffer: string, + newChunk: string, + context: PlaceholderContext, + formatValue?: (original: string) => string, +): { output: string; remainingBuffer: string } { + return processStreamChunk(buffer, newChunk, context, (text, ctx) => + restorePlaceholders(text, ctx, formatValue), + ); +} + +/** + * Flushes remaining buffer at end of stream + * + * @param buffer - Remaining buffer content + * @param context - Placeholder context + * @param formatValue - Optional function to format restored values + */ +export function flushMaskingBuffer( + buffer: string, + context: PlaceholderContext, + formatValue?: (original: string) => string, +): string { + return flushBuffer(buffer, context, (text, ctx) => restorePlaceholders(text, ctx, formatValue)); +} diff --git a/src/masking/types.ts b/src/masking/types.ts new file mode 100644 index 0000000..e63c84d --- /dev/null +++ b/src/masking/types.ts @@ -0,0 +1,31 @@ +/** + * Masking types + */ + +import type { PlaceholderContext } from "../masking/context"; + +export interface TextSpan { + text: string; + path: string; + messageIndex: number; + partIndex: number; + nestedPartIndex?: number; +} + +export interface MaskedSpan { + path: string; + maskedText: string; + messageIndex: number; + partIndex: number; + nestedPartIndex?: number; +} + +export interface RequestExtractor { + extractTexts(request: TRequest): TextSpan[]; + applyMasked(request: TRequest, maskedSpans: MaskedSpan[]): TRequest; + unmaskResponse( + response: TResponse, + context: PlaceholderContext, + formatValue?: (original: string) => string, + ): TResponse; +} diff --git a/src/pii/detect.test.ts b/src/pii/detect.test.ts index 46be2ff..9b2169b 100644 --- a/src/pii/detect.test.ts +++ b/src/pii/detect.test.ts @@ -1,4 +1,6 @@ import { afterEach, describe, expect, mock, test } from "bun:test"; +import { openaiExtractor } from "../masking/extractors/openai"; +import type { OpenAIMessage, OpenAIRequest } from "../providers/openai/types"; import { PIIDetector } from "./detect"; const originalFetch = globalThis.fetch; @@ -39,12 +41,16 @@ function mockPresidio( }) as unknown as typeof fetch; } +function createRequest(messages: OpenAIMessage[]): OpenAIRequest { + return { model: "gpt-4", messages }; +} + describe("PIIDetector", () => { afterEach(() => { globalThis.fetch = originalFetch; }); - describe("analyzeMessages", () => { + describe("analyzeRequest", () => { test("scans all message roles", async () => { mockPresidio({ "system-pii": [{ entity_type: "PERSON", start: 0, end: 10, score: 0.9 }], @@ -53,25 +59,19 @@ describe("PIIDetector", () => { }); const detector = new PIIDetector(); - const messages = [ + const request = createRequest([ { role: "system", content: "system-pii here" }, { role: "user", content: "user-pii here" }, { role: "assistant", content: "assistant-pii here" }, - ]; + ]); - const result = await detector.analyzeMessages(messages); + const result = await detector.analyzeRequest(request, openaiExtractor); expect(result.hasPII).toBe(true); - // Per-message, per-part: messageEntities[msgIdx][partIdx] = entities - expect(result.messageEntities).toHaveLength(3); - // Each message has 1 part (string content) - expect(result.messageEntities[0]).toHaveLength(1); - expect(result.messageEntities[1]).toHaveLength(1); - expect(result.messageEntities[2]).toHaveLength(1); - // Each part has 1 entity - expect(result.messageEntities[0][0]).toHaveLength(1); - expect(result.messageEntities[1][0]).toHaveLength(1); - expect(result.messageEntities[2][0]).toHaveLength(1); + expect(result.spanEntities).toHaveLength(3); + expect(result.spanEntities[0]).toHaveLength(1); + expect(result.spanEntities[1]).toHaveLength(1); + expect(result.spanEntities[2]).toHaveLength(1); }); test("detects PII in system message when user message has none", async () => { @@ -80,16 +80,16 @@ describe("PIIDetector", () => { }); const detector = new PIIDetector(); - const messages = [ + const request = createRequest([ { role: "system", content: "Context from PDF: John Doe lives at 123 Main St" }, { role: "user", content: "Extract the data into JSON" }, - ]; + ]); - const result = await detector.analyzeMessages(messages); + const result = await detector.analyzeRequest(request, openaiExtractor); expect(result.hasPII).toBe(true); - expect(result.messageEntities[0][0]).toHaveLength(1); - expect(result.messageEntities[0][0][0].entity_type).toBe("PERSON"); + expect(result.spanEntities[0]).toHaveLength(1); + expect(result.spanEntities[0][0].entity_type).toBe("PERSON"); }); test("detects PII in earlier user message", async () => { @@ -98,26 +98,28 @@ describe("PIIDetector", () => { }); const detector = new PIIDetector(); - const messages = [ + const request = createRequest([ { role: "user", content: "My email is secret@email.com" }, { role: "assistant", content: "Got it." }, { role: "user", content: "Now do something else" }, - ]; + ]); - const result = await detector.analyzeMessages(messages); + const result = await detector.analyzeRequest(request, openaiExtractor); expect(result.hasPII).toBe(true); - expect(result.messageEntities[0][0]).toHaveLength(1); + expect(result.spanEntities[0]).toHaveLength(1); }); test("returns empty result for no messages", async () => { mockPresidio({}); const detector = new PIIDetector(); - const result = await detector.analyzeMessages([]); + const request = createRequest([]); + + const result = await detector.analyzeRequest(request, openaiExtractor); expect(result.hasPII).toBe(false); - expect(result.messageEntities).toHaveLength(0); + expect(result.spanEntities).toHaveLength(0); expect(result.allEntities).toHaveLength(0); }); @@ -127,7 +129,7 @@ describe("PIIDetector", () => { }); const detector = new PIIDetector(); - const messages = [ + const request = createRequest([ { role: "user", content: [ @@ -135,17 +137,14 @@ describe("PIIDetector", () => { { type: "image_url", image_url: { url: "data:image/png;base64,..." } }, ], }, - ]; + ]); - const result = await detector.analyzeMessages(messages); + const result = await detector.analyzeRequest(request, openaiExtractor); expect(result.hasPII).toBe(true); - // Multimodal message has 2 parts - expect(result.messageEntities[0]).toHaveLength(2); - // First part (text) has 1 entity - expect(result.messageEntities[0][0]).toHaveLength(1); - // Second part (image) has no entities - expect(result.messageEntities[0][1]).toHaveLength(0); + // Only text parts are extracted as spans (image is skipped) + expect(result.spanEntities).toHaveLength(1); + expect(result.spanEntities[0]).toHaveLength(1); }); test("skips messages with empty content", async () => { @@ -154,17 +153,16 @@ describe("PIIDetector", () => { }); const detector = new PIIDetector(); - const messages = [ + const request = createRequest([ { role: "user", content: "" }, { role: "assistant", content: "test response" }, - ]; + ]); - const result = await detector.analyzeMessages(messages); + const result = await detector.analyzeRequest(request, openaiExtractor); - expect(result.messageEntities).toHaveLength(2); - // First message (empty string) has 1 part with no entities - expect(result.messageEntities[0]).toHaveLength(1); - expect(result.messageEntities[0][0]).toHaveLength(0); + expect(result.spanEntities).toHaveLength(2); + // First message (empty string) has no entities + expect(result.spanEntities[0]).toHaveLength(0); }); }); diff --git a/src/pii/detect.ts b/src/pii/detect.ts index ae078f5..ecd9bff 100644 --- a/src/pii/detect.ts +++ b/src/pii/detect.ts @@ -1,6 +1,7 @@ import { getConfig } from "../config"; +import { HEALTH_CHECK_TIMEOUT_MS } from "../constants/timeouts"; +import type { RequestExtractor } from "../masking/types"; import { getLanguageDetector, type SupportedLanguage } from "../services/language-detector"; -import { extractTextContent, type MessageContent } from "../utils/content"; export interface PIIEntity { entity_type: string; @@ -16,15 +17,9 @@ interface AnalyzeRequest { score_threshold?: number; } -/** - * Per-message, per-part PII detection result - * Structure: messageEntities[msgIdx][partIdx] = entities for that part - */ export interface PIIDetectionResult { hasPII: boolean; - /** Per-message, per-part entities */ - messageEntities: PIIEntity[][][]; - /** Flattened list of all entities (for summary/logging) */ + spanEntities: PIIEntity[][]; allEntities: PIIEntity[]; scanTimeMs: number; language: SupportedLanguage; @@ -85,63 +80,38 @@ export class PIIDetector { } /** - * Analyzes messages for PII with per-part granularity - * - * For string content, entities are in messageEntities[msgIdx][0]. - * For array content (multimodal), each text part is scanned separately. + * Analyzes a request for PII using an extractor */ - async analyzeMessages( - messages: Array<{ role: string; content: MessageContent }>, + async analyzeRequest( + request: TRequest, + extractor: RequestExtractor, ): Promise { const startTime = Date.now(); const config = getConfig(); - // Detect language from the last user message - const lastUserMsg = messages.findLast((m) => m.role === "user"); - const langText = lastUserMsg ? extractTextContent(lastUserMsg.content) : ""; + // Extract all text spans from request + const spans = extractor.extractTexts(request); + + // Detect language from message content (skip system spans with messageIndex -1) + const messageSpans = spans.filter((span) => span.messageIndex >= 0); + const langText = messageSpans.map((s) => s.text).join("\n"); const langResult = langText ? getLanguageDetector().detect(langText) : { language: config.pii_detection.fallback_language, usedFallback: true }; - const scannedRoles = ["system", "developer", "user", "assistant", "tool"]; - - // Detect PII per message, per content part - const messageEntities: PIIEntity[][][] = await Promise.all( - messages.map(async (message) => { - if (!scannedRoles.includes(message.role)) { - return []; - } - - // String content → wrap in single-element array - if (typeof message.content === "string") { - const entities = message.content - ? await this.detectPII(message.content, langResult.language) - : []; - return [entities]; - } - - // Array content (multimodal) → per-part detection - if (Array.isArray(message.content)) { - return await Promise.all( - message.content.map(async (part) => { - if (part.type === "text" && typeof part.text === "string") { - return await this.detectPII(part.text, langResult.language); - } - return []; - }), - ); - } - - // Null/undefined content - return []; + // Detect PII for each span independently + const spanEntities: PIIEntity[][] = await Promise.all( + spans.map(async (span) => { + if (!span.text) return []; + return this.detectPII(span.text, langResult.language); }), ); - const allEntities = messageEntities.flat(2); + const allEntities = spanEntities.flat(); return { hasPII: allEntities.length > 0, - messageEntities, + spanEntities, allEntities, scanTimeMs: Date.now() - startTime, language: langResult.language, @@ -154,7 +124,7 @@ export class PIIDetector { try { const response = await fetch(`${this.presidioUrl}/health`, { method: "GET", - signal: AbortSignal.timeout(5000), + signal: AbortSignal.timeout(HEALTH_CHECK_TIMEOUT_MS), }); return response.ok; } catch { @@ -199,7 +169,7 @@ export class PIIDetector { language, entities: ["PERSON"], }), - signal: AbortSignal.timeout(5000), + signal: AbortSignal.timeout(HEALTH_CHECK_TIMEOUT_MS), }); // If we get a response (even empty array), the language is supported diff --git a/src/pii/mask.test.ts b/src/pii/mask.test.ts index 7cac72c..39ba8b8 100644 --- a/src/pii/mask.test.ts +++ b/src/pii/mask.test.ts @@ -1,13 +1,14 @@ import { describe, expect, test } from "bun:test"; import type { MaskingConfig } from "../config"; -import type { ChatMessage } from "../providers/openai-client"; -import { createPIIResult } from "../test-utils/detection-results"; +import { openaiExtractor } from "../masking/extractors/openai"; +import type { OpenAIMessage, OpenAIRequest, OpenAIResponse } from "../providers/openai/types"; +import { createPIIResultFromSpans } from "../test-utils/detection-results"; import type { PIIEntity } from "./detect"; import { createMaskingContext, flushMaskingBuffer, mask, - maskMessages, + maskRequest, unmask, unmaskResponse, unmaskStreamChunk, @@ -23,6 +24,11 @@ const configWithMarkers: MaskingConfig = { marker_text: "[protected]", }; +/** Helper to create a minimal request from messages */ +function createRequest(messages: OpenAIMessage[]): OpenAIRequest { + return { model: "gpt-4", messages }; +} + describe("PII placeholder format", () => { test("uses [[TYPE_N]] format", () => { const entities: PIIEntity[] = [{ entity_type: "EMAIL_ADDRESS", start: 0, end: 16, score: 1.0 }]; @@ -83,50 +89,51 @@ describe("marker feature", () => { const context = createMaskingContext(); context.mapping["[[PERSON_1]]"] = "John Doe"; - const response = { + const response: OpenAIResponse = { id: "test", - object: "chat.completion" as const, + object: "chat.completion", created: 1234567890, model: "gpt-4", choices: [ { index: 0, - message: { role: "assistant" as const, content: "Hello [[PERSON_1]]" }, - finish_reason: "stop" as const, + message: { role: "assistant", content: "Hello [[PERSON_1]]" }, + finish_reason: "stop", }, ], }; - const result = unmaskResponse(response, context, configWithMarkers); + const result = unmaskResponse(response, context, configWithMarkers, openaiExtractor); expect(result.choices[0].message.content).toBe("Hello [protected]John Doe"); }); }); -describe("maskMessages with PIIDetectionResult", () => { +describe("maskRequest with PIIDetectionResult", () => { test("masks multiple messages using detection result", () => { - const messages: ChatMessage[] = [ + const request = createRequest([ { role: "user", content: "My email is test@example.com" }, { role: "assistant", content: "Got it" }, { role: "user", content: "Also john@test.com" }, - ]; + ]); - const detection = createPIIResult([ - [[{ entity_type: "EMAIL_ADDRESS", start: 12, end: 28, score: 1.0 }]], - [[]], - [[{ entity_type: "EMAIL_ADDRESS", start: 5, end: 18, score: 1.0 }]], + // spanEntities[0] = first message, [1] = second message, [2] = third message + const detection = createPIIResultFromSpans([ + [{ entity_type: "EMAIL_ADDRESS", start: 12, end: 28, score: 1.0 }], + [], + [{ entity_type: "EMAIL_ADDRESS", start: 5, end: 18, score: 1.0 }], ]); - const { masked, context } = maskMessages(messages, detection); + const { request: masked, context } = maskRequest(request, detection, openaiExtractor); - expect(masked[0].content).toBe("My email is [[EMAIL_ADDRESS_1]]"); - expect(masked[1].content).toBe("Got it"); - expect(masked[2].content).toBe("Also [[EMAIL_ADDRESS_2]]"); + expect(masked.messages[0].content).toBe("My email is [[EMAIL_ADDRESS_1]]"); + expect(masked.messages[1].content).toBe("Got it"); + expect(masked.messages[2].content).toBe("Also [[EMAIL_ADDRESS_2]]"); expect(context.mapping["[[EMAIL_ADDRESS_1]]"]).toBe("test@example.com"); expect(context.mapping["[[EMAIL_ADDRESS_2]]"]).toBe("john@test.com"); }); test("handles multimodal content", () => { - const messages: ChatMessage[] = [ + const request = createRequest([ { role: "user", content: [ @@ -134,15 +141,16 @@ describe("maskMessages with PIIDetectionResult", () => { { type: "image_url", image_url: { url: "https://example.com/img.jpg" } }, ], }, - ]; + ]); - const detection = createPIIResult([ - [[{ entity_type: "EMAIL_ADDRESS", start: 8, end: 21, score: 1.0 }], []], + // One span for the text content (image is skipped) + const detection = createPIIResultFromSpans([ + [{ entity_type: "EMAIL_ADDRESS", start: 8, end: 21, score: 1.0 }], ]); - const { masked } = maskMessages(messages, detection); + const { request: masked } = maskRequest(request, detection, openaiExtractor); - const content = masked[0].content as Array<{ type: string; text?: string }>; + const content = masked.messages[0].content as Array<{ type: string; text?: string }>; expect(content[0].text).toBe("Contact [[EMAIL_ADDRESS_1]]"); expect(content[1].type).toBe("image_url"); }); @@ -288,25 +296,25 @@ describe("unmaskResponse", () => { context.mapping["[[EMAIL_ADDRESS_1]]"] = "test@test.com"; context.mapping["[[PERSON_1]]"] = "John Doe"; - const response = { + const response: OpenAIResponse = { id: "chatcmpl-123", - object: "chat.completion" as const, + object: "chat.completion", created: 1234567890, model: "gpt-4", choices: [ { index: 0, message: { - role: "assistant" as const, + role: "assistant", content: "Contact [[PERSON_1]] at [[EMAIL_ADDRESS_1]]", }, - finish_reason: "stop" as const, + finish_reason: "stop", }, ], usage: { prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 }, }; - const result = unmaskResponse(response, context, defaultConfig); + const result = unmaskResponse(response, context, defaultConfig, openaiExtractor); expect(result.choices[0].message.content).toBe("Contact John Doe at test@test.com"); expect(result.id).toBe("chatcmpl-123"); diff --git a/src/pii/mask.ts b/src/pii/mask.ts index 6133e71..d318293 100644 --- a/src/pii/mask.ts +++ b/src/pii/mask.ts @@ -1,31 +1,33 @@ +/** + * PII masking + */ + import type { MaskingConfig } from "../config"; -import type { ChatCompletionResponse, ChatMessage } from "../providers/openai-client"; -import { resolveConflicts } from "../utils/conflict-resolver"; -import { - createPlaceholderContext, - flushBuffer, - incrementAndGenerate, - type MaskResult, - type PlaceholderContext, - processStreamChunk, - replaceWithPlaceholders, - restorePlaceholders, - restoreResponsePlaceholders, - transformMessagesPerPart, -} from "../utils/message-transform"; +import { resolveConflicts } from "../masking/conflict-resolver"; +import { incrementAndGenerate } from "../masking/context"; import { generatePlaceholder as generatePlaceholderFromFormat, PII_PLACEHOLDER_FORMAT, -} from "../utils/placeholders"; +} from "../masking/placeholders"; +import { + flushMaskingBuffer as flushBuffer, + type MaskSpansResult, + maskSpans, + type PlaceholderContext, + unmaskStreamChunk as unmaskChunk, + unmask as unmaskText, +} from "../masking/service"; +import type { RequestExtractor, TextSpan } from "../masking/types"; import type { PIIDetectionResult, PIIEntity } from "./detect"; -export type { MaskResult } from "../utils/message-transform"; +export { createMaskingContext, type PlaceholderContext } from "../masking/service"; /** - * Creates a new masking context for a request + * Result of masking operation */ -export function createMaskingContext(): PlaceholderContext { - return createPlaceholderContext(); +export interface MaskResult { + masked: string; + context: PlaceholderContext; } /** @@ -52,52 +54,33 @@ export function mask( entities: PIIEntity[], context?: PlaceholderContext, ): MaskResult { - const ctx = context || createMaskingContext(); - const masked = replaceWithPlaceholders( - text, - entities, - ctx, + const spans: TextSpan[] = [{ text, path: "text", messageIndex: 0, partIndex: 0 }]; + const perSpanData = [entities]; + + const result = maskSpans( + spans, + perSpanData, (e) => e.entity_type, generatePlaceholder, resolveConflicts, + context, ); - return { masked, context: ctx }; + + return { + masked: result.maskedSpans[0]?.maskedText ?? text, + context: result.context, + }; } /** * Unmasks text by replacing placeholders with original values - * - * Optionally adds markers to indicate protected content */ export function unmask(text: string, context: PlaceholderContext, config: MaskingConfig): string { - return restorePlaceholders(text, context, getFormatValue(config)); -} - -/** - * Masks messages using per-part entity detection results - * - * Uses transformMessagesPerPart for the common iteration pattern. - */ -export function maskMessages( - messages: ChatMessage[], - detection: PIIDetectionResult, -): { masked: ChatMessage[]; context: PlaceholderContext } { - const context = createMaskingContext(); - - const masked = transformMessagesPerPart( - messages, - detection.messageEntities, - (text, entities, ctx) => mask(text, entities, ctx).masked, - context, - ); - - return { masked, context }; + return unmaskText(text, context, getFormatValue(config)); } /** * Streaming unmask helper - processes chunks and unmasks when complete placeholders are found - * - * Returns the unmasked portion and any remaining buffer that might contain partial placeholders */ export function unmaskStreamChunk( buffer: string, @@ -105,7 +88,7 @@ export function unmaskStreamChunk( context: PlaceholderContext, config: MaskingConfig, ): { output: string; remainingBuffer: string } { - return processStreamChunk(buffer, newChunk, context, (text, ctx) => unmask(text, ctx, config)); + return unmaskChunk(buffer, newChunk, context, getFormatValue(config)); } /** @@ -116,16 +99,68 @@ export function flushMaskingBuffer( context: PlaceholderContext, config: MaskingConfig, ): string { - return flushBuffer(buffer, context, (text, ctx) => unmask(text, ctx, config)); + return flushBuffer(buffer, context, getFormatValue(config)); +} + +/** + * Result of masking a request + */ +export interface MaskRequestResult { + /** The masked request */ + request: TRequest; + /** Masking context for unmasking response */ + context: PlaceholderContext; +} + +/** + * Masks PII in a request using an extractor + */ +export function maskRequest( + request: TRequest, + detection: PIIDetectionResult, + extractor: RequestExtractor, + existingContext?: PlaceholderContext, +): MaskRequestResult { + const spans = extractor.extractTexts(request); + const { maskedSpans, context } = maskSpansWithEntities( + spans, + detection.spanEntities, + existingContext, + ); + + // Filter to only spans that were actually masked + const changedSpans = maskedSpans.filter((_, i) => { + const entities = detection.spanEntities[i] || []; + return entities.length > 0; + }); + + const maskedRequest = extractor.applyMasked(request, changedSpans); + return { request: maskedRequest, context }; +} + +function maskSpansWithEntities( + spans: TextSpan[], + spanEntities: PIIEntity[][], + existingContext?: PlaceholderContext, +): MaskSpansResult { + return maskSpans( + spans, + spanEntities, + (e) => e.entity_type, + generatePlaceholder, + resolveConflicts, + existingContext, + ); } /** - * Unmasks a chat completion response by replacing placeholders in all choices + * Unmasks a response using a request extractor */ -export function unmaskResponse( - response: ChatCompletionResponse, +export function unmaskResponse( + response: TResponse, context: PlaceholderContext, config: MaskingConfig, -): ChatCompletionResponse { - return restoreResponsePlaceholders(response, context, getFormatValue(config)); + extractor: RequestExtractor, +): TResponse { + return extractor.unmaskResponse(response, context, getFormatValue(config)); } diff --git a/src/providers/errors.ts b/src/providers/errors.ts new file mode 100644 index 0000000..d9a5a1e --- /dev/null +++ b/src/providers/errors.ts @@ -0,0 +1,17 @@ +/** + * Shared provider errors + */ + +/** + * Error from upstream provider (OpenAI, etc.) + */ +export class ProviderError extends Error { + constructor( + public readonly status: number, + public readonly statusText: string, + public readonly body: string, + ) { + super(`Provider error: ${status} ${statusText}`); + this.name = "ProviderError"; + } +} diff --git a/src/providers/local.ts b/src/providers/local.ts new file mode 100644 index 0000000..d00fd26 --- /dev/null +++ b/src/providers/local.ts @@ -0,0 +1,76 @@ +/** + * Local provider - simple functions for forwarding to local LLM + * Used in route mode for PII-containing requests (no masking needed) + */ + +import type { LocalProviderConfig } from "../config"; +import { HEALTH_CHECK_TIMEOUT_MS, REQUEST_TIMEOUT_MS } from "../constants/timeouts"; +import { ProviderError, type ProviderResult } from "./openai/client"; +import type { OpenAIRequest } from "./openai/types"; + +/** + * Call local LLM (Ollama or OpenAI-compatible) + */ +export async function callLocal( + request: OpenAIRequest, + config: LocalProviderConfig, +): Promise { + const baseUrl = config.base_url.replace(/\/$/, ""); + const endpoint = + config.type === "ollama" ? `${baseUrl}/v1/chat/completions` : `${baseUrl}/chat/completions`; + + const headers: Record = { "Content-Type": "application/json" }; + if (config.api_key) { + headers.Authorization = `Bearer ${config.api_key}`; + } + + const isStreaming = request.stream ?? false; + + const response = await fetch(endpoint, { + method: "POST", + headers, + body: JSON.stringify({ ...request, model: config.model, stream: isStreaming }), + signal: AbortSignal.timeout(REQUEST_TIMEOUT_MS), + }); + + if (!response.ok) { + throw new ProviderError(response.status, response.statusText, await response.text()); + } + + if (isStreaming) { + if (!response.body) { + throw new Error("No response body for streaming request"); + } + return { response: response.body, isStreaming: true, model: config.model }; + } + + return { response: await response.json(), isStreaming: false, model: config.model }; +} + +/** + * Check if local provider is reachable + */ +export async function checkLocalHealth(config: LocalProviderConfig): Promise { + try { + const baseUrl = config.base_url.replace(/\/$/, ""); + const endpoint = config.type === "ollama" ? `${baseUrl}/api/tags` : `${baseUrl}/models`; + + const response = await fetch(endpoint, { + method: "GET", + signal: AbortSignal.timeout(HEALTH_CHECK_TIMEOUT_MS), + }); + return response.ok; + } catch { + return false; + } +} + +/** + * Get local provider info for /info endpoint + */ +export function getLocalInfo(config: LocalProviderConfig): { type: string; baseUrl: string } { + return { + type: config.type, + baseUrl: config.base_url, + }; +} diff --git a/src/providers/openai-client.ts b/src/providers/openai-client.ts deleted file mode 100644 index a7467e4..0000000 --- a/src/providers/openai-client.ts +++ /dev/null @@ -1,200 +0,0 @@ -import type { LocalProviderConfig, OpenAIProviderConfig } from "../config"; -import type { MessageContent } from "../utils/content"; - -/** - * OpenAI-compatible message format - * Supports both text-only (content: string) and multimodal (content: array) formats - */ -export interface ChatMessage { - role: "system" | "developer" | "user" | "assistant"; - content: MessageContent; -} - -/** - * OpenAI-compatible chat completion request - * Only required field is messages - all other params pass through to provider - */ -export interface ChatCompletionRequest { - messages: ChatMessage[]; - model?: string; - stream?: boolean; - [key: string]: unknown; -} - -/** - * OpenAI-compatible chat completion response - */ -export interface ChatCompletionResponse { - id: string; - object: "chat.completion"; - created: number; - model: string; - choices: Array<{ - index: number; - message: ChatMessage; - finish_reason: "stop" | "length" | "content_filter" | null; - }>; - usage?: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - }; -} - -/** - * Result from LLM client including metadata (Discriminated Union) - */ -export type LLMResult = - | { - isStreaming: true; - response: ReadableStream; - model: string; - provider: "openai" | "local"; - } - | { - isStreaming: false; - response: ChatCompletionResponse; - model: string; - provider: "openai" | "local"; - }; - -/** - * Error from upstream LLM provider with original status code and response - */ -export class LLMError extends Error { - constructor( - public readonly status: number, - public readonly statusText: string, - public readonly body: string, - ) { - super(`API error: ${status} ${statusText}`); - this.name = "LLMError"; - } -} - -/** - * LLM Client for OpenAI-compatible APIs (OpenAI, Ollama, etc.) - */ -export class LLMClient { - private baseUrl: string; - private apiKey?: string; - private providerType: "openai" | "ollama"; - private providerName: "openai" | "local"; - private defaultModel?: string; - - constructor( - provider: OpenAIProviderConfig | LocalProviderConfig, - providerName: "openai" | "local", - defaultModel?: string, - ) { - this.baseUrl = provider.base_url.replace(/\/$/, ""); - this.apiKey = provider.api_key; - // Configured providers (openai) always use openai protocol - // Local providers specify their type (ollama or openai-compatible) - this.providerType = "type" in provider ? provider.type : "openai"; - this.providerName = providerName; - this.defaultModel = defaultModel; - } - - /** - * Sends a chat completion request - * @param request The chat completion request - * @param authHeader Optional Authorization header from client (forwarded for openai provider) - */ - async chatCompletion(request: ChatCompletionRequest, authHeader?: string): Promise { - // Local uses configured model, openai uses request model - const model = this.defaultModel || request.model; - const isStreaming = request.stream ?? false; - - if (!model) { - throw new Error("Model is required in request"); - } - - // Build the endpoint URL - const endpoint = - this.providerType === "ollama" - ? `${this.baseUrl}/v1/chat/completions` - : `${this.baseUrl}/chat/completions`; - - // Build headers - const headers: Record = { - "Content-Type": "application/json", - }; - - // Use client's auth header if provided, otherwise fall back to config - if (authHeader) { - headers.Authorization = authHeader; - } else if (this.apiKey) { - headers.Authorization = `Bearer ${this.apiKey}`; - } - - // Build request body - convert max_tokens to max_completion_tokens for OpenAI - const body: Record = { - ...request, - model, - stream: isStreaming, - }; - - // OpenAI newer models use max_completion_tokens instead of max_tokens - if (this.providerType === "openai" && body.max_tokens) { - body.max_completion_tokens = body.max_tokens; - delete body.max_tokens; - } - - const response = await fetch(endpoint, { - method: "POST", - headers, - body: JSON.stringify(body), - signal: AbortSignal.timeout(120_000), // 2 minute timeout for LLM requests - }); - - if (!response.ok) { - const errorText = await response.text(); - throw new LLMError(response.status, response.statusText, errorText); - } - - if (isStreaming) { - if (!response.body) { - throw new Error("No response body for streaming request"); - } - - return { - response: response.body, - isStreaming: true, - model, - provider: this.providerName, - }; - } - - const data = (await response.json()) as ChatCompletionResponse; - return { - response: data, - isStreaming: false, - model, - provider: this.providerName, - }; - } - - /** - * Checks if the local LLM service is healthy (Ollama) - */ - async healthCheck(): Promise { - try { - const response = await fetch(`${this.baseUrl}/api/tags`, { - method: "GET", - signal: AbortSignal.timeout(5000), - }); - return response.ok; - } catch { - return false; - } - } - - getInfo(): { name: "openai" | "local"; type: "openai" | "ollama"; baseUrl: string } { - return { - name: this.providerName, - type: this.providerType, - baseUrl: this.baseUrl, - }; - } -} diff --git a/src/providers/openai/client.ts b/src/providers/openai/client.ts new file mode 100644 index 0000000..a5bf050 --- /dev/null +++ b/src/providers/openai/client.ts @@ -0,0 +1,117 @@ +/** + * OpenAI client - simple functions for OpenAI API + */ + +import type { OpenAIProviderConfig } from "../../config"; +import { HEALTH_CHECK_TIMEOUT_MS, REQUEST_TIMEOUT_MS } from "../../constants/timeouts"; +import { ProviderError } from "../errors"; +import type { OpenAIRequest, OpenAIResponse } from "./types"; + +export { ProviderError } from "../errors"; + +/** + * Result from provider (streaming or non-streaming) + */ +export type ProviderResult = + | { + isStreaming: true; + response: ReadableStream; + model: string; + } + | { + isStreaming: false; + response: OpenAIResponse; + model: string; + }; + +/** + * Call OpenAI chat completion API + */ +export async function callOpenAI( + request: OpenAIRequest, + config: OpenAIProviderConfig, + authHeader?: string, +): Promise { + const model = request.model; + const isStreaming = request.stream ?? false; + + if (!model) { + throw new Error("Model is required in request"); + } + + const baseUrl = config.base_url.replace(/\/$/, ""); + const endpoint = `${baseUrl}/chat/completions`; + + const headers: Record = { + "Content-Type": "application/json", + }; + + // Use client's auth header if provided, otherwise fall back to config + if (authHeader) { + headers.Authorization = authHeader; + } else if (config.api_key) { + headers.Authorization = `Bearer ${config.api_key}`; + } + + // Build request body + const body: Record = { + ...request, + model, + stream: isStreaming, + }; + + // OpenAI newer models use max_completion_tokens instead of max_tokens + if (body.max_tokens) { + body.max_completion_tokens = body.max_tokens; + delete body.max_tokens; + } + + const response = await fetch(endpoint, { + method: "POST", + headers, + body: JSON.stringify(body), + signal: AbortSignal.timeout(REQUEST_TIMEOUT_MS), + }); + + if (!response.ok) { + throw new ProviderError(response.status, response.statusText, await response.text()); + } + + if (isStreaming) { + if (!response.body) { + throw new Error("No response body for streaming request"); + } + return { response: response.body, isStreaming: true, model }; + } + + return { response: await response.json(), isStreaming: false, model }; +} + +/** + * Check if OpenAI API is reachable + */ +export async function checkOpenAIHealth(config: OpenAIProviderConfig): Promise { + try { + const baseUrl = config.base_url.replace(/\/$/, ""); + // Use models endpoint - returns 401 if no auth, 200 if OK + const response = await fetch(`${baseUrl}/models`, { + method: "GET", + signal: AbortSignal.timeout(HEALTH_CHECK_TIMEOUT_MS), + }); + + // 401 means API is up but no auth - that's OK for health check + // 200 means API is up with valid auth + return response.status === 401 || response.status === 200; + } catch { + return false; + } +} + +/** + * Get OpenAI provider info for /info endpoint + */ +export function getOpenAIInfo(config: OpenAIProviderConfig): { baseUrl: string } { + return { + baseUrl: config.base_url, + }; +} diff --git a/src/services/stream-transformer.test.ts b/src/providers/openai/stream-transformer.test.ts similarity index 97% rename from src/services/stream-transformer.test.ts rename to src/providers/openai/stream-transformer.test.ts index eba9922..df9134c 100644 --- a/src/services/stream-transformer.test.ts +++ b/src/providers/openai/stream-transformer.test.ts @@ -1,6 +1,6 @@ import { describe, expect, test } from "bun:test"; -import type { MaskingConfig } from "../config"; -import { createMaskingContext } from "../pii/mask"; +import type { MaskingConfig } from "../../config"; +import { createMaskingContext } from "../../pii/mask"; import { createUnmaskingStream } from "./stream-transformer"; const defaultConfig: MaskingConfig = { diff --git a/src/services/stream-transformer.ts b/src/providers/openai/stream-transformer.ts similarity index 95% rename from src/services/stream-transformer.ts rename to src/providers/openai/stream-transformer.ts index ea64add..b807a86 100644 --- a/src/services/stream-transformer.ts +++ b/src/providers/openai/stream-transformer.ts @@ -1,7 +1,7 @@ -import type { MaskingConfig } from "../config"; -import { flushMaskingBuffer, unmaskStreamChunk } from "../pii/mask"; -import { flushSecretsMaskingBuffer, unmaskSecretsStreamChunk } from "../secrets/mask"; -import type { PlaceholderContext } from "../utils/message-transform"; +import type { MaskingConfig } from "../../config"; +import type { PlaceholderContext } from "../../masking/context"; +import { flushMaskingBuffer, unmaskStreamChunk } from "../../pii/mask"; +import { flushSecretsMaskingBuffer, unmaskSecretsStreamChunk } from "../../secrets/mask"; /** * Creates a transform stream that unmasks SSE content diff --git a/src/providers/openai/types.ts b/src/providers/openai/types.ts new file mode 100644 index 0000000..7500d8d --- /dev/null +++ b/src/providers/openai/types.ts @@ -0,0 +1,69 @@ +/** + * OpenAI API Types + * Based on: https://platform.openai.com/docs/api-reference/chat + */ + +import { z } from "zod"; + +// Content part for multimodal messages +export const OpenAIContentPartSchema = z.object({ + type: z.string(), + text: z.string().optional(), + image_url: z + .object({ + url: z.string(), + detail: z.string().optional(), + }) + .optional(), +}); + +// Message content: string, array (multimodal), or null +export const OpenAIMessageContentSchema = z.union([ + z.string(), + z.array(OpenAIContentPartSchema), + z.null(), +]); + +// Chat message +export const OpenAIMessageSchema = z.object({ + role: z.enum(["system", "developer", "user", "assistant", "tool", "function"]), + content: OpenAIMessageContentSchema.optional(), +}); + +// Chat completion request - minimal required fields, rest passthrough +export const OpenAIRequestSchema = z + .object({ + messages: z.array(OpenAIMessageSchema.passthrough()).min(1, "At least one message is required"), + model: z.string().optional(), + stream: z.boolean().optional(), + }) + .passthrough(); + +// Chat completion response +export const OpenAIResponseSchema = z.object({ + id: z.string(), + object: z.literal("chat.completion"), + created: z.number(), + model: z.string(), + choices: z.array( + z.object({ + index: z.number(), + message: OpenAIMessageSchema.passthrough(), + finish_reason: z.enum(["stop", "length", "content_filter"]).nullable(), + }), + ), + usage: z + .object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + total_tokens: z.number(), + }) + .optional(), +}); + +// Inferred types +export type OpenAIContentPart = z.infer; +export type OpenAIMessageContent = z.infer; +export type OpenAIMessage = z.infer; +export type OpenAIRequest = z.infer; +export type OpenAIResponse = z.infer; diff --git a/src/routes/health.ts b/src/routes/health.ts index 3486e6d..a1696ee 100644 --- a/src/routes/health.ts +++ b/src/routes/health.ts @@ -1,21 +1,30 @@ import { Hono } from "hono"; import { getConfig } from "../config"; -import { getRouter } from "../services/decision"; +import { checkLocalHealth } from "../providers/local"; +import { healthCheck as checkPresidio } from "../services/pii"; export const healthRoutes = new Hono(); healthRoutes.get("/health", async (c) => { const config = getConfig(); - const router = getRouter(); - const health = await router.healthCheck(); - const isHealthy = health.presidio; + const piiEnabled = config.pii_detection.enabled; - const services: Record = { - presidio: health.presidio ? "up" : "down", - }; + const [presidioHealth, localHealth] = await Promise.all([ + piiEnabled ? checkPresidio() : Promise.resolve(true), + config.mode === "route" && config.local + ? checkLocalHealth(config.local) + : Promise.resolve(true), + ]); + + const isHealthy = piiEnabled ? presidioHealth : true; + + const services: Record = {}; + if (piiEnabled) { + services.presidio = presidioHealth ? "up" : "down"; + } if (config.mode === "route") { - services.local_llm = health.local ? "up" : "down"; + services.local_llm = localHealth ? "up" : "down"; } return c.json( diff --git a/src/routes/info.ts b/src/routes/info.ts index 76d525c..8d80c01 100644 --- a/src/routes/info.ts +++ b/src/routes/info.ts @@ -2,27 +2,28 @@ import { Hono } from "hono"; import pkg from "../../package.json"; import { getConfig } from "../config"; import { getPIIDetector } from "../pii/detect"; -import { getRouter } from "../services/decision"; +import { getLocalInfo } from "../providers/local"; +import { getOpenAIInfo } from "../providers/openai/client"; export const infoRoutes = new Hono(); infoRoutes.get("/info", (c) => { const config = getConfig(); - const router = getRouter(); - const providers = router.getProvidersInfo(); const detector = getPIIDetector(); const languageValidation = detector.getLanguageValidation(); + const providers: Record = { + openai: { + base_url: getOpenAIInfo(config.providers.openai).baseUrl, + }, + }; + const info: Record = { name: "PasteGuard", version: pkg.version, description: "Privacy proxy for LLMs", mode: config.mode, - providers: { - openai: { - base_url: providers.openai.baseUrl, - }, - }, + providers, pii_detection: { languages: languageValidation ? { @@ -37,10 +38,11 @@ infoRoutes.get("/info", (c) => { }, }; - if (config.mode === "route" && providers.local) { + if (config.mode === "route" && config.local) { + const localInfo = getLocalInfo(config.local); info.local = { - type: providers.local.type, - base_url: providers.local.baseUrl, + type: localInfo.type, + base_url: localInfo.baseUrl, }; } diff --git a/src/routes/openai.test.ts b/src/routes/openai.test.ts index b9fe618..aaf80ac 100644 --- a/src/routes/openai.test.ts +++ b/src/routes/openai.test.ts @@ -34,6 +34,7 @@ describe("POST /openai/v1/chat/completions", () => { const res = await app.request("/openai/v1/chat/completions", { method: "POST", body: JSON.stringify({ + model: "gpt-5.2", messages: [{ role: "invalid", content: "test" }], }), headers: { "Content-Type": "application/json" }, @@ -41,30 +42,4 @@ describe("POST /openai/v1/chat/completions", () => { expect(res.status).toBe(400); }); - - test("accepts developer role (GPT-5.x compatibility)", async () => { - const res = await app.request("/openai/v1/chat/completions", { - method: "POST", - body: JSON.stringify({ - messages: [ - { role: "developer", content: "You are a helpful assistant" }, - { role: "user", content: "Hello" }, - ], - model: "gpt-5.2", - }), - headers: { "Content-Type": "application/json" }, - }); - - // Should not be 400 (validation passed) - // Will be 401/502 without API key, but that's fine - we're testing validation - expect(res.status).not.toBe(400); - }); -}); - -describe("GET /openai/v1/models", () => { - test("forwards to upstream (returns error without auth)", async () => { - const res = await app.request("/openai/v1/models"); - // Without auth, upstream returns 401 - expect([200, 401, 500, 502]).toContain(res.status); - }); }); diff --git a/src/routes/openai.ts b/src/routes/openai.ts index b49a4ec..b4e7b55 100644 --- a/src/routes/openai.ts +++ b/src/routes/openai.ts @@ -1,488 +1,416 @@ +/** + * OpenAI-compatible chat completion route + * + * Flow: + * 1. Validate request + * 2. Process secrets (detect, maybe block or mask) + * 3. Detect PII + * 4. Based on mode: + * - mask: mask PII, send to OpenAI, unmask response + * - route: send to local (if PII) or OpenAI (if clean) + * 5. Return response + */ + import { zValidator } from "@hono/zod-validator"; import type { Context } from "hono"; import { Hono } from "hono"; import { proxy } from "hono/proxy"; -import { z } from "zod"; import { getConfig, type MaskingConfig } from "../config"; +import type { PlaceholderContext } from "../masking/context"; +import { openaiExtractor } from "../masking/extractors/openai"; import { unmaskResponse as unmaskPIIResponse } from "../pii/mask"; +import { callLocal } from "../providers/local"; +import { callOpenAI, getOpenAIInfo, type ProviderResult } from "../providers/openai/client"; +import { createUnmaskingStream } from "../providers/openai/stream-transformer"; import { - type ChatCompletionRequest, - type ChatCompletionResponse, - type ChatMessage, - LLMError, - type LLMResult, -} from "../providers/openai-client"; -import { detectSecretsInMessages, type MessageSecretsResult } from "../secrets/detect"; -import { maskMessages as maskSecretsMessages, unmaskSecretsResponse } from "../secrets/mask"; -import { getRouter, type MaskDecision, type RoutingDecision } from "../services/decision"; -import { logRequest, type RequestLogData } from "../services/logger"; -import { createUnmaskingStream } from "../services/stream-transformer"; + type OpenAIMessage, + type OpenAIRequest, + OpenAIRequestSchema, + type OpenAIResponse, +} from "../providers/openai/types"; +import { unmaskSecretsResponse } from "../secrets/mask"; +import { logRequest } from "../services/logger"; +import { detectPII, maskPII, type PIIDetectResult } from "../services/pii"; +import { processSecretsRequest, type SecretsProcessResult } from "../services/secrets"; import { extractTextContent } from "../utils/content"; -import type { PlaceholderContext } from "../utils/message-transform"; - -// Request validation schema -const ChatCompletionSchema = z - .object({ - messages: z - .array( - z - .object({ - role: z.enum(["system", "developer", "user", "assistant", "tool", "function"]), - content: z.union([z.string(), z.array(z.any()), z.null()]).optional(), - }) - .passthrough(), // Allow additional fields like name, tool_calls, etc. - ) - .min(1, "At least one message is required"), - }) - .passthrough(); +import { + createLogData, + errorFormats, + handleProviderError, + setBlockedHeaders, + setResponseHeaders, + toPIIHeaderData, + toPIILogData, + toSecretsHeaderData, + toSecretsLogData, +} from "./utils"; export const openaiRoutes = new Hono(); /** - * Type guard for MaskDecision - */ -function isMaskDecision(decision: RoutingDecision): decision is MaskDecision { - return decision.mode === "mask"; -} - -/** - * Create log data for error responses - */ -function createErrorLogData( - body: ChatCompletionRequest, - startTime: number, - statusCode: number, - errorMessage: string, - decision?: RoutingDecision, - secretsResult?: MessageSecretsResult, - maskedContent?: string, -): RequestLogData { - const config = getConfig(); - return { - timestamp: new Date().toISOString(), - mode: decision?.mode ?? config.mode, - provider: decision?.provider ?? "openai", - model: body.model || "unknown", - piiDetected: decision?.piiResult.hasPII ?? false, - entities: decision - ? [...new Set(decision.piiResult.allEntities.map((e) => e.entity_type))] - : [], - latencyMs: Date.now() - startTime, - scanTimeMs: decision?.piiResult.scanTimeMs ?? 0, - language: decision?.piiResult.language ?? config.pii_detection.fallback_language, - languageFallback: decision?.piiResult.languageFallback ?? false, - detectedLanguage: decision?.piiResult.detectedLanguage, - maskedContent, - secretsDetected: secretsResult?.detected, - secretsTypes: secretsResult?.matches.map((m) => m.type), - statusCode, - errorMessage, - }; -} - -/** - * POST /v1/chat/completions - OpenAI-compatible chat completion endpoint + * POST /v1/chat/completions */ openaiRoutes.post( "/chat/completions", - zValidator("json", ChatCompletionSchema, (result, c) => { + zValidator("json", OpenAIRequestSchema, (result, c) => { if (!result.success) { return c.json( - { - error: { - message: "Invalid request body", - type: "invalid_request_error", - param: null, - code: null, - }, - }, + errorFormats.openai.error( + `Invalid request body: ${result.error.message}`, + "invalid_request_error", + ), 400, ); } }), async (c) => { const startTime = Date.now(); - let body = c.req.valid("json") as ChatCompletionRequest; + let request = c.req.valid("json") as OpenAIRequest; const config = getConfig(); - const router = getRouter(); - - // Track secrets detection state for response handling - let secretsResult: MessageSecretsResult | undefined; - let secretsMaskingContext: PlaceholderContext | undefined; - let secretsMasked = false; - - // Secrets detection runs before PII detection (per-part) - if (config.secrets_detection.enabled) { - secretsResult = detectSecretsInMessages(body.messages, config.secrets_detection); - - if (secretsResult.detected) { - const secretTypes = secretsResult.matches.map((m) => m.type); - const secretTypesStr = secretTypes.join(","); - - // Block action - return 400 error - if (config.secrets_detection.action === "block") { - c.header("X-PasteGuard-Secrets-Detected", "true"); - c.header("X-PasteGuard-Secrets-Types", secretTypesStr); - - logRequest( - { - timestamp: new Date().toISOString(), - mode: config.mode, - provider: "openai", - model: body.model || "unknown", - piiDetected: false, - entities: [], - latencyMs: Date.now() - startTime, - scanTimeMs: 0, - language: config.pii_detection.fallback_language, - languageFallback: false, - secretsDetected: true, - secretsTypes: secretTypes, - }, - c.req.header("User-Agent") || null, - ); - - return c.json( - { - error: { - message: `Request blocked: detected secret material (${secretTypesStr}). Remove secrets and retry.`, - type: "invalid_request_error", - param: null, - code: "secrets_detected", - }, - }, - 400, - ); - } - - // Mask action - replace secrets with placeholders (per-part) - if (config.secrets_detection.action === "mask") { - const result = maskSecretsMessages(body.messages, secretsResult); - body = { ...body, messages: result.masked }; - secretsMaskingContext = result.context; - secretsMasked = true; - } - - // route_local action is handled in handleCompletion via secretsResult - } + + // Step 1: Process secrets + const secretsResult = processSecretsRequest(request, config.secrets_detection, openaiExtractor); + + if (secretsResult.blocked) { + return respondBlocked(c, request, secretsResult, startTime); } - let decision: RoutingDecision; - try { - decision = await router.decide(body.messages, secretsResult); - } catch (error) { - console.error("PII detection error:", error); - const errorMessage = "PII detection service unavailable"; - logRequest( - createErrorLogData(body, startTime, 503, errorMessage, undefined, secretsResult), - c.req.header("User-Agent") || null, - ); + // Apply secrets masking to request + if (secretsResult.masked) { + request = secretsResult.request; + } - return c.json( - { - error: { - message: errorMessage, - type: "server_error", - param: null, - code: "service_unavailable", - }, + // Step 2: Detect PII (skip if disabled) + let piiResult: PIIDetectResult; + if (!config.pii_detection.enabled) { + piiResult = { + detection: { + hasPII: false, + spanEntities: [], + allEntities: [], + scanTimeMs: 0, + language: "en", + languageFallback: false, }, - 503, - ); + hasPII: false, + }; + } else { + try { + piiResult = await detectPII(request, openaiExtractor); + } catch (error) { + console.error("PII detection error:", error); + return respondDetectionError(c, request, startTime); + } } - return handleCompletion( - c, - body, - decision, - startTime, - router, + // Step 3: Process based on mode + if (config.mode === "mask") { + const piiMasked = maskPII(request, piiResult.detection, openaiExtractor); + return sendToOpenAI(c, request, { + request: piiMasked.request, + piiResult, + piiMaskingContext: piiMasked.maskingContext, + secretsResult, + startTime, + authHeader: c.req.header("Authorization"), + }); + } + + // Route mode: send to local if PII/secrets detected, otherwise OpenAI + const shouldRouteLocal = + piiResult.hasPII || + (secretsResult.detection?.detected && config.secrets_detection.action === "route_local"); + + if (shouldRouteLocal) { + return sendToLocal(c, request, { + request, + piiResult, + secretsResult, + startTime, + }); + } + + return sendToOpenAI(c, request, { + request, + piiResult, secretsResult, - secretsMaskingContext, - secretsMasked, - ); + startTime, + authHeader: c.req.header("Authorization"), + }); }, ); /** - * Handle chat completion for both route and mask modes + * Wildcard proxy for /models, /embeddings, /audio/*, /images/*, etc. */ -async function handleCompletion( +openaiRoutes.all("/*", (c) => { + const config = getConfig(); + const { baseUrl } = getOpenAIInfo(config.providers.openai); + const path = c.req.path.replace(/^\/openai\/v1/, ""); + + return proxy(`${baseUrl}${path}`, { + ...c.req, + headers: { + "Content-Type": c.req.header("Content-Type"), + Authorization: c.req.header("Authorization"), + }, + }); +}); + +// --- Types --- + +interface OpenAIOptions { + request: OpenAIRequest; + piiResult: PIIDetectResult; + piiMaskingContext?: PlaceholderContext; + secretsResult: SecretsProcessResult; + startTime: number; + authHeader?: string; +} + +interface LocalOptions { + request: OpenAIRequest; + piiResult: PIIDetectResult; + secretsResult: SecretsProcessResult; + startTime: number; +} + +// --- Helpers --- + +function formatMessagesForLog(messages: OpenAIMessage[]): string { + return messages + .map((m) => { + const text = extractTextContent(m.content); + const isMultimodal = Array.isArray(m.content); + return `[${m.role}${isMultimodal ? " multimodal" : ""}] ${text}`; + }) + .join("\n"); +} + +// --- Response handlers --- + +function respondBlocked( c: Context, - body: ChatCompletionRequest, - decision: RoutingDecision, + body: OpenAIRequest, + secretsResult: SecretsProcessResult, startTime: number, - router: ReturnType, - secretsResult?: MessageSecretsResult, - secretsMaskingContext?: PlaceholderContext, - secretsMasked?: boolean, ) { - const client = router.getClient(decision.provider); - const maskingConfig = router.getMaskingConfig(); - const authHeader = decision.provider === "openai" ? c.req.header("Authorization") : undefined; + const secretTypes = secretsResult.blockedTypes ?? []; - // Prepare request and masked content for logging - let request: ChatCompletionRequest = body; - let maskedContent: string | undefined; + setBlockedHeaders(c, secretTypes); - if (isMaskDecision(decision)) { - request = { ...body, messages: decision.maskedMessages }; - maskedContent = formatMessagesForLog(decision.maskedMessages); - } + logRequest( + createLogData({ + provider: "openai", + model: body.model || "unknown", + startTime, + secrets: { detected: true, matches: secretTypes.map((t) => ({ type: t })), masked: false }, + statusCode: 400, + errorMessage: secretsResult.blockedReason, + }), + c.req.header("User-Agent") || null, + ); - // Determine secrets state - const secretsDetected = secretsResult?.detected ?? false; - const secretsTypes = secretsResult?.matches.map((m) => m.type) ?? []; - - // Set response headers (included automatically by c.json/c.body) - c.header("X-PasteGuard-Mode", decision.mode); - c.header("X-PasteGuard-Provider", decision.provider); - c.header("X-PasteGuard-PII-Detected", decision.piiResult.hasPII.toString()); - c.header("X-PasteGuard-Language", decision.piiResult.language); - if (decision.piiResult.languageFallback) { - c.header("X-PasteGuard-Language-Fallback", "true"); - } - if (decision.mode === "mask") { - c.header("X-PasteGuard-PII-Masked", decision.piiResult.hasPII.toString()); - } - if (secretsDetected && secretsTypes.length > 0) { - c.header("X-PasteGuard-Secrets-Detected", "true"); - c.header("X-PasteGuard-Secrets-Types", secretsTypes.join(",")); - } - if (secretsMasked) { - c.header("X-PasteGuard-Secrets-Masked", "true"); - } + return c.json( + errorFormats.openai.error( + `Request blocked: detected secret material (${secretTypes.join(",")}). Remove secrets and retry.`, + "invalid_request_error", + "secrets_detected", + ), + 400, + ); +} + +function respondDetectionError(c: Context, body: OpenAIRequest, startTime: number) { + logRequest( + createLogData({ + provider: "openai", + model: body.model || "unknown", + startTime, + statusCode: 503, + errorMessage: "Detection service unavailable", + }), + c.req.header("User-Agent") || null, + ); + + return c.json( + errorFormats.openai.error( + "Detection service unavailable", + "server_error", + "service_unavailable", + ), + 503, + ); +} + +// --- Provider handlers --- + +async function sendToOpenAI(c: Context, originalRequest: OpenAIRequest, opts: OpenAIOptions) { + const config = getConfig(); + const { request, piiResult, piiMaskingContext, secretsResult, startTime, authHeader } = opts; + + const maskedContent = + piiResult.hasPII || secretsResult.masked ? formatMessagesForLog(request.messages) : undefined; + + setResponseHeaders( + c, + config.mode, + "openai", + toPIIHeaderData(piiResult), + toSecretsHeaderData(secretsResult), + ); try { - const result = await client.chatCompletion(request, authHeader); + const result = await callOpenAI(request, config.providers.openai, authHeader); + + logRequest( + createLogData({ + provider: "openai", + model: result.model || originalRequest.model || "unknown", + startTime, + pii: toPIILogData(piiResult), + secrets: toSecretsLogData(secretsResult), + maskedContent, + }), + c.req.header("User-Agent") || null, + ); if (result.isStreaming) { - return handleStreamingResponse( + return respondStreaming( c, result, - decision, - startTime, - maskedContent, - maskingConfig, - secretsDetected, - secretsTypes, - secretsMaskingContext, + piiMaskingContext, + secretsResult.maskingContext, + config.masking, ); } - return handleJsonResponse( + return respondJson( c, - result, - decision, - startTime, - maskedContent, - maskingConfig, - secretsDetected, - secretsTypes, - secretsMaskingContext, + result.response, + piiMaskingContext, + secretsResult.maskingContext, + config.masking, ); } catch (error) { - console.error("LLM request error:", error); - - // Pass through upstream LLM errors with original status code - if (error instanceof LLMError) { - logRequest( - createErrorLogData( - body, - startTime, - error.status, - error.message, - decision, - secretsResult, - maskedContent, - ), - c.req.header("User-Agent") || null, - ); + return handleProviderError( + c, + error, + { + provider: "openai", + model: originalRequest.model || "unknown", + startTime, + pii: toPIILogData(piiResult), + secrets: toSecretsLogData(secretsResult), + maskedContent, + userAgent: c.req.header("User-Agent") || null, + }, + (msg) => errorFormats.openai.error(msg, "server_error", "upstream_error"), + ); + } +} - // Pass through upstream error - must use Response for dynamic status code - return new Response(error.body, { - status: error.status, - headers: c.res.headers, - }); - } +async function sendToLocal(c: Context, originalRequest: OpenAIRequest, opts: LocalOptions) { + const config = getConfig(); + const { request, piiResult, secretsResult, startTime } = opts; + + if (!config.local) { + throw new Error("Local provider not configured"); + } + + const maskedContent = + piiResult.hasPII || secretsResult.masked ? formatMessagesForLog(request.messages) : undefined; + + setResponseHeaders( + c, + config.mode, + "local", + toPIIHeaderData(piiResult), + toSecretsHeaderData(secretsResult), + ); + + try { + const result = await callLocal(request, config.local); - // For other errors (network, timeout, etc.), return 502 in OpenAI-compatible format - const message = error instanceof Error ? error.message : "Unknown error"; - const errorMessage = `Provider error: ${message}`; logRequest( - createErrorLogData( - body, + createLogData({ + provider: "local", + model: result.model || originalRequest.model || "unknown", startTime, - 502, - errorMessage, - decision, - secretsResult, + pii: toPIILogData(piiResult), + secrets: toSecretsLogData(secretsResult), maskedContent, - ), + }), c.req.header("User-Agent") || null, ); - return c.json( + if (result.isStreaming) { + c.header("Content-Type", "text/event-stream"); + c.header("Cache-Control", "no-cache"); + c.header("Connection", "keep-alive"); + return c.body(result.response as ReadableStream); + } + + return c.json(result.response); + } catch (error) { + return handleProviderError( + c, + error, { - error: { - message: errorMessage, - type: "server_error", - param: null, - code: "upstream_error", - }, + provider: "local", + model: originalRequest.model || "unknown", + startTime, + pii: toPIILogData(piiResult), + secrets: toSecretsLogData(secretsResult), + maskedContent, + userAgent: c.req.header("User-Agent") || null, }, - 502, + (msg) => errorFormats.openai.error(msg, "server_error", "upstream_error"), ); } } -/** - * Handle streaming response - */ -function handleStreamingResponse( +// --- Response formatters --- + +function respondStreaming( c: Context, - result: LLMResult & { isStreaming: true }, - decision: RoutingDecision, - startTime: number, - maskedContent: string | undefined, - maskingConfig: MaskingConfig, - secretsDetected?: boolean, - secretsTypes?: string[], - secretsMaskingContext?: PlaceholderContext, + result: ProviderResult & { isStreaming: true }, + piiContext?: PlaceholderContext, + secretsContext?: PlaceholderContext, + maskingConfig?: MaskingConfig, ) { - logRequest( - createLogData( - decision, - result, - startTime, - undefined, - maskedContent, - secretsDetected, - secretsTypes, - ), - c.req.header("User-Agent") || null, - ); - c.header("Content-Type", "text/event-stream"); c.header("Cache-Control", "no-cache"); c.header("Connection", "keep-alive"); - // Determine if we need to transform the stream - const needsPIIUnmasking = isMaskDecision(decision); - const needsSecretsUnmasking = secretsMaskingContext !== undefined; - - if (needsPIIUnmasking || needsSecretsUnmasking) { - const unmaskingStream = createUnmaskingStream( + if (piiContext || secretsContext) { + const stream = createUnmaskingStream( result.response, - needsPIIUnmasking ? decision.maskingContext : undefined, - maskingConfig, - secretsMaskingContext, + piiContext, + maskingConfig!, + secretsContext, ); - return c.body(unmaskingStream); + return c.body(stream); } return c.body(result.response); } -/** - * Handle JSON response - */ -function handleJsonResponse( +function respondJson( c: Context, - result: LLMResult & { isStreaming: false }, - decision: RoutingDecision, - startTime: number, - maskedContent: string | undefined, - maskingConfig: MaskingConfig, - secretsDetected?: boolean, - secretsTypes?: string[], - secretsMaskingContext?: PlaceholderContext, + response: OpenAIResponse, + piiContext?: PlaceholderContext, + secretsContext?: PlaceholderContext, + maskingConfig?: MaskingConfig, ) { - logRequest( - createLogData( - decision, - result, - startTime, - result.response, - maskedContent, - secretsDetected, - secretsTypes, - ), - c.req.header("User-Agent") || null, - ); - - let response = result.response; + let result = response; - // First unmask PII if needed - if (isMaskDecision(decision)) { - response = unmaskPIIResponse(response, decision.maskingContext, maskingConfig); + if (piiContext) { + result = unmaskPIIResponse(result, piiContext, maskingConfig!, openaiExtractor); } - - // Then unmask secrets if needed - if (secretsMaskingContext) { - response = unmaskSecretsResponse(response, secretsMaskingContext); + if (secretsContext) { + result = unmaskSecretsResponse(result, secretsContext, openaiExtractor); } - return c.json(response); -} - -/** - * Create log data from decision and result - */ -function createLogData( - decision: RoutingDecision, - result: LLMResult, - startTime: number, - response?: ChatCompletionResponse, - maskedContent?: string, - secretsDetected?: boolean, - secretsTypes?: string[], -): RequestLogData { - return { - timestamp: new Date().toISOString(), - mode: decision.mode, - provider: decision.provider, - model: result.model, - piiDetected: decision.piiResult.hasPII, - entities: [...new Set(decision.piiResult.allEntities.map((e) => e.entity_type))], - latencyMs: Date.now() - startTime, - scanTimeMs: decision.piiResult.scanTimeMs, - promptTokens: response?.usage?.prompt_tokens, - completionTokens: response?.usage?.completion_tokens, - language: decision.piiResult.language, - languageFallback: decision.piiResult.languageFallback, - detectedLanguage: decision.piiResult.detectedLanguage, - maskedContent, - secretsDetected, - secretsTypes, - }; -} - -/** - * Format messages for logging - */ -function formatMessagesForLog(messages: ChatMessage[]): string { - return messages - .map((m) => { - const text = extractTextContent(m.content); - const isMultimodal = Array.isArray(m.content); - return `[${m.role}${isMultimodal ? " multimodal" : ""}] ${text}`; - }) - .join("\n"); + return c.json(result); } - -/** - * Wildcard proxy for /models, /embeddings, /audio/*, /images/*, etc. - */ -openaiRoutes.all("/*", (c) => { - const { openai } = getRouter().getProvidersInfo(); - const path = c.req.path.replace(/^\/openai\/v1/, ""); - - return proxy(`${openai.baseUrl}${path}`, { - ...c.req, - headers: { - "Content-Type": c.req.header("Content-Type"), - Authorization: c.req.header("Authorization"), - }, - }); -}); diff --git a/src/routes/utils.ts b/src/routes/utils.ts new file mode 100644 index 0000000..4f1e891 --- /dev/null +++ b/src/routes/utils.ts @@ -0,0 +1,293 @@ +/** + * Shared route utilities + * + * Common utilities for route handlers including error formatting, + * response headers, and logging helpers. + */ + +import type { Context } from "hono"; +import { getConfig } from "../config"; +import { ProviderError } from "../providers/errors"; +import type { RequestLogData } from "../services/logger"; +import { logRequest } from "../services/logger"; +import type { PIIDetectResult } from "../services/pii"; +import type { SecretsProcessResult } from "../services/secrets"; + +// ============================================================================ +// Error Response Types & Formatting +// ============================================================================ + +/** + * Error response format for OpenAI + */ +export interface OpenAIErrorResponse { + error: { + message: string; + type: "invalid_request_error" | "server_error"; + param: null; + code: string | null; + }; +} + +/** + * Format adapters for different API schemas + */ +export const errorFormats = { + openai: { + error( + message: string, + type: "invalid_request_error" | "server_error", + code?: string, + ): OpenAIErrorResponse { + return { + error: { + message, + type, + param: null, + code: code ?? null, + }, + }; + }, + }, +}; + +// ============================================================================ +// Response Headers +// ============================================================================ + +export interface PIIHeaderData { + hasPII: boolean; + language: string; + languageFallback: boolean; +} + +export interface SecretsHeaderData { + detected: boolean; + types: string[]; + masked: boolean; +} + +/** + * Set common PasteGuard response headers + */ +export function setResponseHeaders( + c: Context, + mode: string, + provider: string, + pii: PIIHeaderData, + secrets?: SecretsHeaderData, +): void { + c.header("X-PasteGuard-Mode", mode); + c.header("X-PasteGuard-Provider", provider); + c.header("X-PasteGuard-PII-Detected", pii.hasPII.toString()); + c.header("X-PasteGuard-Language", pii.language); + + if (pii.languageFallback) { + c.header("X-PasteGuard-Language-Fallback", "true"); + } + if (mode === "mask" && pii.hasPII) { + c.header("X-PasteGuard-PII-Masked", "true"); + } + if (secrets?.detected) { + c.header("X-PasteGuard-Secrets-Detected", "true"); + c.header("X-PasteGuard-Secrets-Types", secrets.types.join(",")); + } + if (secrets?.masked) { + c.header("X-PasteGuard-Secrets-Masked", "true"); + } +} + +/** + * Set headers for blocked request (secrets detected) + */ +export function setBlockedHeaders(c: Context, secretTypes: string[]): void { + c.header("X-PasteGuard-Secrets-Detected", "true"); + c.header("X-PasteGuard-Secrets-Types", secretTypes.join(",")); +} + +// ============================================================================ +// Logging Helpers +// ============================================================================ + +/** + * PII detection result for logging + */ +export interface PIILogData { + hasPII: boolean; + allEntities: { entity_type: string }[]; + language: string; + languageFallback: boolean; + detectedLanguage?: string; + scanTimeMs: number; +} + +/** + * Secrets detection result for logging + */ +export interface SecretsLogData { + detected?: boolean; + matches?: { type: string }[]; + masked: boolean; +} + +/** + * Convert PIIDetectResult to PIILogData + */ +export function toPIILogData(piiResult: PIIDetectResult): PIILogData { + return { + hasPII: piiResult.hasPII, + allEntities: piiResult.detection.allEntities, + language: piiResult.detection.language, + languageFallback: piiResult.detection.languageFallback, + detectedLanguage: piiResult.detection.detectedLanguage, + scanTimeMs: piiResult.detection.scanTimeMs, + }; +} + +/** + * Convert PIIDetectResult to PIIHeaderData + */ +export function toPIIHeaderData(piiResult: PIIDetectResult): PIIHeaderData { + return { + hasPII: piiResult.hasPII, + language: piiResult.detection.language, + languageFallback: piiResult.detection.languageFallback, + }; +} + +/** + * Convert SecretsProcessResult to SecretsLogData + */ +export function toSecretsLogData( + secretsResult: SecretsProcessResult, +): SecretsLogData | undefined { + if (!secretsResult.detection) return undefined; + return { + detected: secretsResult.detection.detected, + matches: secretsResult.detection.matches, + masked: secretsResult.masked, + }; +} + +/** + * Convert SecretsProcessResult to SecretsHeaderData + */ +export function toSecretsHeaderData( + secretsResult: SecretsProcessResult, +): SecretsHeaderData | undefined { + if (!secretsResult.detection?.detected) return undefined; + return { + detected: true, + types: secretsResult.detection.matches.map((m) => m.type), + masked: secretsResult.masked, + }; +} + +export interface CreateLogDataOptions { + provider: "openai" | "local"; + model: string; + startTime: number; + pii?: PIILogData; + secrets?: SecretsLogData; + maskedContent?: string; + statusCode?: number; + errorMessage?: string; +} + +/** + * Create log data object for request logging + */ +export function createLogData(options: CreateLogDataOptions): RequestLogData { + const config = getConfig(); + const { provider, model, startTime, pii, secrets, maskedContent, statusCode, errorMessage } = + options; + + return { + timestamp: new Date().toISOString(), + mode: config.mode, + provider, + model: model || "unknown", + piiDetected: pii?.hasPII ?? false, + entities: pii ? [...new Set(pii.allEntities.map((e) => e.entity_type))] : [], + latencyMs: Date.now() - startTime, + scanTimeMs: pii?.scanTimeMs ?? 0, + language: pii?.language ?? config.pii_detection.fallback_language, + languageFallback: pii?.languageFallback ?? false, + detectedLanguage: pii?.detectedLanguage, + maskedContent, + secretsDetected: secrets?.detected, + secretsTypes: secrets?.matches?.map((m) => m.type), + statusCode, + errorMessage, + }; +} + +// ============================================================================ +// Provider Error Handling +// ============================================================================ + +export interface ProviderErrorContext { + provider: "openai" | "local"; + model: string; + startTime: number; + pii?: PIILogData; + secrets?: SecretsLogData; + maskedContent?: string; + userAgent: string | null; +} + +/** + * Handle provider errors with logging + * + * Returns the appropriate response for the error type. + * For ProviderError, returns the original error body. + * For other errors, returns a formatted error response. + */ +export function handleProviderError( + c: Context, + error: unknown, + ctx: ProviderErrorContext, + formatError: (message: string) => object, +): Response { + console.error(`${ctx.provider} request error:`, error); + + if (error instanceof ProviderError) { + logRequest( + createLogData({ + provider: ctx.provider, + model: ctx.model, + startTime: ctx.startTime, + pii: ctx.pii, + secrets: ctx.secrets, + maskedContent: ctx.maskedContent, + statusCode: error.status, + errorMessage: error.message, + }), + ctx.userAgent, + ); + + return new Response(error.body, { + status: error.status, + headers: c.res.headers, + }); + } + + const message = error instanceof Error ? error.message : "Unknown error"; + const errorMessage = `Provider error: ${message}`; + + logRequest( + createLogData({ + provider: ctx.provider, + model: ctx.model, + startTime: ctx.startTime, + pii: ctx.pii, + secrets: ctx.secrets, + maskedContent: ctx.maskedContent, + statusCode: 502, + errorMessage, + }), + ctx.userAgent, + ); + + return c.json(formatError(errorMessage), 502); +} diff --git a/src/secrets/detect.ts b/src/secrets/detect.ts index dfe106e..35bfc05 100644 --- a/src/secrets/detect.ts +++ b/src/secrets/detect.ts @@ -1,6 +1,5 @@ import type { SecretsDetectionConfig } from "../config"; -import type { ChatMessage } from "../providers/openai-client"; -import type { ContentPart } from "../utils/content"; +import type { RequestExtractor, TextSpan } from "../masking/types"; import { patternDetectors } from "./patterns"; import type { MessageSecretsResult, @@ -69,64 +68,53 @@ export function detectSecrets( } /** - * Detects secrets in chat messages with per-part granularity - * - * For string content, partIdx is always 0. - * For array content (multimodal), each text part is scanned separately. - * This avoids complex offset mapping when applying masks. + * Detects secrets in a request using an extractor */ -export function detectSecretsInMessages( - messages: ChatMessage[], +export function detectSecretsInRequest( + request: TRequest, + config: SecretsDetectionConfig, + extractor: RequestExtractor, +): MessageSecretsResult { + const spans = extractor.extractTexts(request); + return detectSecretsInSpans(spans, config); +} + +/** + * Detects secrets in text spans (low-level) + */ +export function detectSecretsInSpans( + spans: TextSpan[], config: SecretsDetectionConfig, ): MessageSecretsResult { if (!config.enabled) { return { detected: false, matches: [], - messageLocations: messages.map(() => []), + spanLocations: spans.map(() => []), }; } + // Detect secrets in each span const matchCounts = new Map(); - - const messageLocations: SecretLocation[][][] = messages.map((message) => { - // String content → single part at index 0 - if (typeof message.content === "string") { - const result = detectSecrets(message.content, config); - for (const match of result.matches) { - matchCounts.set(match.type, (matchCounts.get(match.type) || 0) + match.count); - } - return [result.locations || []]; + const spanLocations: SecretLocation[][] = spans.map((span) => { + const result = detectSecrets(span.text, config); + for (const match of result.matches) { + matchCounts.set(match.type, (matchCounts.get(match.type) || 0) + match.count); } - - // Array content (multimodal) → one array per part - if (Array.isArray(message.content)) { - return message.content.map((part: ContentPart) => { - if (part.type !== "text" || typeof part.text !== "string") { - return []; - } - const result = detectSecrets(part.text, config); - for (const match of result.matches) { - matchCounts.set(match.type, (matchCounts.get(match.type) || 0) + match.count); - } - return result.locations || []; - }); - } - - // Null/undefined content - return []; + return result.locations || []; }); + // Build matches array const allMatches: SecretsMatch[] = []; for (const [type, count] of matchCounts) { allMatches.push({ type: type as SecretLocation["type"], count }); } - const hasLocations = messageLocations.some((msg) => msg.some((part) => part.length > 0)); + const hasLocations = spanLocations.some((locs) => locs.length > 0); return { detected: hasLocations, matches: allMatches, - messageLocations, + spanLocations, }; } diff --git a/src/secrets/mask.test.ts b/src/secrets/mask.test.ts index e58cb46..2fc7e3d 100644 --- a/src/secrets/mask.test.ts +++ b/src/secrets/mask.test.ts @@ -1,10 +1,12 @@ import { describe, expect, test } from "bun:test"; -import { createSecretsResult } from "../test-utils/detection-results"; +import { openaiExtractor } from "../masking/extractors/openai"; +import type { OpenAIMessage, OpenAIRequest, OpenAIResponse } from "../providers/openai/types"; +import { createSecretsResultFromSpans } from "../test-utils/detection-results"; import type { SecretLocation } from "./detect"; import { createSecretsMaskingContext, flushSecretsMaskingBuffer, - maskMessages, + maskRequest, maskSecrets, unmaskSecrets, unmaskSecretsResponse, @@ -13,6 +15,11 @@ import { const sampleSecret = "sk-proj-abc123def456ghi789jkl012mno345pqr678stu901vwx"; +/** Helper to create a minimal request from messages */ +function createRequest(messages: OpenAIMessage[]): OpenAIRequest { + return { model: "gpt-4", messages }; +} + describe("secrets placeholder format", () => { test("uses [[SECRET_MASKED_TYPE_N]] format", () => { const text = `My API key is ${sampleSecret}`; @@ -59,59 +66,61 @@ describe("secrets placeholder format", () => { }); }); -describe("maskMessages with MessageSecretsResult", () => { +describe("maskRequest with MessageSecretsResult", () => { test("masks secrets in multiple messages", () => { - const messages = [ - { role: "user" as const, content: `My key is ${sampleSecret}` }, - { role: "assistant" as const, content: "I'll help you with that." }, - ]; - const detection = createSecretsResult([ - [[{ start: 10, end: 10 + sampleSecret.length, type: "API_KEY_OPENAI" }]], - [[]], + const request = createRequest([ + { role: "user", content: `My key is ${sampleSecret}` }, + { role: "assistant", content: "I'll help you with that." }, + ]); + // spanLocations[0] = first message (user), spanLocations[1] = second message (assistant) + const detection = createSecretsResultFromSpans([ + [{ start: 10, end: 10 + sampleSecret.length, type: "API_KEY_OPENAI" }], + [], ]); - const { masked, context } = maskMessages(messages, detection); + const { masked, context } = maskRequest(request, detection, openaiExtractor); - expect(masked[0].content).toContain("[[SECRET_MASKED_API_KEY_OPENAI_1]]"); - expect(masked[0].content).not.toContain(sampleSecret); - expect(masked[1].content).toBe("I'll help you with that."); + expect(masked.messages[0].content).toContain("[[SECRET_MASKED_API_KEY_OPENAI_1]]"); + expect(masked.messages[0].content).not.toContain(sampleSecret); + expect(masked.messages[1].content).toBe("I'll help you with that."); expect(Object.keys(context.mapping)).toHaveLength(1); }); test("shares context across messages - same secret gets same placeholder", () => { - const messages = [ - { role: "user" as const, content: `Key1: ${sampleSecret}` }, - { role: "user" as const, content: `Key2: ${sampleSecret}` }, - ]; - const detection = createSecretsResult([ - [[{ start: 6, end: 6 + sampleSecret.length, type: "API_KEY_OPENAI" }]], - [[{ start: 6, end: 6 + sampleSecret.length, type: "API_KEY_OPENAI" }]], + const request = createRequest([ + { role: "user", content: `Key1: ${sampleSecret}` }, + { role: "user", content: `Key2: ${sampleSecret}` }, + ]); + const detection = createSecretsResultFromSpans([ + [{ start: 6, end: 6 + sampleSecret.length, type: "API_KEY_OPENAI" }], + [{ start: 6, end: 6 + sampleSecret.length, type: "API_KEY_OPENAI" }], ]); - const { masked, context } = maskMessages(messages, detection); + const { masked, context } = maskRequest(request, detection, openaiExtractor); - expect(masked[0].content).toBe("Key1: [[SECRET_MASKED_API_KEY_OPENAI_1]]"); - expect(masked[1].content).toBe("Key2: [[SECRET_MASKED_API_KEY_OPENAI_1]]"); + expect(masked.messages[0].content).toBe("Key1: [[SECRET_MASKED_API_KEY_OPENAI_1]]"); + expect(masked.messages[1].content).toBe("Key2: [[SECRET_MASKED_API_KEY_OPENAI_1]]"); expect(Object.keys(context.mapping)).toHaveLength(1); }); test("handles multimodal array content", () => { - const messages = [ + const request = createRequest([ { - role: "user" as const, + role: "user", content: [ { type: "text", text: `Key: ${sampleSecret}` }, { type: "image_url", image_url: { url: "https://example.com/img.jpg" } }, ], }, - ]; - const detection = createSecretsResult([ - [[{ start: 5, end: 5 + sampleSecret.length, type: "API_KEY_OPENAI" }], []], + ]); + // Two spans: text content at index 0, image is skipped + const detection = createSecretsResultFromSpans([ + [{ start: 5, end: 5 + sampleSecret.length, type: "API_KEY_OPENAI" }], ]); - const { masked } = maskMessages(messages, detection); + const { masked } = maskRequest(request, detection, openaiExtractor); - const content = masked[0].content as Array<{ type: string; text?: string }>; + const content = masked.messages[0].content as Array<{ type: string; text?: string }>; expect(content[0].text).toBe("Key: [[SECRET_MASKED_API_KEY_OPENAI_1]]"); expect(content[1].type).toBe("image_url"); }); @@ -179,45 +188,45 @@ describe("unmaskSecretsResponse", () => { const context = createSecretsMaskingContext(); context.mapping["[[SECRET_MASKED_API_KEY_OPENAI_1]]"] = sampleSecret; - const response = { + const response: OpenAIResponse = { id: "test", - object: "chat.completion" as const, + object: "chat.completion", created: Date.now(), model: "gpt-4", choices: [ { index: 0, message: { - role: "assistant" as const, + role: "assistant", content: "Your key is [[SECRET_MASKED_API_KEY_OPENAI_1]]", }, - finish_reason: "stop" as const, + finish_reason: "stop", }, ], }; - const result = unmaskSecretsResponse(response, context); + const result = unmaskSecretsResponse(response, context, openaiExtractor); expect(result.choices[0].message.content).toBe(`Your key is ${sampleSecret}`); }); test("preserves response structure", () => { const context = createSecretsMaskingContext(); - const response = { + const response: OpenAIResponse = { id: "test-id", - object: "chat.completion" as const, + object: "chat.completion", created: 12345, model: "gpt-4-turbo", choices: [ { index: 0, - message: { role: "assistant" as const, content: "Hello" }, - finish_reason: "stop" as const, + message: { role: "assistant", content: "Hello" }, + finish_reason: "stop", }, ], usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, }; - const result = unmaskSecretsResponse(response, context); + const result = unmaskSecretsResponse(response, context, openaiExtractor); expect(result.id).toBe("test-id"); expect(result.model).toBe("gpt-4-turbo"); expect(result.usage).toEqual({ prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }); diff --git a/src/secrets/mask.ts b/src/secrets/mask.ts index 2d3139c..f905a6e 100644 --- a/src/secrets/mask.ts +++ b/src/secrets/mask.ts @@ -1,33 +1,36 @@ -import type { ChatCompletionResponse, ChatMessage } from "../providers/openai-client"; -import { resolveOverlaps } from "../utils/conflict-resolver"; +/** + * Secrets masking + */ + +import { resolveOverlaps } from "../masking/conflict-resolver"; +import { incrementAndGenerate } from "../masking/context"; +import { generateSecretPlaceholder } from "../masking/placeholders"; import { - createPlaceholderContext, - flushBuffer, - incrementAndGenerate, - type MaskResult, + createMaskingContext, + flushMaskingBuffer as flushBuffer, + maskSpans, type PlaceholderContext, - processStreamChunk, - replaceWithPlaceholders, - restorePlaceholders, - restoreResponsePlaceholders, - transformMessagesPerPart, -} from "../utils/message-transform"; -import { generateSecretPlaceholder } from "../utils/placeholders"; + unmaskStreamChunk as unmaskChunk, + unmask as unmaskText, +} from "../masking/service"; +import type { RequestExtractor, TextSpan } from "../masking/types"; import type { MessageSecretsResult, SecretLocation } from "./detect"; -export type { MaskResult } from "../utils/message-transform"; +export { + createMaskingContext as createSecretsMaskingContext, + type PlaceholderContext, +} from "../masking/service"; /** - * Creates a new secrets masking context for a request + * Result of masking operation */ -export function createSecretsMaskingContext(): PlaceholderContext { - return createPlaceholderContext(); +export interface MaskResult { + masked: string; + context: PlaceholderContext; } /** * Generates a placeholder for a secret type - * - * Format: [[SECRET_MASKED_{TYPE}_{N}]] e.g. [[SECRET_MASKED_API_KEY_OPENAI_1]] */ function generatePlaceholder(secretType: string, context: PlaceholderContext): string { return incrementAndGenerate(secretType, context, generateSecretPlaceholder); @@ -41,75 +44,105 @@ export function maskSecrets( locations: SecretLocation[], context?: PlaceholderContext, ): MaskResult { - const ctx = context || createSecretsMaskingContext(); - const masked = replaceWithPlaceholders( - text, - locations, - ctx, + const spans: TextSpan[] = [{ text, path: "text", messageIndex: 0, partIndex: 0 }]; + const perSpanData = [locations]; + + const result = maskSpans( + spans, + perSpanData, (loc) => loc.type, generatePlaceholder, resolveOverlaps, + context, ); - return { masked, context: ctx }; + + return { + masked: result.maskedSpans[0]?.maskedText ?? text, + context: result.context, + }; } /** * Unmasks text by replacing placeholders with original secrets - * - * @param text - Text containing secret placeholders - * @param context - Masking context with mappings */ export function unmaskSecrets(text: string, context: PlaceholderContext): string { - return restorePlaceholders(text, context); -} - -/** - * Masks secrets in messages using per-part detection results - * - * Uses transformMessagesPerPart for the common iteration pattern. - */ -export function maskMessages( - messages: ChatMessage[], - detection: MessageSecretsResult, -): { masked: ChatMessage[]; context: PlaceholderContext } { - const context = createSecretsMaskingContext(); - - const masked = transformMessagesPerPart( - messages, - detection.messageLocations, - (text, locations, ctx) => maskSecrets(text, locations, ctx).masked, - context, - ); - - return { masked, context }; + return unmaskText(text, context); } /** * Streaming unmask helper - processes chunks and unmasks when complete placeholders are found - * - * Returns the unmasked portion and any remaining buffer that might contain partial placeholders. */ export function unmaskSecretsStreamChunk( buffer: string, newChunk: string, context: PlaceholderContext, ): { output: string; remainingBuffer: string } { - return processStreamChunk(buffer, newChunk, context, unmaskSecrets); + return unmaskChunk(buffer, newChunk, context); } /** * Flushes remaining buffer at end of stream */ export function flushSecretsMaskingBuffer(buffer: string, context: PlaceholderContext): string { - return flushBuffer(buffer, context, unmaskSecrets); + return flushBuffer(buffer, context); } /** - * Unmasks a chat completion response by replacing placeholders in all choices + * Unmasks secrets in a response using an extractor */ -export function unmaskSecretsResponse( - response: ChatCompletionResponse, +export function unmaskSecretsResponse( + response: TResponse, context: PlaceholderContext, -): ChatCompletionResponse { - return restoreResponsePlaceholders(response, context); + extractor: RequestExtractor, +): TResponse { + return extractor.unmaskResponse(response, context); +} + +/** + * Result of masking a request + */ +export interface MaskRequestResult { + /** The masked request */ + masked: TRequest; + /** Masking context for unmasking response */ + context: PlaceholderContext; +} + +/** + * Masks secrets in a request using an extractor + */ +export function maskRequest( + request: TRequest, + detection: MessageSecretsResult, + extractor: RequestExtractor, +): MaskRequestResult { + const context = createMaskingContext(); + + if (!detection.spanLocations) { + return { masked: request, context }; + } + + // Extract text spans from request + const spans = extractor.extractTexts(request); + + // Mask the spans + const { maskedSpans } = maskSpans( + spans, + detection.spanLocations, + (loc) => loc.type, + generatePlaceholder, + resolveOverlaps, + context, + ); + + // Filter to only spans that were actually masked (have locations) + const changedSpans = maskedSpans.filter((_, i) => { + const locations = detection.spanLocations![i] || []; + return locations.length > 0; + }); + + // Apply masked text back to request + const masked = extractor.applyMasked(request, changedSpans); + + return { masked, context }; } diff --git a/src/secrets/multimodal.test.ts b/src/secrets/multimodal.test.ts index 6a23a8e..b738ac6 100644 --- a/src/secrets/multimodal.test.ts +++ b/src/secrets/multimodal.test.ts @@ -1,27 +1,19 @@ import { describe, expect, test } from "bun:test"; -import type { PIIDetectionResult, PIIEntity } from "../pii/detect"; -import { maskMessages } from "../pii/mask"; -import type { ChatMessage } from "../providers/openai-client"; -import type { ContentPart } from "../utils/content"; - -/** - * Helper to create PIIDetectionResult from per-part entities - */ -function createPIIResult(messageEntities: PIIEntity[][][]): PIIDetectionResult { - return { - hasPII: messageEntities.flat(2).length > 0, - messageEntities, - allEntities: messageEntities.flat(2), - scanTimeMs: 0, - language: "en", - languageFallback: false, - }; +import { openaiExtractor } from "../masking/extractors/openai"; +import { maskRequest } from "../pii/mask"; +import type { OpenAIMessage, OpenAIRequest } from "../providers/openai/types"; +import { createPIIResultFromSpans } from "../test-utils/detection-results"; +import type { OpenAIContentPart } from "../utils/content"; + +/** Helper to create a minimal request from messages */ +function createRequest(messages: OpenAIMessage[]): OpenAIRequest { + return { model: "gpt-4", messages }; } describe("Multimodal content handling", () => { - describe("PII masking with per-part entities", () => { + describe("PII masking with per-span entities", () => { test("masks PII in multimodal array content", () => { - const messages: ChatMessage[] = [ + const request = createRequest([ { role: "user", content: [ @@ -30,26 +22,23 @@ describe("Multimodal content handling", () => { { type: "text", text: "my phone is 555-1234" }, ], }, - ]; + ]); - // Per-part entities: messageEntities[msgIdx][partIdx] = entities - const detection = createPIIResult([ - [ - // Part 0: email entity (positions relative to part text) - [{ entity_type: "EMAIL_ADDRESS", start: 12, end: 28, score: 0.9 }], - // Part 1: image, no entities - [], - // Part 2: phone entity (positions relative to part text) - [{ entity_type: "PHONE_NUMBER", start: 12, end: 20, score: 0.85 }], - ], + // Per-span entities: spanEntities[spanIdx] = entities + // Span 0: first text part, Span 1: second text part (image skipped) + const detection = createPIIResultFromSpans([ + // Span 0: email entity (positions relative to span text) + [{ entity_type: "EMAIL_ADDRESS", start: 12, end: 28, score: 0.9 }], + // Span 1: phone entity (positions relative to span text) + [{ entity_type: "PHONE_NUMBER", start: 12, end: 20, score: 0.85 }], ]); - const { masked } = maskMessages(messages, detection); + const { request: masked } = maskRequest(request, detection, openaiExtractor); // Verify the content is still an array - expect(Array.isArray(masked[0].content)).toBe(true); + expect(Array.isArray(masked.messages[0].content)).toBe(true); - const maskedContent = masked[0].content as ContentPart[]; + const maskedContent = masked.messages[0].content as OpenAIContentPart[]; // Part 0 should have email masked expect(maskedContent[0].type).toBe("text"); @@ -67,29 +56,27 @@ describe("Multimodal content handling", () => { }); test("returns masked array instead of original unmasked array", () => { - const messages: ChatMessage[] = [ + const request = createRequest([ { role: "user", content: [{ type: "text", text: "Contact Alice at alice@secret.com" }], }, - ]; + ]); - const detection = createPIIResult([ + const detection = createPIIResultFromSpans([ + // Span 0 entities [ - // Part 0 entities - [ - { entity_type: "PERSON", start: 8, end: 13, score: 0.9 }, - { entity_type: "EMAIL_ADDRESS", start: 17, end: 33, score: 0.95 }, - ], + { entity_type: "PERSON", start: 8, end: 13, score: 0.9 }, + { entity_type: "EMAIL_ADDRESS", start: 17, end: 33, score: 0.95 }, ], ]); - const { masked } = maskMessages(messages, detection); + const { request: masked } = maskRequest(request, detection, openaiExtractor); // Verify content is still array - expect(Array.isArray(masked[0].content)).toBe(true); + expect(Array.isArray(masked.messages[0].content)).toBe(true); - const maskedContent = masked[0].content as ContentPart[]; + const maskedContent = masked.messages[0].content as OpenAIContentPart[]; // Verify the text is actually masked (not the original) expect(maskedContent[0].text).not.toContain("Alice"); @@ -99,7 +86,7 @@ describe("Multimodal content handling", () => { }); test("handles multiple text parts independently", () => { - const messages: ChatMessage[] = [ + const request = createRequest([ { role: "user", content: [ @@ -107,49 +94,50 @@ describe("Multimodal content handling", () => { { type: "text", text: "Second: jane@example.com" }, ], }, - ]; + ]); - const detection = createPIIResult([ - [ - // Part 0 entity - [{ entity_type: "EMAIL_ADDRESS", start: 7, end: 23, score: 0.9 }], - // Part 1 entity - [{ entity_type: "EMAIL_ADDRESS", start: 8, end: 24, score: 0.9 }], - ], + const detection = createPIIResultFromSpans([ + // Span 0 entity + [{ entity_type: "EMAIL_ADDRESS", start: 7, end: 23, score: 0.9 }], + // Span 1 entity + [{ entity_type: "EMAIL_ADDRESS", start: 8, end: 24, score: 0.9 }], ]); - const { masked } = maskMessages(messages, detection); + const { request: masked } = maskRequest(request, detection, openaiExtractor); - const maskedContent = masked[0].content as ContentPart[]; + const maskedContent = masked.messages[0].content as OpenAIContentPart[]; expect(maskedContent[0].text).toBe("First: [[EMAIL_ADDRESS_1]]"); expect(maskedContent[1].text).toBe("Second: [[EMAIL_ADDRESS_2]]"); }); test("handles mixed string and array content messages", () => { - const messages: ChatMessage[] = [ + const request = createRequest([ { role: "system", content: "You are helpful" }, { role: "user", content: [{ type: "text", text: "My name is John" }], }, { role: "assistant", content: "Hello John!" }, - ]; - - const detection = createPIIResult([ - // Message 0 (system): no PII - [[]], - // Message 1 (user multimodal): PII in part 0 - [[{ entity_type: "PERSON", start: 11, end: 15, score: 0.9 }]], - // Message 2 (assistant): PII in part 0 - [[{ entity_type: "PERSON", start: 6, end: 10, score: 0.9 }]], ]); - const { masked } = maskMessages(messages, detection); + // Spans: 0=system, 1=user text, 2=assistant + const detection = createPIIResultFromSpans([ + // Span 0 (system): no PII + [], + // Span 1 (user multimodal text): PII + [{ entity_type: "PERSON", start: 11, end: 15, score: 0.9 }], + // Span 2 (assistant): PII + [{ entity_type: "PERSON", start: 6, end: 10, score: 0.9 }], + ]); + + const { request: masked } = maskRequest(request, detection, openaiExtractor); - expect(masked[0].content).toBe("You are helpful"); - expect((masked[1].content as ContentPart[])[0].text).toBe("My name is [[PERSON_1]]"); - expect(masked[2].content).toBe("Hello [[PERSON_1]]!"); + expect(masked.messages[0].content).toBe("You are helpful"); + expect((masked.messages[1].content as OpenAIContentPart[])[0].text).toBe( + "My name is [[PERSON_1]]", + ); + expect(masked.messages[2].content).toBe("Hello [[PERSON_1]]!"); }); }); }); diff --git a/src/secrets/patterns/index.ts b/src/secrets/patterns/index.ts index 5c38950..5505ea2 100644 --- a/src/secrets/patterns/index.ts +++ b/src/secrets/patterns/index.ts @@ -17,6 +17,5 @@ export const patternDetectors: PatternDetector[] = [ envVarsDetector, ]; -// Re-export types and utilities for convenience export type { PatternDetector, SecretEntityType, SecretsDetectionResult } from "./types"; export { detectPattern } from "./utils"; diff --git a/src/secrets/patterns/types.ts b/src/secrets/patterns/types.ts index 1c19985..500016a 100644 --- a/src/secrets/patterns/types.ts +++ b/src/secrets/patterns/types.ts @@ -34,14 +34,13 @@ export interface SecretsDetectionResult { } /** - * Per-message, per-part secrets detection result - * Structure: messageLocations[msgIdx][partIdx] = locations for that part + * Per-span secrets detection result */ export interface MessageSecretsResult { detected: boolean; matches: SecretsMatch[]; - /** Per-message, per-part secret locations */ - messageLocations: SecretLocation[][][]; + /** Per-span secret locations: spanLocations[spanIdx] = locations */ + spanLocations?: SecretLocation[][]; } /** diff --git a/src/services/decision.test.ts b/src/services/decision.test.ts deleted file mode 100644 index 8a87c24..0000000 --- a/src/services/decision.test.ts +++ /dev/null @@ -1,207 +0,0 @@ -import { describe, expect, test } from "bun:test"; -import type { PIIDetectionResult } from "../pii/detect"; -import type { MessageSecretsResult, SecretsMatch } from "../secrets/detect"; - -/** - * Pure routing logic extracted for testing - * This mirrors the logic in Router.decideRoute() - */ -function decideRoute( - piiResult: PIIDetectionResult, - secretsResult?: MessageSecretsResult, - secretsAction?: "block" | "mask" | "route_local", -): { provider: "openai" | "local"; reason: string } { - // Check for secrets route_local action first (takes precedence) - if (secretsResult?.detected && secretsAction === "route_local") { - const secretTypes = secretsResult.matches.map((m) => m.type); - return { - provider: "local", - reason: `Secrets detected (route_local): ${secretTypes.join(", ")}`, - }; - } - - if (piiResult.hasPII) { - const entityTypes = [...new Set(piiResult.allEntities.map((e) => e.entity_type))]; - return { - provider: "local", - reason: `PII detected: ${entityTypes.join(", ")}`, - }; - } - - return { - provider: "openai", - reason: "No PII detected", - }; -} - -/** - * Helper to create a mock PIIDetectionResult - */ -function createPIIResult( - hasPII: boolean, - entities: Array<{ entity_type: string }> = [], -): PIIDetectionResult { - const allEntities = entities.map((e) => ({ - entity_type: e.entity_type, - start: 0, - end: 10, - score: 0.9, - })); - - return { - hasPII, - allEntities, - messageEntities: [[allEntities]], - language: "en", - languageFallback: false, - scanTimeMs: 50, - }; -} - -describe("decideRoute", () => { - test("routes to openai when no PII detected", () => { - const result = decideRoute(createPIIResult(false)); - - expect(result.provider).toBe("openai"); - expect(result.reason).toBe("No PII detected"); - }); - - test("routes to local when PII detected", () => { - const result = decideRoute(createPIIResult(true, [{ entity_type: "PERSON" }])); - - expect(result.provider).toBe("local"); - expect(result.reason).toContain("PII detected"); - expect(result.reason).toContain("PERSON"); - }); - - test("includes all entity types in reason", () => { - const result = decideRoute( - createPIIResult(true, [ - { entity_type: "PERSON" }, - { entity_type: "EMAIL_ADDRESS" }, - { entity_type: "PHONE_NUMBER" }, - ]), - ); - - expect(result.reason).toContain("PERSON"); - expect(result.reason).toContain("EMAIL_ADDRESS"); - expect(result.reason).toContain("PHONE_NUMBER"); - }); - - test("deduplicates entity types in reason", () => { - const result = decideRoute( - createPIIResult(true, [ - { entity_type: "PERSON" }, - { entity_type: "PERSON" }, - { entity_type: "PERSON" }, - ]), - ); - - // Should only contain PERSON once - const matches = result.reason.match(/PERSON/g); - expect(matches?.length).toBe(1); - }); -}); - -/** - * Helper to create a mock MessageSecretsResult - */ -function createSecretsResult( - detected: boolean, - matches: SecretsMatch[] = [], -): MessageSecretsResult { - return { - detected, - matches, - messageLocations: [], - }; -} - -describe("decideRoute with secrets", () => { - describe("with route_local action", () => { - test("routes to local when secrets detected", () => { - const piiResult = createPIIResult(false); - const secretsResult = createSecretsResult(true, [{ type: "API_KEY_OPENAI", count: 1 }]); - - const result = decideRoute(piiResult, secretsResult, "route_local"); - - expect(result.provider).toBe("local"); - expect(result.reason).toContain("Secrets detected"); - expect(result.reason).toContain("route_local"); - expect(result.reason).toContain("API_KEY_OPENAI"); - }); - - test("secrets routing takes precedence over PII routing", () => { - const piiResult = createPIIResult(true, [{ entity_type: "PERSON" }]); - const secretsResult = createSecretsResult(true, [{ type: "API_KEY_AWS", count: 1 }]); - - const result = decideRoute(piiResult, secretsResult, "route_local"); - - expect(result.provider).toBe("local"); - expect(result.reason).toContain("Secrets detected"); - }); - - test("routes based on PII when no secrets detected", () => { - const piiResult = createPIIResult(true, [{ entity_type: "EMAIL_ADDRESS" }]); - const secretsResult = createSecretsResult(false); - - const result = decideRoute(piiResult, secretsResult, "route_local"); - - expect(result.provider).toBe("local"); // PII detected -> local - expect(result.reason).toContain("PII detected"); - }); - - test("routes to openai when no secrets and no PII detected", () => { - const piiResult = createPIIResult(false); - const secretsResult = createSecretsResult(false); - - const result = decideRoute(piiResult, secretsResult, "route_local"); - - expect(result.provider).toBe("openai"); - expect(result.reason).toBe("No PII detected"); - }); - }); - - describe("with block action", () => { - test("ignores secrets detection for routing (block happens earlier)", () => { - const piiResult = createPIIResult(false); - const secretsResult = createSecretsResult(true, [{ type: "JWT_TOKEN", count: 1 }]); - - const result = decideRoute(piiResult, secretsResult, "block"); - - // With block action, we shouldn't route based on secrets - expect(result.provider).toBe("openai"); - expect(result.reason).toBe("No PII detected"); - }); - }); - - describe("with mask action", () => { - test("ignores secrets detection for routing (masked before PII check)", () => { - const piiResult = createPIIResult(false); - const secretsResult = createSecretsResult(true, [{ type: "BEARER_TOKEN", count: 1 }]); - - const result = decideRoute(piiResult, secretsResult, "mask"); - - // With mask action, we route based on PII, not secrets - expect(result.provider).toBe("openai"); - expect(result.reason).toBe("No PII detected"); - }); - }); - - describe("with multiple secret types", () => { - test("includes all secret types in reason", () => { - const piiResult = createPIIResult(false); - const secretsResult = createSecretsResult(true, [ - { type: "API_KEY_OPENAI", count: 1 }, - { type: "API_KEY_GITHUB", count: 2 }, - { type: "JWT_TOKEN", count: 1 }, - ]); - - const result = decideRoute(piiResult, secretsResult, "route_local"); - - expect(result.reason).toContain("API_KEY_OPENAI"); - expect(result.reason).toContain("API_KEY_GITHUB"); - expect(result.reason).toContain("JWT_TOKEN"); - }); - }); -}); diff --git a/src/services/decision.ts b/src/services/decision.ts deleted file mode 100644 index a07286d..0000000 --- a/src/services/decision.ts +++ /dev/null @@ -1,199 +0,0 @@ -import { type Config, getConfig } from "../config"; -import { getPIIDetector, type PIIDetectionResult } from "../pii/detect"; -import { createMaskingContext, maskMessages } from "../pii/mask"; -import { type ChatMessage, LLMClient } from "../providers/openai-client"; -import type { MessageSecretsResult } from "../secrets/detect"; -import type { PlaceholderContext } from "../utils/message-transform"; - -/** - * Routing decision result for route mode - */ -export interface RouteDecision { - mode: "route"; - provider: "openai" | "local"; - reason: string; - piiResult: PIIDetectionResult; -} - -/** - * Masking decision result for mask mode - */ -export interface MaskDecision { - mode: "mask"; - provider: "openai"; - reason: string; - piiResult: PIIDetectionResult; - maskedMessages: ChatMessage[]; - maskingContext: PlaceholderContext; -} - -export type RoutingDecision = RouteDecision | MaskDecision; - -/** - * Router that decides how to handle requests based on PII detection - * Supports two modes: route (to local LLM) or mask (anonymize for provider) - */ -export class Router { - private openaiClient: LLMClient; - private localClient: LLMClient | null; - private config: Config; - - constructor() { - this.config = getConfig(); - - this.openaiClient = new LLMClient(this.config.providers.openai, "openai"); - this.localClient = this.config.local - ? new LLMClient(this.config.local, "local", this.config.local.model) - : null; - } - - /** - * Returns the current mode - */ - getMode(): "route" | "mask" { - return this.config.mode; - } - - /** - * Decides how to handle messages based on mode, PII detection, and secrets detection - * - * @param messages - The chat messages to process - * @param secretsResult - Optional secrets detection result (for route_local action) - */ - async decide( - messages: ChatMessage[], - secretsResult?: MessageSecretsResult, - ): Promise { - const detector = getPIIDetector(); - const piiResult = await detector.analyzeMessages(messages); - - if (this.config.mode === "mask") { - return this.decideMask(messages, piiResult); - } - - return this.decideRoute(piiResult, secretsResult); - } - - /** - * Route mode: decides which provider to use - * - * - No PII/Secrets → use configured provider (openai) - * - PII detected → use local provider - * - Secrets detected with route_local action → use local provider (takes precedence) - */ - private decideRoute( - piiResult: PIIDetectionResult, - secretsResult?: MessageSecretsResult, - ): RouteDecision { - // Check for secrets route_local action first (takes precedence) - if (secretsResult?.detected && this.config.secrets_detection.action === "route_local") { - const secretTypes = secretsResult.matches.map((m) => m.type); - return { - mode: "route", - provider: "local", - reason: `Secrets detected (route_local): ${secretTypes.join(", ")}`, - piiResult, - }; - } - - // Route based on PII detection - if (piiResult.hasPII) { - const entityTypes = [...new Set(piiResult.allEntities.map((e) => e.entity_type))]; - return { - mode: "route", - provider: "local", - reason: `PII detected: ${entityTypes.join(", ")}`, - piiResult, - }; - } - - // No PII detected, use configured provider - return { - mode: "route", - provider: "openai", - reason: "No PII detected", - piiResult, - }; - } - - private decideMask(messages: ChatMessage[], piiResult: PIIDetectionResult): MaskDecision { - if (!piiResult.hasPII) { - return { - mode: "mask", - provider: "openai", - reason: "No PII detected", - piiResult, - maskedMessages: messages, - maskingContext: createMaskingContext(), - }; - } - - const { masked, context } = maskMessages(messages, piiResult); - - const entityTypes = [...new Set(piiResult.allEntities.map((e) => e.entity_type))]; - - return { - mode: "mask", - provider: "openai", - reason: `PII masked: ${entityTypes.join(", ")}`, - piiResult, - maskedMessages: masked, - maskingContext: context, - }; - } - - getClient(provider: "openai" | "local"): LLMClient { - if (provider === "local") { - if (!this.localClient) { - throw new Error("Local provider not configured"); - } - return this.localClient; - } - return this.openaiClient; - } - - /** - * Gets masking config - */ - getMaskingConfig() { - return this.config.masking; - } - - /** - * Checks health of services (Presidio required, local LLM only in route mode) - */ - async healthCheck(): Promise<{ - local: boolean; - presidio: boolean; - }> { - const detector = getPIIDetector(); - - const [presidioHealth, localHealth] = await Promise.all([ - detector.healthCheck(), - this.localClient?.healthCheck() ?? Promise.resolve(true), - ]); - - return { - local: localHealth, - presidio: presidioHealth, - }; - } - - getProvidersInfo() { - return { - mode: this.config.mode, - openai: this.openaiClient.getInfo(), - local: this.localClient?.getInfo() ?? null, - }; - } -} - -// Singleton instance -let routerInstance: Router | null = null; - -export function getRouter(): Router { - if (!routerInstance) { - routerInstance = new Router(); - } - return routerInstance; -} diff --git a/src/services/pii.ts b/src/services/pii.ts new file mode 100644 index 0000000..b557685 --- /dev/null +++ b/src/services/pii.ts @@ -0,0 +1,70 @@ +/** + * PII Service - detect and mask PII in requests + */ + +import type { PlaceholderContext } from "../masking/context"; +import type { RequestExtractor } from "../masking/types"; +import { getPIIDetector, type PIIDetectionResult } from "../pii/detect"; +import { createMaskingContext, maskRequest } from "../pii/mask"; + +export interface PIIDetectResult { + detection: PIIDetectionResult; + hasPII: boolean; +} + +export interface PIIMaskResult { + request: TRequest; + maskingContext: PlaceholderContext; +} + +/** + * Detect PII in a request + */ +export async function detectPII( + request: TRequest, + extractor: RequestExtractor, +): Promise { + const detector = getPIIDetector(); + const detection = await detector.analyzeRequest(request, extractor); + + return { + detection, + hasPII: detection.hasPII, + }; +} + +/** + * Mask PII in a request + */ +export function maskPII( + request: TRequest, + detection: PIIDetectionResult, + extractor: RequestExtractor, + existingContext?: PlaceholderContext, +): PIIMaskResult { + if (!detection.hasPII) { + return { + request, + maskingContext: existingContext ?? createMaskingContext(), + }; + } + + const result = maskRequest(request, detection, extractor, existingContext); + + return { + request: result.request, + maskingContext: result.context, + }; +} + +export type { PlaceholderContext } from "../masking/context"; +export type { PIIDetectionResult, PIIEntity } from "../pii/detect"; +export { createMaskingContext } from "../pii/mask"; + +/** + * Check if Presidio is healthy + */ +export async function healthCheck(): Promise { + const detector = getPIIDetector(); + return detector.healthCheck(); +} diff --git a/src/services/secrets.ts b/src/services/secrets.ts new file mode 100644 index 0000000..ca44eac --- /dev/null +++ b/src/services/secrets.ts @@ -0,0 +1,67 @@ +/** + * Secrets Service - detect and mask secrets in requests + */ + +import type { SecretsDetectionConfig } from "../config"; +import type { PlaceholderContext } from "../masking/context"; +import type { RequestExtractor } from "../masking/types"; +import { detectSecretsInRequest, type MessageSecretsResult } from "../secrets/detect"; +import { maskRequest } from "../secrets/mask"; + +export interface SecretsProcessResult { + blocked: boolean; + blockedReason?: string; + blockedTypes?: string[]; + request: TRequest; + detection?: MessageSecretsResult; + maskingContext?: PlaceholderContext; + masked: boolean; +} + +/** + * Process a request for secrets detection + */ +export function processSecretsRequest( + request: TRequest, + config: SecretsDetectionConfig, + extractor: RequestExtractor, +): SecretsProcessResult { + if (!config.enabled) { + return { blocked: false, request, masked: false }; + } + + const detection = detectSecretsInRequest(request, config, extractor); + + if (!detection.detected) { + return { blocked: false, request, detection, masked: false }; + } + + const secretTypes = detection.matches.map((m) => m.type); + + // Block action + if (config.action === "block") { + return { + blocked: true, + blockedReason: `Secrets detected: ${secretTypes.join(", ")}`, + blockedTypes: secretTypes, + request, + detection, + masked: false, + }; + } + + // Mask action + if (config.action === "mask") { + const result = maskRequest(request, detection, extractor); + return { + blocked: false, + request: result.masked, + detection, + maskingContext: result.context, + masked: true, + }; + } + + // route_local action - just pass through with detection info + return { blocked: false, request, detection, masked: false }; +} diff --git a/src/test-utils/detection-results.ts b/src/test-utils/detection-results.ts index 30cfc48..5088c54 100644 --- a/src/test-utils/detection-results.ts +++ b/src/test-utils/detection-results.ts @@ -1,8 +1,5 @@ /** * Test utilities for creating detection results - * - * Shared helpers for creating PIIDetectionResult and MessageSecretsResult - * from per-message, per-part data in tests. */ import type { SupportedLanguage } from "../constants/languages"; @@ -10,13 +7,10 @@ import type { PIIDetectionResult, PIIEntity } from "../pii/detect"; import type { MessageSecretsResult, SecretLocation } from "../secrets/detect"; /** - * Creates a PIIDetectionResult from per-message, per-part entities - * - * @param messageEntities - Nested array: messageEntities[msgIdx][partIdx] = entities[] - * @param options - Optional overrides for language, scanTimeMs, etc. + * Creates a PIIDetectionResult from per-span entities */ -export function createPIIResult( - messageEntities: PIIEntity[][][], +export function createPIIResultFromSpans( + spanEntities: PIIEntity[][], options: { language?: SupportedLanguage; languageFallback?: boolean; @@ -24,10 +18,10 @@ export function createPIIResult( scanTimeMs?: number; } = {}, ): PIIDetectionResult { - const allEntities = messageEntities.flat(2); + const allEntities = spanEntities.flat(); return { hasPII: allEntities.length > 0, - messageEntities, + spanEntities, allEntities, scanTimeMs: options.scanTimeMs ?? 0, language: options.language ?? "en", @@ -37,15 +31,15 @@ export function createPIIResult( } /** - * Creates a MessageSecretsResult from per-message, per-part locations - * - * @param messageLocations - Nested array: messageLocations[msgIdx][partIdx] = locations[] + * Creates a MessageSecretsResult from per-span locations */ -export function createSecretsResult(messageLocations: SecretLocation[][][]): MessageSecretsResult { - const hasLocations = messageLocations.some((msg) => msg.some((part) => part.length > 0)); +export function createSecretsResultFromSpans( + spanLocations: SecretLocation[][], +): MessageSecretsResult { + const hasLocations = spanLocations.some((span) => span.length > 0); return { detected: hasLocations, - matches: [], // Matches are aggregated separately in real detection - messageLocations, + matches: [], + spanLocations, }; } diff --git a/src/utils/content.test.ts b/src/utils/content.test.ts index 3b60a2b..2ce3af5 100644 --- a/src/utils/content.test.ts +++ b/src/utils/content.test.ts @@ -1,5 +1,5 @@ import { describe, expect, test } from "bun:test"; -import { type ContentPart, extractTextContent } from "./content"; +import { extractTextContent, type OpenAIContentPart } from "./content"; describe("extractTextContent", () => { test("returns empty string for null", () => { @@ -15,12 +15,12 @@ describe("extractTextContent", () => { }); test("extracts text from single text part", () => { - const content: ContentPart[] = [{ type: "text", text: "What's in this image?" }]; + const content: OpenAIContentPart[] = [{ type: "text", text: "What's in this image?" }]; expect(extractTextContent(content)).toBe("What's in this image?"); }); test("extracts and joins multiple text parts", () => { - const content: ContentPart[] = [ + const content: OpenAIContentPart[] = [ { type: "text", text: "First part" }, { type: "text", text: "Second part" }, ]; @@ -28,7 +28,7 @@ describe("extractTextContent", () => { }); test("skips image_url parts", () => { - const content: ContentPart[] = [ + const content: OpenAIContentPart[] = [ { type: "text", text: "Look at this" }, { type: "image_url", image_url: { url: "https://example.com/image.jpg" } }, { type: "text", text: "What is it?" }, @@ -37,7 +37,7 @@ describe("extractTextContent", () => { }); test("returns empty string for array with no text parts", () => { - const content: ContentPart[] = [ + const content: OpenAIContentPart[] = [ { type: "image_url", image_url: { url: "https://example.com/image.jpg" } }, ]; expect(extractTextContent(content)).toBe(""); diff --git a/src/utils/content.ts b/src/utils/content.ts index 7a256de..86a41d3 100644 --- a/src/utils/content.ts +++ b/src/utils/content.ts @@ -1,62 +1,23 @@ /** - * Utility functions for handling OpenAI message content - * - * OpenAI's Chat Completions API supports two content formats: - * 1. String content (text-only messages) - * 2. Array content (multimodal messages with text and images) + * Message content utilities */ -/** - * Content part for multimodal messages - */ -export interface ContentPart { - type: string; - text?: string; - image_url?: { - url: string; - detail?: string; - }; -} +import type { OpenAIContentPart, OpenAIMessageContent } from "../providers/openai/types"; -/** - * Message content can be a string (text-only) or array (multimodal) - */ -export type MessageContent = string | ContentPart[] | null | undefined; +export type { OpenAIContentPart, OpenAIMessageContent }; /** - * Safely extracts text content from a message - * - * Handles both string content and array content (multimodal messages). - * For array content, extracts and concatenates all text parts. - * - * @param content - The message content (string, array, null, or undefined) - * @returns Extracted text content, or empty string if no text found - * - * @example - * // Text-only message - * extractTextContent("Hello world") // => "Hello world" - * - * // Multimodal message - * extractTextContent([ - * { type: "text", text: "What's in this image?" }, - * { type: "image_url", image_url: { url: "..." } } - * ]) // => "What's in this image?" - * - * // Null/undefined - * extractTextContent(null) // => "" + * Extracts text content from a message (handles string and array content) */ -export function extractTextContent(content: MessageContent): string { - // Handle null/undefined +export function extractTextContent(content: OpenAIMessageContent | undefined): string { if (!content) { return ""; } - // Handle string content (simple case) if (typeof content === "string") { return content; } - // Handle array content (multimodal messages) if (Array.isArray(content)) { return content .filter((part) => part.type === "text" && typeof part.text === "string") @@ -64,6 +25,5 @@ export function extractTextContent(content: MessageContent): string { .join("\n"); } - // Unexpected type - return empty string return ""; }