feat(app): fix issues with model selection

This commit is contained in:
ItzCrazyKns
2025-10-24 22:56:23 +05:30
parent 600d4ceb29
commit d0719429b4
3 changed files with 65 additions and 46 deletions

View File

@@ -10,15 +10,14 @@ import {
} from '@headlessui/react'; } from '@headlessui/react';
import { Fragment, useEffect, useState } from 'react'; import { Fragment, useEffect, useState } from 'react';
import { MinimalProvider } from '@/lib/models/types'; import { MinimalProvider } from '@/lib/models/types';
import { useChat } from '@/lib/hooks/useChat';
const ModelSelector = () => { const ModelSelector = () => {
const [providers, setProviders] = useState<MinimalProvider[]>([]); const [providers, setProviders] = useState<MinimalProvider[]>([]);
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
const [searchQuery, setSearchQuery] = useState(''); const [searchQuery, setSearchQuery] = useState('');
const [selectedModel, setSelectedModel] = useState<{
providerId: string; const { setChatModelProvider, chatModelProvider } = useChat();
modelKey: string;
} | null>(null);
useEffect(() => { useEffect(() => {
const loadProviders = async () => { const loadProviders = async () => {
@@ -30,28 +29,21 @@ const ModelSelector = () => {
throw new Error('Failed to fetch providers'); throw new Error('Failed to fetch providers');
} }
const data = await res.json(); const data: { providers: MinimalProvider[] } = await res.json();
setProviders(data.providers || []);
const savedProviderId = localStorage.getItem('chatModelProviderId'); const currentProviderIndex = data.providers.findIndex((p: MinimalProvider) => {
const savedModelKey = localStorage.getItem('chatModelKey'); return p.id === chatModelProvider?.providerId
})
if (savedProviderId && savedModelKey) { if (currentProviderIndex === -1) {
setSelectedModel({ setProviders(data.providers);
providerId: savedProviderId, return;
modelKey: savedModelKey,
});
} else if (data.providers && data.providers.length > 0) {
const firstProvider = data.providers.find(
(p: MinimalProvider) => p.chatModels.length > 0,
);
if (firstProvider && firstProvider.chatModels[0]) {
setSelectedModel({
providerId: firstProvider.id,
modelKey: firstProvider.chatModels[0].key,
});
}
} }
const selectedProvider = data.providers[currentProviderIndex]
const remainingProviders = data.providers.filter((_, index) => index !== currentProviderIndex)
setProviders([selectedProvider, ...remainingProviders]);
} catch (error) { } catch (error) {
console.error('Error loading providers:', error); console.error('Error loading providers:', error);
} finally { } finally {
@@ -60,10 +52,10 @@ const ModelSelector = () => {
}; };
loadProviders(); loadProviders();
}, []); }, [chatModelProvider]);
const handleModelSelect = (providerId: string, modelKey: string) => { const handleModelSelect = (providerId: string, modelKey: string) => {
setSelectedModel({ providerId, modelKey }); setChatModelProvider({ providerId, key: modelKey });
localStorage.setItem('chatModelProviderId', providerId); localStorage.setItem('chatModelProviderId', providerId);
localStorage.setItem('chatModelKey', modelKey); localStorage.setItem('chatModelKey', modelKey);
}; };
@@ -140,15 +132,15 @@ const ModelSelector = () => {
<div className="flex flex-col px-2 py-2 space-y-0.5"> <div className="flex flex-col px-2 py-2 space-y-0.5">
{provider.chatModels.map((model) => ( {provider.chatModels.map((model) => (
<PopoverButton <button
key={model.key} key={model.key}
onClick={() => onClick={() =>
handleModelSelect(provider.id, model.key) handleModelSelect(provider.id, model.key)
} }
className={cn( className={cn(
'px-3 py-2 flex items-center justify-between text-start duration-200 cursor-pointer transition rounded-lg group', 'px-3 py-2 flex items-center justify-between text-start duration-200 cursor-pointer transition rounded-lg group',
selectedModel?.providerId === provider.id && chatModelProvider?.providerId === provider.id &&
selectedModel?.modelKey === model.key chatModelProvider?.key === model.key
? 'bg-light-secondary dark:bg-dark-secondary' ? 'bg-light-secondary dark:bg-dark-secondary'
: 'hover:bg-light-secondary dark:hover:bg-dark-secondary', : 'hover:bg-light-secondary dark:hover:bg-dark-secondary',
)} )}
@@ -158,8 +150,8 @@ const ModelSelector = () => {
size={15} size={15}
className={cn( className={cn(
'shrink-0', 'shrink-0',
selectedModel?.providerId === provider.id && chatModelProvider?.providerId === provider.id &&
selectedModel?.modelKey === model.key chatModelProvider?.key === model.key
? 'text-sky-500' ? 'text-sky-500'
: 'text-black/50 dark:text-white/50 group-hover:text-black/70 group-hover:dark:text-white/70', : 'text-black/50 dark:text-white/50 group-hover:text-black/70 group-hover:dark:text-white/70',
)} )}
@@ -167,8 +159,8 @@ const ModelSelector = () => {
<p <p
className={cn( className={cn(
'text-sm truncate', 'text-sm truncate',
selectedModel?.providerId === provider.id && chatModelProvider?.providerId === provider.id &&
selectedModel?.modelKey === model.key chatModelProvider?.key === model.key
? 'text-sky-500 font-medium' ? 'text-sky-500 font-medium'
: 'text-black/70 dark:text-white/70 group-hover:text-black dark:group-hover:text-white', : 'text-black/70 dark:text-white/70 group-hover:text-black dark:group-hover:text-white',
)} )}
@@ -176,7 +168,7 @@ const ModelSelector = () => {
{model.name} {model.name}
</p> </p>
</div> </div>
</PopoverButton> </button>
))} ))}
</div> </div>

View File

@@ -1,5 +1,6 @@
import Select from '@/components/ui/Select'; import Select from '@/components/ui/Select';
import { ConfigModelProvider } from '@/lib/config/types'; import { ConfigModelProvider } from '@/lib/config/types';
import { useChat } from '@/lib/hooks/useChat';
import { useState } from 'react'; import { useState } from 'react';
import { toast } from 'sonner'; import { toast } from 'sonner';
@@ -16,6 +17,7 @@ const ModelSelect = ({
: `${localStorage.getItem('embeddingModelProviderId')}/${localStorage.getItem('embeddingModelKey')}`, : `${localStorage.getItem('embeddingModelProviderId')}/${localStorage.getItem('embeddingModelKey')}`,
); );
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const { setChatModelProvider, setEmbeddingModelProvider } = useChat();
const handleSave = async (newValue: string) => { const handleSave = async (newValue: string) => {
setLoading(true); setLoading(true);
@@ -23,20 +25,33 @@ const ModelSelect = ({
try { try {
if (type === 'chat') { if (type === 'chat') {
localStorage.setItem('chatModelProviderId', newValue.split('/')[0]); const providerId = newValue.split('/')[0];
localStorage.setItem( const modelKey = newValue.split('/').slice(1).join('/');
'chatModelKey',
newValue.split('/').slice(1).join('/'), localStorage.setItem('chatModelProviderId', providerId);
); localStorage.setItem('chatModelKey', modelKey);
setChatModelProvider({
providerId: providerId,
key: modelKey,
});
} else { } else {
const providerId = newValue.split('/')[0];
const modelKey = newValue.split('/').slice(1).join('/');
localStorage.setItem( localStorage.setItem(
'embeddingModelProviderId', 'embeddingModelProviderId',
newValue.split('/')[0], providerId,
); );
localStorage.setItem( localStorage.setItem(
'embeddingModelKey', 'embeddingModelKey',
newValue.split('/').slice(1).join('/'), modelKey,
); );
setEmbeddingModelProvider({
providerId: providerId,
key: modelKey,
});
} }
} catch (error) { } catch (error) {
console.error('Error saving config:', error); console.error('Error saving config:', error);

View File

@@ -48,6 +48,8 @@ type ChatContext = {
messageAppeared: boolean; messageAppeared: boolean;
isReady: boolean; isReady: boolean;
hasError: boolean; hasError: boolean;
chatModelProvider: ChatModelProvider;
embeddingModelProvider: EmbeddingModelProvider;
setOptimizationMode: (mode: string) => void; setOptimizationMode: (mode: string) => void;
setFocusMode: (mode: string) => void; setFocusMode: (mode: string) => void;
setFiles: (files: File[]) => void; setFiles: (files: File[]) => void;
@@ -58,6 +60,8 @@ type ChatContext = {
rewrite?: boolean, rewrite?: boolean,
) => Promise<void>; ) => Promise<void>;
rewrite: (messageId: string) => void; rewrite: (messageId: string) => void;
setChatModelProvider: (provider: ChatModelProvider) => void;
setEmbeddingModelProvider: (provider: EmbeddingModelProvider) => void;
}; };
export interface File { export interface File {
@@ -256,12 +260,16 @@ export const chatContext = createContext<ChatContext>({
sections: [], sections: [],
notFound: false, notFound: false,
optimizationMode: '', optimizationMode: '',
rewrite: () => {}, chatModelProvider: { key: '', providerId: '' },
sendMessage: async () => {}, embeddingModelProvider: { key: '', providerId: '' },
setFileIds: () => {}, rewrite: () => { },
setFiles: () => {}, sendMessage: async () => { },
setFocusMode: () => {}, setFileIds: () => { },
setOptimizationMode: () => {}, setFiles: () => { },
setFocusMode: () => { },
setOptimizationMode: () => { },
setChatModelProvider: () => { },
setEmbeddingModelProvider: () => { },
}); });
export const ChatProvider = ({ export const ChatProvider = ({
@@ -743,6 +751,10 @@ export const ChatProvider = ({
setOptimizationMode, setOptimizationMode,
rewrite, rewrite,
sendMessage, sendMessage,
setChatModelProvider,
chatModelProvider,
embeddingModelProvider,
setEmbeddingModelProvider,
}} }}
> >
{children} {children}