mirror of
https://github.com/ItzCrazyKns/Perplexica.git
synced 2025-10-25 16:38:16 +00:00
feat(app): fix issues with model selection
This commit is contained in:
@@ -10,15 +10,14 @@ import {
|
|||||||
} from '@headlessui/react';
|
} from '@headlessui/react';
|
||||||
import { Fragment, useEffect, useState } from 'react';
|
import { Fragment, useEffect, useState } from 'react';
|
||||||
import { MinimalProvider } from '@/lib/models/types';
|
import { MinimalProvider } from '@/lib/models/types';
|
||||||
|
import { useChat } from '@/lib/hooks/useChat';
|
||||||
|
|
||||||
const ModelSelector = () => {
|
const ModelSelector = () => {
|
||||||
const [providers, setProviders] = useState<MinimalProvider[]>([]);
|
const [providers, setProviders] = useState<MinimalProvider[]>([]);
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
const [searchQuery, setSearchQuery] = useState('');
|
const [searchQuery, setSearchQuery] = useState('');
|
||||||
const [selectedModel, setSelectedModel] = useState<{
|
|
||||||
providerId: string;
|
const { setChatModelProvider, chatModelProvider } = useChat();
|
||||||
modelKey: string;
|
|
||||||
} | null>(null);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const loadProviders = async () => {
|
const loadProviders = async () => {
|
||||||
@@ -30,28 +29,21 @@ const ModelSelector = () => {
|
|||||||
throw new Error('Failed to fetch providers');
|
throw new Error('Failed to fetch providers');
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await res.json();
|
const data: { providers: MinimalProvider[] } = await res.json();
|
||||||
setProviders(data.providers || []);
|
|
||||||
|
|
||||||
const savedProviderId = localStorage.getItem('chatModelProviderId');
|
const currentProviderIndex = data.providers.findIndex((p: MinimalProvider) => {
|
||||||
const savedModelKey = localStorage.getItem('chatModelKey');
|
return p.id === chatModelProvider?.providerId
|
||||||
|
})
|
||||||
|
|
||||||
if (savedProviderId && savedModelKey) {
|
if (currentProviderIndex === -1) {
|
||||||
setSelectedModel({
|
setProviders(data.providers);
|
||||||
providerId: savedProviderId,
|
return;
|
||||||
modelKey: savedModelKey,
|
|
||||||
});
|
|
||||||
} else if (data.providers && data.providers.length > 0) {
|
|
||||||
const firstProvider = data.providers.find(
|
|
||||||
(p: MinimalProvider) => p.chatModels.length > 0,
|
|
||||||
);
|
|
||||||
if (firstProvider && firstProvider.chatModels[0]) {
|
|
||||||
setSelectedModel({
|
|
||||||
providerId: firstProvider.id,
|
|
||||||
modelKey: firstProvider.chatModels[0].key,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const selectedProvider = data.providers[currentProviderIndex]
|
||||||
|
const remainingProviders = data.providers.filter((_, index) => index !== currentProviderIndex)
|
||||||
|
|
||||||
|
setProviders([selectedProvider, ...remainingProviders]);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error loading providers:', error);
|
console.error('Error loading providers:', error);
|
||||||
} finally {
|
} finally {
|
||||||
@@ -60,10 +52,10 @@ const ModelSelector = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
loadProviders();
|
loadProviders();
|
||||||
}, []);
|
}, [chatModelProvider]);
|
||||||
|
|
||||||
const handleModelSelect = (providerId: string, modelKey: string) => {
|
const handleModelSelect = (providerId: string, modelKey: string) => {
|
||||||
setSelectedModel({ providerId, modelKey });
|
setChatModelProvider({ providerId, key: modelKey });
|
||||||
localStorage.setItem('chatModelProviderId', providerId);
|
localStorage.setItem('chatModelProviderId', providerId);
|
||||||
localStorage.setItem('chatModelKey', modelKey);
|
localStorage.setItem('chatModelKey', modelKey);
|
||||||
};
|
};
|
||||||
@@ -140,15 +132,15 @@ const ModelSelector = () => {
|
|||||||
|
|
||||||
<div className="flex flex-col px-2 py-2 space-y-0.5">
|
<div className="flex flex-col px-2 py-2 space-y-0.5">
|
||||||
{provider.chatModels.map((model) => (
|
{provider.chatModels.map((model) => (
|
||||||
<PopoverButton
|
<button
|
||||||
key={model.key}
|
key={model.key}
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
handleModelSelect(provider.id, model.key)
|
handleModelSelect(provider.id, model.key)
|
||||||
}
|
}
|
||||||
className={cn(
|
className={cn(
|
||||||
'px-3 py-2 flex items-center justify-between text-start duration-200 cursor-pointer transition rounded-lg group',
|
'px-3 py-2 flex items-center justify-between text-start duration-200 cursor-pointer transition rounded-lg group',
|
||||||
selectedModel?.providerId === provider.id &&
|
chatModelProvider?.providerId === provider.id &&
|
||||||
selectedModel?.modelKey === model.key
|
chatModelProvider?.key === model.key
|
||||||
? 'bg-light-secondary dark:bg-dark-secondary'
|
? 'bg-light-secondary dark:bg-dark-secondary'
|
||||||
: 'hover:bg-light-secondary dark:hover:bg-dark-secondary',
|
: 'hover:bg-light-secondary dark:hover:bg-dark-secondary',
|
||||||
)}
|
)}
|
||||||
@@ -158,8 +150,8 @@ const ModelSelector = () => {
|
|||||||
size={15}
|
size={15}
|
||||||
className={cn(
|
className={cn(
|
||||||
'shrink-0',
|
'shrink-0',
|
||||||
selectedModel?.providerId === provider.id &&
|
chatModelProvider?.providerId === provider.id &&
|
||||||
selectedModel?.modelKey === model.key
|
chatModelProvider?.key === model.key
|
||||||
? 'text-sky-500'
|
? 'text-sky-500'
|
||||||
: 'text-black/50 dark:text-white/50 group-hover:text-black/70 group-hover:dark:text-white/70',
|
: 'text-black/50 dark:text-white/50 group-hover:text-black/70 group-hover:dark:text-white/70',
|
||||||
)}
|
)}
|
||||||
@@ -167,8 +159,8 @@ const ModelSelector = () => {
|
|||||||
<p
|
<p
|
||||||
className={cn(
|
className={cn(
|
||||||
'text-sm truncate',
|
'text-sm truncate',
|
||||||
selectedModel?.providerId === provider.id &&
|
chatModelProvider?.providerId === provider.id &&
|
||||||
selectedModel?.modelKey === model.key
|
chatModelProvider?.key === model.key
|
||||||
? 'text-sky-500 font-medium'
|
? 'text-sky-500 font-medium'
|
||||||
: 'text-black/70 dark:text-white/70 group-hover:text-black dark:group-hover:text-white',
|
: 'text-black/70 dark:text-white/70 group-hover:text-black dark:group-hover:text-white',
|
||||||
)}
|
)}
|
||||||
@@ -176,7 +168,7 @@ const ModelSelector = () => {
|
|||||||
{model.name}
|
{model.name}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</PopoverButton>
|
</button>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import Select from '@/components/ui/Select';
|
import Select from '@/components/ui/Select';
|
||||||
import { ConfigModelProvider } from '@/lib/config/types';
|
import { ConfigModelProvider } from '@/lib/config/types';
|
||||||
|
import { useChat } from '@/lib/hooks/useChat';
|
||||||
import { useState } from 'react';
|
import { useState } from 'react';
|
||||||
import { toast } from 'sonner';
|
import { toast } from 'sonner';
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ const ModelSelect = ({
|
|||||||
: `${localStorage.getItem('embeddingModelProviderId')}/${localStorage.getItem('embeddingModelKey')}`,
|
: `${localStorage.getItem('embeddingModelProviderId')}/${localStorage.getItem('embeddingModelKey')}`,
|
||||||
);
|
);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
const { setChatModelProvider, setEmbeddingModelProvider } = useChat();
|
||||||
|
|
||||||
const handleSave = async (newValue: string) => {
|
const handleSave = async (newValue: string) => {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
@@ -23,20 +25,33 @@ const ModelSelect = ({
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
if (type === 'chat') {
|
if (type === 'chat') {
|
||||||
localStorage.setItem('chatModelProviderId', newValue.split('/')[0]);
|
const providerId = newValue.split('/')[0];
|
||||||
localStorage.setItem(
|
const modelKey = newValue.split('/').slice(1).join('/');
|
||||||
'chatModelKey',
|
|
||||||
newValue.split('/').slice(1).join('/'),
|
localStorage.setItem('chatModelProviderId', providerId);
|
||||||
);
|
localStorage.setItem('chatModelKey', modelKey);
|
||||||
|
|
||||||
|
setChatModelProvider({
|
||||||
|
providerId: providerId,
|
||||||
|
key: modelKey,
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
|
const providerId = newValue.split('/')[0];
|
||||||
|
const modelKey = newValue.split('/').slice(1).join('/');
|
||||||
|
|
||||||
localStorage.setItem(
|
localStorage.setItem(
|
||||||
'embeddingModelProviderId',
|
'embeddingModelProviderId',
|
||||||
newValue.split('/')[0],
|
providerId,
|
||||||
);
|
);
|
||||||
localStorage.setItem(
|
localStorage.setItem(
|
||||||
'embeddingModelKey',
|
'embeddingModelKey',
|
||||||
newValue.split('/').slice(1).join('/'),
|
modelKey,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
setEmbeddingModelProvider({
|
||||||
|
providerId: providerId,
|
||||||
|
key: modelKey,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error saving config:', error);
|
console.error('Error saving config:', error);
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ type ChatContext = {
|
|||||||
messageAppeared: boolean;
|
messageAppeared: boolean;
|
||||||
isReady: boolean;
|
isReady: boolean;
|
||||||
hasError: boolean;
|
hasError: boolean;
|
||||||
|
chatModelProvider: ChatModelProvider;
|
||||||
|
embeddingModelProvider: EmbeddingModelProvider;
|
||||||
setOptimizationMode: (mode: string) => void;
|
setOptimizationMode: (mode: string) => void;
|
||||||
setFocusMode: (mode: string) => void;
|
setFocusMode: (mode: string) => void;
|
||||||
setFiles: (files: File[]) => void;
|
setFiles: (files: File[]) => void;
|
||||||
@@ -58,6 +60,8 @@ type ChatContext = {
|
|||||||
rewrite?: boolean,
|
rewrite?: boolean,
|
||||||
) => Promise<void>;
|
) => Promise<void>;
|
||||||
rewrite: (messageId: string) => void;
|
rewrite: (messageId: string) => void;
|
||||||
|
setChatModelProvider: (provider: ChatModelProvider) => void;
|
||||||
|
setEmbeddingModelProvider: (provider: EmbeddingModelProvider) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export interface File {
|
export interface File {
|
||||||
@@ -256,12 +260,16 @@ export const chatContext = createContext<ChatContext>({
|
|||||||
sections: [],
|
sections: [],
|
||||||
notFound: false,
|
notFound: false,
|
||||||
optimizationMode: '',
|
optimizationMode: '',
|
||||||
rewrite: () => {},
|
chatModelProvider: { key: '', providerId: '' },
|
||||||
sendMessage: async () => {},
|
embeddingModelProvider: { key: '', providerId: '' },
|
||||||
setFileIds: () => {},
|
rewrite: () => { },
|
||||||
setFiles: () => {},
|
sendMessage: async () => { },
|
||||||
setFocusMode: () => {},
|
setFileIds: () => { },
|
||||||
setOptimizationMode: () => {},
|
setFiles: () => { },
|
||||||
|
setFocusMode: () => { },
|
||||||
|
setOptimizationMode: () => { },
|
||||||
|
setChatModelProvider: () => { },
|
||||||
|
setEmbeddingModelProvider: () => { },
|
||||||
});
|
});
|
||||||
|
|
||||||
export const ChatProvider = ({
|
export const ChatProvider = ({
|
||||||
@@ -743,6 +751,10 @@ export const ChatProvider = ({
|
|||||||
setOptimizationMode,
|
setOptimizationMode,
|
||||||
rewrite,
|
rewrite,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
|
setChatModelProvider,
|
||||||
|
chatModelProvider,
|
||||||
|
embeddingModelProvider,
|
||||||
|
setEmbeddingModelProvider,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
|
|||||||
Reference in New Issue
Block a user