mirror of
https://github.com/ItzCrazyKns/Perplexica.git
synced 2025-07-18 14:38:32 +00:00
implemented a refactoring plan with the
configurable delay feature. 1. Created AlternatingMessageValidator (renamed from MessageProcessor): -Focused on handling alternating message patterns -Made it model-agnostic with configuration-driven approach -Kept the core validation logic intact 2. Created ReasoningChatModel (renamed from DeepSeekChat): -Made it generic for any model with reasoning/thinking capabilities -Added configurable streaming delay parameter (streamDelay) -Implemented delay logic in the streaming process 3. Updated the DeepSeek provider: -Now uses ReasoningChatModel for deepseek-reasoner with a 50ms delay -Uses standard ChatOpenAI for deepseek-chat -Added a clear distinction between models that need reasoning capabilities Updated references in metaSearchAgent.ts: 4. Changed import from messageProcessor to alternatingMessageValidator -Updated function calls to use the new validator -The configurable delay implementation allows to control the speed of token generation, which can help with the issue you were seeing. The delay is set to 20ms by default for the deepseek-reasoner model, but you can adjust his value in the deepseek.ts provider file to find the optimal speed. This refactoring maintains all the existing functionality while making the code more maintainable and future-proof. The separation of concerns between message validation and model implementation will make it easier to add support for other models with similar requirements in the future.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import { DeepSeekChat } from '../deepseekChat';
|
||||
import { ReasoningChatModel } from '../reasoningChatModel';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import logger from '../../utils/logger';
|
||||
import { getDeepseekApiKey } from '../../config';
|
||||
import axios from 'axios';
|
||||
@ -16,9 +17,12 @@ interface ModelListResponse {
|
||||
|
||||
interface ChatModelConfig {
|
||||
displayName: string;
|
||||
model: DeepSeekChat;
|
||||
model: ReasoningChatModel | ChatOpenAI;
|
||||
}
|
||||
|
||||
// Define which models require reasoning capabilities
|
||||
const REASONING_MODELS = ['deepseek-reasoner'];
|
||||
|
||||
const MODEL_DISPLAY_NAMES: Record<string, string> = {
|
||||
'deepseek-reasoner': 'DeepSeek R1',
|
||||
'deepseek-chat': 'DeepSeek V3'
|
||||
@ -48,15 +52,32 @@ export const loadDeepSeekChatModels = async (): Promise<Record<string, ChatModel
|
||||
const chatModels = deepSeekModels.reduce<Record<string, ChatModelConfig>>((acc, model) => {
|
||||
// Only include models we have display names for
|
||||
if (model.id in MODEL_DISPLAY_NAMES) {
|
||||
acc[model.id] = {
|
||||
displayName: MODEL_DISPLAY_NAMES[model.id],
|
||||
model: new DeepSeekChat({
|
||||
apiKey,
|
||||
baseURL: deepSeekEndpoint,
|
||||
modelName: model.id,
|
||||
temperature: 0.7,
|
||||
}),
|
||||
};
|
||||
// Use ReasoningChatModel for models that need reasoning capabilities
|
||||
if (REASONING_MODELS.includes(model.id)) {
|
||||
acc[model.id] = {
|
||||
displayName: MODEL_DISPLAY_NAMES[model.id],
|
||||
model: new ReasoningChatModel({
|
||||
apiKey,
|
||||
baseURL: deepSeekEndpoint,
|
||||
modelName: model.id,
|
||||
temperature: 0.7,
|
||||
streamDelay: 50, // Add a small delay to control streaming speed
|
||||
}),
|
||||
};
|
||||
} else {
|
||||
// Use standard ChatOpenAI for other models
|
||||
acc[model.id] = {
|
||||
displayName: MODEL_DISPLAY_NAMES[model.id],
|
||||
model: new ChatOpenAI({
|
||||
openAIApiKey: apiKey,
|
||||
configuration: {
|
||||
baseURL: deepSeekEndpoint,
|
||||
},
|
||||
modelName: model.id,
|
||||
temperature: 0.7,
|
||||
}),
|
||||
};
|
||||
}
|
||||
}
|
||||
return acc;
|
||||
}, {});
|
||||
|
284
src/lib/reasoningChatModel.ts
Normal file
284
src/lib/reasoningChatModel.ts
Normal file
@ -0,0 +1,284 @@
|
||||
import { BaseChatModel, BaseChatModelCallOptions } from '@langchain/core/language_models/chat_models';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
|
||||
import { ChatResult, ChatGenerationChunk } from '@langchain/core/outputs';
|
||||
import axios from 'axios';
|
||||
|
||||
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
|
||||
|
||||
interface ReasoningChatModelParams extends BaseChatModelParams {
|
||||
apiKey: string;
|
||||
baseURL: string;
|
||||
modelName: string;
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
top_p?: number;
|
||||
frequency_penalty?: number;
|
||||
presence_penalty?: number;
|
||||
streamDelay?: number; // Add this parameter for controlling stream delay
|
||||
}
|
||||
|
||||
export class ReasoningChatModel extends BaseChatModel<BaseChatModelCallOptions & { stream?: boolean }> {
|
||||
private apiKey: string;
|
||||
private baseURL: string;
|
||||
private modelName: string;
|
||||
private temperature: number;
|
||||
private maxTokens: number;
|
||||
private topP: number;
|
||||
private frequencyPenalty: number;
|
||||
private presencePenalty: number;
|
||||
private streamDelay: number;
|
||||
|
||||
constructor(params: ReasoningChatModelParams) {
|
||||
super(params);
|
||||
this.apiKey = params.apiKey;
|
||||
this.baseURL = params.baseURL;
|
||||
this.modelName = params.modelName;
|
||||
this.temperature = params.temperature ?? 0.7;
|
||||
this.maxTokens = params.max_tokens ?? 8192;
|
||||
this.topP = params.top_p ?? 1;
|
||||
this.frequencyPenalty = params.frequency_penalty ?? 0;
|
||||
this.presencePenalty = params.presence_penalty ?? 0;
|
||||
this.streamDelay = params.streamDelay ?? 0; // Default to no delay
|
||||
}
|
||||
|
||||
async _generate(
|
||||
messages: BaseMessage[],
|
||||
options: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): Promise<ChatResult> {
|
||||
const formattedMessages = messages.map(msg => ({
|
||||
role: this.getRole(msg),
|
||||
content: msg.content.toString(),
|
||||
}));
|
||||
const response = await this.callAPI(formattedMessages, options.stream);
|
||||
|
||||
if (options.stream) {
|
||||
return this.processStreamingResponse(response, messages, options, runManager);
|
||||
} else {
|
||||
const choice = response.data.choices[0];
|
||||
let content = choice.message.content || '';
|
||||
if (choice.message.reasoning_content) {
|
||||
content = `<think>\n${choice.message.reasoning_content}\n</think>\n\n${content}`;
|
||||
}
|
||||
|
||||
// Report usage stats if available
|
||||
if (response.data.usage && runManager) {
|
||||
runManager.handleLLMEnd({
|
||||
generations: [],
|
||||
llmOutput: {
|
||||
tokenUsage: {
|
||||
completionTokens: response.data.usage.completion_tokens,
|
||||
promptTokens: response.data.usage.prompt_tokens,
|
||||
totalTokens: response.data.usage.total_tokens
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
return {
|
||||
generations: [
|
||||
{
|
||||
text: content,
|
||||
message: new AIMessage(content),
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private getRole(msg: BaseMessage): string {
|
||||
if (msg instanceof SystemMessage) return 'system';
|
||||
if (msg instanceof HumanMessage) return 'user';
|
||||
if (msg instanceof AIMessage) return 'assistant';
|
||||
return 'user'; // Default to user
|
||||
}
|
||||
|
||||
private async callAPI(messages: Array<{ role: string; content: string }>, streaming?: boolean) {
|
||||
return axios.post(
|
||||
`${this.baseURL}/chat/completions`,
|
||||
{
|
||||
messages,
|
||||
model: this.modelName,
|
||||
stream: streaming,
|
||||
temperature: this.temperature,
|
||||
max_tokens: this.maxTokens,
|
||||
top_p: this.topP,
|
||||
frequency_penalty: this.frequencyPenalty,
|
||||
presence_penalty: this.presencePenalty,
|
||||
response_format: { type: 'text' },
|
||||
...(streaming && {
|
||||
stream_options: {
|
||||
include_usage: true
|
||||
}
|
||||
})
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
},
|
||||
responseType: streaming ? 'text' : 'json',
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
public async *_streamResponseChunks(messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun) {
|
||||
const response = await this.callAPI(messages.map(msg => ({
|
||||
role: this.getRole(msg),
|
||||
content: msg.content.toString(),
|
||||
})), true);
|
||||
|
||||
let thinkState = -1; // -1: not started, 0: thinking, 1: answered
|
||||
let currentContent = '';
|
||||
|
||||
// Split the response into lines
|
||||
const lines = response.data.split('\n');
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue;
|
||||
const jsonStr = line.slice(6);
|
||||
if (jsonStr === '[DONE]') break;
|
||||
|
||||
try {
|
||||
console.log('Received chunk:', jsonStr);
|
||||
const chunk = JSON.parse(jsonStr);
|
||||
const delta = chunk.choices[0].delta;
|
||||
console.log('Parsed delta:', delta);
|
||||
|
||||
// Handle usage stats in final chunk
|
||||
if (chunk.usage && !chunk.choices?.length) {
|
||||
runManager?.handleLLMEnd?.({
|
||||
generations: [],
|
||||
llmOutput: {
|
||||
tokenUsage: {
|
||||
completionTokens: chunk.usage.completion_tokens,
|
||||
promptTokens: chunk.usage.prompt_tokens,
|
||||
totalTokens: chunk.usage.total_tokens
|
||||
}
|
||||
}
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle reasoning content
|
||||
if (delta.reasoning_content) {
|
||||
if (thinkState === -1) {
|
||||
thinkState = 0;
|
||||
const startTag = '<think>\n';
|
||||
currentContent += startTag;
|
||||
console.log('Emitting think start:', startTag);
|
||||
runManager?.handleLLMNewToken(startTag);
|
||||
const chunk = new ChatGenerationChunk({
|
||||
text: startTag,
|
||||
message: new AIMessageChunk(startTag),
|
||||
generationInfo: {}
|
||||
});
|
||||
|
||||
// Add configurable delay before yielding the chunk
|
||||
if (this.streamDelay > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, this.streamDelay));
|
||||
}
|
||||
|
||||
yield chunk;
|
||||
}
|
||||
currentContent += delta.reasoning_content;
|
||||
console.log('Emitting reasoning:', delta.reasoning_content);
|
||||
runManager?.handleLLMNewToken(delta.reasoning_content);
|
||||
const chunk = new ChatGenerationChunk({
|
||||
text: delta.reasoning_content,
|
||||
message: new AIMessageChunk(delta.reasoning_content),
|
||||
generationInfo: {}
|
||||
});
|
||||
|
||||
// Add configurable delay before yielding the chunk
|
||||
if (this.streamDelay > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, this.streamDelay));
|
||||
}
|
||||
|
||||
yield chunk;
|
||||
}
|
||||
|
||||
// Handle regular content
|
||||
if (delta.content) {
|
||||
if (thinkState === 0) {
|
||||
thinkState = 1;
|
||||
const endTag = '\n</think>\n\n';
|
||||
currentContent += endTag;
|
||||
console.log('Emitting think end:', endTag);
|
||||
runManager?.handleLLMNewToken(endTag);
|
||||
const chunk = new ChatGenerationChunk({
|
||||
text: endTag,
|
||||
message: new AIMessageChunk(endTag),
|
||||
generationInfo: {}
|
||||
});
|
||||
|
||||
// Add configurable delay before yielding the chunk
|
||||
if (this.streamDelay > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, this.streamDelay));
|
||||
}
|
||||
|
||||
yield chunk;
|
||||
}
|
||||
currentContent += delta.content;
|
||||
console.log('Emitting content:', delta.content);
|
||||
runManager?.handleLLMNewToken(delta.content);
|
||||
const chunk = new ChatGenerationChunk({
|
||||
text: delta.content,
|
||||
message: new AIMessageChunk(delta.content),
|
||||
generationInfo: {}
|
||||
});
|
||||
|
||||
// Add configurable delay before yielding the chunk
|
||||
if (this.streamDelay > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, this.streamDelay));
|
||||
}
|
||||
|
||||
yield chunk;
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Failed to parse chunk';
|
||||
console.error(`Streaming error: ${errorMessage}`);
|
||||
if (error instanceof Error && error.message.includes('DeepSeek API Error')) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle any unclosed think block
|
||||
if (thinkState === 0) {
|
||||
const endTag = '\n</think>\n\n';
|
||||
currentContent += endTag;
|
||||
runManager?.handleLLMNewToken(endTag);
|
||||
const chunk = new ChatGenerationChunk({
|
||||
text: endTag,
|
||||
message: new AIMessageChunk(endTag),
|
||||
generationInfo: {}
|
||||
});
|
||||
|
||||
// Add configurable delay before yielding the chunk
|
||||
if (this.streamDelay > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, this.streamDelay));
|
||||
}
|
||||
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
|
||||
private async processStreamingResponse(response: any, messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun): Promise<ChatResult> {
|
||||
let accumulatedContent = '';
|
||||
for await (const chunk of this._streamResponseChunks(messages, options, runManager)) {
|
||||
accumulatedContent += chunk.message.content;
|
||||
}
|
||||
return {
|
||||
generations: [
|
||||
{
|
||||
text: accumulatedContent,
|
||||
message: new AIMessage(accumulatedContent),
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
_llmType(): string {
|
||||
return 'reasoning';
|
||||
}
|
||||
}
|
@ -23,7 +23,7 @@ import fs from 'fs';
|
||||
import computeSimilarity from '../utils/computeSimilarity';
|
||||
import formatChatHistoryAsString from '../utils/formatHistory';
|
||||
import eventEmitter from 'events';
|
||||
import { getMessageProcessor } from '../utils/messageProcessor';
|
||||
import { getMessageValidator } from '../utils/alternatingMessageValidator';
|
||||
import { StreamEvent } from '@langchain/core/tracers/log_stream';
|
||||
import { IterableReadableStream } from '@langchain/core/utils/stream';
|
||||
|
||||
@ -483,10 +483,10 @@ class MetaSearchAgent implements MetaSearchAgentType {
|
||||
new HumanMessage(message)
|
||||
];
|
||||
|
||||
// Get message processor if model needs it
|
||||
const messageProcessor = getMessageProcessor((llm as any).modelName);
|
||||
const processedMessages = messageProcessor
|
||||
? messageProcessor.processMessages(allMessages)
|
||||
// Get message validator if model needs it
|
||||
const messageValidator = getMessageValidator((llm as any).modelName);
|
||||
const processedMessages = messageValidator
|
||||
? messageValidator.processMessages(allMessages)
|
||||
: allMessages;
|
||||
|
||||
// Extract system message and chat history
|
||||
|
95
src/utils/alternatingMessageValidator.ts
Normal file
95
src/utils/alternatingMessageValidator.ts
Normal file
@ -0,0 +1,95 @@
|
||||
// Using the import paths that have been working for you
|
||||
import { BaseMessage, HumanMessage, AIMessage, SystemMessage } from "@langchain/core/messages";
|
||||
import logger from "./logger";
|
||||
|
||||
export interface MessageValidationRules {
|
||||
requireAlternating?: boolean;
|
||||
firstMessageType?: typeof HumanMessage | typeof AIMessage;
|
||||
allowSystem?: boolean;
|
||||
}
|
||||
|
||||
export class AlternatingMessageValidator {
|
||||
private rules: MessageValidationRules;
|
||||
private modelName: string;
|
||||
|
||||
constructor(modelName: string, rules: MessageValidationRules) {
|
||||
this.rules = rules;
|
||||
this.modelName = modelName;
|
||||
}
|
||||
|
||||
processMessages(messages: BaseMessage[]): BaseMessage[] {
|
||||
// Always respect requireAlternating for models that need it
|
||||
if (!this.rules.requireAlternating) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
const processedMessages: BaseMessage[] = [];
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const currentMsg = messages[i];
|
||||
|
||||
// Handle system messages
|
||||
if (currentMsg instanceof SystemMessage) {
|
||||
if (this.rules.allowSystem) {
|
||||
processedMessages.push(currentMsg);
|
||||
} else {
|
||||
logger.warn(`${this.modelName}: Skipping system message - not allowed`);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle first non-system message
|
||||
if (processedMessages.length === 0 ||
|
||||
processedMessages[processedMessages.length - 1] instanceof SystemMessage) {
|
||||
if (this.rules.firstMessageType &&
|
||||
!(currentMsg instanceof this.rules.firstMessageType)) {
|
||||
logger.warn(`${this.modelName}: Converting first message to required type`);
|
||||
processedMessages.push(new this.rules.firstMessageType({
|
||||
content: currentMsg.content,
|
||||
additional_kwargs: currentMsg.additional_kwargs
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle alternating pattern
|
||||
const lastMsg = processedMessages[processedMessages.length - 1];
|
||||
if (lastMsg instanceof HumanMessage && currentMsg instanceof HumanMessage) {
|
||||
logger.warn(`${this.modelName}: Skipping consecutive human message`);
|
||||
continue;
|
||||
}
|
||||
if (lastMsg instanceof AIMessage && currentMsg instanceof AIMessage) {
|
||||
logger.warn(`${this.modelName}: Skipping consecutive AI message`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// For deepseek-reasoner, strip out reasoning_content from message history
|
||||
if (this.modelName === 'deepseek-reasoner' && currentMsg instanceof AIMessage) {
|
||||
const { reasoning_content, ...cleanedKwargs } = currentMsg.additional_kwargs;
|
||||
processedMessages.push(new AIMessage({
|
||||
content: currentMsg.content,
|
||||
additional_kwargs: cleanedKwargs
|
||||
}));
|
||||
} else {
|
||||
processedMessages.push(currentMsg);
|
||||
}
|
||||
}
|
||||
|
||||
return processedMessages;
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-configured validators for specific models
|
||||
export const getMessageValidator = (modelName: string): AlternatingMessageValidator | null => {
|
||||
const validators: Record<string, MessageValidationRules> = {
|
||||
'deepseek-reasoner': {
|
||||
requireAlternating: true,
|
||||
firstMessageType: HumanMessage,
|
||||
allowSystem: true
|
||||
},
|
||||
// Add more model configurations as needed
|
||||
};
|
||||
|
||||
const rules = validators[modelName];
|
||||
return rules ? new AlternatingMessageValidator(modelName, rules) : null;
|
||||
};
|
Reference in New Issue
Block a user