From 6220822c7c5b9727d8a82c5a35c93e738f9ccb09 Mon Sep 17 00:00:00 2001 From: Willie Zutz Date: Mon, 5 May 2025 00:05:19 -0600 Subject: [PATCH] feat(app): Allow selecting the AI model at any time without opening the settings page. Allow changing focus mode at any time while chatting. Styling tweaks. --- src/components/Chat.tsx | 12 +- src/components/ChatWindow.tsx | 15 +- src/components/EmptyChat.tsx | 6 +- src/components/EmptyChatMessageInput.tsx | 114 ------- src/components/MessageInput.tsx | 178 +++++----- src/components/MessageInputActions/Attach.tsx | 8 +- src/components/MessageInputActions/Focus.tsx | 6 +- .../MessageInputActions/ModelSelector.tsx | 305 ++++++++++++++++++ .../MessageInputActions/Optimization.tsx | 4 +- src/lib/providers/index.ts | 18 +- 10 files changed, 430 insertions(+), 236 deletions(-) delete mode 100644 src/components/EmptyChatMessageInput.tsx create mode 100644 src/components/MessageInputActions/ModelSelector.tsx diff --git a/src/components/Chat.tsx b/src/components/Chat.tsx index 785e928..fdedc0d 100644 --- a/src/components/Chat.tsx +++ b/src/components/Chat.tsx @@ -19,6 +19,8 @@ const Chat = ({ setFiles, optimizationMode, setOptimizationMode, + focusMode, + setFocusMode, }: { messages: Message[]; sendMessage: ( @@ -38,13 +40,15 @@ const Chat = ({ setFiles: (files: File[]) => void; optimizationMode: string; setOptimizationMode: (mode: string) => void; + focusMode: string; + setFocusMode: (mode: string) => void; }) => { const [dividerWidth, setDividerWidth] = useState(0); const [isAtBottom, setIsAtBottom] = useState(true); const [manuallyScrolledUp, setManuallyScrolledUp] = useState(false); const dividerRef = useRef(null); const messageEnd = useRef(null); - const SCROLL_THRESHOLD = 200; // pixels from bottom to consider "at bottom" + const SCROLL_THRESHOLD = 250; // pixels from bottom to consider "at bottom" // Check if user is at bottom of page useEffect(() => { @@ -146,7 +150,6 @@ const Chat = ({ const position = window.innerHeight + window.scrollY; const height = document.body.scrollHeight; const atBottom = position >= height - SCROLL_THRESHOLD; - console.log('scrollTrigger', scrollTrigger); setIsAtBottom(atBottom); if (isAtBottom && !manuallyScrolledUp && messages.length > 0) { @@ -155,7 +158,7 @@ const Chat = ({ }, [scrollTrigger, isAtBottom, messages.length, manuallyScrolledUp]); return ( -
+
{messages.map((msg, i) => { const isLast = i === messages.length - 1; @@ -217,6 +220,7 @@ const Chat = ({ )}
)} diff --git a/src/components/ChatWindow.tsx b/src/components/ChatWindow.tsx index abb220a..4c706cb 100644 --- a/src/components/ChatWindow.tsx +++ b/src/components/ChatWindow.tsx @@ -531,6 +531,15 @@ const ChatWindow = ({ id }: { id?: string }) => { const ollamaContextWindow = localStorage.getItem('ollamaContextWindow') || '2048'; + // Get the latest model selection from localStorage + const currentChatModelProvider = localStorage.getItem('chatModelProvider'); + const currentChatModel = localStorage.getItem('chatModel'); + + // Use the most current model selection from localStorage, falling back to the state if not available + const modelProvider = + currentChatModelProvider || chatModelProvider.provider; + const modelName = currentChatModel || chatModelProvider.name; + const res = await fetch('/api/chat', { method: 'POST', headers: { @@ -549,8 +558,8 @@ const ChatWindow = ({ id }: { id?: string }) => { optimizationMode: optimizationMode, history: messageChatHistory, chatModel: { - name: chatModelProvider.name, - provider: chatModelProvider.provider, + name: modelName, + provider: modelProvider, ...(chatModelProvider.provider === 'ollama' && { ollamaContextWindow: parseInt(ollamaContextWindow), }), @@ -645,6 +654,8 @@ const ChatWindow = ({ id }: { id?: string }) => { setFiles={setFiles} optimizationMode={optimizationMode} setOptimizationMode={setOptimizationMode} + focusMode={focusMode} + setFocusMode={setFocusMode} /> ) : ( diff --git a/src/components/EmptyChat.tsx b/src/components/EmptyChat.tsx index 838849f..78b3275 100644 --- a/src/components/EmptyChat.tsx +++ b/src/components/EmptyChat.tsx @@ -1,8 +1,8 @@ import { Settings } from 'lucide-react'; -import EmptyChatMessageInput from './EmptyChatMessageInput'; import { useState } from 'react'; import { File } from './ChatWindow'; import Link from 'next/link'; +import MessageInput from './MessageInput'; const EmptyChat = ({ sendMessage, @@ -38,7 +38,9 @@ const EmptyChat = ({

Research begins here.

- void; - focusMode: string; - setFocusMode: (mode: string) => void; - optimizationMode: string; - setOptimizationMode: (mode: string) => void; - fileIds: string[]; - setFileIds: (fileIds: string[]) => void; - files: File[]; - setFiles: (files: File[]) => void; -}) => { - const [copilotEnabled, setCopilotEnabled] = useState(false); - const [message, setMessage] = useState(''); - - const inputRef = useRef(null); - - useEffect(() => { - const handleKeyDown = (e: KeyboardEvent) => { - const activeElement = document.activeElement; - - const isInputFocused = - activeElement?.tagName === 'INPUT' || - activeElement?.tagName === 'TEXTAREA' || - activeElement?.hasAttribute('contenteditable'); - - if (e.key === '/' && !isInputFocused) { - e.preventDefault(); - inputRef.current?.focus(); - } - }; - - document.addEventListener('keydown', handleKeyDown); - - inputRef.current?.focus(); - - return () => { - document.removeEventListener('keydown', handleKeyDown); - }; - }, []); - - return ( -
{ - e.preventDefault(); - sendMessage(message); - setMessage(''); - }} - onKeyDown={(e) => { - if (e.key === 'Enter' && !e.shiftKey) { - e.preventDefault(); - sendMessage(message); - setMessage(''); - } - }} - className="w-full" - > -
- setMessage(e.target.value)} - minRows={2} - className="bg-transparent placeholder:text-black/50 dark:placeholder:text-white/50 text-sm text-black dark:text-white resize-none focus:outline-none w-full max-h-24 lg:max-h-36 xl:max-h-48" - placeholder="Ask anything..." - /> -
-
- - -
-
- - -
-
-
-
- ); -}; - -export default EmptyChatMessageInput; diff --git a/src/components/MessageInput.tsx b/src/components/MessageInput.tsx index a72381a..3303dea 100644 --- a/src/components/MessageInput.tsx +++ b/src/components/MessageInput.tsx @@ -1,12 +1,11 @@ -import { cn } from '@/lib/utils'; -import { ArrowUp } from 'lucide-react'; +import { ArrowRight, ArrowUp } from 'lucide-react'; import { useEffect, useRef, useState } from 'react'; import TextareaAutosize from 'react-textarea-autosize'; -import Attach from './MessageInputActions/Attach'; -import CopilotToggle from './MessageInputActions/Copilot'; -import Optimization from './MessageInputActions/Optimization'; import { File } from './ChatWindow'; -import AttachSmall from './MessageInputActions/AttachSmall'; +import Attach from './MessageInputActions/Attach'; +import Focus from './MessageInputActions/Focus'; +import ModelSelector from './MessageInputActions/ModelSelector'; +import Optimization from './MessageInputActions/Optimization'; const MessageInput = ({ sendMessage, @@ -17,6 +16,9 @@ const MessageInput = ({ setFiles, optimizationMode, setOptimizationMode, + focusMode, + setFocusMode, + firstMessage, }: { sendMessage: (message: string) => void; loading: boolean; @@ -26,19 +28,28 @@ const MessageInput = ({ setFiles: (files: File[]) => void; optimizationMode: string; setOptimizationMode: (mode: string) => void; + focusMode: string; + setFocusMode: (mode: string) => void; + firstMessage: boolean; }) => { - const [copilotEnabled, setCopilotEnabled] = useState(false); const [message, setMessage] = useState(''); - const [textareaRows, setTextareaRows] = useState(1); - const [mode, setMode] = useState<'multi' | 'single'>('single'); + const [selectedModel, setSelectedModel] = useState<{ + provider: string; + model: string; + } | null>(null); useEffect(() => { - if (textareaRows >= 2 && message && mode === 'single') { - setMode('multi'); - } else if (!message && mode === 'multi') { - setMode('single'); + // Load saved model preferences from localStorage + const chatModelProvider = localStorage.getItem('chatModelProvider'); + const chatModel = localStorage.getItem('chatModel'); + + if (chatModelProvider && chatModel) { + setSelectedModel({ + provider: chatModelProvider, + model: chatModel, + }); } - }, [textareaRows, mode, message]); + }, []); const inputRef = useRef(null); @@ -60,117 +71,74 @@ const MessageInput = ({ }; }, []); - return ( + // Function to handle message submission + const handleSubmitMessage = () => { + // Only submit if we have a non-empty message and not currently loading + if (loading || message.trim().length === 0) return; + + // Make sure the selected model is used when sending a message + if (selectedModel) { + localStorage.setItem('chatModelProvider', selectedModel.provider); + localStorage.setItem('chatModel', selectedModel.model); + } + + sendMessage(message); + setMessage(''); + }; + + return (
{ - if (loading) return; e.preventDefault(); - sendMessage(message); - setMessage(''); + handleSubmitMessage(); }} onKeyDown={(e) => { - if (e.key === 'Enter' && !e.shiftKey && !loading) { + if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); - sendMessage(message); - setMessage(''); + handleSubmitMessage(); } }} - className={cn( - 'bg-light-secondary dark:bg-dark-secondary p-4 flex items-center border border-light-200 dark:border-dark-200', - mode === 'multi' - ? 'flex-col rounded-lg' - : 'flex-col md:flex-row rounded-lg md:rounded-full', - )} + className="w-full" > - {mode === 'single' && ( -
-
- - -
-
- -
-
- )} -
+
setMessage(e.target.value)} - onHeightChange={(height, props) => { - setTextareaRows(Math.ceil(height / props.rowHeight)); - }} - className="transition bg-transparent dark:placeholder:text-white/50 placeholder:text-sm text-sm dark:text-white resize-none focus:outline-none w-full px-2 max-h-24 lg:max-h-36 xl:max-h-48 flex-grow flex-shrink" - placeholder="Ask a follow-up" + minRows={2} + className="bg-transparent placeholder:text-black/50 dark:placeholder:text-white/50 text-sm text-black dark:text-white resize-none focus:outline-none w-full max-h-24 lg:max-h-36 xl:max-h-48" + placeholder={firstMessage ? "Ask anything..." :"Ask a follow-up"} /> - {mode === 'single' && ( -
-
- -
- +
+
+ + +
- )} -
- - {mode === 'multi' && ( -
-
-
- - -
-
- -
-
-
-
- -
+
+
- )} +
); }; diff --git a/src/components/MessageInputActions/Attach.tsx b/src/components/MessageInputActions/Attach.tsx index 7e2f7f2..cf3cccd 100644 --- a/src/components/MessageInputActions/Attach.tsx +++ b/src/components/MessageInputActions/Attach.tsx @@ -5,7 +5,7 @@ import { PopoverPanel, Transition, } from '@headlessui/react'; -import { CopyPlus, File, LoaderCircle, Plus, Trash } from 'lucide-react'; +import { File, LoaderCircle, Paperclip, Plus, Trash } from 'lucide-react'; import { Fragment, useRef, useState } from 'react'; import { File as FileType } from '../ChatWindow'; @@ -176,8 +176,10 @@ const Attach = ({ multiple hidden /> - - {showText &&

Attach

} + + {showText && ( +

Attach

+ )} ); }; diff --git a/src/components/MessageInputActions/Focus.tsx b/src/components/MessageInputActions/Focus.tsx index 09d97ac..70db951 100644 --- a/src/components/MessageInputActions/Focus.tsx +++ b/src/components/MessageInputActions/Focus.tsx @@ -93,13 +93,13 @@ const Focus = ({ - +
{focusModes.map((mode, i) => ( void; +}) => { + const [providerModels, setProviderModels] = useState({}); + const [providersList, setProvidersList] = useState([]); + const [loading, setLoading] = useState(true); + const [selectedModelDisplay, setSelectedModelDisplay] = useState(''); + const [selectedProviderDisplay, setSelectedProviderDisplay] = + useState(''); + const [expandedProviders, setExpandedProviders] = useState< + Record + >({}); + + useEffect(() => { + const fetchModels = async () => { + try { + const response = await fetch('/api/models', { + headers: { + 'Content-Type': 'application/json', + }, + }); + + if (!response.ok) { + throw new Error(`Failed to fetch models: ${response.status}`); + } + + const data = await response.json(); + const providersData: ProviderModelMap = {}; + + // Organize models by provider + Object.entries(data.chatModelProviders).forEach( + ([provider, models]: [string, any]) => { + const providerDisplayName = + provider.charAt(0).toUpperCase() + provider.slice(1); + providersData[provider] = { + displayName: providerDisplayName, + models: [], + }; + + Object.entries(models).forEach( + ([modelKey, modelData]: [string, any]) => { + providersData[provider].models.push({ + provider, + model: modelKey, + displayName: modelData.displayName || modelKey, + }); + }, + ); + }, + ); + + // Filter out providers with no models + Object.keys(providersData).forEach((provider) => { + if (providersData[provider].models.length === 0) { + delete providersData[provider]; + } + }); + + // Sort providers by name (only those that have models) + const sortedProviders = Object.keys(providersData).sort(); + setProvidersList(sortedProviders); + + // Initialize expanded state for all providers + const initialExpandedState: Record = {}; + sortedProviders.forEach((provider) => { + initialExpandedState[provider] = selectedModel?.provider === provider; + }); + + // Expand the first provider if none is selected + if (sortedProviders.length > 0 && !selectedModel) { + initialExpandedState[sortedProviders[0]] = true; + } + + setExpandedProviders(initialExpandedState); + setProviderModels(providersData); + + // Find the current model in our options to display its name + if (selectedModel) { + const provider = providersData[selectedModel.provider]; + if (provider) { + const currentModel = provider.models.find( + (option) => option.model === selectedModel.model, + ); + + if (currentModel) { + setSelectedModelDisplay(currentModel.displayName); + setSelectedProviderDisplay(provider.displayName); + } + } + } + + setLoading(false); + } catch (error) { + console.error('Error fetching models:', error); + setLoading(false); + } + }; + + fetchModels(); + }, [selectedModel, setSelectedModel]); + + const toggleProviderExpanded = (provider: string) => { + setExpandedProviders((prev) => ({ + ...prev, + [provider]: !prev[provider], + })); + }; + + const handleSelectModel = (option: ModelOption) => { + setSelectedModel({ + provider: option.provider, + model: option.model, + }); + + setSelectedModelDisplay(option.displayName); + setSelectedProviderDisplay( + providerModels[option.provider]?.displayName || option.provider, + ); + + // Save to localStorage for persistence + localStorage.setItem('chatModelProvider', option.provider); + localStorage.setItem('chatModel', option.model); + }; + + const getDisplayText = () => { + if (loading) return 'Loading...'; + if (!selectedModelDisplay) return 'Select model'; + + return `${selectedModelDisplay} (${selectedProviderDisplay})`; + }; + + return ( + + {({ open }) => ( + <> +
+ + + + {getDisplayText()} + + + +
+ + + +
+
+

+ Select Model +

+

+ Choose a provider and model for your conversation +

+
+
+ {loading ? ( +
+ Loading available models... +
+ ) : providersList.length === 0 ? ( +
+ No models available +
+ ) : ( +
+ {providersList.map((providerKey) => { + const provider = providerModels[providerKey]; + const isExpanded = expandedProviders[providerKey]; + + return ( +
+ {/* Provider header */} + + + {/* Models list */} + {isExpanded && ( +
+ {provider.models.map((modelOption) => ( + + ))} +
+ )} +
+ ); + })} +
+ )} +
+
+
+
+ + )} +
+ ); +}; + +export default ModelSelector; diff --git a/src/components/MessageInputActions/Optimization.tsx b/src/components/MessageInputActions/Optimization.tsx index 0126f53..0f88639 100644 --- a/src/components/MessageInputActions/Optimization.tsx +++ b/src/components/MessageInputActions/Optimization.tsx @@ -56,12 +56,12 @@ const Optimization = ({ OptimizationModes.find((mode) => mode.key === optimizationMode) ?.icon } -

+ {/*

{ OptimizationModes.find((mode) => mode.key === optimizationMode) ?.title } -

+

*/}
diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index e536431..04ad82f 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -96,7 +96,14 @@ export const getAvailableChatModelProviders = async () => { for (const provider in chatModelProviders) { const providerModels = await chatModelProviders[provider](); if (Object.keys(providerModels).length > 0) { - models[provider] = providerModels; + // Sort models alphabetically by their keys + const sortedModels: Record = {}; + Object.keys(providerModels) + .sort() + .forEach((key) => { + sortedModels[key] = providerModels[key]; + }); + models[provider] = sortedModels; } } @@ -131,7 +138,14 @@ export const getAvailableEmbeddingModelProviders = async () => { for (const provider in embeddingModelProviders) { const providerModels = await embeddingModelProviders[provider](); if (Object.keys(providerModels).length > 0) { - models[provider] = providerModels; + // Sort embedding models alphabetically by their keys + const sortedModels: Record = {}; + Object.keys(providerModels) + .sort() + .forEach((key) => { + sortedModels[key] = providerModels[key]; + }); + models[provider] = sortedModels; } }