feat(chat-route): use new model registry

This commit is contained in:
ItzCrazyKns
2025-10-16 17:58:13 +05:30
parent 9706079ed4
commit 768578951c

View File

@@ -1,23 +1,14 @@
import crypto from 'crypto';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { EventEmitter } from 'stream';
import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
} from '@/lib/providers';
import db from '@/lib/db';
import { chats, messages as messagesSchema } from '@/lib/db/schema';
import { and, eq, gt } from 'drizzle-orm';
import { getFileDetails } from '@/lib/utils/files';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatOpenAI } from '@langchain/openai';
import {
getCustomOpenaiApiKey,
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import { searchHandlers } from '@/lib/search';
import { z } from 'zod';
import ModelRegistry from '@/lib/models/registry';
import { ModelWithProvider } from '@/lib/models/types';
export const runtime = 'nodejs';
export const dynamic = 'force-dynamic';
@@ -28,14 +19,30 @@ const messageSchema = z.object({
content: z.string().min(1, 'Message content is required'),
});
const chatModelSchema = z.object({
provider: z.string().optional(),
name: z.string().optional(),
const chatModelSchema: z.ZodType<ModelWithProvider> = z.object({
providerId: z.string({
errorMap: () => ({
message: 'Chat model provider id must be provided',
}),
}),
key: z.string({
errorMap: () => ({
message: 'Chat model key must be provided',
}),
}),
});
const embeddingModelSchema = z.object({
provider: z.string().optional(),
name: z.string().optional(),
const embeddingModelSchema: z.ZodType<ModelWithProvider> = z.object({
providerId: z.string({
errorMap: () => ({
message: 'Embedding model provider id must be provided',
}),
}),
key: z.string({
errorMap: () => ({
message: 'Embedding model key must be provided',
}),
}),
});
const bodySchema = z.object({
@@ -57,8 +64,8 @@ const bodySchema = z.object({
.optional()
.default([]),
files: z.array(z.string()).optional().default([]),
chatModel: chatModelSchema.optional().default({}),
embeddingModel: embeddingModelSchema.optional().default({}),
chatModel: chatModelSchema,
embeddingModel: embeddingModelSchema,
systemInstructions: z.string().nullable().optional().default(''),
});
@@ -248,56 +255,16 @@ export const POST = async (req: Request) => {
);
}
const [chatModelProviders, embeddingModelProviders] = await Promise.all([
getAvailableChatModelProviders(),
getAvailableEmbeddingModelProviders(),
const registry = new ModelRegistry();
const [llm, embedding] = await Promise.all([
registry.loadChatModel(body.chatModel.providerId, body.chatModel.key),
registry.loadEmbeddingModel(
body.embeddingModel.providerId,
body.embeddingModel.key,
),
]);
const chatModelProvider =
chatModelProviders[
body.chatModel?.provider || Object.keys(chatModelProviders)[0]
];
const chatModel =
chatModelProvider[
body.chatModel?.name || Object.keys(chatModelProvider)[0]
];
const embeddingProvider =
embeddingModelProviders[
body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0]
];
const embeddingModel =
embeddingProvider[
body.embeddingModel?.name || Object.keys(embeddingProvider)[0]
];
let llm: BaseChatModel | undefined;
let embedding = embeddingModel.model;
if (body.chatModel?.provider === 'custom_openai') {
llm = new ChatOpenAI({
apiKey: getCustomOpenaiApiKey(),
modelName: getCustomOpenaiModelName(),
temperature: 0.7,
configuration: {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}
if (!llm) {
return Response.json({ error: 'Invalid chat model' }, { status: 400 });
}
if (!embedding) {
return Response.json(
{ error: 'Invalid embedding model' },
{ status: 400 },
);
}
const humanMessageId =
message.messageId ?? crypto.randomBytes(7).toString('hex');