diff --git a/src/routes/images.ts b/src/routes/images.ts index 5671657..2e8e912 100644 --- a/src/routes/images.ts +++ b/src/routes/images.ts @@ -5,6 +5,7 @@ import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import { ChatOpenAI } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; import { getCustomOpenaiApiKey, getCustomOpenaiApiUrl, @@ -16,6 +17,7 @@ const router = express.Router(); interface ChatModel { provider: string; model: string; + ollamaContextWindow?: number; } interface ImageSearchBody { @@ -61,6 +63,10 @@ router.post('/', async (req, res) => { ) { llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; + + if (llm instanceof ChatOllama) { + llm.numCtx = body.chatModel?.ollamaContextWindow || 2048; + } } if (!llm) { diff --git a/src/routes/search.ts b/src/routes/search.ts index 57d90a3..daefece 100644 --- a/src/routes/search.ts +++ b/src/routes/search.ts @@ -15,12 +15,14 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '../config'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; const router = express.Router(); interface chatModel { provider: string; model: string; + ollamaContextWindow?: number; customOpenAIKey?: string; customOpenAIBaseURL?: string; } @@ -78,6 +80,7 @@ router.post('/', async (req, res) => { const embeddingModel = body.embeddingModel?.model || Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; + const ollamaContextWindow = body.chatModel?.ollamaContextWindow || 2048; let llm: BaseChatModel | undefined; let embeddings: Embeddings | undefined; @@ -99,6 +102,9 @@ router.post('/', async (req, res) => { ) { llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; + if (llm instanceof ChatOllama) { + llm.numCtx = ollamaContextWindow; + } } if ( diff --git a/src/routes/suggestions.ts b/src/routes/suggestions.ts index 7dd1739..c7a1409 100644 --- a/src/routes/suggestions.ts +++ b/src/routes/suggestions.ts @@ -10,12 +10,14 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '../config'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; const router = express.Router(); interface ChatModel { provider: string; model: string; + ollamaContextWindow?: number; } interface SuggestionsBody { @@ -60,6 +62,9 @@ router.post('/', async (req, res) => { ) { llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; + if (llm instanceof ChatOllama) { + llm.numCtx = body.chatModel?.ollamaContextWindow || 2048; + } } if (!llm) { diff --git a/src/routes/videos.ts b/src/routes/videos.ts index b631f26..debe3cd 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -10,12 +10,14 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '../config'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; const router = express.Router(); interface ChatModel { provider: string; model: string; + ollamaContextWindow?: number; } interface VideoSearchBody { @@ -61,6 +63,10 @@ router.post('/', async (req, res) => { ) { llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; + + if (llm instanceof ChatOllama) { + llm.numCtx = body.chatModel?.ollamaContextWindow || 2048; + } } if (!llm) { diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index bb8f242..979b8a0 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -14,6 +14,7 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '../config'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; export const handleConnection = async ( ws: WebSocket, @@ -42,6 +43,8 @@ export const handleConnection = async ( searchParams.get('embeddingModel') || Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; + const ollamaContextWindow = searchParams.get('ollamaContextWindow'); + let llm: BaseChatModel | undefined; let embeddings: Embeddings | undefined; @@ -52,6 +55,9 @@ export const handleConnection = async ( ) { llm = chatModelProviders[chatModelProvider][chatModel] .model as unknown as BaseChatModel | undefined; + if (llm instanceof ChatOllama) { + llm.numCtx = ollamaContextWindow ? parseInt(ollamaContextWindow) : 2048; + } } else if (chatModelProvider == 'custom_openai') { const customOpenaiApiKey = getCustomOpenaiApiKey(); const customOpenaiApiUrl = getCustomOpenaiApiUrl(); diff --git a/ui/app/settings/page.tsx b/ui/app/settings/page.tsx index 371d091..26cdbd6 100644 --- a/ui/app/settings/page.tsx +++ b/ui/app/settings/page.tsx @@ -23,6 +23,7 @@ interface SettingsType { customOpenaiApiKey: string; customOpenaiApiUrl: string; customOpenaiModelName: string; + ollamaContextWindow: number; } interface InputProps extends React.InputHTMLAttributes { @@ -112,6 +113,11 @@ const Page = () => { const [automaticImageSearch, setAutomaticImageSearch] = useState(false); const [automaticVideoSearch, setAutomaticVideoSearch] = useState(false); const [savingStates, setSavingStates] = useState>({}); + const [contextWindowSize, setContextWindowSize] = useState(2048); + const [isCustomContextWindow, setIsCustomContextWindow] = useState(false); + const predefinedContextSizes = [ + 1024, 2048, 3072, 4096, 8192, 16384, 32768, 65536, 131072, + ]; useEffect(() => { const fetchConfig = async () => { @@ -123,6 +129,7 @@ const Page = () => { }); const data = (await res.json()) as SettingsType; + setConfig(data); const chatModelProvidersKeys = Object.keys(data.chatModelProviders || {}); @@ -171,6 +178,13 @@ const Page = () => { setAutomaticVideoSearch( localStorage.getItem('autoVideoSearch') === 'true', ); + const storedContextWindow = parseInt( + localStorage.getItem('ollamaContextWindow') ?? '2048', + ); + setContextWindowSize(storedContextWindow); + setIsCustomContextWindow( + !predefinedContextSizes.includes(storedContextWindow), + ); setIsLoading(false); }; @@ -331,6 +345,8 @@ const Page = () => { localStorage.setItem('embeddingModelProvider', value); } else if (key === 'embeddingModel') { localStorage.setItem('embeddingModel', value); + } else if (key === 'ollamaContextWindow') { + localStorage.setItem('ollamaContextWindow', value.toString()); } } catch (err) { console.error('Failed to save:', err); @@ -548,6 +564,78 @@ const Page = () => { ]; })()} /> + {selectedChatModelProvider === 'ollama' && ( +
+

+ Chat Context Window Size +

+ { + // Allow any value to be typed + const value = + parseInt(e.target.value) || + contextWindowSize; + setContextWindowSize(value); + }} + onSave={(value) => { + // Validate only when saving + const numValue = Math.max( + 512, + parseInt(value) || 2048, + ); + setContextWindowSize(numValue); + setConfig((prev) => ({ + ...prev!, + ollamaContextWindow: numValue, + })); + saveConfig('ollamaContextWindow', numValue); + }} + /> +
+ )} +

+ {isCustomContextWindow + ? 'Adjust the context window size for Ollama models (minimum 512 tokens)' + : 'Adjust the context window size for Ollama models'} +

+ + )} )} diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 1940f42..7151383 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -197,6 +197,11 @@ const useSocket = ( 'openAIBaseURL', localStorage.getItem('openAIBaseURL')!, ); + } else { + searchParams.append( + 'ollamaContextWindow', + localStorage.getItem('ollamaContextWindow') || '2048', + ); } searchParams.append('embeddingModel', embeddingModel!); diff --git a/ui/components/SearchImages.tsx b/ui/components/SearchImages.tsx index 383f780..dfd387e 100644 --- a/ui/components/SearchImages.tsx +++ b/ui/components/SearchImages.tsx @@ -33,9 +33,10 @@ const SearchImages = ({ const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModel = localStorage.getItem('chatModel'); - const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const ollamaContextWindow = + localStorage.getItem('ollamaContextWindow') || '2048'; const res = await fetch( `${process.env.NEXT_PUBLIC_API_URL}/images`, @@ -54,6 +55,9 @@ const SearchImages = ({ customOpenAIBaseURL: customOpenAIBaseURL, customOpenAIKey: customOpenAIKey, }), + ...(chatModelProvider === 'ollama' && { + ollamaContextWindow: parseInt(ollamaContextWindow), + }), }, }), }, diff --git a/ui/components/SearchVideos.tsx b/ui/components/SearchVideos.tsx index c284dc2..00e8f2e 100644 --- a/ui/components/SearchVideos.tsx +++ b/ui/components/SearchVideos.tsx @@ -48,9 +48,10 @@ const Searchvideos = ({ const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModel = localStorage.getItem('chatModel'); - const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const ollamaContextWindow = + localStorage.getItem('ollamaContextWindow') || '2048'; const res = await fetch( `${process.env.NEXT_PUBLIC_API_URL}/videos`, @@ -69,6 +70,9 @@ const Searchvideos = ({ customOpenAIBaseURL: customOpenAIBaseURL, customOpenAIKey: customOpenAIKey, }), + ...(chatModelProvider === 'ollama' && { + ollamaContextWindow: parseInt(ollamaContextWindow), + }), }, }), }, diff --git a/ui/lib/actions.ts b/ui/lib/actions.ts index a4409b0..0a43ef8 100644 --- a/ui/lib/actions.ts +++ b/ui/lib/actions.ts @@ -6,6 +6,8 @@ export const getSuggestions = async (chatHisory: Message[]) => { const customOpenAIKey = localStorage.getItem('openAIApiKey'); const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); + const ollamaContextWindow = + localStorage.getItem('ollamaContextWindow') || '2048'; const res = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/suggestions`, { method: 'POST', @@ -21,6 +23,9 @@ export const getSuggestions = async (chatHisory: Message[]) => { customOpenAIKey, customOpenAIBaseURL, }), + ...(chatModelProvider === 'ollama' && { + ollamaContextWindow: parseInt(ollamaContextWindow), + }), }, }), });