From 41d056e755fb4d752df334d1c3bd53848bc4fce6 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:29:08 +0530 Subject: [PATCH] feat(handlers): use new custom openai --- src/lib/providers/index.ts | 28 +++++++++++++++++++++++- src/routes/config.ts | 35 +++++++++++++++++++++++------- src/routes/images.ts | 22 +++++++------------ src/routes/search.ts | 22 ++++++++----------- src/routes/suggestions.ts | 22 +++++++------------ src/routes/videos.ts | 22 +++++++------------ src/websocket/connectionManager.ts | 27 ++++++++++++++++------- 7 files changed, 106 insertions(+), 72 deletions(-) diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index 98846e7..57e9185 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -4,6 +4,12 @@ import { loadOpenAIChatModels, loadOpenAIEmbeddingsModels } from './openai'; import { loadAnthropicChatModels } from './anthropic'; import { loadTransformersEmbeddingsModels } from './transformers'; import { loadGeminiChatModels, loadGeminiEmbeddingsModels } from './gemini'; +import { + getCustomOpenaiApiKey, + getCustomOpenaiApiUrl, + getCustomOpenaiModelName, +} from '../../config'; +import { ChatOpenAI } from '@langchain/openai'; const chatModelProviders = { openai: loadOpenAIChatModels, @@ -30,7 +36,27 @@ export const getAvailableChatModelProviders = async () => { } } - models['custom_openai'] = {}; + const customOpenAiApiKey = getCustomOpenaiApiKey(); + const customOpenAiApiUrl = getCustomOpenaiApiUrl(); + const customOpenAiModelName = getCustomOpenaiModelName(); + + models['custom_openai'] = { + ...(customOpenAiApiKey && customOpenAiApiUrl && customOpenAiModelName + ? { + [customOpenAiModelName]: { + displayName: customOpenAiModelName, + model: new ChatOpenAI({ + openAIApiKey: customOpenAiApiKey, + modelName: customOpenAiModelName, + temperature: 0.7, + configuration: { + baseURL: customOpenAiApiUrl, + }, + }), + }, + } + : {}), + }; return models; }; diff --git a/src/routes/config.ts b/src/routes/config.ts index 6ff80c6..18b370d 100644 --- a/src/routes/config.ts +++ b/src/routes/config.ts @@ -10,6 +10,9 @@ import { getGeminiApiKey, getOpenaiApiKey, updateConfig, + getCustomOpenaiApiUrl, + getCustomOpenaiApiKey, + getCustomOpenaiModelName, } from '../config'; import logger from '../utils/logger'; @@ -54,6 +57,9 @@ router.get('/', async (_, res) => { config['anthropicApiKey'] = getAnthropicApiKey(); config['groqApiKey'] = getGroqApiKey(); config['geminiApiKey'] = getGeminiApiKey(); + config['customOpenaiApiUrl'] = getCustomOpenaiApiUrl(); + config['customOpenaiApiKey'] = getCustomOpenaiApiKey(); + config['customOpenaiModelName'] = getCustomOpenaiModelName(); res.status(200).json(config); } catch (err: any) { @@ -66,14 +72,27 @@ router.post('/', async (req, res) => { const config = req.body; const updatedConfig = { - API_KEYS: { - OPENAI: config.openaiApiKey, - GROQ: config.groqApiKey, - ANTHROPIC: config.anthropicApiKey, - GEMINI: config.geminiApiKey, - }, - API_ENDPOINTS: { - OLLAMA: config.ollamaApiUrl, + MODELS: { + OPENAI: { + API_KEY: config.openaiApiKey, + }, + GROQ: { + API_KEY: config.groqApiKey, + }, + ANTHROPIC: { + API_KEY: config.anthropicApiKey, + }, + GEMINI: { + API_KEY: config.geminiApiKey, + }, + OLLAMA: { + API_URL: config.ollamaApiUrl, + }, + CUSTOM_OPENAI: { + API_URL: config.customOpenaiApiUrl, + API_KEY: config.customOpenaiApiKey, + MODEL_NAME: config.customOpenaiModelName, + }, }, }; diff --git a/src/routes/images.ts b/src/routes/images.ts index efa095a..5671657 100644 --- a/src/routes/images.ts +++ b/src/routes/images.ts @@ -5,14 +5,17 @@ import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import { ChatOpenAI } from '@langchain/openai'; +import { + getCustomOpenaiApiKey, + getCustomOpenaiApiUrl, + getCustomOpenaiModelName, +} from '../config'; const router = express.Router(); interface ChatModel { provider: string; model: string; - customOpenAIBaseURL?: string; - customOpenAIKey?: string; } interface ImageSearchBody { @@ -44,21 +47,12 @@ router.post('/', async (req, res) => { let llm: BaseChatModel | undefined; if (body.chatModel?.provider === 'custom_openai') { - if ( - !body.chatModel?.customOpenAIBaseURL || - !body.chatModel?.customOpenAIKey - ) { - return res - .status(400) - .json({ message: 'Missing custom OpenAI base URL or key' }); - } - llm = new ChatOpenAI({ - modelName: body.chatModel.model, - openAIApiKey: body.chatModel.customOpenAIKey, + modelName: getCustomOpenaiModelName(), + openAIApiKey: getCustomOpenaiApiKey(), temperature: 0.7, configuration: { - baseURL: body.chatModel.customOpenAIBaseURL, + baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; } else if ( diff --git a/src/routes/search.ts b/src/routes/search.ts index e24b3f9..a29f64b 100644 --- a/src/routes/search.ts +++ b/src/routes/search.ts @@ -10,14 +10,19 @@ import { import { searchHandlers } from '../websocket/messageHandler'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { MetaSearchAgentType } from '../search/metaSearchAgent'; +import { + getCustomOpenaiApiKey, + getCustomOpenaiApiUrl, + getCustomOpenaiModelName, +} from '../config'; const router = express.Router(); interface chatModel { provider: string; model: string; - customOpenAIBaseURL?: string; customOpenAIKey?: string; + customOpenAIBaseURL?: string; } interface embeddingModel { @@ -78,21 +83,12 @@ router.post('/', async (req, res) => { let embeddings: Embeddings | undefined; if (body.chatModel?.provider === 'custom_openai') { - if ( - !body.chatModel?.customOpenAIBaseURL || - !body.chatModel?.customOpenAIKey - ) { - return res - .status(400) - .json({ message: 'Missing custom OpenAI base URL or key' }); - } - llm = new ChatOpenAI({ - modelName: body.chatModel.model, - openAIApiKey: body.chatModel.customOpenAIKey, + modelName: body.chatModel?.model || getCustomOpenaiModelName(), + openAIApiKey: body.chatModel?.customOpenAIKey || getCustomOpenaiApiKey(), temperature: 0.7, configuration: { - baseURL: body.chatModel.customOpenAIBaseURL, + baseURL: body.chatModel?.customOpenAIBaseURL || getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; } else if ( diff --git a/src/routes/suggestions.ts b/src/routes/suggestions.ts index 1d46e5b..7dd1739 100644 --- a/src/routes/suggestions.ts +++ b/src/routes/suggestions.ts @@ -5,14 +5,17 @@ import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import { ChatOpenAI } from '@langchain/openai'; +import { + getCustomOpenaiApiKey, + getCustomOpenaiApiUrl, + getCustomOpenaiModelName, +} from '../config'; const router = express.Router(); interface ChatModel { provider: string; model: string; - customOpenAIBaseURL?: string; - customOpenAIKey?: string; } interface SuggestionsBody { @@ -43,21 +46,12 @@ router.post('/', async (req, res) => { let llm: BaseChatModel | undefined; if (body.chatModel?.provider === 'custom_openai') { - if ( - !body.chatModel?.customOpenAIBaseURL || - !body.chatModel?.customOpenAIKey - ) { - return res - .status(400) - .json({ message: 'Missing custom OpenAI base URL or key' }); - } - llm = new ChatOpenAI({ - modelName: body.chatModel.model, - openAIApiKey: body.chatModel.customOpenAIKey, + modelName: getCustomOpenaiModelName(), + openAIApiKey: getCustomOpenaiApiKey(), temperature: 0.7, configuration: { - baseURL: body.chatModel.customOpenAIBaseURL, + baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; } else if ( diff --git a/src/routes/videos.ts b/src/routes/videos.ts index ad87460..b631f26 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -5,14 +5,17 @@ import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import handleVideoSearch from '../chains/videoSearchAgent'; import { ChatOpenAI } from '@langchain/openai'; +import { + getCustomOpenaiApiKey, + getCustomOpenaiApiUrl, + getCustomOpenaiModelName, +} from '../config'; const router = express.Router(); interface ChatModel { provider: string; model: string; - customOpenAIBaseURL?: string; - customOpenAIKey?: string; } interface VideoSearchBody { @@ -44,21 +47,12 @@ router.post('/', async (req, res) => { let llm: BaseChatModel | undefined; if (body.chatModel?.provider === 'custom_openai') { - if ( - !body.chatModel?.customOpenAIBaseURL || - !body.chatModel?.customOpenAIKey - ) { - return res - .status(400) - .json({ message: 'Missing custom OpenAI base URL or key' }); - } - llm = new ChatOpenAI({ - modelName: body.chatModel.model, - openAIApiKey: body.chatModel.customOpenAIKey, + modelName: getCustomOpenaiModelName(), + openAIApiKey: getCustomOpenaiApiKey(), temperature: 0.7, configuration: { - baseURL: body.chatModel.customOpenAIBaseURL, + baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; } else if ( diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index d980500..bb8f242 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -9,6 +9,11 @@ import type { Embeddings } from '@langchain/core/embeddings'; import type { IncomingMessage } from 'http'; import logger from '../utils/logger'; import { ChatOpenAI } from '@langchain/openai'; +import { + getCustomOpenaiApiKey, + getCustomOpenaiApiUrl, + getCustomOpenaiModelName, +} from '../config'; export const handleConnection = async ( ws: WebSocket, @@ -48,14 +53,20 @@ export const handleConnection = async ( llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; } else if (chatModelProvider == 'custom_openai') { - llm = new ChatOpenAI({ - modelName: chatModel, - openAIApiKey: searchParams.get('openAIApiKey'), - temperature: 0.7, - configuration: { - baseURL: searchParams.get('openAIBaseURL'), - }, - }) as unknown as BaseChatModel; + const customOpenaiApiKey = getCustomOpenaiApiKey(); + const customOpenaiApiUrl = getCustomOpenaiApiUrl(); + const customOpenaiModelName = getCustomOpenaiModelName(); + + if (customOpenaiApiKey && customOpenaiApiUrl && customOpenaiModelName) { + llm = new ChatOpenAI({ + modelName: customOpenaiModelName, + openAIApiKey: customOpenaiApiKey, + temperature: 0.7, + configuration: { + baseURL: customOpenaiApiUrl, + }, + }) as unknown as BaseChatModel; + } } if (