import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; import { ChatOpenAI } from '@langchain/openai'; import { getAvailableChatModelProviders, getAvailableEmbeddingModelProviders, } from '@/lib/providers'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { MetaSearchAgentType } from '@/lib/search/metaSearchAgent'; import { getCustomOpenaiApiKey, getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '@/lib/config'; import { searchHandlers } from '@/lib/search'; interface chatModel { provider: string; name: string; customOpenAIKey?: string; customOpenAIBaseURL?: string; } interface embeddingModel { provider: string; name: string; } interface ChatRequestBody { optimizationMode: 'speed' | 'balanced'; focusMode: string; chatModel?: chatModel; embeddingModel?: embeddingModel; query: string; history: Array<[string, string]>; stream?: boolean; } export const POST = async (req: Request) => { try { const body: ChatRequestBody = await req.json(); if (!body.focusMode || !body.query) { return Response.json( { message: 'Missing focus mode or query' }, { status: 400 }, ); } body.history = body.history || []; body.optimizationMode = body.optimizationMode || 'balanced'; body.stream = body.stream || false; const history: BaseMessage[] = body.history.map((msg) => { return msg[0] === 'human' ? new HumanMessage({ content: msg[1] }) : new AIMessage({ content: msg[1] }); }); const [chatModelProviders, embeddingModelProviders] = await Promise.all([ getAvailableChatModelProviders(), getAvailableEmbeddingModelProviders(), ]); const chatModelProvider = body.chatModel?.provider || Object.keys(chatModelProviders)[0]; const chatModel = body.chatModel?.name || Object.keys(chatModelProviders[chatModelProvider])[0]; const embeddingModelProvider = body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0]; const embeddingModel = body.embeddingModel?.name || Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; let llm: BaseChatModel | undefined; let embeddings: Embeddings | undefined; if (body.chatModel?.provider === 'custom_openai') { llm = new ChatOpenAI({ modelName: body.chatModel?.name || getCustomOpenaiModelName(), openAIApiKey: body.chatModel?.customOpenAIKey || getCustomOpenaiApiKey(), temperature: 0.7, configuration: { baseURL: body.chatModel?.customOpenAIBaseURL || getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; } else if ( chatModelProviders[chatModelProvider] && chatModelProviders[chatModelProvider][chatModel] ) { llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; } if ( embeddingModelProviders[embeddingModelProvider] && embeddingModelProviders[embeddingModelProvider][embeddingModel] ) { embeddings = embeddingModelProviders[embeddingModelProvider][ embeddingModel ].model as Embeddings | undefined; } if (!llm || !embeddings) { return Response.json( { message: 'Invalid model selected' }, { status: 400 }, ); } const searchHandler: MetaSearchAgentType = searchHandlers[body.focusMode]; if (!searchHandler) { return Response.json({ message: 'Invalid focus mode' }, { status: 400 }); } const emitter = await searchHandler.searchAndAnswer( body.query, history, llm, embeddings, body.optimizationMode, [], "", ); if (!body.stream) { return new Promise( ( resolve: (value: Response) => void, reject: (value: Response) => void, ) => { let message = ''; let sources: any[] = []; emitter.on('data', (data: string) => { try { const parsedData = JSON.parse(data); if (parsedData.type === 'response') { message += parsedData.data; } else if (parsedData.type === 'sources') { sources = parsedData.data; } } catch (error) { reject( Response.json( { message: 'Error parsing data' }, { status: 500 }, ), ); } }); emitter.on('end', () => { resolve(Response.json({ message, sources }, { status: 200 })); }); emitter.on('error', (error: any) => { reject( Response.json( { message: 'Search error', error }, { status: 500 }, ), ); }); }, ); } const encoder = new TextEncoder(); const abortController = new AbortController(); const { signal } = abortController; const stream = new ReadableStream({ start(controller) { let sources: any[] = []; controller.enqueue( encoder.encode( JSON.stringify({ type: 'init', data: 'Stream connected', }) + '\n', ), ); signal.addEventListener('abort', () => { emitter.removeAllListeners(); try { controller.close(); } catch (error) {} }); emitter.on('data', (data: string) => { if (signal.aborted) return; try { const parsedData = JSON.parse(data); if (parsedData.type === 'response') { controller.enqueue( encoder.encode( JSON.stringify({ type: 'response', data: parsedData.data, }) + '\n', ), ); } else if (parsedData.type === 'sources') { sources = parsedData.data; controller.enqueue( encoder.encode( JSON.stringify({ type: 'sources', data: sources, }) + '\n', ), ); } } catch (error) { controller.error(error); } }); emitter.on('end', () => { if (signal.aborted) return; controller.enqueue( encoder.encode( JSON.stringify({ type: 'done', }) + '\n', ), ); controller.close(); }); emitter.on('error', (error: any) => { if (signal.aborted) return; controller.error(error); }); }, cancel() { abortController.abort(); }, }); return new Response(stream, { headers: { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache, no-transform', Connection: 'keep-alive', }, }); } catch (err: any) { console.error(`Error in getting search results: ${err.message}`); return Response.json( { message: 'An error has occurred.' }, { status: 500 }, ); } };