diff --git a/src/app/api/chat/cancel/route.ts b/src/app/api/chat/cancel/route.ts new file mode 100644 index 0000000..c7614a9 --- /dev/null +++ b/src/app/api/chat/cancel/route.ts @@ -0,0 +1,50 @@ +import { NextRequest } from 'next/server'; + +// In-memory map to store cancel tokens by messageId +const cancelTokens: Record = {}; + +// Export for use in chat/route.ts +export function registerCancelToken( + messageId: string, + controller: AbortController, +) { + cancelTokens[messageId] = controller; +} + +export function cleanupCancelToken(messageId: string) { + var cancelled = false; + if (messageId in cancelTokens) { + delete cancelTokens[messageId]; + cancelled = true; + } + return cancelled; +} + +export function cancelRequest(messageId: string) { + const controller = cancelTokens[messageId]; + if (controller) { + try { + controller.abort(); + } catch (e) { + console.error(`Error aborting request for messageId ${messageId}:`, e); + } + return true; + } + return false; +} + +export async function POST(req: NextRequest) { + const { messageId } = await req.json(); + if (!messageId) { + return Response.json({ error: 'Missing messageId' }, { status: 400 }); + } + const cancelled = cancelRequest(messageId); + if (cancelled) { + return Response.json({ success: true }); + } else { + return Response.json( + { error: 'No in-progress request for this messageId' }, + { status: 404 }, + ); + } +} diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index b6f8b7d..c242002 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -18,6 +18,10 @@ import { ChatOpenAI } from '@langchain/openai'; import crypto from 'crypto'; import { and, eq, gt } from 'drizzle-orm'; import { EventEmitter } from 'stream'; +import { + registerCancelToken, + cleanupCancelToken, +} from './cancel/route'; export const runtime = 'nodejs'; export const dynamic = 'force-dynamic'; @@ -62,6 +66,7 @@ const handleEmitterEvents = async ( aiMessageId: string, chatId: string, startTime: number, + userMessageId: string, ) => { let recievedMessage = ''; let sources: any[] = []; @@ -139,6 +144,9 @@ const handleEmitterEvents = async ( ); writer.close(); + // Clean up the abort controller reference + cleanupCancelToken(userMessageId); + db.insert(messagesSchema) .values({ content: recievedMessage, @@ -329,6 +337,28 @@ export const POST = async (req: Request) => { ); } + const responseStream = new TransformStream(); + const writer = responseStream.writable.getWriter(); + const encoder = new TextEncoder(); + + // --- Cancellation logic --- + const abortController = new AbortController(); + registerCancelToken(message.messageId, abortController); + + abortController.signal.addEventListener('abort', () => { + console.log('Stream aborted, sending cancel event'); + writer.write( + encoder.encode( + JSON.stringify({ + type: 'error', + data: 'Request cancelled by user', + }), + ), + ); + cleanupCancelToken(message.messageId); + }); + + // Pass the abort signal to the search handler const stream = await handler.searchAndAnswer( message.content, history, @@ -337,12 +367,9 @@ export const POST = async (req: Request) => { body.optimizationMode, body.files, body.systemInstructions, + abortController.signal, ); - const responseStream = new TransformStream(); - const writer = responseStream.writable.getWriter(); - const encoder = new TextEncoder(); - handleEmitterEvents( stream, writer, @@ -350,7 +377,9 @@ export const POST = async (req: Request) => { aiMessageId, message.chatId, startTime, + message.messageId, ); + handleHistorySave(message, humanMessageId, body.focusMode, body.files); return new Response(responseStream.readable, { diff --git a/src/app/api/search/route.ts b/src/app/api/search/route.ts index c4f6970..96e9814 100644 --- a/src/app/api/search/route.ts +++ b/src/app/api/search/route.ts @@ -124,6 +124,8 @@ export const POST = async (req: Request) => { if (!searchHandler) { return Response.json({ message: 'Invalid focus mode' }, { status: 400 }); } + const abortController = new AbortController(); + const { signal } = abortController; const emitter = await searchHandler.searchAndAnswer( body.query, @@ -133,6 +135,7 @@ export const POST = async (req: Request) => { body.optimizationMode, [], body.systemInstructions || '', + signal, ); if (!body.stream) { @@ -180,9 +183,6 @@ export const POST = async (req: Request) => { const encoder = new TextEncoder(); - const abortController = new AbortController(); - const { signal } = abortController; - const stream = new ReadableStream({ start(controller) { let sources: any[] = []; diff --git a/src/components/Chat.tsx b/src/components/Chat.tsx index e8f3a05..fd6ac18 100644 --- a/src/components/Chat.tsx +++ b/src/components/Chat.tsx @@ -50,6 +50,9 @@ const Chat = ({ const messageEnd = useRef(null); const containerRef = useRef(null); const SCROLL_THRESHOLD = 250; // pixels from bottom to consider "at bottom" + const [currentMessageId, setCurrentMessageId] = useState( + undefined, + ); // Check if user is at bottom of page useEffect(() => { @@ -166,6 +169,33 @@ const Chat = ({ }; }, []); + // Track the last user messageId when loading starts + useEffect(() => { + if (loading) { + // Find the last user message + const lastUserMsg = [...messages] + .reverse() + .find((m) => m.role === 'user'); + setCurrentMessageId(lastUserMsg?.messageId); + } else { + setCurrentMessageId(undefined); + } + }, [loading, messages]); + + // Cancel handler + const handleCancel = async () => { + if (!currentMessageId) return; + try { + await fetch('/api/chat/cancel', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ messageId: currentMessageId }), + }); + } catch (e) { + // Optionally handle error + } + }; + return (
{messages.map((msg, i) => { @@ -234,6 +264,7 @@ const Chat = ({ setOptimizationMode={setOptimizationMode} focusMode={focusMode} setFocusMode={setFocusMode} + onCancel={handleCancel} />
diff --git a/src/components/MessageBox.tsx b/src/components/MessageBox.tsx index 312a4c3..793ff8f 100644 --- a/src/components/MessageBox.tsx +++ b/src/components/MessageBox.tsx @@ -67,6 +67,7 @@ const MessageBox = ({ className="w-full p-3 text-lg bg-light-100 dark:bg-dark-100 rounded-lg border border-light-secondary dark:border-dark-secondary text-black dark:text-white focus:outline-none focus:border-[#24A0ED] transition duration-200 min-h-[120px] font-medium" value={editedContent} onChange={(e) => setEditedContent(e.target.value)} + placeholder="Edit your message..." autoFocus />
diff --git a/src/components/MessageInput.tsx b/src/components/MessageInput.tsx index 2bd0dec..3a1fd26 100644 --- a/src/components/MessageInput.tsx +++ b/src/components/MessageInput.tsx @@ -1,4 +1,4 @@ -import { ArrowRight, ArrowUp } from 'lucide-react'; +import { ArrowRight, ArrowUp, Square } from 'lucide-react'; import { useEffect, useRef, useState } from 'react'; import TextareaAutosize from 'react-textarea-autosize'; import { File } from './ChatWindow'; @@ -19,6 +19,7 @@ const MessageInput = ({ focusMode, setFocusMode, firstMessage, + onCancel, }: { sendMessage: (message: string) => void; loading: boolean; @@ -31,6 +32,7 @@ const MessageInput = ({ focusMode: string; setFocusMode: (mode: string) => void; firstMessage: boolean; + onCancel?: () => void; }) => { const [message, setMessage] = useState(''); const [selectedModel, setSelectedModel] = useState<{ @@ -129,17 +131,33 @@ const MessageInput = ({ optimizationMode={optimizationMode} setOptimizationMode={setOptimizationMode} /> - + {loading ? ( + + ) : ( + + )}
diff --git a/src/lib/search/metaSearchAgent.ts b/src/lib/search/metaSearchAgent.ts index 77adfff..b21a1e4 100644 --- a/src/lib/search/metaSearchAgent.ts +++ b/src/lib/search/metaSearchAgent.ts @@ -1,6 +1,7 @@ -import { ChatOpenAI } from '@langchain/openai'; -import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { BaseMessage } from '@langchain/core/messages'; +import { StringOutputParser } from '@langchain/core/output_parsers'; import { ChatPromptTemplate, MessagesPlaceholder, @@ -11,19 +12,18 @@ import { RunnableMap, RunnableSequence, } from '@langchain/core/runnables'; -import { BaseMessage } from '@langchain/core/messages'; -import { StringOutputParser } from '@langchain/core/output_parsers'; -import LineListOutputParser from '../outputParsers/listLineOutputParser'; -import LineOutputParser from '../outputParsers/lineOutputParser'; -import { getDocumentsFromLinks } from '../utils/documents'; -import { Document } from 'langchain/document'; -import { searchSearxng } from '../searxng'; -import path from 'node:path'; -import fs from 'node:fs'; -import computeSimilarity from '../utils/computeSimilarity'; -import formatChatHistoryAsString from '../utils/formatHistory'; -import eventEmitter from 'events'; import { StreamEvent } from '@langchain/core/tracers/log_stream'; +import { ChatOpenAI } from '@langchain/openai'; +import eventEmitter from 'events'; +import { Document } from 'langchain/document'; +import fs from 'node:fs'; +import path from 'node:path'; +import LineOutputParser from '../outputParsers/lineOutputParser'; +import LineListOutputParser from '../outputParsers/listLineOutputParser'; +import { searchSearxng } from '../searxng'; +import computeSimilarity from '../utils/computeSimilarity'; +import { getDocumentsFromLinks } from '../utils/documents'; +import formatChatHistoryAsString from '../utils/formatHistory'; export interface MetaSearchAgentType { searchAndAnswer: ( @@ -34,6 +34,7 @@ export interface MetaSearchAgentType { optimizationMode: 'speed' | 'balanced' | 'quality', fileIds: string[], systemInstructions: string, + signal: AbortSignal, ) => Promise; } @@ -247,6 +248,7 @@ class MetaSearchAgent implements MetaSearchAgentType { embeddings: Embeddings, optimizationMode: 'speed' | 'balanced' | 'quality', systemInstructions: string, + signal: AbortSignal, ) { return RunnableSequence.from([ RunnableMap.from({ @@ -254,43 +256,58 @@ class MetaSearchAgent implements MetaSearchAgentType { query: (input: BasicChainInput) => input.query, chat_history: (input: BasicChainInput) => input.chat_history, date: () => new Date().toISOString(), - context: RunnableLambda.from(async (input: BasicChainInput) => { - const processedHistory = formatChatHistoryAsString( - input.chat_history, - ); - - let docs: Document[] | null = null; - let query = input.query; - - if (this.config.searchWeb) { - const searchRetrieverChain = - await this.createSearchRetrieverChain(llm); - var date = new Date().toISOString(); - const searchRetrieverResult = await searchRetrieverChain.invoke({ - chat_history: processedHistory, - query, - date, - }); - - query = searchRetrieverResult.query; - docs = searchRetrieverResult.docs; - - // Store the search query in the context for emitting to the client - if (searchRetrieverResult.searchQuery) { - this.searchQuery = searchRetrieverResult.searchQuery; + context: RunnableLambda.from( + async ( + input: BasicChainInput, + options?: { signal?: AbortSignal }, + ) => { + // Check if the request was aborted + if (options?.signal?.aborted || signal?.aborted) { + console.log('Request cancelled by user'); + throw new Error('Request cancelled by user'); } - } - const sortedDocs = await this.rerankDocs( - query, - docs ?? [], - fileIds, - embeddings, - optimizationMode, - ); + const processedHistory = formatChatHistoryAsString( + input.chat_history, + ); - return sortedDocs; - }) + let docs: Document[] | null = null; + let query = input.query; + + if (this.config.searchWeb) { + const searchRetrieverChain = + await this.createSearchRetrieverChain(llm); + var date = new Date().toISOString(); + + const searchRetrieverResult = await searchRetrieverChain.invoke( + { + chat_history: processedHistory, + query, + date, + }, + { signal: options?.signal }, + ); + + query = searchRetrieverResult.query; + docs = searchRetrieverResult.docs; + + // Store the search query in the context for emitting to the client + if (searchRetrieverResult.searchQuery) { + this.searchQuery = searchRetrieverResult.searchQuery; + } + } + + const sortedDocs = await this.rerankDocs( + query, + docs ?? [], + fileIds, + embeddings, + optimizationMode, + ); + + return sortedDocs; + }, + ) .withConfig({ runName: 'FinalSourceRetriever', }) @@ -450,8 +467,17 @@ class MetaSearchAgent implements MetaSearchAgentType { stream: AsyncGenerator, emitter: eventEmitter, llm: BaseChatModel, + signal: AbortSignal, ) { + if (signal.aborted) { + return; + } + for await (const event of stream) { + if (signal.aborted) { + return; + } + if ( event.event === 'on_chain_end' && event.name === 'FinalSourceRetriever' @@ -544,6 +570,7 @@ class MetaSearchAgent implements MetaSearchAgentType { optimizationMode: 'speed' | 'balanced' | 'quality', fileIds: string[], systemInstructions: string, + signal: AbortSignal, ) { const emitter = new eventEmitter(); @@ -553,6 +580,7 @@ class MetaSearchAgent implements MetaSearchAgentType { embeddings, optimizationMode, systemInstructions, + signal, ); const stream = answeringChain.streamEvents( @@ -562,10 +590,12 @@ class MetaSearchAgent implements MetaSearchAgentType { }, { version: 'v1', + // Pass the abort signal to the LLM streaming chain + signal, }, ); - this.handleStream(stream, emitter, llm); + this.handleStream(stream, emitter, llm, signal); return emitter; }