feat(app): add chat functionality

This commit is contained in:
ItzCrazyKns
2025-03-19 13:41:52 +05:30
parent 3150c21f17
commit c24edac16d
4 changed files with 535 additions and 186 deletions

346
ui/app/api/chat/route.ts Normal file
View File

@ -0,0 +1,346 @@
import prompts from '@/lib/prompts';
import MetaSearchAgent from '@/lib/search/metaSearchAgent';
import crypto from 'crypto';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { EventEmitter } from 'stream';
import { chatModelProviders, embeddingModelProviders } from '@/lib/providers';
import db from '@/lib/db';
import { chats, messages as messagesSchema } from '@/lib/db/schema';
import { and, eq, gt } from 'drizzle-orm';
import { getFileDetails } from '@/lib/utils/files';
export const runtime = 'nodejs';
export const dynamic = 'force-dynamic';
export const searchHandlers: Record<string, MetaSearchAgent> = {
webSearch: new MetaSearchAgent({
activeEngines: [],
queryGeneratorPrompt: prompts.webSearchRetrieverPrompt,
responsePrompt: prompts.webSearchResponsePrompt,
rerank: true,
rerankThreshold: 0.3,
searchWeb: true,
summarizer: true,
}),
academicSearch: new MetaSearchAgent({
activeEngines: ['arxiv', 'google scholar', 'pubmed'],
queryGeneratorPrompt: prompts.academicSearchRetrieverPrompt,
responsePrompt: prompts.academicSearchResponsePrompt,
rerank: true,
rerankThreshold: 0,
searchWeb: true,
summarizer: false,
}),
writingAssistant: new MetaSearchAgent({
activeEngines: [],
queryGeneratorPrompt: '',
responsePrompt: prompts.writingAssistantPrompt,
rerank: true,
rerankThreshold: 0,
searchWeb: false,
summarizer: false,
}),
wolframAlphaSearch: new MetaSearchAgent({
activeEngines: ['wolframalpha'],
queryGeneratorPrompt: prompts.wolframAlphaSearchRetrieverPrompt,
responsePrompt: prompts.wolframAlphaSearchResponsePrompt,
rerank: false,
rerankThreshold: 0,
searchWeb: true,
summarizer: false,
}),
youtubeSearch: new MetaSearchAgent({
activeEngines: ['youtube'],
queryGeneratorPrompt: prompts.youtubeSearchRetrieverPrompt,
responsePrompt: prompts.youtubeSearchResponsePrompt,
rerank: true,
rerankThreshold: 0.3,
searchWeb: true,
summarizer: false,
}),
redditSearch: new MetaSearchAgent({
activeEngines: ['reddit'],
queryGeneratorPrompt: prompts.redditSearchRetrieverPrompt,
responsePrompt: prompts.redditSearchResponsePrompt,
rerank: true,
rerankThreshold: 0.3,
searchWeb: true,
summarizer: false,
}),
};
type Message = {
messageId: string;
chatId: string;
content: string;
};
type ChatModel = {
provider: string;
name: string;
};
type EmbeddingModel = {
provider: string;
name: string;
};
type Body = {
message: Message;
optimizationMode: 'speed' | 'balanced' | 'quality';
focusMode: string;
history: Array<[string, string]>;
files: Array<string>;
chatModel: ChatModel;
embeddingModel: EmbeddingModel;
};
const handleEmitterEvents = async (
stream: EventEmitter,
writer: WritableStreamDefaultWriter,
encoder: TextEncoder,
aiMessageId: string,
chatId: string,
) => {
let recievedMessage = '';
let sources: any[] = [];
stream.on('data', (data) => {
const parsedData = JSON.parse(data);
if (parsedData.type === 'response') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'message',
data: parsedData.data,
messageId: aiMessageId,
}) + '\n',
),
);
recievedMessage += parsedData.data;
} else if (parsedData.type === 'sources') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'sources',
data: parsedData.data,
messageId: aiMessageId,
}) + '\n',
),
);
sources = parsedData.data;
}
});
stream.on('end', () => {
writer.write(
encoder.encode(
JSON.stringify({
type: 'messageEnd',
messageId: aiMessageId,
}) + '\n',
),
);
writer.close();
db.insert(messagesSchema)
.values({
content: recievedMessage,
chatId: chatId,
messageId: aiMessageId,
role: 'assistant',
metadata: JSON.stringify({
createdAt: new Date(),
...(sources && sources.length > 0 && { sources }),
}),
})
.execute();
});
stream.on('error', (data) => {
const parsedData = JSON.parse(data);
writer.write(
encoder.encode(
JSON.stringify({
type: 'error',
data: parsedData.data,
}),
),
);
writer.close();
});
};
const handleHistorySave = async (
message: Message,
humanMessageId: string,
focusMode: string,
files: string[],
) => {
const chat = await db.query.chats.findFirst({
where: eq(chats.id, message.chatId),
});
if (!chat) {
await db
.insert(chats)
.values({
id: message.chatId,
title: message.content,
createdAt: new Date().toString(),
focusMode: focusMode,
files: files.map(getFileDetails),
})
.execute();
}
const messageExists = await db.query.messages.findFirst({
where: eq(messagesSchema.messageId, humanMessageId),
});
if (!messageExists) {
await db
.insert(messagesSchema)
.values({
content: message.content,
chatId: message.chatId,
messageId: humanMessageId,
role: 'user',
metadata: JSON.stringify({
createdAt: new Date(),
}),
})
.execute();
} else {
await db
.delete(messagesSchema)
.where(
and(
gt(messagesSchema.id, messageExists.id),
eq(messagesSchema.chatId, message.chatId),
),
)
.execute();
}
};
export const POST = async (req: Request) => {
try {
const body = (await req.json()) as Body;
const { message, chatModel, embeddingModel } = body;
if (message.content === '') {
return Response.json(
{
message: 'Please provide a message to process',
},
{ status: 400 },
);
}
const getProviderChatModels = chatModelProviders[chatModel.provider];
if (!getProviderChatModels) {
return Response.json(
{
message: 'Invalid chat model provider',
},
{ status: 400 },
);
}
const chatModels = await getProviderChatModels();
const llm = chatModels[chatModel.name].model;
if (!llm) {
return Response.json(
{
message: 'Invalid chat model',
},
{ status: 400 },
);
}
const getProviderEmbeddingModels =
embeddingModelProviders[embeddingModel.provider];
if (!getProviderEmbeddingModels) {
return Response.json(
{
message: 'Invalid embedding model provider',
},
{ status: 400 },
);
}
const embeddingModels = await getProviderEmbeddingModels();
const embedding = embeddingModels[embeddingModel.name].model;
if (!embedding) {
return Response.json(
{
message: 'Invalid embedding model',
},
{ status: 400 },
);
}
const humanMessageId =
message.messageId ?? crypto.randomBytes(7).toString('hex');
const aiMessageId = crypto.randomBytes(7).toString('hex');
const history: BaseMessage[] = body.history.map((msg) => {
if (msg[0] === 'human') {
return new HumanMessage({
content: msg[1],
});
} else {
return new AIMessage({
content: msg[1],
});
}
});
const handler = searchHandlers[body.focusMode];
if (!handler) {
return Response.json(
{
message: 'Invalid focus mode',
},
{ status: 400 },
);
}
const stream = await handler.searchAndAnswer(
message.content,
history,
llm,
embedding,
body.optimizationMode,
body.files,
);
const responseStream = new TransformStream();
const writer = responseStream.writable.getWriter();
const encoder = new TextEncoder();
handleEmitterEvents(stream, writer, encoder, aiMessageId, message.chatId);
handleHistorySave(message, humanMessageId, body.focusMode, body.files);
return new Response(responseStream.readable, {
headers: {
'Content-Type': 'text/event-stream',
Connection: 'keep-alive',
'Cache-Control': 'no-cache, no-transform',
},
});
} catch (err) {
console.error('An error ocurred while processing chat request:', err);
return Response.json(
{ message: 'An error ocurred while processing chat request' },
{ status: 500 },
);
}
};

View File

@ -0,0 +1,47 @@
import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
} from '@/lib/providers';
export const GET = async (req: Request) => {
try {
const [chatModelProviders, embeddingModelProviders] = await Promise.all([
getAvailableChatModelProviders(),
getAvailableEmbeddingModelProviders(),
]);
Object.keys(chatModelProviders).forEach((provider) => {
Object.keys(chatModelProviders[provider]).forEach((model) => {
delete (chatModelProviders[provider][model] as { model?: unknown })
.model;
});
});
Object.keys(embeddingModelProviders).forEach((provider) => {
Object.keys(embeddingModelProviders[provider]).forEach((model) => {
delete (embeddingModelProviders[provider][model] as { model?: unknown })
.model;
});
});
return Response.json(
{
chatModelProviders,
embeddingModelProviders,
},
{
status: 200,
},
);
} catch (err) {
console.error('An error ocurred while fetching models', err);
return Response.json(
{
message: 'An error has occurred.',
},
{
status: 500,
},
);
}
};

View File

@ -29,29 +29,24 @@ export interface File {
fileId: string; fileId: string;
} }
const useSocket = ( interface ChatModelProvider {
url: string, name: string;
setIsWSReady: (ready: boolean) => void, provider: string;
setError: (error: boolean) => void,
) => {
const wsRef = useRef<WebSocket | null>(null);
const reconnectTimeoutRef = useRef<NodeJS.Timeout>();
const retryCountRef = useRef(0);
const isCleaningUpRef = useRef(false);
const MAX_RETRIES = 3;
const INITIAL_BACKOFF = 1000; // 1 second
const isConnectionErrorRef = useRef(false);
const getBackoffDelay = (retryCount: number) => {
return Math.min(INITIAL_BACKOFF * Math.pow(2, retryCount), 10000); // Cap at 10 seconds
};
useEffect(() => {
const connectWs = async () => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.close();
} }
interface EmbeddingModelProvider {
name: string;
provider: string;
}
const checkConfig = async (
setChatModelProvider: (provider: ChatModelProvider) => void,
setEmbeddingModelProvider: (provider: EmbeddingModelProvider) => void,
setIsConfigReady: (ready: boolean) => void,
setHasError: (hasError: boolean) => void,
) => {
useEffect(() => {
const checkConfig = async () => {
try { try {
let chatModel = localStorage.getItem('chatModel'); let chatModel = localStorage.getItem('chatModel');
let chatModelProvider = localStorage.getItem('chatModelProvider'); let chatModelProvider = localStorage.getItem('chatModelProvider');
@ -71,14 +66,11 @@ const useSocket = (
localStorage.setItem('autoVideoSearch', 'false'); localStorage.setItem('autoVideoSearch', 'false');
} }
const providers = await fetch( const providers = await fetch(`/api/models`, {
`${process.env.NEXT_PUBLIC_API_URL}/models`,
{
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
}, }).then(async (res) => {
).then(async (res) => {
if (!res.ok) if (!res.ok)
throw new Error( throw new Error(
`Failed to fetch models: ${res.status} ${res.statusText}`, `Failed to fetch models: ${res.status} ${res.statusText}`,
@ -182,127 +174,30 @@ const useSocket = (
} }
} }
const wsURL = new URL(url); setChatModelProvider({
const searchParams = new URLSearchParams({}); name: chatModel!,
provider: chatModelProvider,
searchParams.append('chatModel', chatModel!);
searchParams.append('chatModelProvider', chatModelProvider);
if (chatModelProvider === 'custom_openai') {
searchParams.append(
'openAIApiKey',
localStorage.getItem('openAIApiKey')!,
);
searchParams.append(
'openAIBaseURL',
localStorage.getItem('openAIBaseURL')!,
);
}
searchParams.append('embeddingModel', embeddingModel!);
searchParams.append('embeddingModelProvider', embeddingModelProvider);
wsURL.search = searchParams.toString();
const ws = new WebSocket(wsURL.toString());
wsRef.current = ws;
const timeoutId = setTimeout(() => {
if (ws.readyState !== 1) {
toast.error(
'Failed to connect to the server. Please try again later.',
);
}
}, 10000);
ws.addEventListener('message', (e) => {
const data = JSON.parse(e.data);
if (data.type === 'signal' && data.data === 'open') {
const interval = setInterval(() => {
if (ws.readyState === 1) {
setIsWSReady(true);
setError(false);
if (retryCountRef.current > 0) {
toast.success('Connection restored.');
}
retryCountRef.current = 0;
clearInterval(interval);
}
}, 5);
clearTimeout(timeoutId);
console.debug(new Date(), 'ws:connected');
}
if (data.type === 'error') {
isConnectionErrorRef.current = true;
setError(true);
toast.error(data.data);
}
}); });
ws.onerror = () => { setEmbeddingModelProvider({
clearTimeout(timeoutId); name: embeddingModel!,
setIsWSReady(false); provider: embeddingModelProvider,
toast.error('WebSocket connection error.'); });
};
ws.onclose = () => { setIsConfigReady(true);
clearTimeout(timeoutId); } catch (err) {
setIsWSReady(false); console.error(
console.debug(new Date(), 'ws:disconnected'); 'An error occurred while checking the configuration:',
if (!isCleaningUpRef.current && !isConnectionErrorRef.current) { err,
toast.error('Connection lost. Attempting to reconnect...');
attemptReconnect();
}
};
} catch (error) {
console.debug(new Date(), 'ws:error', error);
setIsWSReady(false);
attemptReconnect();
}
};
const attemptReconnect = () => {
retryCountRef.current += 1;
if (retryCountRef.current > MAX_RETRIES) {
console.debug(new Date(), 'ws:max_retries');
setError(true);
toast.error(
'Unable to connect to server after multiple attempts. Please refresh the page to try again.',
); );
return; setIsConfigReady(false);
} setHasError(true);
const backoffDelay = getBackoffDelay(retryCountRef.current);
console.debug(
new Date(),
`ws:retry attempt=${retryCountRef.current}/${MAX_RETRIES} delay=${backoffDelay}ms`,
);
if (reconnectTimeoutRef.current) {
clearTimeout(reconnectTimeoutRef.current);
}
reconnectTimeoutRef.current = setTimeout(() => {
connectWs();
}, backoffDelay);
};
connectWs();
return () => {
if (reconnectTimeoutRef.current) {
clearTimeout(reconnectTimeoutRef.current);
}
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.close();
isCleaningUpRef.current = true;
console.debug(new Date(), 'ws:cleanup');
} }
}; };
}, [url, setIsWSReady, setError]);
return wsRef.current; checkConfig();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
}; };
const loadMessages = async ( const loadMessages = async (
@ -315,15 +210,12 @@ const loadMessages = async (
setFiles: (files: File[]) => void, setFiles: (files: File[]) => void,
setFileIds: (fileIds: string[]) => void, setFileIds: (fileIds: string[]) => void,
) => { ) => {
const res = await fetch( const res = await fetch(`/api/chats/${chatId}`, {
`${process.env.NEXT_PUBLIC_API_URL}/chats/${chatId}`,
{
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
}, });
);
if (res.status === 404) { if (res.status === 404) {
setNotFound(true); setNotFound(true);
@ -373,13 +265,27 @@ const ChatWindow = ({ id }: { id?: string }) => {
const [chatId, setChatId] = useState<string | undefined>(id); const [chatId, setChatId] = useState<string | undefined>(id);
const [newChatCreated, setNewChatCreated] = useState(false); const [newChatCreated, setNewChatCreated] = useState(false);
const [chatModelProvider, setChatModelProvider] = useState<ChatModelProvider>(
{
name: '',
provider: '',
},
);
const [embeddingModelProvider, setEmbeddingModelProvider] =
useState<EmbeddingModelProvider>({
name: '',
provider: '',
});
const [isConfigReady, setIsConfigReady] = useState(false);
const [hasError, setHasError] = useState(false); const [hasError, setHasError] = useState(false);
const [isReady, setIsReady] = useState(false); const [isReady, setIsReady] = useState(false);
const [isWSReady, setIsWSReady] = useState(false); checkConfig(
const ws = useSocket( setChatModelProvider,
process.env.NEXT_PUBLIC_WS_URL!, setEmbeddingModelProvider,
setIsWSReady, setIsConfigReady,
setHasError, setHasError,
); );
@ -399,8 +305,6 @@ const ChatWindow = ({ id }: { id?: string }) => {
const [notFound, setNotFound] = useState(false); const [notFound, setNotFound] = useState(false);
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
useEffect(() => { useEffect(() => {
if ( if (
chatId && chatId &&
@ -426,16 +330,6 @@ const ChatWindow = ({ id }: { id?: string }) => {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []); }, []);
useEffect(() => {
return () => {
if (ws?.readyState === 1) {
ws.close();
console.debug(new Date(), 'ws:cleanup');
}
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const messagesRef = useRef<Message[]>([]); const messagesRef = useRef<Message[]>([]);
useEffect(() => { useEffect(() => {
@ -443,18 +337,18 @@ const ChatWindow = ({ id }: { id?: string }) => {
}, [messages]); }, [messages]);
useEffect(() => { useEffect(() => {
if (isMessagesLoaded && isWSReady) { if (isMessagesLoaded && isConfigReady) {
setIsReady(true); setIsReady(true);
console.debug(new Date(), 'app:ready'); console.debug(new Date(), 'app:ready');
} else { } else {
setIsReady(false); setIsReady(false);
} }
}, [isMessagesLoaded, isWSReady]); }, [isMessagesLoaded, isConfigReady]);
const sendMessage = async (message: string, messageId?: string) => { const sendMessage = async (message: string, messageId?: string) => {
if (loading) return; if (loading) return;
if (!ws || ws.readyState !== WebSocket.OPEN) { if (!isConfigReady) {
toast.error('Cannot send message while disconnected'); toast.error('Cannot send message before the configuration is ready');
return; return;
} }
@ -467,18 +361,27 @@ const ChatWindow = ({ id }: { id?: string }) => {
messageId = messageId ?? crypto.randomBytes(7).toString('hex'); messageId = messageId ?? crypto.randomBytes(7).toString('hex');
ws.send( console.log(
JSON.stringify({ JSON.stringify({
type: 'message', content: message,
message: { message: {
messageId: messageId, messageId: messageId,
chatId: chatId!, chatId: chatId!,
content: message, content: message,
}, },
chatId: chatId!,
files: fileIds, files: fileIds,
focusMode: focusMode, focusMode: focusMode,
optimizationMode: optimizationMode, optimizationMode: optimizationMode,
history: [...chatHistory, ['human', message]], history: chatHistory,
chatModel: {
name: chatModelProvider.name,
provider: chatModelProvider.provider,
},
embeddingModel: {
name: embeddingModelProvider.name,
provider: embeddingModelProvider.provider,
},
}), }),
); );
@ -493,9 +396,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
}, },
]); ]);
const messageHandler = async (e: MessageEvent) => { const messageHandler = async (data: any) => {
const data = JSON.parse(e.data);
if (data.type === 'error') { if (data.type === 'error') {
toast.error(data.data); toast.error(data.data);
setLoading(false); setLoading(false);
@ -558,7 +459,6 @@ const ChatWindow = ({ id }: { id?: string }) => {
['assistant', recievedMessage], ['assistant', recievedMessage],
]); ]);
ws?.removeEventListener('message', messageHandler);
setLoading(false); setLoading(false);
const lastMsg = messagesRef.current[messagesRef.current.length - 1]; const lastMsg = messagesRef.current[messagesRef.current.length - 1];
@ -584,16 +484,72 @@ const ChatWindow = ({ id }: { id?: string }) => {
const autoVideoSearch = localStorage.getItem('autoVideoSearch'); const autoVideoSearch = localStorage.getItem('autoVideoSearch');
if (autoImageSearch === 'true') { if (autoImageSearch === 'true') {
document.getElementById('search-images')?.click(); document
.getElementById(`search-images-${lastMsg.messageId}`)
?.click();
} }
if (autoVideoSearch === 'true') { if (autoVideoSearch === 'true') {
document.getElementById('search-videos')?.click(); document
.getElementById(`search-videos-${lastMsg.messageId}`)
?.click();
} }
} }
}; };
ws?.addEventListener('message', messageHandler); const res = await fetch('/api/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
content: message,
message: {
messageId: messageId,
chatId: chatId!,
content: message,
},
chatId: chatId!,
files: fileIds,
focusMode: focusMode,
optimizationMode: optimizationMode,
history: chatHistory,
chatModel: {
name: chatModelProvider.name,
provider: chatModelProvider.provider,
},
embeddingModel: {
name: embeddingModelProvider.name,
provider: embeddingModelProvider.provider,
},
}),
});
if (!res.body) throw new Error('No response body');
const reader = res.body?.getReader();
const decoder = new TextDecoder('utf-8');
let partialChunk = '';
while (true) {
const { value, done } = await reader.read();
if (done) break;
partialChunk += decoder.decode(value, { stream: true });
try {
const messages = partialChunk.split('\n');
for (const msg of messages) {
if (!msg.trim()) continue;
const json = JSON.parse(msg);
messageHandler(json);
}
partialChunk = '';
} catch (error) {
console.warn('Incomplete JSON, waiting for next chunk...');
}
}
}; };
const rewrite = (messageId: string) => { const rewrite = (messageId: string) => {
@ -614,11 +570,11 @@ const ChatWindow = ({ id }: { id?: string }) => {
}; };
useEffect(() => { useEffect(() => {
if (isReady && initialMessage && ws?.readyState === 1) { if (isReady && initialMessage && isConfigReady) {
sendMessage(initialMessage); sendMessage(initialMessage);
} }
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [ws?.readyState, isReady, initialMessage, isWSReady]); }, [isConfigReady, isReady, initialMessage]);
if (hasError) { if (hasError) {
return ( return (

View File

@ -22,7 +22,7 @@ export interface EmbeddingModel {
model: Embeddings; model: Embeddings;
} }
const chatModelProviders: Record< export const chatModelProviders: Record<
string, string,
() => Promise<Record<string, ChatModel>> () => Promise<Record<string, ChatModel>>
> = { > = {
@ -30,16 +30,16 @@ const chatModelProviders: Record<
ollama: loadOllamaChatModels, ollama: loadOllamaChatModels,
groq: loadGroqChatModels, groq: loadGroqChatModels,
anthropic: loadAnthropicChatModels, anthropic: loadAnthropicChatModels,
gemini: loadGeminiChatModels gemini: loadGeminiChatModels,
}; };
const embeddingModelProviders: Record< export const embeddingModelProviders: Record<
string, string,
() => Promise<Record<string, EmbeddingModel>> () => Promise<Record<string, EmbeddingModel>>
> = { > = {
openai: loadOpenAIEmbeddingModels, openai: loadOpenAIEmbeddingModels,
ollama: loadOllamaEmbeddingModels, ollama: loadOllamaEmbeddingModels,
gemini: loadGeminiEmbeddingModels gemini: loadGeminiEmbeddingModels,
}; };
export const getAvailableChatModelProviders = async () => { export const getAvailableChatModelProviders = async () => {