diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index 6037c1a..7329299 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -17,35 +17,71 @@ import { getCustomOpenaiModelName, } from '@/lib/config'; import { searchHandlers } from '@/lib/search'; +import { z } from 'zod'; export const runtime = 'nodejs'; export const dynamic = 'force-dynamic'; -type Message = { - messageId: string; - chatId: string; - content: string; -}; +const messageSchema = z.object({ + messageId: z.string().min(1, 'Message ID is required'), + chatId: z.string().min(1, 'Chat ID is required'), + content: z.string().min(1, 'Message content is required'), +}); -type ChatModel = { - provider: string; - name: string; -}; +const chatModelSchema = z.object({ + provider: z.string().optional(), + name: z.string().optional(), +}); -type EmbeddingModel = { - provider: string; - name: string; -}; +const embeddingModelSchema = z.object({ + provider: z.string().optional(), + name: z.string().optional(), +}); -type Body = { - message: Message; - optimizationMode: 'speed' | 'balanced' | 'quality'; - focusMode: string; - history: Array<[string, string]>; - files: Array; - chatModel: ChatModel; - embeddingModel: EmbeddingModel; - systemInstructions: string; +const bodySchema = z.object({ + message: messageSchema, + optimizationMode: z.enum(['speed', 'balanced', 'quality'], { + errorMap: () => ({ + message: 'Optimization mode must be one of: speed, balanced, quality', + }), + }), + focusMode: z.string().min(1, 'Focus mode is required'), + history: z + .array( + z.tuple([z.string(), z.string()], { + errorMap: () => ({ + message: 'History items must be tuples of two strings', + }), + }), + ) + .optional() + .default([]), + files: z.array(z.string()).optional().default([]), + chatModel: chatModelSchema.optional().default({}), + embeddingModel: embeddingModelSchema.optional().default({}), + systemInstructions: z.string().nullable().optional().default(''), +}); + +type Message = z.infer; +type Body = z.infer; + +const safeValidateBody = (data: unknown) => { + const result = bodySchema.safeParse(data); + + if (!result.success) { + return { + success: false, + error: result.error.errors.map((e) => ({ + path: e.path.join('.'), + message: e.message, + })), + }; + } + + return { + success: true, + data: result.data, + }; }; const handleEmitterEvents = async ( @@ -190,7 +226,17 @@ const handleHistorySave = async ( export const POST = async (req: Request) => { try { - const body = (await req.json()) as Body; + const reqBody = (await req.json()) as Body; + + const parseBody = safeValidateBody(reqBody); + if (!parseBody.success) { + return Response.json( + { message: 'Invalid request body', error: parseBody.error }, + { status: 400 }, + ); + } + + const body = parseBody.data as Body; const { message } = body; if (message.content === '') { @@ -285,7 +331,7 @@ export const POST = async (req: Request) => { embedding, body.optimizationMode, body.files, - body.systemInstructions, + body.systemInstructions as string, ); const responseStream = new TransformStream();