From d0719429b45f2ed24f9118e776cabeacb3460898 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Fri, 24 Oct 2025 22:56:23 +0530 Subject: [PATCH] feat(app): fix issues with model selection --- .../MessageInputActions/ChatModelSelector.tsx | 58 ++++++++----------- .../Settings/Sections/Models/ModelSelect.tsx | 29 +++++++--- src/lib/hooks/useChat.tsx | 24 ++++++-- 3 files changed, 65 insertions(+), 46 deletions(-) diff --git a/src/components/MessageInputActions/ChatModelSelector.tsx b/src/components/MessageInputActions/ChatModelSelector.tsx index afc3b3b..35f7001 100644 --- a/src/components/MessageInputActions/ChatModelSelector.tsx +++ b/src/components/MessageInputActions/ChatModelSelector.tsx @@ -10,15 +10,14 @@ import { } from '@headlessui/react'; import { Fragment, useEffect, useState } from 'react'; import { MinimalProvider } from '@/lib/models/types'; +import { useChat } from '@/lib/hooks/useChat'; const ModelSelector = () => { const [providers, setProviders] = useState([]); const [isLoading, setIsLoading] = useState(true); const [searchQuery, setSearchQuery] = useState(''); - const [selectedModel, setSelectedModel] = useState<{ - providerId: string; - modelKey: string; - } | null>(null); + + const { setChatModelProvider, chatModelProvider } = useChat(); useEffect(() => { const loadProviders = async () => { @@ -30,28 +29,21 @@ const ModelSelector = () => { throw new Error('Failed to fetch providers'); } - const data = await res.json(); - setProviders(data.providers || []); + const data: { providers: MinimalProvider[] } = await res.json(); - const savedProviderId = localStorage.getItem('chatModelProviderId'); - const savedModelKey = localStorage.getItem('chatModelKey'); + const currentProviderIndex = data.providers.findIndex((p: MinimalProvider) => { + return p.id === chatModelProvider?.providerId + }) - if (savedProviderId && savedModelKey) { - setSelectedModel({ - providerId: savedProviderId, - modelKey: savedModelKey, - }); - } else if (data.providers && data.providers.length > 0) { - const firstProvider = data.providers.find( - (p: MinimalProvider) => p.chatModels.length > 0, - ); - if (firstProvider && firstProvider.chatModels[0]) { - setSelectedModel({ - providerId: firstProvider.id, - modelKey: firstProvider.chatModels[0].key, - }); - } + if (currentProviderIndex === -1) { + setProviders(data.providers); + return; } + + const selectedProvider = data.providers[currentProviderIndex] + const remainingProviders = data.providers.filter((_, index) => index !== currentProviderIndex) + + setProviders([selectedProvider, ...remainingProviders]); } catch (error) { console.error('Error loading providers:', error); } finally { @@ -60,10 +52,10 @@ const ModelSelector = () => { }; loadProviders(); - }, []); + }, [chatModelProvider]); const handleModelSelect = (providerId: string, modelKey: string) => { - setSelectedModel({ providerId, modelKey }); + setChatModelProvider({ providerId, key: modelKey }); localStorage.setItem('chatModelProviderId', providerId); localStorage.setItem('chatModelKey', modelKey); }; @@ -140,15 +132,15 @@ const ModelSelector = () => {
{provider.chatModels.map((model) => ( - handleModelSelect(provider.id, model.key) } className={cn( 'px-3 py-2 flex items-center justify-between text-start duration-200 cursor-pointer transition rounded-lg group', - selectedModel?.providerId === provider.id && - selectedModel?.modelKey === model.key + chatModelProvider?.providerId === provider.id && + chatModelProvider?.key === model.key ? 'bg-light-secondary dark:bg-dark-secondary' : 'hover:bg-light-secondary dark:hover:bg-dark-secondary', )} @@ -158,8 +150,8 @@ const ModelSelector = () => { size={15} className={cn( 'shrink-0', - selectedModel?.providerId === provider.id && - selectedModel?.modelKey === model.key + chatModelProvider?.providerId === provider.id && + chatModelProvider?.key === model.key ? 'text-sky-500' : 'text-black/50 dark:text-white/50 group-hover:text-black/70 group-hover:dark:text-white/70', )} @@ -167,8 +159,8 @@ const ModelSelector = () => {

{ {model.name}

- + ))} diff --git a/src/components/Settings/Sections/Models/ModelSelect.tsx b/src/components/Settings/Sections/Models/ModelSelect.tsx index dae84c7..96eaa55 100644 --- a/src/components/Settings/Sections/Models/ModelSelect.tsx +++ b/src/components/Settings/Sections/Models/ModelSelect.tsx @@ -1,5 +1,6 @@ import Select from '@/components/ui/Select'; import { ConfigModelProvider } from '@/lib/config/types'; +import { useChat } from '@/lib/hooks/useChat'; import { useState } from 'react'; import { toast } from 'sonner'; @@ -16,6 +17,7 @@ const ModelSelect = ({ : `${localStorage.getItem('embeddingModelProviderId')}/${localStorage.getItem('embeddingModelKey')}`, ); const [loading, setLoading] = useState(false); + const { setChatModelProvider, setEmbeddingModelProvider } = useChat(); const handleSave = async (newValue: string) => { setLoading(true); @@ -23,20 +25,33 @@ const ModelSelect = ({ try { if (type === 'chat') { - localStorage.setItem('chatModelProviderId', newValue.split('/')[0]); - localStorage.setItem( - 'chatModelKey', - newValue.split('/').slice(1).join('/'), - ); + const providerId = newValue.split('/')[0]; + const modelKey = newValue.split('/').slice(1).join('/'); + + localStorage.setItem('chatModelProviderId', providerId); + localStorage.setItem('chatModelKey', modelKey); + + setChatModelProvider({ + providerId: providerId, + key: modelKey, + }); } else { + const providerId = newValue.split('/')[0]; + const modelKey = newValue.split('/').slice(1).join('/'); + localStorage.setItem( 'embeddingModelProviderId', - newValue.split('/')[0], + providerId, ); localStorage.setItem( 'embeddingModelKey', - newValue.split('/').slice(1).join('/'), + modelKey, ); + + setEmbeddingModelProvider({ + providerId: providerId, + key: modelKey, + }); } } catch (error) { console.error('Error saving config:', error); diff --git a/src/lib/hooks/useChat.tsx b/src/lib/hooks/useChat.tsx index 04b17b5..ab346ec 100644 --- a/src/lib/hooks/useChat.tsx +++ b/src/lib/hooks/useChat.tsx @@ -48,6 +48,8 @@ type ChatContext = { messageAppeared: boolean; isReady: boolean; hasError: boolean; + chatModelProvider: ChatModelProvider; + embeddingModelProvider: EmbeddingModelProvider; setOptimizationMode: (mode: string) => void; setFocusMode: (mode: string) => void; setFiles: (files: File[]) => void; @@ -58,6 +60,8 @@ type ChatContext = { rewrite?: boolean, ) => Promise; rewrite: (messageId: string) => void; + setChatModelProvider: (provider: ChatModelProvider) => void; + setEmbeddingModelProvider: (provider: EmbeddingModelProvider) => void; }; export interface File { @@ -256,12 +260,16 @@ export const chatContext = createContext({ sections: [], notFound: false, optimizationMode: '', - rewrite: () => {}, - sendMessage: async () => {}, - setFileIds: () => {}, - setFiles: () => {}, - setFocusMode: () => {}, - setOptimizationMode: () => {}, + chatModelProvider: { key: '', providerId: '' }, + embeddingModelProvider: { key: '', providerId: '' }, + rewrite: () => { }, + sendMessage: async () => { }, + setFileIds: () => { }, + setFiles: () => { }, + setFocusMode: () => { }, + setOptimizationMode: () => { }, + setChatModelProvider: () => { }, + setEmbeddingModelProvider: () => { }, }); export const ChatProvider = ({ @@ -743,6 +751,10 @@ export const ChatProvider = ({ setOptimizationMode, rewrite, sendMessage, + setChatModelProvider, + chatModelProvider, + embeddingModelProvider, + setEmbeddingModelProvider, }} > {children}