import prompts from '@/lib/prompts'; import MetaSearchAgent from '@/lib/search/metaSearchAgent'; import crypto from 'crypto'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { EventEmitter } from 'stream'; import { chatModelProviders, embeddingModelProviders } from '@/lib/providers'; import db from '@/lib/db'; import { chats, messages as messagesSchema } from '@/lib/db/schema'; import { and, eq, gt } from 'drizzle-orm'; import { getFileDetails } from '@/lib/utils/files'; export const runtime = 'nodejs'; export const dynamic = 'force-dynamic'; export const searchHandlers: Record = { webSearch: new MetaSearchAgent({ activeEngines: [], queryGeneratorPrompt: prompts.webSearchRetrieverPrompt, responsePrompt: prompts.webSearchResponsePrompt, rerank: true, rerankThreshold: 0.3, searchWeb: true, summarizer: true, }), academicSearch: new MetaSearchAgent({ activeEngines: ['arxiv', 'google scholar', 'pubmed'], queryGeneratorPrompt: prompts.academicSearchRetrieverPrompt, responsePrompt: prompts.academicSearchResponsePrompt, rerank: true, rerankThreshold: 0, searchWeb: true, summarizer: false, }), writingAssistant: new MetaSearchAgent({ activeEngines: [], queryGeneratorPrompt: '', responsePrompt: prompts.writingAssistantPrompt, rerank: true, rerankThreshold: 0, searchWeb: false, summarizer: false, }), wolframAlphaSearch: new MetaSearchAgent({ activeEngines: ['wolframalpha'], queryGeneratorPrompt: prompts.wolframAlphaSearchRetrieverPrompt, responsePrompt: prompts.wolframAlphaSearchResponsePrompt, rerank: false, rerankThreshold: 0, searchWeb: true, summarizer: false, }), youtubeSearch: new MetaSearchAgent({ activeEngines: ['youtube'], queryGeneratorPrompt: prompts.youtubeSearchRetrieverPrompt, responsePrompt: prompts.youtubeSearchResponsePrompt, rerank: true, rerankThreshold: 0.3, searchWeb: true, summarizer: false, }), redditSearch: new MetaSearchAgent({ activeEngines: ['reddit'], queryGeneratorPrompt: prompts.redditSearchRetrieverPrompt, responsePrompt: prompts.redditSearchResponsePrompt, rerank: true, rerankThreshold: 0.3, searchWeb: true, summarizer: false, }), }; type Message = { messageId: string; chatId: string; content: string; }; type ChatModel = { provider: string; name: string; }; type EmbeddingModel = { provider: string; name: string; }; type Body = { message: Message; optimizationMode: 'speed' | 'balanced' | 'quality'; focusMode: string; history: Array<[string, string]>; files: Array; chatModel: ChatModel; embeddingModel: EmbeddingModel; }; const handleEmitterEvents = async ( stream: EventEmitter, writer: WritableStreamDefaultWriter, encoder: TextEncoder, aiMessageId: string, chatId: string, ) => { let recievedMessage = ''; let sources: any[] = []; stream.on('data', (data) => { const parsedData = JSON.parse(data); if (parsedData.type === 'response') { writer.write( encoder.encode( JSON.stringify({ type: 'message', data: parsedData.data, messageId: aiMessageId, }) + '\n', ), ); recievedMessage += parsedData.data; } else if (parsedData.type === 'sources') { writer.write( encoder.encode( JSON.stringify({ type: 'sources', data: parsedData.data, messageId: aiMessageId, }) + '\n', ), ); sources = parsedData.data; } }); stream.on('end', () => { writer.write( encoder.encode( JSON.stringify({ type: 'messageEnd', messageId: aiMessageId, }) + '\n', ), ); writer.close(); db.insert(messagesSchema) .values({ content: recievedMessage, chatId: chatId, messageId: aiMessageId, role: 'assistant', metadata: JSON.stringify({ createdAt: new Date(), ...(sources && sources.length > 0 && { sources }), }), }) .execute(); }); stream.on('error', (data) => { const parsedData = JSON.parse(data); writer.write( encoder.encode( JSON.stringify({ type: 'error', data: parsedData.data, }), ), ); writer.close(); }); }; const handleHistorySave = async ( message: Message, humanMessageId: string, focusMode: string, files: string[], ) => { const chat = await db.query.chats.findFirst({ where: eq(chats.id, message.chatId), }); if (!chat) { await db .insert(chats) .values({ id: message.chatId, title: message.content, createdAt: new Date().toString(), focusMode: focusMode, files: files.map(getFileDetails), }) .execute(); } const messageExists = await db.query.messages.findFirst({ where: eq(messagesSchema.messageId, humanMessageId), }); if (!messageExists) { await db .insert(messagesSchema) .values({ content: message.content, chatId: message.chatId, messageId: humanMessageId, role: 'user', metadata: JSON.stringify({ createdAt: new Date(), }), }) .execute(); } else { await db .delete(messagesSchema) .where( and( gt(messagesSchema.id, messageExists.id), eq(messagesSchema.chatId, message.chatId), ), ) .execute(); } }; export const POST = async (req: Request) => { try { const body = (await req.json()) as Body; const { message, chatModel, embeddingModel } = body; if (message.content === '') { return Response.json( { message: 'Please provide a message to process', }, { status: 400 }, ); } const getProviderChatModels = chatModelProviders[chatModel.provider]; if (!getProviderChatModels) { return Response.json( { message: 'Invalid chat model provider', }, { status: 400 }, ); } const chatModels = await getProviderChatModels(); const llm = chatModels[chatModel.name].model; if (!llm) { return Response.json( { message: 'Invalid chat model', }, { status: 400 }, ); } const getProviderEmbeddingModels = embeddingModelProviders[embeddingModel.provider]; if (!getProviderEmbeddingModels) { return Response.json( { message: 'Invalid embedding model provider', }, { status: 400 }, ); } const embeddingModels = await getProviderEmbeddingModels(); const embedding = embeddingModels[embeddingModel.name].model; if (!embedding) { return Response.json( { message: 'Invalid embedding model', }, { status: 400 }, ); } const humanMessageId = message.messageId ?? crypto.randomBytes(7).toString('hex'); const aiMessageId = crypto.randomBytes(7).toString('hex'); const history: BaseMessage[] = body.history.map((msg) => { if (msg[0] === 'human') { return new HumanMessage({ content: msg[1], }); } else { return new AIMessage({ content: msg[1], }); } }); const handler = searchHandlers[body.focusMode]; if (!handler) { return Response.json( { message: 'Invalid focus mode', }, { status: 400 }, ); } const stream = await handler.searchAndAnswer( message.content, history, llm, embedding, body.optimizationMode, body.files, ); const responseStream = new TransformStream(); const writer = responseStream.writable.getWriter(); const encoder = new TextEncoder(); handleEmitterEvents(stream, writer, encoder, aiMessageId, message.chatId); handleHistorySave(message, humanMessageId, body.focusMode, body.files); return new Response(responseStream.readable, { headers: { 'Content-Type': 'text/event-stream', Connection: 'keep-alive', 'Cache-Control': 'no-cache, no-transform', }, }); } catch (err) { console.error('An error ocurred while processing chat request:', err); return Response.json( { message: 'An error ocurred while processing chat request' }, { status: 500 }, ); } };