diff --git a/src/lib/agents/search/index.ts b/src/lib/agents/search/index.ts index 26fc13d..1ecfe51 100644 --- a/src/lib/agents/search/index.ts +++ b/src/lib/agents/search/index.ts @@ -4,9 +4,53 @@ import { classify } from './classifier'; import Researcher from './researcher'; import { getWriterPrompt } from '@/lib/prompts/search/writer'; import { WidgetExecutor } from './widgets'; +import db from '@/lib/db'; +import { chats, messages } from '@/lib/db/schema'; +import { and, eq, gt } from 'drizzle-orm'; +import { TextBlock } from '@/lib/types'; class SearchAgent { async searchAsync(session: SessionManager, input: SearchAgentInput) { + const exists = await db.query.messages.findFirst({ + where: and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), + }); + + if (!exists) { + await db.insert(messages).values({ + chatId: input.chatId, + messageId: input.messageId, + backendId: session.id, + query: input.followUp, + createdAt: new Date().toISOString(), + status: 'answering', + responseBlocks: [], + }); + } else { + await db + .delete(messages) + .where( + and(eq(messages.chatId, input.chatId), gt(messages.id, exists.id)), + ) + .execute(); + await db + .update(messages) + .set({ + status: 'answering', + backendId: session.id, + responseBlocks: [], + }) + .where( + and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), + ) + .execute(); + } + const classification = await classify({ chatHistory: input.chatHistory, enabledSources: input.config.sources, @@ -85,18 +129,41 @@ class SearchAgent { ], }); - let accumulatedText = ''; + const block: TextBlock = { + id: crypto.randomUUID(), + type: 'text', + data: '', + }; + + session.emitBlock(block); for await (const chunk of answerStream) { - accumulatedText += chunk.contentChunk; + block.data += chunk.contentChunk; - session.emit('data', { - type: 'response', - data: chunk.contentChunk, - }); + session.updateBlock(block.id, [ + { + op: 'replace', + path: '/data', + value: block.data, + }, + ]); } session.emit('end', {}); + + await db + .update(messages) + .set({ + status: 'completed', + responseBlocks: session.getAllBlocks(), + }) + .where( + and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), + ) + .execute(); } }