From: Stefan Gasser Date: Sun, 11 Jan 2026 18:42:09 +0000 (+0100) Subject: Fix PII detection to scan all message roles (#25) X-Git-Url: http://git.99rst.org/?a=commitdiff_plain;h=92e2f720c82784be6df0dd2127622d972f9785d3;p=sgasser-llm-shield.git Fix PII detection to scan all message roles (#25) Previously, PII detection only scanned the last user message initially, then did a full scan only if PII was found. This caused PII in system messages (e.g., RAG context from PDFs) to be missed entirely when the user message contained no PII. Changes: - Consolidate analyzeMessages() to always scan all messages - Scan system, developer, user, and assistant roles - Remove analyzeAllMessages() as it's no longer needed - Simplify decision.ts by removing the redundant full scan call This ensures PII in system messages (common in RAG patterns) is properly detected and masked before being sent to upstream LLMs. Fixes #17 --- diff --git a/src/services/decision.ts b/src/services/decision.ts index 40be5ad..9be2b5a 100644 --- a/src/services/decision.ts +++ b/src/services/decision.ts @@ -133,13 +133,7 @@ export class Router { }; } - const detector = getPIIDetector(); - const fullScan = await detector.analyzeAllMessages(messages, { - language: piiResult.language, - usedFallback: piiResult.languageFallback, - }); - - const { masked, context } = maskMessages(messages, fullScan.entitiesByMessage); + const { masked, context } = maskMessages(messages, piiResult.entitiesByMessage); const entityTypes = [...new Set(piiResult.newEntities.map((e) => e.entity_type))]; diff --git a/src/services/pii-detector.test.ts b/src/services/pii-detector.test.ts new file mode 100644 index 0000000..6c748d7 --- /dev/null +++ b/src/services/pii-detector.test.ts @@ -0,0 +1,202 @@ +import { afterEach, describe, expect, mock, test } from "bun:test"; +import { PIIDetector } from "./pii-detector"; + +const originalFetch = globalThis.fetch; + +function mockPresidio( + responses: Record< + string, + Array<{ entity_type: string; start: number; end: number; score: number }> + >, +) { + globalThis.fetch = mock(async (url: string | URL | Request, init?: RequestInit) => { + const urlStr = url.toString(); + + if (urlStr.includes("/health")) { + return new Response("OK", { status: 200 }); + } + + if (urlStr.includes("/analyze") && init?.body) { + const body = JSON.parse(init.body as string); + const text = body.text as string; + + for (const [key, entities] of Object.entries(responses)) { + if (text.includes(key)) { + return new Response(JSON.stringify(entities), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + } + } + + return new Response(JSON.stringify([]), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + } + + return originalFetch(url, init); + }) as unknown as typeof fetch; +} + +describe("PIIDetector", () => { + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + describe("analyzeMessages", () => { + test("scans all message roles", async () => { + mockPresidio({ + "system-pii": [{ entity_type: "PERSON", start: 0, end: 10, score: 0.9 }], + "user-pii": [{ entity_type: "EMAIL_ADDRESS", start: 0, end: 8, score: 0.9 }], + "assistant-pii": [{ entity_type: "PHONE_NUMBER", start: 0, end: 13, score: 0.9 }], + }); + + const detector = new PIIDetector(); + const messages = [ + { role: "system", content: "system-pii here" }, + { role: "user", content: "user-pii here" }, + { role: "assistant", content: "assistant-pii here" }, + ]; + + const result = await detector.analyzeMessages(messages); + + expect(result.hasPII).toBe(true); + expect(result.entitiesByMessage).toHaveLength(3); + expect(result.entitiesByMessage[0]).toHaveLength(1); + expect(result.entitiesByMessage[1]).toHaveLength(1); + expect(result.entitiesByMessage[2]).toHaveLength(1); + }); + + test("detects PII in system message when user message has none", async () => { + mockPresidio({ + "John Doe": [{ entity_type: "PERSON", start: 18, end: 26, score: 0.95 }], + }); + + const detector = new PIIDetector(); + const messages = [ + { role: "system", content: "Context from PDF: John Doe lives at 123 Main St" }, + { role: "user", content: "Extract the data into JSON" }, + ]; + + const result = await detector.analyzeMessages(messages); + + expect(result.hasPII).toBe(true); + expect(result.entitiesByMessage[0]).toHaveLength(1); + expect(result.entitiesByMessage[0][0].entity_type).toBe("PERSON"); + }); + + test("detects PII in earlier user message", async () => { + mockPresidio({ + "secret@email.com": [{ entity_type: "EMAIL_ADDRESS", start: 12, end: 28, score: 0.99 }], + }); + + const detector = new PIIDetector(); + const messages = [ + { role: "user", content: "My email is secret@email.com" }, + { role: "assistant", content: "Got it." }, + { role: "user", content: "Now do something else" }, + ]; + + const result = await detector.analyzeMessages(messages); + + expect(result.hasPII).toBe(true); + expect(result.entitiesByMessage[0]).toHaveLength(1); + }); + + test("returns empty result for no messages", async () => { + mockPresidio({}); + + const detector = new PIIDetector(); + const result = await detector.analyzeMessages([]); + + expect(result.hasPII).toBe(false); + expect(result.entitiesByMessage).toHaveLength(0); + expect(result.newEntities).toHaveLength(0); + }); + + test("handles multimodal content", async () => { + mockPresidio({ + "Hans Müller": [{ entity_type: "PERSON", start: 0, end: 11, score: 0.9 }], + }); + + const detector = new PIIDetector(); + const messages = [ + { + role: "user", + content: [ + { type: "text", text: "Hans Müller in this image" }, + { type: "image_url", image_url: { url: "data:image/png;base64,..." } }, + ], + }, + ]; + + const result = await detector.analyzeMessages(messages); + + expect(result.hasPII).toBe(true); + expect(result.entitiesByMessage[0]).toHaveLength(1); + }); + + test("skips messages with empty content", async () => { + mockPresidio({ + test: [{ entity_type: "PERSON", start: 0, end: 4, score: 0.9 }], + }); + + const detector = new PIIDetector(); + const messages = [ + { role: "user", content: "" }, + { role: "assistant", content: "test response" }, + ]; + + const result = await detector.analyzeMessages(messages); + + expect(result.entitiesByMessage).toHaveLength(2); + expect(result.entitiesByMessage[0]).toHaveLength(0); + }); + }); + + describe("detectPII", () => { + test("returns entities from Presidio", async () => { + mockPresidio({ + "test@example.com": [{ entity_type: "EMAIL_ADDRESS", start: 0, end: 16, score: 0.99 }], + }); + + const detector = new PIIDetector(); + const entities = await detector.detectPII("test@example.com", "en"); + + expect(entities).toHaveLength(1); + expect(entities[0].entity_type).toBe("EMAIL_ADDRESS"); + }); + + test("returns empty array for text without PII", async () => { + mockPresidio({}); + + const detector = new PIIDetector(); + const entities = await detector.detectPII("Hello world", "en"); + + expect(entities).toHaveLength(0); + }); + }); + + describe("healthCheck", () => { + test("returns true when Presidio is healthy", async () => { + mockPresidio({}); + + const detector = new PIIDetector(); + const healthy = await detector.healthCheck(); + + expect(healthy).toBe(true); + }); + + test("returns false when Presidio is unavailable", async () => { + globalThis.fetch = mock(async () => { + throw new Error("Connection refused"); + }) as unknown as typeof fetch; + + const detector = new PIIDetector(); + const healthy = await detector.healthCheck(); + + expect(healthy).toBe(false); + }); + }); +}); diff --git a/src/services/pii-detector.ts b/src/services/pii-detector.ts index 4d965b8..444f130 100644 --- a/src/services/pii-detector.ts +++ b/src/services/pii-detector.ts @@ -1,10 +1,6 @@ import { getConfig } from "../config"; import { extractTextContent, type MessageContent } from "../utils/content"; -import { - getLanguageDetector, - type LanguageDetectionResult, - type SupportedLanguage, -} from "./language-detector"; +import { getLanguageDetector, type SupportedLanguage } from "./language-detector"; export interface PIIEntity { entity_type: string; @@ -86,48 +82,20 @@ export class PIIDetector { messages: Array<{ role: string; content: MessageContent }>, ): Promise { const startTime = Date.now(); + const config = getConfig(); - const lastUserIndex = messages.findLastIndex((m) => m.role === "user"); - - if (lastUserIndex === -1 || !messages[lastUserIndex].content) { - const config = getConfig(); - return { - hasPII: false, - entitiesByMessage: messages.map(() => []), - newEntities: [], - scanTimeMs: Date.now() - startTime, - language: config.pii_detection.fallback_language, - languageFallback: false, - }; - } - - const text = extractTextContent(messages[lastUserIndex].content); - const langResult = getLanguageDetector().detect(text); - const newEntities = await this.detectPII(text, langResult.language); - - const entitiesByMessage = messages.map((_, i) => (i === lastUserIndex ? newEntities : [])); - - return { - hasPII: newEntities.length > 0, - entitiesByMessage, - newEntities, - scanTimeMs: Date.now() - startTime, - language: langResult.language, - languageFallback: langResult.usedFallback, - detectedLanguage: langResult.detectedLanguage, - }; - } + const lastUserMsg = messages.findLast((m) => m.role === "user"); + const langText = lastUserMsg ? extractTextContent(lastUserMsg.content) : ""; + const langResult = langText + ? getLanguageDetector().detect(langText) + : { language: config.pii_detection.fallback_language, usedFallback: true }; - async analyzeAllMessages( - messages: Array<{ role: string; content: MessageContent }>, - langResult: LanguageDetectionResult, - ): Promise { - const startTime = Date.now(); + const scannedRoles = ["system", "developer", "user", "assistant"]; const entitiesByMessage = await Promise.all( messages.map((message) => { const text = extractTextContent(message.content); - return text && (message.role === "user" || message.role === "assistant") + return text && scannedRoles.includes(message.role) ? this.detectPII(text, langResult.language) : Promise.resolve([]); }), @@ -136,7 +104,7 @@ export class PIIDetector { return { hasPII: entitiesByMessage.some((e) => e.length > 0), entitiesByMessage, - newEntities: [], + newEntities: entitiesByMessage.flat(), scanTimeMs: Date.now() - startTime, language: langResult.language, languageFallback: langResult.usedFallback,