From c24edac16da9e8c720eb4cbc5f599e0b51b88ecc Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Wed, 19 Mar 2025 13:41:52 +0530 Subject: [PATCH] feat(app): add chat functionality --- ui/app/api/chat/route.ts | 346 +++++++++++++++++++++++++++++++++++ ui/app/api/models/route.ts | 47 +++++ ui/components/ChatWindow.tsx | 320 ++++++++++++++------------------ ui/lib/providers/index.ts | 8 +- 4 files changed, 535 insertions(+), 186 deletions(-) create mode 100644 ui/app/api/chat/route.ts create mode 100644 ui/app/api/models/route.ts diff --git a/ui/app/api/chat/route.ts b/ui/app/api/chat/route.ts new file mode 100644 index 0000000..0b130de --- /dev/null +++ b/ui/app/api/chat/route.ts @@ -0,0 +1,346 @@ +import prompts from '@/lib/prompts'; +import MetaSearchAgent from '@/lib/search/metaSearchAgent'; +import crypto from 'crypto'; +import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; +import { EventEmitter } from 'stream'; +import { chatModelProviders, embeddingModelProviders } from '@/lib/providers'; +import db from '@/lib/db'; +import { chats, messages as messagesSchema } from '@/lib/db/schema'; +import { and, eq, gt } from 'drizzle-orm'; +import { getFileDetails } from '@/lib/utils/files'; + +export const runtime = 'nodejs'; +export const dynamic = 'force-dynamic'; + +export const searchHandlers: Record = { + webSearch: new MetaSearchAgent({ + activeEngines: [], + queryGeneratorPrompt: prompts.webSearchRetrieverPrompt, + responsePrompt: prompts.webSearchResponsePrompt, + rerank: true, + rerankThreshold: 0.3, + searchWeb: true, + summarizer: true, + }), + academicSearch: new MetaSearchAgent({ + activeEngines: ['arxiv', 'google scholar', 'pubmed'], + queryGeneratorPrompt: prompts.academicSearchRetrieverPrompt, + responsePrompt: prompts.academicSearchResponsePrompt, + rerank: true, + rerankThreshold: 0, + searchWeb: true, + summarizer: false, + }), + writingAssistant: new MetaSearchAgent({ + activeEngines: [], + queryGeneratorPrompt: '', + responsePrompt: prompts.writingAssistantPrompt, + rerank: true, + rerankThreshold: 0, + searchWeb: false, + summarizer: false, + }), + wolframAlphaSearch: new MetaSearchAgent({ + activeEngines: ['wolframalpha'], + queryGeneratorPrompt: prompts.wolframAlphaSearchRetrieverPrompt, + responsePrompt: prompts.wolframAlphaSearchResponsePrompt, + rerank: false, + rerankThreshold: 0, + searchWeb: true, + summarizer: false, + }), + youtubeSearch: new MetaSearchAgent({ + activeEngines: ['youtube'], + queryGeneratorPrompt: prompts.youtubeSearchRetrieverPrompt, + responsePrompt: prompts.youtubeSearchResponsePrompt, + rerank: true, + rerankThreshold: 0.3, + searchWeb: true, + summarizer: false, + }), + redditSearch: new MetaSearchAgent({ + activeEngines: ['reddit'], + queryGeneratorPrompt: prompts.redditSearchRetrieverPrompt, + responsePrompt: prompts.redditSearchResponsePrompt, + rerank: true, + rerankThreshold: 0.3, + searchWeb: true, + summarizer: false, + }), +}; + +type Message = { + messageId: string; + chatId: string; + content: string; +}; + +type ChatModel = { + provider: string; + name: string; +}; + +type EmbeddingModel = { + provider: string; + name: string; +}; + +type Body = { + message: Message; + optimizationMode: 'speed' | 'balanced' | 'quality'; + focusMode: string; + history: Array<[string, string]>; + files: Array; + chatModel: ChatModel; + embeddingModel: EmbeddingModel; +}; + +const handleEmitterEvents = async ( + stream: EventEmitter, + writer: WritableStreamDefaultWriter, + encoder: TextEncoder, + aiMessageId: string, + chatId: string, +) => { + let recievedMessage = ''; + let sources: any[] = []; + + stream.on('data', (data) => { + const parsedData = JSON.parse(data); + if (parsedData.type === 'response') { + writer.write( + encoder.encode( + JSON.stringify({ + type: 'message', + data: parsedData.data, + messageId: aiMessageId, + }) + '\n', + ), + ); + + recievedMessage += parsedData.data; + } else if (parsedData.type === 'sources') { + writer.write( + encoder.encode( + JSON.stringify({ + type: 'sources', + data: parsedData.data, + messageId: aiMessageId, + }) + '\n', + ), + ); + + sources = parsedData.data; + } + }); + stream.on('end', () => { + writer.write( + encoder.encode( + JSON.stringify({ + type: 'messageEnd', + messageId: aiMessageId, + }) + '\n', + ), + ); + writer.close(); + + db.insert(messagesSchema) + .values({ + content: recievedMessage, + chatId: chatId, + messageId: aiMessageId, + role: 'assistant', + metadata: JSON.stringify({ + createdAt: new Date(), + ...(sources && sources.length > 0 && { sources }), + }), + }) + .execute(); + }); + stream.on('error', (data) => { + const parsedData = JSON.parse(data); + writer.write( + encoder.encode( + JSON.stringify({ + type: 'error', + data: parsedData.data, + }), + ), + ); + writer.close(); + }); +}; + +const handleHistorySave = async ( + message: Message, + humanMessageId: string, + focusMode: string, + files: string[], +) => { + const chat = await db.query.chats.findFirst({ + where: eq(chats.id, message.chatId), + }); + + if (!chat) { + await db + .insert(chats) + .values({ + id: message.chatId, + title: message.content, + createdAt: new Date().toString(), + focusMode: focusMode, + files: files.map(getFileDetails), + }) + .execute(); + } + + const messageExists = await db.query.messages.findFirst({ + where: eq(messagesSchema.messageId, humanMessageId), + }); + + if (!messageExists) { + await db + .insert(messagesSchema) + .values({ + content: message.content, + chatId: message.chatId, + messageId: humanMessageId, + role: 'user', + metadata: JSON.stringify({ + createdAt: new Date(), + }), + }) + .execute(); + } else { + await db + .delete(messagesSchema) + .where( + and( + gt(messagesSchema.id, messageExists.id), + eq(messagesSchema.chatId, message.chatId), + ), + ) + .execute(); + } +}; + +export const POST = async (req: Request) => { + try { + const body = (await req.json()) as Body; + const { message, chatModel, embeddingModel } = body; + + if (message.content === '') { + return Response.json( + { + message: 'Please provide a message to process', + }, + { status: 400 }, + ); + } + + const getProviderChatModels = chatModelProviders[chatModel.provider]; + + if (!getProviderChatModels) { + return Response.json( + { + message: 'Invalid chat model provider', + }, + { status: 400 }, + ); + } + + const chatModels = await getProviderChatModels(); + + const llm = chatModels[chatModel.name].model; + + if (!llm) { + return Response.json( + { + message: 'Invalid chat model', + }, + { status: 400 }, + ); + } + + const getProviderEmbeddingModels = + embeddingModelProviders[embeddingModel.provider]; + + if (!getProviderEmbeddingModels) { + return Response.json( + { + message: 'Invalid embedding model provider', + }, + { status: 400 }, + ); + } + + const embeddingModels = await getProviderEmbeddingModels(); + const embedding = embeddingModels[embeddingModel.name].model; + + if (!embedding) { + return Response.json( + { + message: 'Invalid embedding model', + }, + { status: 400 }, + ); + } + + const humanMessageId = + message.messageId ?? crypto.randomBytes(7).toString('hex'); + const aiMessageId = crypto.randomBytes(7).toString('hex'); + + const history: BaseMessage[] = body.history.map((msg) => { + if (msg[0] === 'human') { + return new HumanMessage({ + content: msg[1], + }); + } else { + return new AIMessage({ + content: msg[1], + }); + } + }); + + const handler = searchHandlers[body.focusMode]; + + if (!handler) { + return Response.json( + { + message: 'Invalid focus mode', + }, + { status: 400 }, + ); + } + + const stream = await handler.searchAndAnswer( + message.content, + history, + llm, + embedding, + body.optimizationMode, + body.files, + ); + + const responseStream = new TransformStream(); + const writer = responseStream.writable.getWriter(); + const encoder = new TextEncoder(); + + handleEmitterEvents(stream, writer, encoder, aiMessageId, message.chatId); + handleHistorySave(message, humanMessageId, body.focusMode, body.files); + + return new Response(responseStream.readable, { + headers: { + 'Content-Type': 'text/event-stream', + Connection: 'keep-alive', + 'Cache-Control': 'no-cache, no-transform', + }, + }); + } catch (err) { + console.error('An error ocurred while processing chat request:', err); + return Response.json( + { message: 'An error ocurred while processing chat request' }, + { status: 500 }, + ); + } +}; diff --git a/ui/app/api/models/route.ts b/ui/app/api/models/route.ts new file mode 100644 index 0000000..a5e5b43 --- /dev/null +++ b/ui/app/api/models/route.ts @@ -0,0 +1,47 @@ +import { + getAvailableChatModelProviders, + getAvailableEmbeddingModelProviders, +} from '@/lib/providers'; + +export const GET = async (req: Request) => { + try { + const [chatModelProviders, embeddingModelProviders] = await Promise.all([ + getAvailableChatModelProviders(), + getAvailableEmbeddingModelProviders(), + ]); + + Object.keys(chatModelProviders).forEach((provider) => { + Object.keys(chatModelProviders[provider]).forEach((model) => { + delete (chatModelProviders[provider][model] as { model?: unknown }) + .model; + }); + }); + + Object.keys(embeddingModelProviders).forEach((provider) => { + Object.keys(embeddingModelProviders[provider]).forEach((model) => { + delete (embeddingModelProviders[provider][model] as { model?: unknown }) + .model; + }); + }); + + return Response.json( + { + chatModelProviders, + embeddingModelProviders, + }, + { + status: 200, + }, + ); + } catch (err) { + console.error('An error ocurred while fetching models', err); + return Response.json( + { + message: 'An error has occurred.', + }, + { + status: 500, + }, + ); + } +}; diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 1940f42..f642525 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -29,29 +29,24 @@ export interface File { fileId: string; } -const useSocket = ( - url: string, - setIsWSReady: (ready: boolean) => void, - setError: (error: boolean) => void, +interface ChatModelProvider { + name: string; + provider: string; +} + +interface EmbeddingModelProvider { + name: string; + provider: string; +} + +const checkConfig = async ( + setChatModelProvider: (provider: ChatModelProvider) => void, + setEmbeddingModelProvider: (provider: EmbeddingModelProvider) => void, + setIsConfigReady: (ready: boolean) => void, + setHasError: (hasError: boolean) => void, ) => { - const wsRef = useRef(null); - const reconnectTimeoutRef = useRef(); - const retryCountRef = useRef(0); - const isCleaningUpRef = useRef(false); - const MAX_RETRIES = 3; - const INITIAL_BACKOFF = 1000; // 1 second - const isConnectionErrorRef = useRef(false); - - const getBackoffDelay = (retryCount: number) => { - return Math.min(INITIAL_BACKOFF * Math.pow(2, retryCount), 10000); // Cap at 10 seconds - }; - useEffect(() => { - const connectWs = async () => { - if (wsRef.current?.readyState === WebSocket.OPEN) { - wsRef.current.close(); - } - + const checkConfig = async () => { try { let chatModel = localStorage.getItem('chatModel'); let chatModelProvider = localStorage.getItem('chatModelProvider'); @@ -71,14 +66,11 @@ const useSocket = ( localStorage.setItem('autoVideoSearch', 'false'); } - const providers = await fetch( - `${process.env.NEXT_PUBLIC_API_URL}/models`, - { - headers: { - 'Content-Type': 'application/json', - }, + const providers = await fetch(`/api/models`, { + headers: { + 'Content-Type': 'application/json', }, - ).then(async (res) => { + }).then(async (res) => { if (!res.ok) throw new Error( `Failed to fetch models: ${res.status} ${res.statusText}`, @@ -182,127 +174,30 @@ const useSocket = ( } } - const wsURL = new URL(url); - const searchParams = new URLSearchParams({}); - - searchParams.append('chatModel', chatModel!); - searchParams.append('chatModelProvider', chatModelProvider); - - if (chatModelProvider === 'custom_openai') { - searchParams.append( - 'openAIApiKey', - localStorage.getItem('openAIApiKey')!, - ); - searchParams.append( - 'openAIBaseURL', - localStorage.getItem('openAIBaseURL')!, - ); - } - - searchParams.append('embeddingModel', embeddingModel!); - searchParams.append('embeddingModelProvider', embeddingModelProvider); - - wsURL.search = searchParams.toString(); - - const ws = new WebSocket(wsURL.toString()); - wsRef.current = ws; - - const timeoutId = setTimeout(() => { - if (ws.readyState !== 1) { - toast.error( - 'Failed to connect to the server. Please try again later.', - ); - } - }, 10000); - - ws.addEventListener('message', (e) => { - const data = JSON.parse(e.data); - if (data.type === 'signal' && data.data === 'open') { - const interval = setInterval(() => { - if (ws.readyState === 1) { - setIsWSReady(true); - setError(false); - if (retryCountRef.current > 0) { - toast.success('Connection restored.'); - } - retryCountRef.current = 0; - clearInterval(interval); - } - }, 5); - clearTimeout(timeoutId); - console.debug(new Date(), 'ws:connected'); - } - if (data.type === 'error') { - isConnectionErrorRef.current = true; - setError(true); - toast.error(data.data); - } + setChatModelProvider({ + name: chatModel!, + provider: chatModelProvider, }); - ws.onerror = () => { - clearTimeout(timeoutId); - setIsWSReady(false); - toast.error('WebSocket connection error.'); - }; + setEmbeddingModelProvider({ + name: embeddingModel!, + provider: embeddingModelProvider, + }); - ws.onclose = () => { - clearTimeout(timeoutId); - setIsWSReady(false); - console.debug(new Date(), 'ws:disconnected'); - if (!isCleaningUpRef.current && !isConnectionErrorRef.current) { - toast.error('Connection lost. Attempting to reconnect...'); - attemptReconnect(); - } - }; - } catch (error) { - console.debug(new Date(), 'ws:error', error); - setIsWSReady(false); - attemptReconnect(); - } - }; - - const attemptReconnect = () => { - retryCountRef.current += 1; - - if (retryCountRef.current > MAX_RETRIES) { - console.debug(new Date(), 'ws:max_retries'); - setError(true); - toast.error( - 'Unable to connect to server after multiple attempts. Please refresh the page to try again.', + setIsConfigReady(true); + } catch (err) { + console.error( + 'An error occurred while checking the configuration:', + err, ); - return; - } - - const backoffDelay = getBackoffDelay(retryCountRef.current); - console.debug( - new Date(), - `ws:retry attempt=${retryCountRef.current}/${MAX_RETRIES} delay=${backoffDelay}ms`, - ); - - if (reconnectTimeoutRef.current) { - clearTimeout(reconnectTimeoutRef.current); - } - - reconnectTimeoutRef.current = setTimeout(() => { - connectWs(); - }, backoffDelay); - }; - - connectWs(); - - return () => { - if (reconnectTimeoutRef.current) { - clearTimeout(reconnectTimeoutRef.current); - } - if (wsRef.current?.readyState === WebSocket.OPEN) { - wsRef.current.close(); - isCleaningUpRef.current = true; - console.debug(new Date(), 'ws:cleanup'); + setIsConfigReady(false); + setHasError(true); } }; - }, [url, setIsWSReady, setError]); - return wsRef.current; + checkConfig(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); }; const loadMessages = async ( @@ -315,15 +210,12 @@ const loadMessages = async ( setFiles: (files: File[]) => void, setFileIds: (fileIds: string[]) => void, ) => { - const res = await fetch( - `${process.env.NEXT_PUBLIC_API_URL}/chats/${chatId}`, - { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, + const res = await fetch(`/api/chats/${chatId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', }, - ); + }); if (res.status === 404) { setNotFound(true); @@ -373,13 +265,27 @@ const ChatWindow = ({ id }: { id?: string }) => { const [chatId, setChatId] = useState(id); const [newChatCreated, setNewChatCreated] = useState(false); + const [chatModelProvider, setChatModelProvider] = useState( + { + name: '', + provider: '', + }, + ); + + const [embeddingModelProvider, setEmbeddingModelProvider] = + useState({ + name: '', + provider: '', + }); + + const [isConfigReady, setIsConfigReady] = useState(false); const [hasError, setHasError] = useState(false); const [isReady, setIsReady] = useState(false); - const [isWSReady, setIsWSReady] = useState(false); - const ws = useSocket( - process.env.NEXT_PUBLIC_WS_URL!, - setIsWSReady, + checkConfig( + setChatModelProvider, + setEmbeddingModelProvider, + setIsConfigReady, setHasError, ); @@ -399,8 +305,6 @@ const ChatWindow = ({ id }: { id?: string }) => { const [notFound, setNotFound] = useState(false); - const [isSettingsOpen, setIsSettingsOpen] = useState(false); - useEffect(() => { if ( chatId && @@ -426,16 +330,6 @@ const ChatWindow = ({ id }: { id?: string }) => { // eslint-disable-next-line react-hooks/exhaustive-deps }, []); - useEffect(() => { - return () => { - if (ws?.readyState === 1) { - ws.close(); - console.debug(new Date(), 'ws:cleanup'); - } - }; - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - const messagesRef = useRef([]); useEffect(() => { @@ -443,18 +337,18 @@ const ChatWindow = ({ id }: { id?: string }) => { }, [messages]); useEffect(() => { - if (isMessagesLoaded && isWSReady) { + if (isMessagesLoaded && isConfigReady) { setIsReady(true); console.debug(new Date(), 'app:ready'); } else { setIsReady(false); } - }, [isMessagesLoaded, isWSReady]); + }, [isMessagesLoaded, isConfigReady]); const sendMessage = async (message: string, messageId?: string) => { if (loading) return; - if (!ws || ws.readyState !== WebSocket.OPEN) { - toast.error('Cannot send message while disconnected'); + if (!isConfigReady) { + toast.error('Cannot send message before the configuration is ready'); return; } @@ -467,18 +361,27 @@ const ChatWindow = ({ id }: { id?: string }) => { messageId = messageId ?? crypto.randomBytes(7).toString('hex'); - ws.send( + console.log( JSON.stringify({ - type: 'message', + content: message, message: { messageId: messageId, chatId: chatId!, content: message, }, + chatId: chatId!, files: fileIds, focusMode: focusMode, optimizationMode: optimizationMode, - history: [...chatHistory, ['human', message]], + history: chatHistory, + chatModel: { + name: chatModelProvider.name, + provider: chatModelProvider.provider, + }, + embeddingModel: { + name: embeddingModelProvider.name, + provider: embeddingModelProvider.provider, + }, }), ); @@ -493,9 +396,7 @@ const ChatWindow = ({ id }: { id?: string }) => { }, ]); - const messageHandler = async (e: MessageEvent) => { - const data = JSON.parse(e.data); - + const messageHandler = async (data: any) => { if (data.type === 'error') { toast.error(data.data); setLoading(false); @@ -558,7 +459,6 @@ const ChatWindow = ({ id }: { id?: string }) => { ['assistant', recievedMessage], ]); - ws?.removeEventListener('message', messageHandler); setLoading(false); const lastMsg = messagesRef.current[messagesRef.current.length - 1]; @@ -584,16 +484,72 @@ const ChatWindow = ({ id }: { id?: string }) => { const autoVideoSearch = localStorage.getItem('autoVideoSearch'); if (autoImageSearch === 'true') { - document.getElementById('search-images')?.click(); + document + .getElementById(`search-images-${lastMsg.messageId}`) + ?.click(); } if (autoVideoSearch === 'true') { - document.getElementById('search-videos')?.click(); + document + .getElementById(`search-videos-${lastMsg.messageId}`) + ?.click(); } } }; - ws?.addEventListener('message', messageHandler); + const res = await fetch('/api/chat', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + content: message, + message: { + messageId: messageId, + chatId: chatId!, + content: message, + }, + chatId: chatId!, + files: fileIds, + focusMode: focusMode, + optimizationMode: optimizationMode, + history: chatHistory, + chatModel: { + name: chatModelProvider.name, + provider: chatModelProvider.provider, + }, + embeddingModel: { + name: embeddingModelProvider.name, + provider: embeddingModelProvider.provider, + }, + }), + }); + + if (!res.body) throw new Error('No response body'); + + const reader = res.body?.getReader(); + const decoder = new TextDecoder('utf-8'); + + let partialChunk = ''; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + partialChunk += decoder.decode(value, { stream: true }); + + try { + const messages = partialChunk.split('\n'); + for (const msg of messages) { + if (!msg.trim()) continue; + const json = JSON.parse(msg); + messageHandler(json); + } + partialChunk = ''; + } catch (error) { + console.warn('Incomplete JSON, waiting for next chunk...'); + } + } }; const rewrite = (messageId: string) => { @@ -614,11 +570,11 @@ const ChatWindow = ({ id }: { id?: string }) => { }; useEffect(() => { - if (isReady && initialMessage && ws?.readyState === 1) { + if (isReady && initialMessage && isConfigReady) { sendMessage(initialMessage); } // eslint-disable-next-line react-hooks/exhaustive-deps - }, [ws?.readyState, isReady, initialMessage, isWSReady]); + }, [isConfigReady, isReady, initialMessage]); if (hasError) { return ( diff --git a/ui/lib/providers/index.ts b/ui/lib/providers/index.ts index caa8074..e45c09d 100644 --- a/ui/lib/providers/index.ts +++ b/ui/lib/providers/index.ts @@ -22,7 +22,7 @@ export interface EmbeddingModel { model: Embeddings; } -const chatModelProviders: Record< +export const chatModelProviders: Record< string, () => Promise> > = { @@ -30,16 +30,16 @@ const chatModelProviders: Record< ollama: loadOllamaChatModels, groq: loadGroqChatModels, anthropic: loadAnthropicChatModels, - gemini: loadGeminiChatModels + gemini: loadGeminiChatModels, }; -const embeddingModelProviders: Record< +export const embeddingModelProviders: Record< string, () => Promise> > = { openai: loadOpenAIEmbeddingModels, ollama: loadOllamaEmbeddingModels, - gemini: loadGeminiEmbeddingModels + gemini: loadGeminiEmbeddingModels, }; export const getAvailableChatModelProviders = async () => {