diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index 7329299..bab34fa 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -1,23 +1,14 @@ import crypto from 'crypto'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { EventEmitter } from 'stream'; -import { - getAvailableChatModelProviders, - getAvailableEmbeddingModelProviders, -} from '@/lib/providers'; import db from '@/lib/db'; import { chats, messages as messagesSchema } from '@/lib/db/schema'; import { and, eq, gt } from 'drizzle-orm'; import { getFileDetails } from '@/lib/utils/files'; -import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { ChatOpenAI } from '@langchain/openai'; -import { - getCustomOpenaiApiKey, - getCustomOpenaiApiUrl, - getCustomOpenaiModelName, -} from '@/lib/config'; import { searchHandlers } from '@/lib/search'; import { z } from 'zod'; +import ModelRegistry from '@/lib/models/registry'; +import { ModelWithProvider } from '@/lib/models/types'; export const runtime = 'nodejs'; export const dynamic = 'force-dynamic'; @@ -28,14 +19,30 @@ const messageSchema = z.object({ content: z.string().min(1, 'Message content is required'), }); -const chatModelSchema = z.object({ - provider: z.string().optional(), - name: z.string().optional(), +const chatModelSchema: z.ZodType = z.object({ + providerId: z.string({ + errorMap: () => ({ + message: 'Chat model provider id must be provided', + }), + }), + key: z.string({ + errorMap: () => ({ + message: 'Chat model key must be provided', + }), + }), }); -const embeddingModelSchema = z.object({ - provider: z.string().optional(), - name: z.string().optional(), +const embeddingModelSchema: z.ZodType = z.object({ + providerId: z.string({ + errorMap: () => ({ + message: 'Embedding model provider id must be provided', + }), + }), + key: z.string({ + errorMap: () => ({ + message: 'Embedding model key must be provided', + }), + }), }); const bodySchema = z.object({ @@ -57,8 +64,8 @@ const bodySchema = z.object({ .optional() .default([]), files: z.array(z.string()).optional().default([]), - chatModel: chatModelSchema.optional().default({}), - embeddingModel: embeddingModelSchema.optional().default({}), + chatModel: chatModelSchema, + embeddingModel: embeddingModelSchema, systemInstructions: z.string().nullable().optional().default(''), }); @@ -248,56 +255,16 @@ export const POST = async (req: Request) => { ); } - const [chatModelProviders, embeddingModelProviders] = await Promise.all([ - getAvailableChatModelProviders(), - getAvailableEmbeddingModelProviders(), + const registry = new ModelRegistry(); + + const [llm, embedding] = await Promise.all([ + registry.loadChatModel(body.chatModel.providerId, body.chatModel.key), + registry.loadEmbeddingModel( + body.embeddingModel.providerId, + body.embeddingModel.key, + ), ]); - const chatModelProvider = - chatModelProviders[ - body.chatModel?.provider || Object.keys(chatModelProviders)[0] - ]; - const chatModel = - chatModelProvider[ - body.chatModel?.name || Object.keys(chatModelProvider)[0] - ]; - - const embeddingProvider = - embeddingModelProviders[ - body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0] - ]; - const embeddingModel = - embeddingProvider[ - body.embeddingModel?.name || Object.keys(embeddingProvider)[0] - ]; - - let llm: BaseChatModel | undefined; - let embedding = embeddingModel.model; - - if (body.chatModel?.provider === 'custom_openai') { - llm = new ChatOpenAI({ - apiKey: getCustomOpenaiApiKey(), - modelName: getCustomOpenaiModelName(), - temperature: 0.7, - configuration: { - baseURL: getCustomOpenaiApiUrl(), - }, - }) as unknown as BaseChatModel; - } else if (chatModelProvider && chatModel) { - llm = chatModel.model; - } - - if (!llm) { - return Response.json({ error: 'Invalid chat model' }, { status: 400 }); - } - - if (!embedding) { - return Response.json( - { error: 'Invalid embedding model' }, - { status: 400 }, - ); - } - const humanMessageId = message.messageId ?? crypto.randomBytes(7).toString('hex');