diff --git a/src/lib/providers/deepseek.ts b/src/lib/providers/deepseek.ts index 13bf147..55d594b 100644 --- a/src/lib/providers/deepseek.ts +++ b/src/lib/providers/deepseek.ts @@ -1,4 +1,5 @@ -import { DeepSeekChat } from '../deepseekChat'; +import { ReasoningChatModel } from '../reasoningChatModel'; +import { ChatOpenAI } from '@langchain/openai'; import logger from '../../utils/logger'; import { getDeepseekApiKey } from '../../config'; import axios from 'axios'; @@ -16,9 +17,12 @@ interface ModelListResponse { interface ChatModelConfig { displayName: string; - model: DeepSeekChat; + model: ReasoningChatModel | ChatOpenAI; } +// Define which models require reasoning capabilities +const REASONING_MODELS = ['deepseek-reasoner']; + const MODEL_DISPLAY_NAMES: Record = { 'deepseek-reasoner': 'DeepSeek R1', 'deepseek-chat': 'DeepSeek V3' @@ -48,15 +52,32 @@ export const loadDeepSeekChatModels = async (): Promise>((acc, model) => { // Only include models we have display names for if (model.id in MODEL_DISPLAY_NAMES) { - acc[model.id] = { - displayName: MODEL_DISPLAY_NAMES[model.id], - model: new DeepSeekChat({ - apiKey, - baseURL: deepSeekEndpoint, - modelName: model.id, - temperature: 0.7, - }), - }; + // Use ReasoningChatModel for models that need reasoning capabilities + if (REASONING_MODELS.includes(model.id)) { + acc[model.id] = { + displayName: MODEL_DISPLAY_NAMES[model.id], + model: new ReasoningChatModel({ + apiKey, + baseURL: deepSeekEndpoint, + modelName: model.id, + temperature: 0.7, + streamDelay: 50, // Add a small delay to control streaming speed + }), + }; + } else { + // Use standard ChatOpenAI for other models + acc[model.id] = { + displayName: MODEL_DISPLAY_NAMES[model.id], + model: new ChatOpenAI({ + openAIApiKey: apiKey, + configuration: { + baseURL: deepSeekEndpoint, + }, + modelName: model.id, + temperature: 0.7, + }), + }; + } } return acc; }, {}); diff --git a/src/lib/reasoningChatModel.ts b/src/lib/reasoningChatModel.ts new file mode 100644 index 0000000..903650d --- /dev/null +++ b/src/lib/reasoningChatModel.ts @@ -0,0 +1,284 @@ +import { BaseChatModel, BaseChatModelCallOptions } from '@langchain/core/language_models/chat_models'; +import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; +import { AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages'; +import { ChatResult, ChatGenerationChunk } from '@langchain/core/outputs'; +import axios from 'axios'; + +import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; + +interface ReasoningChatModelParams extends BaseChatModelParams { + apiKey: string; + baseURL: string; + modelName: string; + temperature?: number; + max_tokens?: number; + top_p?: number; + frequency_penalty?: number; + presence_penalty?: number; + streamDelay?: number; // Add this parameter for controlling stream delay +} + +export class ReasoningChatModel extends BaseChatModel { + private apiKey: string; + private baseURL: string; + private modelName: string; + private temperature: number; + private maxTokens: number; + private topP: number; + private frequencyPenalty: number; + private presencePenalty: number; + private streamDelay: number; + + constructor(params: ReasoningChatModelParams) { + super(params); + this.apiKey = params.apiKey; + this.baseURL = params.baseURL; + this.modelName = params.modelName; + this.temperature = params.temperature ?? 0.7; + this.maxTokens = params.max_tokens ?? 8192; + this.topP = params.top_p ?? 1; + this.frequencyPenalty = params.frequency_penalty ?? 0; + this.presencePenalty = params.presence_penalty ?? 0; + this.streamDelay = params.streamDelay ?? 0; // Default to no delay + } + + async _generate( + messages: BaseMessage[], + options: this['ParsedCallOptions'], + runManager?: CallbackManagerForLLMRun + ): Promise { + const formattedMessages = messages.map(msg => ({ + role: this.getRole(msg), + content: msg.content.toString(), + })); + const response = await this.callAPI(formattedMessages, options.stream); + + if (options.stream) { + return this.processStreamingResponse(response, messages, options, runManager); + } else { + const choice = response.data.choices[0]; + let content = choice.message.content || ''; + if (choice.message.reasoning_content) { + content = `\n${choice.message.reasoning_content}\n\n\n${content}`; + } + + // Report usage stats if available + if (response.data.usage && runManager) { + runManager.handleLLMEnd({ + generations: [], + llmOutput: { + tokenUsage: { + completionTokens: response.data.usage.completion_tokens, + promptTokens: response.data.usage.prompt_tokens, + totalTokens: response.data.usage.total_tokens + } + } + }); + } + return { + generations: [ + { + text: content, + message: new AIMessage(content), + }, + ], + }; + } + } + + private getRole(msg: BaseMessage): string { + if (msg instanceof SystemMessage) return 'system'; + if (msg instanceof HumanMessage) return 'user'; + if (msg instanceof AIMessage) return 'assistant'; + return 'user'; // Default to user + } + + private async callAPI(messages: Array<{ role: string; content: string }>, streaming?: boolean) { + return axios.post( + `${this.baseURL}/chat/completions`, + { + messages, + model: this.modelName, + stream: streaming, + temperature: this.temperature, + max_tokens: this.maxTokens, + top_p: this.topP, + frequency_penalty: this.frequencyPenalty, + presence_penalty: this.presencePenalty, + response_format: { type: 'text' }, + ...(streaming && { + stream_options: { + include_usage: true + } + }) + }, + { + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${this.apiKey}`, + }, + responseType: streaming ? 'text' : 'json', + } + ); + } + + public async *_streamResponseChunks(messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun) { + const response = await this.callAPI(messages.map(msg => ({ + role: this.getRole(msg), + content: msg.content.toString(), + })), true); + + let thinkState = -1; // -1: not started, 0: thinking, 1: answered + let currentContent = ''; + + // Split the response into lines + const lines = response.data.split('\n'); + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + const jsonStr = line.slice(6); + if (jsonStr === '[DONE]') break; + + try { + console.log('Received chunk:', jsonStr); + const chunk = JSON.parse(jsonStr); + const delta = chunk.choices[0].delta; + console.log('Parsed delta:', delta); + + // Handle usage stats in final chunk + if (chunk.usage && !chunk.choices?.length) { + runManager?.handleLLMEnd?.({ + generations: [], + llmOutput: { + tokenUsage: { + completionTokens: chunk.usage.completion_tokens, + promptTokens: chunk.usage.prompt_tokens, + totalTokens: chunk.usage.total_tokens + } + } + }); + continue; + } + + // Handle reasoning content + if (delta.reasoning_content) { + if (thinkState === -1) { + thinkState = 0; + const startTag = '\n'; + currentContent += startTag; + console.log('Emitting think start:', startTag); + runManager?.handleLLMNewToken(startTag); + const chunk = new ChatGenerationChunk({ + text: startTag, + message: new AIMessageChunk(startTag), + generationInfo: {} + }); + + // Add configurable delay before yielding the chunk + if (this.streamDelay > 0) { + await new Promise(resolve => setTimeout(resolve, this.streamDelay)); + } + + yield chunk; + } + currentContent += delta.reasoning_content; + console.log('Emitting reasoning:', delta.reasoning_content); + runManager?.handleLLMNewToken(delta.reasoning_content); + const chunk = new ChatGenerationChunk({ + text: delta.reasoning_content, + message: new AIMessageChunk(delta.reasoning_content), + generationInfo: {} + }); + + // Add configurable delay before yielding the chunk + if (this.streamDelay > 0) { + await new Promise(resolve => setTimeout(resolve, this.streamDelay)); + } + + yield chunk; + } + + // Handle regular content + if (delta.content) { + if (thinkState === 0) { + thinkState = 1; + const endTag = '\n\n\n'; + currentContent += endTag; + console.log('Emitting think end:', endTag); + runManager?.handleLLMNewToken(endTag); + const chunk = new ChatGenerationChunk({ + text: endTag, + message: new AIMessageChunk(endTag), + generationInfo: {} + }); + + // Add configurable delay before yielding the chunk + if (this.streamDelay > 0) { + await new Promise(resolve => setTimeout(resolve, this.streamDelay)); + } + + yield chunk; + } + currentContent += delta.content; + console.log('Emitting content:', delta.content); + runManager?.handleLLMNewToken(delta.content); + const chunk = new ChatGenerationChunk({ + text: delta.content, + message: new AIMessageChunk(delta.content), + generationInfo: {} + }); + + // Add configurable delay before yielding the chunk + if (this.streamDelay > 0) { + await new Promise(resolve => setTimeout(resolve, this.streamDelay)); + } + + yield chunk; + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to parse chunk'; + console.error(`Streaming error: ${errorMessage}`); + if (error instanceof Error && error.message.includes('DeepSeek API Error')) { + throw error; + } + } + } + + // Handle any unclosed think block + if (thinkState === 0) { + const endTag = '\n\n\n'; + currentContent += endTag; + runManager?.handleLLMNewToken(endTag); + const chunk = new ChatGenerationChunk({ + text: endTag, + message: new AIMessageChunk(endTag), + generationInfo: {} + }); + + // Add configurable delay before yielding the chunk + if (this.streamDelay > 0) { + await new Promise(resolve => setTimeout(resolve, this.streamDelay)); + } + + yield chunk; + } + } + + private async processStreamingResponse(response: any, messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun): Promise { + let accumulatedContent = ''; + for await (const chunk of this._streamResponseChunks(messages, options, runManager)) { + accumulatedContent += chunk.message.content; + } + return { + generations: [ + { + text: accumulatedContent, + message: new AIMessage(accumulatedContent), + }, + ], + }; + } + + _llmType(): string { + return 'reasoning'; + } +} diff --git a/src/search/metaSearchAgent.ts b/src/search/metaSearchAgent.ts index 2a904f0..0cf1f52 100644 --- a/src/search/metaSearchAgent.ts +++ b/src/search/metaSearchAgent.ts @@ -23,7 +23,7 @@ import fs from 'fs'; import computeSimilarity from '../utils/computeSimilarity'; import formatChatHistoryAsString from '../utils/formatHistory'; import eventEmitter from 'events'; -import { getMessageProcessor } from '../utils/messageProcessor'; +import { getMessageValidator } from '../utils/alternatingMessageValidator'; import { StreamEvent } from '@langchain/core/tracers/log_stream'; import { IterableReadableStream } from '@langchain/core/utils/stream'; @@ -483,10 +483,10 @@ class MetaSearchAgent implements MetaSearchAgentType { new HumanMessage(message) ]; - // Get message processor if model needs it - const messageProcessor = getMessageProcessor((llm as any).modelName); - const processedMessages = messageProcessor - ? messageProcessor.processMessages(allMessages) + // Get message validator if model needs it + const messageValidator = getMessageValidator((llm as any).modelName); + const processedMessages = messageValidator + ? messageValidator.processMessages(allMessages) : allMessages; // Extract system message and chat history diff --git a/src/utils/alternatingMessageValidator.ts b/src/utils/alternatingMessageValidator.ts new file mode 100644 index 0000000..785244c --- /dev/null +++ b/src/utils/alternatingMessageValidator.ts @@ -0,0 +1,95 @@ +// Using the import paths that have been working for you +import { BaseMessage, HumanMessage, AIMessage, SystemMessage } from "@langchain/core/messages"; +import logger from "./logger"; + +export interface MessageValidationRules { + requireAlternating?: boolean; + firstMessageType?: typeof HumanMessage | typeof AIMessage; + allowSystem?: boolean; +} + +export class AlternatingMessageValidator { + private rules: MessageValidationRules; + private modelName: string; + + constructor(modelName: string, rules: MessageValidationRules) { + this.rules = rules; + this.modelName = modelName; + } + + processMessages(messages: BaseMessage[]): BaseMessage[] { + // Always respect requireAlternating for models that need it + if (!this.rules.requireAlternating) { + return messages; + } + + const processedMessages: BaseMessage[] = []; + + for (let i = 0; i < messages.length; i++) { + const currentMsg = messages[i]; + + // Handle system messages + if (currentMsg instanceof SystemMessage) { + if (this.rules.allowSystem) { + processedMessages.push(currentMsg); + } else { + logger.warn(`${this.modelName}: Skipping system message - not allowed`); + } + continue; + } + + // Handle first non-system message + if (processedMessages.length === 0 || + processedMessages[processedMessages.length - 1] instanceof SystemMessage) { + if (this.rules.firstMessageType && + !(currentMsg instanceof this.rules.firstMessageType)) { + logger.warn(`${this.modelName}: Converting first message to required type`); + processedMessages.push(new this.rules.firstMessageType({ + content: currentMsg.content, + additional_kwargs: currentMsg.additional_kwargs + })); + continue; + } + } + + // Handle alternating pattern + const lastMsg = processedMessages[processedMessages.length - 1]; + if (lastMsg instanceof HumanMessage && currentMsg instanceof HumanMessage) { + logger.warn(`${this.modelName}: Skipping consecutive human message`); + continue; + } + if (lastMsg instanceof AIMessage && currentMsg instanceof AIMessage) { + logger.warn(`${this.modelName}: Skipping consecutive AI message`); + continue; + } + + // For deepseek-reasoner, strip out reasoning_content from message history + if (this.modelName === 'deepseek-reasoner' && currentMsg instanceof AIMessage) { + const { reasoning_content, ...cleanedKwargs } = currentMsg.additional_kwargs; + processedMessages.push(new AIMessage({ + content: currentMsg.content, + additional_kwargs: cleanedKwargs + })); + } else { + processedMessages.push(currentMsg); + } + } + + return processedMessages; + } +} + +// Pre-configured validators for specific models +export const getMessageValidator = (modelName: string): AlternatingMessageValidator | null => { + const validators: Record = { + 'deepseek-reasoner': { + requireAlternating: true, + firstMessageType: HumanMessage, + allowSystem: true + }, + // Add more model configurations as needed + }; + + const rules = validators[modelName]; + return rules ? new AlternatingMessageValidator(modelName, rules) : null; +};