diff --git a/src/app/api/images/route.ts b/src/app/api/images/route.ts index e02854d..d3416ca 100644 --- a/src/app/api/images/route.ts +++ b/src/app/api/images/route.ts @@ -1,23 +1,12 @@ import handleImageSearch from '@/lib/chains/imageSearchAgent'; -import { - getCustomOpenaiApiKey, - getCustomOpenaiApiUrl, - getCustomOpenaiModelName, -} from '@/lib/config'; -import { getAvailableChatModelProviders } from '@/lib/providers'; -import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import ModelRegistry from '@/lib/models/registry'; +import { ModelWithProvider } from '@/lib/models/types'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; -import { ChatOpenAI } from '@langchain/openai'; - -interface ChatModel { - provider: string; - model: string; -} interface ImageSearchBody { query: string; chatHistory: any[]; - chatModel?: ChatModel; + chatModel: ModelWithProvider; } export const POST = async (req: Request) => { @@ -34,35 +23,12 @@ export const POST = async (req: Request) => { }) .filter((msg) => msg !== undefined) as BaseMessage[]; - const chatModelProviders = await getAvailableChatModelProviders(); + const registry = new ModelRegistry(); - const chatModelProvider = - chatModelProviders[ - body.chatModel?.provider || Object.keys(chatModelProviders)[0] - ]; - const chatModel = - chatModelProvider[ - body.chatModel?.model || Object.keys(chatModelProvider)[0] - ]; - - let llm: BaseChatModel | undefined; - - if (body.chatModel?.provider === 'custom_openai') { - llm = new ChatOpenAI({ - apiKey: getCustomOpenaiApiKey(), - modelName: getCustomOpenaiModelName(), - temperature: 0.7, - configuration: { - baseURL: getCustomOpenaiApiUrl(), - }, - }) as unknown as BaseChatModel; - } else if (chatModelProvider && chatModel) { - llm = chatModel.model; - } - - if (!llm) { - return Response.json({ error: 'Invalid chat model' }, { status: 400 }); - } + const llm = await registry.loadChatModel( + body.chatModel.providerId, + body.chatModel.key, + ); const images = await handleImageSearch( { diff --git a/src/app/api/suggestions/route.ts b/src/app/api/suggestions/route.ts index 99179d2..d8312cf 100644 --- a/src/app/api/suggestions/route.ts +++ b/src/app/api/suggestions/route.ts @@ -1,22 +1,12 @@ import generateSuggestions from '@/lib/chains/suggestionGeneratorAgent'; -import { - getCustomOpenaiApiKey, - getCustomOpenaiApiUrl, - getCustomOpenaiModelName, -} from '@/lib/config'; -import { getAvailableChatModelProviders } from '@/lib/providers'; +import ModelRegistry from '@/lib/models/registry'; +import { ModelWithProvider } from '@/lib/models/types'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; -import { ChatOpenAI } from '@langchain/openai'; - -interface ChatModel { - provider: string; - model: string; -} interface SuggestionsGenerationBody { chatHistory: any[]; - chatModel?: ChatModel; + chatModel: ModelWithProvider; } export const POST = async (req: Request) => { @@ -33,35 +23,12 @@ export const POST = async (req: Request) => { }) .filter((msg) => msg !== undefined) as BaseMessage[]; - const chatModelProviders = await getAvailableChatModelProviders(); + const registry = new ModelRegistry(); - const chatModelProvider = - chatModelProviders[ - body.chatModel?.provider || Object.keys(chatModelProviders)[0] - ]; - const chatModel = - chatModelProvider[ - body.chatModel?.model || Object.keys(chatModelProvider)[0] - ]; - - let llm: BaseChatModel | undefined; - - if (body.chatModel?.provider === 'custom_openai') { - llm = new ChatOpenAI({ - apiKey: getCustomOpenaiApiKey(), - modelName: getCustomOpenaiModelName(), - temperature: 0.7, - configuration: { - baseURL: getCustomOpenaiApiUrl(), - }, - }) as unknown as BaseChatModel; - } else if (chatModelProvider && chatModel) { - llm = chatModel.model; - } - - if (!llm) { - return Response.json({ error: 'Invalid chat model' }, { status: 400 }); - } + const llm = await registry.loadChatModel( + body.chatModel.providerId, + body.chatModel.key, + ); const suggestions = await generateSuggestions( { diff --git a/src/app/api/videos/route.ts b/src/app/api/videos/route.ts index 7e8288b..02e5909 100644 --- a/src/app/api/videos/route.ts +++ b/src/app/api/videos/route.ts @@ -1,23 +1,12 @@ import handleVideoSearch from '@/lib/chains/videoSearchAgent'; -import { - getCustomOpenaiApiKey, - getCustomOpenaiApiUrl, - getCustomOpenaiModelName, -} from '@/lib/config'; -import { getAvailableChatModelProviders } from '@/lib/providers'; -import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import ModelRegistry from '@/lib/models/registry'; +import { ModelWithProvider } from '@/lib/models/types'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; -import { ChatOpenAI } from '@langchain/openai'; - -interface ChatModel { - provider: string; - model: string; -} interface VideoSearchBody { query: string; chatHistory: any[]; - chatModel?: ChatModel; + chatModel: ModelWithProvider; } export const POST = async (req: Request) => { @@ -34,35 +23,12 @@ export const POST = async (req: Request) => { }) .filter((msg) => msg !== undefined) as BaseMessage[]; - const chatModelProviders = await getAvailableChatModelProviders(); + const registry = new ModelRegistry(); - const chatModelProvider = - chatModelProviders[ - body.chatModel?.provider || Object.keys(chatModelProviders)[0] - ]; - const chatModel = - chatModelProvider[ - body.chatModel?.model || Object.keys(chatModelProvider)[0] - ]; - - let llm: BaseChatModel | undefined; - - if (body.chatModel?.provider === 'custom_openai') { - llm = new ChatOpenAI({ - apiKey: getCustomOpenaiApiKey(), - modelName: getCustomOpenaiModelName(), - temperature: 0.7, - configuration: { - baseURL: getCustomOpenaiApiUrl(), - }, - }) as unknown as BaseChatModel; - } else if (chatModelProvider && chatModel) { - llm = chatModel.model; - } - - if (!llm) { - return Response.json({ error: 'Invalid chat model' }, { status: 400 }); - } + const llm = await registry.loadChatModel( + body.chatModel.providerId, + body.chatModel.key, + ); const videos = await handleVideoSearch( { diff --git a/src/components/SearchImages.tsx b/src/components/SearchImages.tsx index 08c16ee..ca4d477 100644 --- a/src/components/SearchImages.tsx +++ b/src/components/SearchImages.tsx @@ -33,11 +33,10 @@ const SearchImages = ({ onClick={async () => { setLoading(true); - const chatModelProvider = localStorage.getItem('chatModelProvider'); - const chatModel = localStorage.getItem('chatModel'); - - const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); - const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const chatModelProvider = localStorage.getItem( + 'chatModelProviderId', + ); + const chatModel = localStorage.getItem('chatModelKey'); const res = await fetch(`/api/images`, { method: 'POST', @@ -48,12 +47,8 @@ const SearchImages = ({ query: query, chatHistory: chatHistory, chatModel: { - provider: chatModelProvider, - model: chatModel, - ...(chatModelProvider === 'custom_openai' && { - customOpenAIBaseURL: customOpenAIBaseURL, - customOpenAIKey: customOpenAIKey, - }), + providerId: chatModelProvider, + key: chatModel, }, }), }); diff --git a/src/components/SearchVideos.tsx b/src/components/SearchVideos.tsx index a09a0d2..4084383 100644 --- a/src/components/SearchVideos.tsx +++ b/src/components/SearchVideos.tsx @@ -48,11 +48,10 @@ const Searchvideos = ({ onClick={async () => { setLoading(true); - const chatModelProvider = localStorage.getItem('chatModelProvider'); - const chatModel = localStorage.getItem('chatModel'); - - const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); - const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const chatModelProvider = localStorage.getItem( + 'chatModelProviderId', + ); + const chatModel = localStorage.getItem('chatModelKey'); const res = await fetch(`/api/videos`, { method: 'POST', @@ -63,12 +62,8 @@ const Searchvideos = ({ query: query, chatHistory: chatHistory, chatModel: { - provider: chatModelProvider, - model: chatModel, - ...(chatModelProvider === 'custom_openai' && { - customOpenAIBaseURL: customOpenAIBaseURL, - customOpenAIKey: customOpenAIKey, - }), + providerId: chatModelProvider, + key: chatModel, }, }), }); diff --git a/src/lib/actions.ts b/src/lib/actions.ts index 93d0b38..cb75d88 100644 --- a/src/lib/actions.ts +++ b/src/lib/actions.ts @@ -1,11 +1,8 @@ import { Message } from '@/components/ChatWindow'; export const getSuggestions = async (chatHistory: Message[]) => { - const chatModel = localStorage.getItem('chatModel'); - const chatModelProvider = localStorage.getItem('chatModelProvider'); - - const customOpenAIKey = localStorage.getItem('openAIApiKey'); - const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); + const chatModel = localStorage.getItem('chatModelKey'); + const chatModelProvider = localStorage.getItem('chatModelProviderId'); const res = await fetch(`/api/suggestions`, { method: 'POST', @@ -15,12 +12,8 @@ export const getSuggestions = async (chatHistory: Message[]) => { body: JSON.stringify({ chatHistory: chatHistory, chatModel: { - provider: chatModelProvider, - model: chatModel, - ...(chatModelProvider === 'custom_openai' && { - customOpenAIKey, - customOpenAIBaseURL, - }), + providerId: chatModelProvider, + key: chatModel, }, }), });