Compare commits

...

10 Commits

Author SHA1 Message Date
ItzCrazyKns
c3b74a3fd0 feat(assistant-steps): only open last comp 2025-12-23 17:17:56 +05:30
ItzCrazyKns
5f04034650 feat(chat-hook): handle reconnect 2025-12-23 17:17:19 +05:30
ItzCrazyKns
5847379db0 Update types.ts 2025-12-23 17:15:46 +05:30
ItzCrazyKns
8520ea6fe5 feat(researcher): emit sources as block 2025-12-23 17:15:42 +05:30
ItzCrazyKns
a6d4f47130 feat(search-agent): save history 2025-12-23 17:15:32 +05:30
ItzCrazyKns
f278eb8bf1 feat(routes): add reconnect route 2025-12-23 17:15:02 +05:30
ItzCrazyKns
0e176e0b78 feat(chat-route): add history saving, disconnect on abort, use subscribe method 2025-12-23 17:14:02 +05:30
ItzCrazyKns
8ba64be446 feat(session): fix sessions getting disregarded due to reload 2025-12-23 17:12:56 +05:30
ItzCrazyKns
216332fb20 feat(session): add subscribe method, getAllBlocks 2025-12-23 17:12:15 +05:30
ItzCrazyKns
68a9e048ac feat(schema): change focusMode to sources 2025-12-23 17:11:38 +05:30
10 changed files with 416 additions and 187 deletions

View File

@@ -5,6 +5,10 @@ import SearchAgent from '@/lib/agents/search';
import SessionManager from '@/lib/session';
import { ChatTurnMessage } from '@/lib/types';
import { SearchSources } from '@/lib/agents/search/types';
import db from '@/lib/db';
import { eq } from 'drizzle-orm';
import { chats } from '@/lib/db/schema';
import UploadManager from '@/lib/uploads/manager';
export const runtime = 'nodejs';
export const dynamic = 'force-dynamic';
@@ -64,6 +68,38 @@ const safeValidateBody = (data: unknown) => {
};
};
const ensureChatExists = async (input: {
id: string;
sources: SearchSources[];
query: string;
fileIds: string[];
}) => {
try {
const exists = await db.query.chats
.findFirst({
where: eq(chats.id, input.id),
})
.execute();
if (!exists) {
await db.insert(chats).values({
id: input.id,
createdAt: new Date().toISOString(),
sources: input.sources,
title: input.query,
files: input.fileIds.map((id) => {
return {
fileId: id,
name: UploadManager.getFile(id)?.name || 'Uploaded File',
};
}),
});
}
} catch (err) {
console.error('Failed to check/save chat:', err);
}
};
export const POST = async (req: Request) => {
try {
const reqBody = (await req.json()) as Body;
@@ -120,29 +156,9 @@ export const POST = async (req: Request) => {
const writer = responseStream.writable.getWriter();
const encoder = new TextEncoder();
let receivedMessage = '';
session.addListener('data', (data: any) => {
if (data.type === 'response') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'message',
data: data.data,
}) + '\n',
),
);
receivedMessage += data.data;
} else if (data.type === 'sources') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'sources',
data: data.data,
}) + '\n',
),
);
} else if (data.type === 'block') {
const disconnect = session.subscribe((event: string, data: any) => {
if (event === 'data') {
if (data.type === 'block') {
writer.write(
encoder.encode(
JSON.stringify({
@@ -170,9 +186,7 @@ export const POST = async (req: Request) => {
),
);
}
});
session.addListener('end', () => {
} else if (event === 'end') {
writer.write(
encoder.encode(
JSON.stringify({
@@ -182,9 +196,7 @@ export const POST = async (req: Request) => {
);
writer.close();
session.removeAllListeners();
});
session.addListener('error', (data: any) => {
} else if (event === 'error') {
writer.write(
encoder.encode(
JSON.stringify({
@@ -195,11 +207,14 @@ export const POST = async (req: Request) => {
);
writer.close();
session.removeAllListeners();
}
});
agent.searchAsync(session, {
chatHistory: history,
followUp: message.content,
chatId: body.message.chatId,
messageId: body.message.messageId,
config: {
llm,
embedding: embedding,
@@ -209,7 +224,17 @@ export const POST = async (req: Request) => {
},
});
/* handleHistorySave(message, humanMessageId, body.focusMode, body.files); */
ensureChatExists({
id: body.message.chatId,
sources: body.sources as SearchSources[],
fileIds: body.files,
query: body.message.content,
});
req.signal.addEventListener('abort', () => {
disconnect();
writer.close();
});
return new Response(responseStream.readable, {
headers: {

View File

@@ -0,0 +1,93 @@
import SessionManager from '@/lib/session';
export const POST = async (
req: Request,
{ params }: { params: Promise<{ id: string }> },
) => {
try {
const { id } = await params;
const session = SessionManager.getSession(id);
if (!session) {
return Response.json({ message: 'Session not found' }, { status: 404 });
}
const responseStream = new TransformStream();
const writer = responseStream.writable.getWriter();
const encoder = new TextEncoder();
const disconnect = session.subscribe((event, data) => {
if (event === 'data') {
if (data.type === 'block') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'block',
block: data.block,
}) + '\n',
),
);
} else if (data.type === 'updateBlock') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'updateBlock',
blockId: data.blockId,
patch: data.patch,
}) + '\n',
),
);
} else if (data.type === 'researchComplete') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'researchComplete',
}) + '\n',
),
);
}
} else if (event === 'end') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'messageEnd',
}) + '\n',
),
);
writer.close();
disconnect();
} else if (event === 'error') {
writer.write(
encoder.encode(
JSON.stringify({
type: 'error',
data: data.data,
}) + '\n',
),
);
writer.close();
disconnect();
}
});
req.signal.addEventListener('abort', () => {
disconnect();
writer.close();
});
return new Response(responseStream.readable, {
headers: {
'Content-Type': 'text/event-stream',
Connection: 'keep-alive',
'Cache-Control': 'no-cache, no-transform',
},
});
} catch (err) {
console.error('Error in reconnecting to session stream: ', err);
return Response.json(
{ message: 'An error has occurred.' },
{ status: 500 },
);
}
};

View File

@@ -54,17 +54,21 @@ const getStepTitle = (
const AssistantSteps = ({
block,
status,
isLast,
}: {
block: ResearchBlock;
status: 'answering' | 'completed' | 'error';
isLast: boolean;
}) => {
const [isExpanded, setIsExpanded] = useState(true);
const [isExpanded, setIsExpanded] = useState(
isLast && status === 'answering' ? true : false,
);
const { researchEnded, loading } = useChat();
useEffect(() => {
if (researchEnded) {
if (researchEnded && isLast) {
setIsExpanded(false);
} else if (status === 'answering') {
} else if (status === 'answering' && isLast) {
setIsExpanded(true);
}
}, [researchEnded, status]);

View File

@@ -131,6 +131,7 @@ const MessageBox = ({
<AssistantSteps
block={researchBlock}
status={section.message.status}
isLast={isLast}
/>
</div>
))}

View File

@@ -4,9 +4,53 @@ import { classify } from './classifier';
import Researcher from './researcher';
import { getWriterPrompt } from '@/lib/prompts/search/writer';
import { WidgetExecutor } from './widgets';
import db from '@/lib/db';
import { chats, messages } from '@/lib/db/schema';
import { and, eq, gt } from 'drizzle-orm';
import { TextBlock } from '@/lib/types';
class SearchAgent {
async searchAsync(session: SessionManager, input: SearchAgentInput) {
const exists = await db.query.messages.findFirst({
where: and(
eq(messages.chatId, input.chatId),
eq(messages.messageId, input.messageId),
),
});
if (!exists) {
await db.insert(messages).values({
chatId: input.chatId,
messageId: input.messageId,
backendId: session.id,
query: input.followUp,
createdAt: new Date().toISOString(),
status: 'answering',
responseBlocks: [],
});
} else {
await db
.delete(messages)
.where(
and(eq(messages.chatId, input.chatId), gt(messages.id, exists.id)),
)
.execute();
await db
.update(messages)
.set({
status: 'answering',
backendId: session.id,
responseBlocks: [],
})
.where(
and(
eq(messages.chatId, input.chatId),
eq(messages.messageId, input.messageId),
),
)
.execute();
}
const classification = await classify({
chatHistory: input.chatHistory,
enabledSources: input.config.sources,
@@ -85,18 +129,41 @@ class SearchAgent {
],
});
let accumulatedText = '';
const block: TextBlock = {
id: crypto.randomUUID(),
type: 'text',
data: '',
};
session.emitBlock(block);
for await (const chunk of answerStream) {
accumulatedText += chunk.contentChunk;
block.data += chunk.contentChunk;
session.emit('data', {
type: 'response',
data: chunk.contentChunk,
});
session.updateBlock(block.id, [
{
op: 'replace',
path: '/data',
value: block.data,
},
]);
}
session.emit('end', {});
await db
.update(messages)
.set({
status: 'completed',
responseBlocks: session.getAllBlocks(),
})
.where(
and(
eq(messages.chatId, input.chatId),
eq(messages.messageId, input.messageId),
),
)
.execute();
}
}

View File

@@ -206,8 +206,9 @@ class Researcher {
})
.filter((r) => r !== undefined);
session.emit('data', {
type: 'sources',
session.emitBlock({
id: crypto.randomUUID(),
type: 'source',
data: filteredSearchResults,
});

View File

@@ -18,6 +18,8 @@ export type SearchAgentInput = {
chatHistory: ChatTurnMessage[];
followUp: string;
config: SearchAgentConfig;
chatId: string;
messageId: string;
};
export type WidgetInput = {

View File

@@ -1,6 +1,7 @@
import { sql } from 'drizzle-orm';
import { text, integer, sqliteTable } from 'drizzle-orm/sqlite-core';
import { Block } from '../types';
import { SearchSources } from '../agents/search/types';
export const messages = sqliteTable('messages', {
id: integer('id').primaryKey(),
@@ -26,7 +27,11 @@ export const chats = sqliteTable('chats', {
id: text('id').primaryKey(),
title: text('title').notNull(),
createdAt: text('createdAt').notNull(),
focusMode: text('focusMode').notNull(),
sources: text('sources', {
mode: 'json',
})
.$type<SearchSources[]>()
.default(sql`'[]'`),
files: text('files', { mode: 'json' })
.$type<DBFile[]>()
.default(sql`'[]'`),

View File

@@ -401,6 +401,50 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
});
}, [messages]);
const checkReconnect = async () => {
if (messages.length > 0) {
const lastMsg = messages[messages.length - 1];
if (lastMsg.status === 'answering') {
setLoading(true);
setResearchEnded(false);
setMessageAppeared(false);
const res = await fetch(`/api/reconnect/${lastMsg.backendId}`, {
method: 'POST',
});
if (!res.body) throw new Error('No response body');
const reader = res.body?.getReader();
const decoder = new TextDecoder('utf-8');
let partialChunk = '';
const messageHandler = getMessageHandler(lastMsg);
while (true) {
const { value, done } = await reader.read();
if (done) break;
partialChunk += decoder.decode(value, { stream: true });
try {
const messages = partialChunk.split('\n');
for (const msg of messages) {
if (!msg.trim()) continue;
const json = JSON.parse(msg);
messageHandler(json);
}
partialChunk = '';
} catch (error) {
console.warn('Incomplete JSON, waiting for next chunk...');
}
}
}
}
};
useEffect(() => {
checkConfig(
setChatModelProvider,
@@ -454,13 +498,22 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
}, [messages]);
useEffect(() => {
if (isMessagesLoaded && isConfigReady) {
if (isMessagesLoaded && isConfigReady && newChatCreated) {
setIsReady(true);
console.debug(new Date(), 'app:ready');
} else if (isMessagesLoaded && isConfigReady && !newChatCreated) {
checkReconnect()
.then(() => {
setIsReady(true);
console.debug(new Date(), 'app:ready');
})
.catch((err) => {
console.error('Error during reconnect:', err);
});
} else {
setIsReady(false);
}
}, [isMessagesLoaded, isConfigReady]);
}, [isMessagesLoaded, isConfigReady, newChatCreated]);
const rewrite = (messageId: string) => {
const index = messages.findIndex((msg) => msg.messageId === messageId);
@@ -488,38 +541,10 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [isConfigReady, isReady, initialMessage]);
const sendMessage: ChatContext['sendMessage'] = async (
message,
messageId,
rewrite = false,
) => {
if (loading || !message) return;
setLoading(true);
setResearchEnded(false);
setMessageAppeared(false);
const getMessageHandler = (message: Message) => {
const messageId = message.messageId;
if (messages.length <= 1) {
window.history.replaceState(null, '', `/c/${chatId}`);
}
messageId = messageId ?? crypto.randomBytes(7).toString('hex');
const backendId = crypto.randomBytes(20).toString('hex');
const newMessage: Message = {
messageId,
chatId: chatId!,
backendId,
query: message,
responseBlocks: [],
status: 'answering',
createdAt: new Date(),
};
setMessages((prevMessages) => [...prevMessages, newMessage]);
const receivedTextRef = { current: '' };
const messageHandler = async (data: any) => {
return async (data: any) => {
if (data.type === 'error') {
toast.error(data.data);
setLoading(false);
@@ -536,7 +561,7 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
if (data.type === 'researchComplete') {
setResearchEnded(true);
if (
newMessage.responseBlocks.find(
message.responseBlocks.find(
(b) => b.type === 'source' && b.data.length > 0,
)
) {
@@ -556,6 +581,13 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
return msg;
}),
);
if (
(data.block.type === 'source' && data.block.data.length > 0) ||
data.block.type === 'text'
) {
setMessageAppeared(true);
}
}
if (data.type === 'updateBlock') {
@@ -577,72 +609,19 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
);
}
if (data.type === 'sources') {
const sourceBlock: Block = {
id: crypto.randomBytes(7).toString('hex'),
type: 'source',
data: data.data,
};
setMessages((prev) =>
prev.map((msg) => {
if (msg.messageId === messageId) {
return {
...msg,
responseBlocks: [...msg.responseBlocks, sourceBlock],
};
}
return msg;
}),
);
if (data.data.length > 0) {
setMessageAppeared(true);
}
}
if (data.type === 'message') {
receivedTextRef.current += data.data;
setMessages((prev) =>
prev.map((msg) => {
if (msg.messageId === messageId) {
const existingTextBlockIndex = msg.responseBlocks.findIndex(
(b) => b.type === 'text',
);
if (existingTextBlockIndex >= 0) {
const updatedBlocks = [...msg.responseBlocks];
const existingBlock = updatedBlocks[
existingTextBlockIndex
] as Block & { type: 'text' };
updatedBlocks[existingTextBlockIndex] = {
...existingBlock,
data: existingBlock.data + data.data,
};
return { ...msg, responseBlocks: updatedBlocks };
} else {
const textBlock: Block = {
id: crypto.randomBytes(7).toString('hex'),
type: 'text',
data: data.data,
};
return {
...msg,
responseBlocks: [...msg.responseBlocks, textBlock],
};
}
}
return msg;
}),
);
setMessageAppeared(true);
}
if (data.type === 'messageEnd') {
const currentMsg = messagesRef.current.find(
(msg) => msg.messageId === messageId,
);
const newHistory: [string, string][] = [
...chatHistory,
['human', message],
['assistant', receivedTextRef.current],
['human', message.query],
[
'assistant',
currentMsg?.responseBlocks.find((b) => b.type === 'text')?.data ||
'',
],
];
setChatHistory(newHistory);
@@ -672,9 +651,6 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
}
// Check if there are sources and no suggestions
const currentMsg = messagesRef.current.find(
(msg) => msg.messageId === messageId,
);
const hasSourceBlocks = currentMsg?.responseBlocks.some(
(block) => block.type === 'source' && block.data.length > 0,
@@ -705,6 +681,36 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
}
}
};
};
const sendMessage: ChatContext['sendMessage'] = async (
message,
messageId,
rewrite = false,
) => {
if (loading || !message) return;
setLoading(true);
setResearchEnded(false);
setMessageAppeared(false);
if (messages.length <= 1) {
window.history.replaceState(null, '', `/c/${chatId}`);
}
messageId = messageId ?? crypto.randomBytes(7).toString('hex');
const backendId = crypto.randomBytes(20).toString('hex');
const newMessage: Message = {
messageId,
chatId: chatId!,
backendId,
query: message,
responseBlocks: [],
status: 'answering',
createdAt: new Date(),
};
setMessages((prevMessages) => [...prevMessages, newMessage]);
const messageIndex = messages.findIndex((m) => m.messageId === messageId);
@@ -746,6 +752,8 @@ export const ChatProvider = ({ children }: { children: React.ReactNode }) => {
let partialChunk = '';
const messageHandler = getMessageHandler(newMessage);
while (true) {
const { value, done } = await reader.read();
if (done) break;

View File

@@ -2,8 +2,14 @@ import { EventEmitter } from 'stream';
import { applyPatch } from 'rfc6902';
import { Block } from './types';
const sessions =
(global as any)._sessionManagerSessions || new Map<string, SessionManager>();
if (process.env.NODE_ENV !== 'production') {
(global as any)._sessionManagerSessions = sessions;
}
class SessionManager {
private static sessions = new Map<string, SessionManager>();
private static sessions: Map<string, SessionManager> = sessions;
readonly id: string;
private blocks = new Map<string, Block>();
private events: { event: string; data: any }[] = [];
@@ -67,15 +73,32 @@ class SessionManager {
}
}
addListener(event: string, listener: (data: any) => void) {
this.emitter.addListener(event, listener);
getAllBlocks() {
return Array.from(this.blocks.values());
}
replay() {
for (const { event, data } of this.events) {
/* Using emitter directly to avoid infinite loop */
this.emitter.emit(event, data);
subscribe(listener: (event: string, data: any) => void): () => void {
const currentEventsLength = this.events.length;
const handler = (event: string) => (data: any) => listener(event, data);
const dataHandler = handler('data');
const endHandler = handler('end');
const errorHandler = handler('error');
this.emitter.on('data', dataHandler);
this.emitter.on('end', endHandler);
this.emitter.on('error', errorHandler);
for (let i = 0; i < currentEventsLength; i++) {
const { event, data } = this.events[i];
listener(event, data);
}
return () => {
this.emitter.off('data', dataHandler);
this.emitter.off('end', endHandler);
this.emitter.off('error', errorHandler);
};
}
}