feat(app): allow stopping requests

This commit is contained in:
Willie Zutz
2025-05-14 12:19:22 -06:00
parent 380216e062
commit 9c7ccf42fc
7 changed files with 227 additions and 68 deletions

View File

@@ -0,0 +1,50 @@
import { NextRequest } from 'next/server';
// In-memory map to store cancel tokens by messageId
const cancelTokens: Record<string, AbortController> = {};
// Export for use in chat/route.ts
export function registerCancelToken(
messageId: string,
controller: AbortController,
) {
cancelTokens[messageId] = controller;
}
export function cleanupCancelToken(messageId: string) {
var cancelled = false;
if (messageId in cancelTokens) {
delete cancelTokens[messageId];
cancelled = true;
}
return cancelled;
}
export function cancelRequest(messageId: string) {
const controller = cancelTokens[messageId];
if (controller) {
try {
controller.abort();
} catch (e) {
console.error(`Error aborting request for messageId ${messageId}:`, e);
}
return true;
}
return false;
}
export async function POST(req: NextRequest) {
const { messageId } = await req.json();
if (!messageId) {
return Response.json({ error: 'Missing messageId' }, { status: 400 });
}
const cancelled = cancelRequest(messageId);
if (cancelled) {
return Response.json({ success: true });
} else {
return Response.json(
{ error: 'No in-progress request for this messageId' },
{ status: 404 },
);
}
}

View File

@@ -18,6 +18,10 @@ import { ChatOpenAI } from '@langchain/openai';
import crypto from 'crypto';
import { and, eq, gt } from 'drizzle-orm';
import { EventEmitter } from 'stream';
import {
registerCancelToken,
cleanupCancelToken,
} from './cancel/route';
export const runtime = 'nodejs';
export const dynamic = 'force-dynamic';
@@ -62,6 +66,7 @@ const handleEmitterEvents = async (
aiMessageId: string,
chatId: string,
startTime: number,
userMessageId: string,
) => {
let recievedMessage = '';
let sources: any[] = [];
@@ -139,6 +144,9 @@ const handleEmitterEvents = async (
);
writer.close();
// Clean up the abort controller reference
cleanupCancelToken(userMessageId);
db.insert(messagesSchema)
.values({
content: recievedMessage,
@@ -329,6 +337,28 @@ export const POST = async (req: Request) => {
);
}
const responseStream = new TransformStream();
const writer = responseStream.writable.getWriter();
const encoder = new TextEncoder();
// --- Cancellation logic ---
const abortController = new AbortController();
registerCancelToken(message.messageId, abortController);
abortController.signal.addEventListener('abort', () => {
console.log('Stream aborted, sending cancel event');
writer.write(
encoder.encode(
JSON.stringify({
type: 'error',
data: 'Request cancelled by user',
}),
),
);
cleanupCancelToken(message.messageId);
});
// Pass the abort signal to the search handler
const stream = await handler.searchAndAnswer(
message.content,
history,
@@ -337,12 +367,9 @@ export const POST = async (req: Request) => {
body.optimizationMode,
body.files,
body.systemInstructions,
abortController.signal,
);
const responseStream = new TransformStream();
const writer = responseStream.writable.getWriter();
const encoder = new TextEncoder();
handleEmitterEvents(
stream,
writer,
@@ -350,7 +377,9 @@ export const POST = async (req: Request) => {
aiMessageId,
message.chatId,
startTime,
message.messageId,
);
handleHistorySave(message, humanMessageId, body.focusMode, body.files);
return new Response(responseStream.readable, {

View File

@@ -124,6 +124,8 @@ export const POST = async (req: Request) => {
if (!searchHandler) {
return Response.json({ message: 'Invalid focus mode' }, { status: 400 });
}
const abortController = new AbortController();
const { signal } = abortController;
const emitter = await searchHandler.searchAndAnswer(
body.query,
@@ -133,6 +135,7 @@ export const POST = async (req: Request) => {
body.optimizationMode,
[],
body.systemInstructions || '',
signal,
);
if (!body.stream) {
@@ -180,9 +183,6 @@ export const POST = async (req: Request) => {
const encoder = new TextEncoder();
const abortController = new AbortController();
const { signal } = abortController;
const stream = new ReadableStream({
start(controller) {
let sources: any[] = [];

View File

@@ -50,6 +50,9 @@ const Chat = ({
const messageEnd = useRef<HTMLDivElement | null>(null);
const containerRef = useRef<HTMLDivElement | null>(null);
const SCROLL_THRESHOLD = 250; // pixels from bottom to consider "at bottom"
const [currentMessageId, setCurrentMessageId] = useState<string | undefined>(
undefined,
);
// Check if user is at bottom of page
useEffect(() => {
@@ -166,6 +169,33 @@ const Chat = ({
};
}, []);
// Track the last user messageId when loading starts
useEffect(() => {
if (loading) {
// Find the last user message
const lastUserMsg = [...messages]
.reverse()
.find((m) => m.role === 'user');
setCurrentMessageId(lastUserMsg?.messageId);
} else {
setCurrentMessageId(undefined);
}
}, [loading, messages]);
// Cancel handler
const handleCancel = async () => {
if (!currentMessageId) return;
try {
await fetch('/api/chat/cancel', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ messageId: currentMessageId }),
});
} catch (e) {
// Optionally handle error
}
};
return (
<div ref={containerRef} className="space-y-6 pt-8 pb-48 sm:mx-4 md:mx-8">
{messages.map((msg, i) => {
@@ -234,6 +264,7 @@ const Chat = ({
setOptimizationMode={setOptimizationMode}
focusMode={focusMode}
setFocusMode={setFocusMode}
onCancel={handleCancel}
/>
</div>
<div ref={messageEnd} className="h-0" />

View File

@@ -67,6 +67,7 @@ const MessageBox = ({
className="w-full p-3 text-lg bg-light-100 dark:bg-dark-100 rounded-lg border border-light-secondary dark:border-dark-secondary text-black dark:text-white focus:outline-none focus:border-[#24A0ED] transition duration-200 min-h-[120px] font-medium"
value={editedContent}
onChange={(e) => setEditedContent(e.target.value)}
placeholder="Edit your message..."
autoFocus
/>
<div className="flex flex-row space-x-2 mt-3 justify-end">

View File

@@ -1,4 +1,4 @@
import { ArrowRight, ArrowUp } from 'lucide-react';
import { ArrowRight, ArrowUp, Square } from 'lucide-react';
import { useEffect, useRef, useState } from 'react';
import TextareaAutosize from 'react-textarea-autosize';
import { File } from './ChatWindow';
@@ -19,6 +19,7 @@ const MessageInput = ({
focusMode,
setFocusMode,
firstMessage,
onCancel,
}: {
sendMessage: (message: string) => void;
loading: boolean;
@@ -31,6 +32,7 @@ const MessageInput = ({
focusMode: string;
setFocusMode: (mode: string) => void;
firstMessage: boolean;
onCancel?: () => void;
}) => {
const [message, setMessage] = useState('');
const [selectedModel, setSelectedModel] = useState<{
@@ -129,17 +131,33 @@ const MessageInput = ({
optimizationMode={optimizationMode}
setOptimizationMode={setOptimizationMode}
/>
<button
disabled={message.trim().length === 0}
className="bg-[#24A0ED] text-white disabled:text-black/50 dark:disabled:text-white/50 disabled:bg-[#e0e0dc] dark:disabled:bg-[#ececec21] hover:bg-opacity-85 transition duration-100 rounded-full p-2"
type="submit"
>
{firstMessage ? (
<ArrowRight className="bg-background" size={17} />
) : (
<ArrowUp className="bg-background" size={17} />
)}
</button>
{loading ? (
<button
type="button"
className="bg-red-700 text-white hover:bg-red-800 transition duration-100 rounded-full p-2 relative group"
onClick={onCancel}
aria-label="Cancel"
>
{loading && (
<div className="absolute inset-0 rounded-full border-2 border-white/30 border-t-white animate-spin" />
)}
<span className="relative flex items-center justify-center w-[17px] h-[17px]">
<Square size={17} className="text-white" />
</span>
</button>
) : (
<button
disabled={message.trim().length === 0}
className="bg-[#24A0ED] text-white disabled:text-black/50 dark:disabled:text-white/50 disabled:bg-[#e0e0dc] dark:disabled:bg-[#ececec21] hover:bg-opacity-85 transition duration-100 rounded-full p-2"
type="submit"
>
{firstMessage ? (
<ArrowRight className="bg-background" size={17} />
) : (
<ArrowUp className="bg-background" size={17} />
)}
</button>
)}
</div>
</div>
</div>

View File

@@ -1,6 +1,7 @@
import { ChatOpenAI } from '@langchain/openai';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { BaseMessage } from '@langchain/core/messages';
import { StringOutputParser } from '@langchain/core/output_parsers';
import {
ChatPromptTemplate,
MessagesPlaceholder,
@@ -11,19 +12,18 @@ import {
RunnableMap,
RunnableSequence,
} from '@langchain/core/runnables';
import { BaseMessage } from '@langchain/core/messages';
import { StringOutputParser } from '@langchain/core/output_parsers';
import LineListOutputParser from '../outputParsers/listLineOutputParser';
import LineOutputParser from '../outputParsers/lineOutputParser';
import { getDocumentsFromLinks } from '../utils/documents';
import { Document } from 'langchain/document';
import { searchSearxng } from '../searxng';
import path from 'node:path';
import fs from 'node:fs';
import computeSimilarity from '../utils/computeSimilarity';
import formatChatHistoryAsString from '../utils/formatHistory';
import eventEmitter from 'events';
import { StreamEvent } from '@langchain/core/tracers/log_stream';
import { ChatOpenAI } from '@langchain/openai';
import eventEmitter from 'events';
import { Document } from 'langchain/document';
import fs from 'node:fs';
import path from 'node:path';
import LineOutputParser from '../outputParsers/lineOutputParser';
import LineListOutputParser from '../outputParsers/listLineOutputParser';
import { searchSearxng } from '../searxng';
import computeSimilarity from '../utils/computeSimilarity';
import { getDocumentsFromLinks } from '../utils/documents';
import formatChatHistoryAsString from '../utils/formatHistory';
export interface MetaSearchAgentType {
searchAndAnswer: (
@@ -34,6 +34,7 @@ export interface MetaSearchAgentType {
optimizationMode: 'speed' | 'balanced' | 'quality',
fileIds: string[],
systemInstructions: string,
signal: AbortSignal,
) => Promise<eventEmitter>;
}
@@ -247,6 +248,7 @@ class MetaSearchAgent implements MetaSearchAgentType {
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
systemInstructions: string,
signal: AbortSignal,
) {
return RunnableSequence.from([
RunnableMap.from({
@@ -254,43 +256,58 @@ class MetaSearchAgent implements MetaSearchAgentType {
query: (input: BasicChainInput) => input.query,
chat_history: (input: BasicChainInput) => input.chat_history,
date: () => new Date().toISOString(),
context: RunnableLambda.from(async (input: BasicChainInput) => {
const processedHistory = formatChatHistoryAsString(
input.chat_history,
);
let docs: Document[] | null = null;
let query = input.query;
if (this.config.searchWeb) {
const searchRetrieverChain =
await this.createSearchRetrieverChain(llm);
var date = new Date().toISOString();
const searchRetrieverResult = await searchRetrieverChain.invoke({
chat_history: processedHistory,
query,
date,
});
query = searchRetrieverResult.query;
docs = searchRetrieverResult.docs;
// Store the search query in the context for emitting to the client
if (searchRetrieverResult.searchQuery) {
this.searchQuery = searchRetrieverResult.searchQuery;
context: RunnableLambda.from(
async (
input: BasicChainInput,
options?: { signal?: AbortSignal },
) => {
// Check if the request was aborted
if (options?.signal?.aborted || signal?.aborted) {
console.log('Request cancelled by user');
throw new Error('Request cancelled by user');
}
}
const sortedDocs = await this.rerankDocs(
query,
docs ?? [],
fileIds,
embeddings,
optimizationMode,
);
const processedHistory = formatChatHistoryAsString(
input.chat_history,
);
return sortedDocs;
})
let docs: Document[] | null = null;
let query = input.query;
if (this.config.searchWeb) {
const searchRetrieverChain =
await this.createSearchRetrieverChain(llm);
var date = new Date().toISOString();
const searchRetrieverResult = await searchRetrieverChain.invoke(
{
chat_history: processedHistory,
query,
date,
},
{ signal: options?.signal },
);
query = searchRetrieverResult.query;
docs = searchRetrieverResult.docs;
// Store the search query in the context for emitting to the client
if (searchRetrieverResult.searchQuery) {
this.searchQuery = searchRetrieverResult.searchQuery;
}
}
const sortedDocs = await this.rerankDocs(
query,
docs ?? [],
fileIds,
embeddings,
optimizationMode,
);
return sortedDocs;
},
)
.withConfig({
runName: 'FinalSourceRetriever',
})
@@ -450,8 +467,17 @@ class MetaSearchAgent implements MetaSearchAgentType {
stream: AsyncGenerator<StreamEvent, any, any>,
emitter: eventEmitter,
llm: BaseChatModel,
signal: AbortSignal,
) {
if (signal.aborted) {
return;
}
for await (const event of stream) {
if (signal.aborted) {
return;
}
if (
event.event === 'on_chain_end' &&
event.name === 'FinalSourceRetriever'
@@ -544,6 +570,7 @@ class MetaSearchAgent implements MetaSearchAgentType {
optimizationMode: 'speed' | 'balanced' | 'quality',
fileIds: string[],
systemInstructions: string,
signal: AbortSignal,
) {
const emitter = new eventEmitter();
@@ -553,6 +580,7 @@ class MetaSearchAgent implements MetaSearchAgentType {
embeddings,
optimizationMode,
systemInstructions,
signal,
);
const stream = answeringChain.streamEvents(
@@ -562,10 +590,12 @@ class MetaSearchAgent implements MetaSearchAgentType {
},
{
version: 'v1',
// Pass the abort signal to the LLM streaming chain
signal,
},
);
this.handleStream(stream, emitter, llm);
this.handleStream(stream, emitter, llm, signal);
return emitter;
}