Compare commits

..

8 Commits

Author SHA1 Message Date
ItzCrazyKns
d97fa708f1 feat(config-route): use new config manager & model registry 2025-10-16 20:42:04 +05:30
ItzCrazyKns
0c7566bb87 feat(sidebar): fix colors on smaller devices 2025-10-16 18:03:40 +05:30
ItzCrazyKns
0ff1be47bf feat(routes): use new model registry 2025-10-16 18:01:25 +05:30
ItzCrazyKns
768578951c feat(chat-route): use new model registry 2025-10-16 17:58:13 +05:30
ItzCrazyKns
9706079ed4 feat(config): add searxngURL 2025-10-16 17:57:30 +05:30
ItzCrazyKns
9219593ee1 feat(model-registry): add loading method 2025-10-16 17:56:57 +05:30
ItzCrazyKns
36fdb6491d feat(model-types): add ModelWithProvider type 2025-10-16 17:56:14 +05:30
ItzCrazyKns
0d2cd4bb1e feat(app): remove compute-dot, make cosine default 2025-10-16 17:53:31 +05:30
16 changed files with 184 additions and 370 deletions

View File

@@ -29,8 +29,8 @@
"better-sqlite3": "^11.9.1", "better-sqlite3": "^11.9.1",
"clsx": "^2.1.0", "clsx": "^2.1.0",
"compute-cosine-similarity": "^1.1.0", "compute-cosine-similarity": "^1.1.0",
"compute-dot": "^1.1.0",
"drizzle-orm": "^0.40.1", "drizzle-orm": "^0.40.1",
"framer-motion": "^12.23.24",
"html-to-text": "^9.0.5", "html-to-text": "^9.0.5",
"jspdf": "^3.0.1", "jspdf": "^3.0.1",
"langchain": "^0.3.30", "langchain": "^0.3.30",

View File

@@ -1,23 +1,14 @@
import crypto from 'crypto'; import crypto from 'crypto';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { EventEmitter } from 'stream'; import { EventEmitter } from 'stream';
import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
} from '@/lib/providers';
import db from '@/lib/db'; import db from '@/lib/db';
import { chats, messages as messagesSchema } from '@/lib/db/schema'; import { chats, messages as messagesSchema } from '@/lib/db/schema';
import { and, eq, gt } from 'drizzle-orm'; import { and, eq, gt } from 'drizzle-orm';
import { getFileDetails } from '@/lib/utils/files'; import { getFileDetails } from '@/lib/utils/files';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatOpenAI } from '@langchain/openai';
import {
getCustomOpenaiApiKey,
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import { searchHandlers } from '@/lib/search'; import { searchHandlers } from '@/lib/search';
import { z } from 'zod'; import { z } from 'zod';
import ModelRegistry from '@/lib/models/registry';
import { ModelWithProvider } from '@/lib/models/types';
export const runtime = 'nodejs'; export const runtime = 'nodejs';
export const dynamic = 'force-dynamic'; export const dynamic = 'force-dynamic';
@@ -28,14 +19,30 @@ const messageSchema = z.object({
content: z.string().min(1, 'Message content is required'), content: z.string().min(1, 'Message content is required'),
}); });
const chatModelSchema = z.object({ const chatModelSchema: z.ZodType<ModelWithProvider> = z.object({
provider: z.string().optional(), providerId: z.string({
name: z.string().optional(), errorMap: () => ({
message: 'Chat model provider id must be provided',
}),
}),
key: z.string({
errorMap: () => ({
message: 'Chat model key must be provided',
}),
}),
}); });
const embeddingModelSchema = z.object({ const embeddingModelSchema: z.ZodType<ModelWithProvider> = z.object({
provider: z.string().optional(), providerId: z.string({
name: z.string().optional(), errorMap: () => ({
message: 'Embedding model provider id must be provided',
}),
}),
key: z.string({
errorMap: () => ({
message: 'Embedding model key must be provided',
}),
}),
}); });
const bodySchema = z.object({ const bodySchema = z.object({
@@ -57,8 +64,8 @@ const bodySchema = z.object({
.optional() .optional()
.default([]), .default([]),
files: z.array(z.string()).optional().default([]), files: z.array(z.string()).optional().default([]),
chatModel: chatModelSchema.optional().default({}), chatModel: chatModelSchema,
embeddingModel: embeddingModelSchema.optional().default({}), embeddingModel: embeddingModelSchema,
systemInstructions: z.string().nullable().optional().default(''), systemInstructions: z.string().nullable().optional().default(''),
}); });
@@ -248,56 +255,16 @@ export const POST = async (req: Request) => {
); );
} }
const [chatModelProviders, embeddingModelProviders] = await Promise.all([ const registry = new ModelRegistry();
getAvailableChatModelProviders(),
getAvailableEmbeddingModelProviders(), const [llm, embedding] = await Promise.all([
registry.loadChatModel(body.chatModel.providerId, body.chatModel.key),
registry.loadEmbeddingModel(
body.embeddingModel.providerId,
body.embeddingModel.key,
),
]); ]);
const chatModelProvider =
chatModelProviders[
body.chatModel?.provider || Object.keys(chatModelProviders)[0]
];
const chatModel =
chatModelProvider[
body.chatModel?.name || Object.keys(chatModelProvider)[0]
];
const embeddingProvider =
embeddingModelProviders[
body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0]
];
const embeddingModel =
embeddingProvider[
body.embeddingModel?.name || Object.keys(embeddingProvider)[0]
];
let llm: BaseChatModel | undefined;
let embedding = embeddingModel.model;
if (body.chatModel?.provider === 'custom_openai') {
llm = new ChatOpenAI({
apiKey: getCustomOpenaiApiKey(),
modelName: getCustomOpenaiModelName(),
temperature: 0.7,
configuration: {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}
if (!llm) {
return Response.json({ error: 'Invalid chat model' }, { status: 400 });
}
if (!embedding) {
return Response.json(
{ error: 'Invalid embedding model' },
{ status: 400 },
);
}
const humanMessageId = const humanMessageId =
message.messageId ?? crypto.randomBytes(7).toString('hex'); message.messageId ?? crypto.randomBytes(7).toString('hex');

View File

@@ -1,134 +1,33 @@
import { import configManager from '@/lib/config';
getAnthropicApiKey, import ModelRegistry from '@/lib/models/registry';
getCustomOpenaiApiKey, import { NextRequest, NextResponse } from 'next/server';
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
getGeminiApiKey,
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
getDeepseekApiKey,
getAimlApiKey,
getLMStudioApiEndpoint,
getLemonadeApiEndpoint,
getLemonadeApiKey,
updateConfig,
getOllamaApiKey,
} from '@/lib/config';
import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
} from '@/lib/providers';
export const GET = async (req: Request) => { export const GET = async (req: NextRequest) => {
try { try {
const config: Record<string, any> = {}; const values = configManager.currentConfig;
const fields = configManager.getUIConfigSections();
const [chatModelProviders, embeddingModelProviders] = await Promise.all([ const modelRegistry = new ModelRegistry();
getAvailableChatModelProviders(), const modelProviders = await modelRegistry.getActiveProviders();
getAvailableEmbeddingModelProviders(),
]);
config['chatModelProviders'] = {}; values.modelProviders = values.modelProviders.map((mp) => {
config['embeddingModelProviders'] = {}; const activeProvider = modelProviders.find((p) => p.id === mp.id)
for (const provider in chatModelProviders) {
config['chatModelProviders'][provider] = Object.keys(
chatModelProviders[provider],
).map((model) => {
return { return {
name: model, ...mp,
displayName: chatModelProviders[provider][model].displayName, chatModels: activeProvider?.chatModels ?? mp.chatModels,
}; embeddingModels: activeProvider?.embeddingModels ?? mp.embeddingModels
});
} }
})
for (const provider in embeddingModelProviders) { return NextResponse.json({
config['embeddingModelProviders'][provider] = Object.keys( values,
embeddingModelProviders[provider], fields,
).map((model) => { })
return {
name: model,
displayName: embeddingModelProviders[provider][model].displayName,
};
});
}
config['openaiApiKey'] = getOpenaiApiKey();
config['ollamaApiUrl'] = getOllamaApiEndpoint();
config['ollamaApiKey'] = getOllamaApiKey();
config['lmStudioApiUrl'] = getLMStudioApiEndpoint();
config['lemonadeApiUrl'] = getLemonadeApiEndpoint();
config['lemonadeApiKey'] = getLemonadeApiKey();
config['anthropicApiKey'] = getAnthropicApiKey();
config['groqApiKey'] = getGroqApiKey();
config['geminiApiKey'] = getGeminiApiKey();
config['deepseekApiKey'] = getDeepseekApiKey();
config['aimlApiKey'] = getAimlApiKey();
config['customOpenaiApiUrl'] = getCustomOpenaiApiUrl();
config['customOpenaiApiKey'] = getCustomOpenaiApiKey();
config['customOpenaiModelName'] = getCustomOpenaiModelName();
return Response.json({ ...config }, { status: 200 });
} catch (err) { } catch (err) {
console.error('An error occurred while getting config:', err); console.error('Error in getting config: ', err);
return Response.json( return Response.json(
{ message: 'An error occurred while getting config' }, { message: 'An error has occurred.' },
{ status: 500 },
);
}
};
export const POST = async (req: Request) => {
try {
const config = await req.json();
const updatedConfig = {
MODELS: {
OPENAI: {
API_KEY: config.openaiApiKey,
},
GROQ: {
API_KEY: config.groqApiKey,
},
ANTHROPIC: {
API_KEY: config.anthropicApiKey,
},
GEMINI: {
API_KEY: config.geminiApiKey,
},
OLLAMA: {
API_URL: config.ollamaApiUrl,
API_KEY: config.ollamaApiKey,
},
DEEPSEEK: {
API_KEY: config.deepseekApiKey,
},
AIMLAPI: {
API_KEY: config.aimlApiKey,
},
LM_STUDIO: {
API_URL: config.lmStudioApiUrl,
},
LEMONADE: {
API_URL: config.lemonadeApiUrl,
API_KEY: config.lemonadeApiKey,
},
CUSTOM_OPENAI: {
API_URL: config.customOpenaiApiUrl,
API_KEY: config.customOpenaiApiKey,
MODEL_NAME: config.customOpenaiModelName,
},
},
};
updateConfig(updatedConfig);
return Response.json({ message: 'Config updated' }, { status: 200 });
} catch (err) {
console.error('An error occurred while updating config:', err);
return Response.json(
{ message: 'An error occurred while updating config' },
{ status: 500 }, { status: 500 },
); );
} }

View File

@@ -1,23 +1,12 @@
import handleImageSearch from '@/lib/chains/imageSearchAgent'; import handleImageSearch from '@/lib/chains/imageSearchAgent';
import { import ModelRegistry from '@/lib/models/registry';
getCustomOpenaiApiKey, import { ModelWithProvider } from '@/lib/models/types';
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import { getAvailableChatModelProviders } from '@/lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
interface ChatModel {
provider: string;
model: string;
}
interface ImageSearchBody { interface ImageSearchBody {
query: string; query: string;
chatHistory: any[]; chatHistory: any[];
chatModel?: ChatModel; chatModel: ModelWithProvider;
} }
export const POST = async (req: Request) => { export const POST = async (req: Request) => {
@@ -34,35 +23,12 @@ export const POST = async (req: Request) => {
}) })
.filter((msg) => msg !== undefined) as BaseMessage[]; .filter((msg) => msg !== undefined) as BaseMessage[];
const chatModelProviders = await getAvailableChatModelProviders(); const registry = new ModelRegistry();
const chatModelProvider = const llm = await registry.loadChatModel(
chatModelProviders[ body.chatModel.providerId,
body.chatModel?.provider || Object.keys(chatModelProviders)[0] body.chatModel.key,
]; );
const chatModel =
chatModelProvider[
body.chatModel?.model || Object.keys(chatModelProvider)[0]
];
let llm: BaseChatModel | undefined;
if (body.chatModel?.provider === 'custom_openai') {
llm = new ChatOpenAI({
apiKey: getCustomOpenaiApiKey(),
modelName: getCustomOpenaiModelName(),
temperature: 0.7,
configuration: {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}
if (!llm) {
return Response.json({ error: 'Invalid chat model' }, { status: 400 });
}
const images = await handleImageSearch( const images = await handleImageSearch(
{ {

View File

@@ -1,22 +1,12 @@
import generateSuggestions from '@/lib/chains/suggestionGeneratorAgent'; import generateSuggestions from '@/lib/chains/suggestionGeneratorAgent';
import { import ModelRegistry from '@/lib/models/registry';
getCustomOpenaiApiKey, import { ModelWithProvider } from '@/lib/models/types';
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import { getAvailableChatModelProviders } from '@/lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
interface ChatModel {
provider: string;
model: string;
}
interface SuggestionsGenerationBody { interface SuggestionsGenerationBody {
chatHistory: any[]; chatHistory: any[];
chatModel?: ChatModel; chatModel: ModelWithProvider;
} }
export const POST = async (req: Request) => { export const POST = async (req: Request) => {
@@ -33,35 +23,12 @@ export const POST = async (req: Request) => {
}) })
.filter((msg) => msg !== undefined) as BaseMessage[]; .filter((msg) => msg !== undefined) as BaseMessage[];
const chatModelProviders = await getAvailableChatModelProviders(); const registry = new ModelRegistry();
const chatModelProvider = const llm = await registry.loadChatModel(
chatModelProviders[ body.chatModel.providerId,
body.chatModel?.provider || Object.keys(chatModelProviders)[0] body.chatModel.key,
]; );
const chatModel =
chatModelProvider[
body.chatModel?.model || Object.keys(chatModelProvider)[0]
];
let llm: BaseChatModel | undefined;
if (body.chatModel?.provider === 'custom_openai') {
llm = new ChatOpenAI({
apiKey: getCustomOpenaiApiKey(),
modelName: getCustomOpenaiModelName(),
temperature: 0.7,
configuration: {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}
if (!llm) {
return Response.json({ error: 'Invalid chat model' }, { status: 400 });
}
const suggestions = await generateSuggestions( const suggestions = await generateSuggestions(
{ {

View File

@@ -1,23 +1,12 @@
import handleVideoSearch from '@/lib/chains/videoSearchAgent'; import handleVideoSearch from '@/lib/chains/videoSearchAgent';
import { import ModelRegistry from '@/lib/models/registry';
getCustomOpenaiApiKey, import { ModelWithProvider } from '@/lib/models/types';
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import { getAvailableChatModelProviders } from '@/lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
interface ChatModel {
provider: string;
model: string;
}
interface VideoSearchBody { interface VideoSearchBody {
query: string; query: string;
chatHistory: any[]; chatHistory: any[];
chatModel?: ChatModel; chatModel: ModelWithProvider;
} }
export const POST = async (req: Request) => { export const POST = async (req: Request) => {
@@ -34,35 +23,12 @@ export const POST = async (req: Request) => {
}) })
.filter((msg) => msg !== undefined) as BaseMessage[]; .filter((msg) => msg !== undefined) as BaseMessage[];
const chatModelProviders = await getAvailableChatModelProviders(); const registry = new ModelRegistry();
const chatModelProvider = const llm = await registry.loadChatModel(
chatModelProviders[ body.chatModel.providerId,
body.chatModel?.provider || Object.keys(chatModelProviders)[0] body.chatModel.key,
]; );
const chatModel =
chatModelProvider[
body.chatModel?.model || Object.keys(chatModelProvider)[0]
];
let llm: BaseChatModel | undefined;
if (body.chatModel?.provider === 'custom_openai') {
llm = new ChatOpenAI({
apiKey: getCustomOpenaiApiKey(),
modelName: getCustomOpenaiModelName(),
temperature: 0.7,
configuration: {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}
if (!llm) {
return Response.json({ error: 'Invalid chat model' }, { status: 400 });
}
const videos = await handleVideoSearch( const videos = await handleVideoSearch(
{ {

View File

@@ -33,11 +33,10 @@ const SearchImages = ({
onClick={async () => { onClick={async () => {
setLoading(true); setLoading(true);
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem(
const chatModel = localStorage.getItem('chatModel'); 'chatModelProviderId',
);
const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); const chatModel = localStorage.getItem('chatModelKey');
const customOpenAIKey = localStorage.getItem('openAIApiKey');
const res = await fetch(`/api/images`, { const res = await fetch(`/api/images`, {
method: 'POST', method: 'POST',
@@ -48,12 +47,8 @@ const SearchImages = ({
query: query, query: query,
chatHistory: chatHistory, chatHistory: chatHistory,
chatModel: { chatModel: {
provider: chatModelProvider, providerId: chatModelProvider,
model: chatModel, key: chatModel,
...(chatModelProvider === 'custom_openai' && {
customOpenAIBaseURL: customOpenAIBaseURL,
customOpenAIKey: customOpenAIKey,
}),
}, },
}), }),
}); });

View File

@@ -48,11 +48,10 @@ const Searchvideos = ({
onClick={async () => { onClick={async () => {
setLoading(true); setLoading(true);
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem(
const chatModel = localStorage.getItem('chatModel'); 'chatModelProviderId',
);
const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); const chatModel = localStorage.getItem('chatModelKey');
const customOpenAIKey = localStorage.getItem('openAIApiKey');
const res = await fetch(`/api/videos`, { const res = await fetch(`/api/videos`, {
method: 'POST', method: 'POST',
@@ -63,12 +62,8 @@ const Searchvideos = ({
query: query, query: query,
chatHistory: chatHistory, chatHistory: chatHistory,
chatModel: { chatModel: {
provider: chatModelProvider, providerId: chatModelProvider,
model: chatModel, key: chatModel,
...(chatModelProvider === 'custom_openai' && {
customOpenAIBaseURL: customOpenAIBaseURL,
customOpenAIKey: customOpenAIKey,
}),
}, },
}), }),
}); });

View File

@@ -70,7 +70,7 @@ const Sidebar = ({ children }: { children: React.ReactNode }) => {
</div> </div>
</div> </div>
<div className="fixed bottom-0 w-full z-50 flex flex-row items-center gap-x-6 bg-light-primary dark:bg-dark-primary px-4 py-4 shadow-sm lg:hidden"> <div className="fixed bottom-0 w-full z-50 flex flex-row items-center gap-x-6 bg-light-secondary dark:bg-dark-secondary px-4 py-4 shadow-sm lg:hidden">
{navLinks.map((link, i) => ( {navLinks.map((link, i) => (
<Link <Link
href={link.href} href={link.href}

View File

@@ -1,11 +1,8 @@
import { Message } from '@/components/ChatWindow'; import { Message } from '@/components/ChatWindow';
export const getSuggestions = async (chatHistory: Message[]) => { export const getSuggestions = async (chatHistory: Message[]) => {
const chatModel = localStorage.getItem('chatModel'); const chatModel = localStorage.getItem('chatModelKey');
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem('chatModelProviderId');
const customOpenAIKey = localStorage.getItem('openAIApiKey');
const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
const res = await fetch(`/api/suggestions`, { const res = await fetch(`/api/suggestions`, {
method: 'POST', method: 'POST',
@@ -15,12 +12,8 @@ export const getSuggestions = async (chatHistory: Message[]) => {
body: JSON.stringify({ body: JSON.stringify({
chatHistory: chatHistory, chatHistory: chatHistory,
chatModel: { chatModel: {
provider: chatModelProvider, providerId: chatModelProvider,
model: chatModel, key: chatModel,
...(chatModelProvider === 'custom_openai' && {
customOpenAIKey,
customOpenAIBaseURL,
}),
}, },
}), }),
}); });

View File

@@ -15,10 +15,26 @@ class ConfigManager {
setupComplete: false, setupComplete: false,
general: {}, general: {},
modelProviders: [], modelProviders: [],
search: {
searxngURL: '',
},
}; };
uiConfigSections: UIConfigSections = { uiConfigSections: UIConfigSections = {
general: [], general: [],
modelProviders: [], modelProviders: [],
search: [
{
name: 'SearXNG URL',
key: 'searxngURL',
type: 'string',
required: false,
description: 'The URL of your SearXNG instance',
placeholder: 'http://localhost:4000',
default: '',
scope: 'server',
env: 'SEARXNG_API_URL',
},
],
}; };
constructor() { constructor() {
@@ -78,6 +94,7 @@ class ConfigManager {
} }
private initializeFromEnv() { private initializeFromEnv() {
/* providers section*/
const providerConfigSections = getModelProvidersUIConfigSection(); const providerConfigSections = getModelProvidersUIConfigSection();
this.uiConfigSections.modelProviders = providerConfigSections; this.uiConfigSections.modelProviders = providerConfigSections;
@@ -130,6 +147,14 @@ class ConfigManager {
this.currentConfig.modelProviders.push(...newProviders); this.currentConfig.modelProviders.push(...newProviders);
/* search section */
this.uiConfigSections.search.forEach((f) => {
if (f.env && !this.currentConfig.search[f.key]) {
this.currentConfig.search[f.key] =
process.env[f.env] ?? f.default ?? '';
}
});
this.saveConfig(); this.saveConfig();
} }
@@ -196,15 +221,19 @@ class ConfigManager {
} }
public isSetupComplete() { public isSetupComplete() {
return this.currentConfig.setupComplete return this.currentConfig.setupComplete;
} }
public markSetupComplete() { public markSetupComplete() {
if (!this.currentConfig.setupComplete) { if (!this.currentConfig.setupComplete) {
this.currentConfig.setupComplete = true this.currentConfig.setupComplete = true;
} }
this.saveConfig() this.saveConfig();
}
public getUIConfigSections(): UIConfigSections {
return this.uiConfigSections;
} }
} }

View File

@@ -53,6 +53,26 @@ class ModelRegistry {
return providers; return providers;
} }
async loadChatModel(providerId: string, modelName: string) {
const provider = this.activeProviders.find((p) => p.id === providerId);
if (!provider) throw new Error('Invalid provider id');
const model = await provider.provider.loadChatModel(modelName);
return model;
}
async loadEmbeddingModel(providerId: string, modelName: string) {
const provider = this.activeProviders.find((p) => p.id === providerId);
if (!provider) throw new Error('Invalid provider id');
const model = await provider.provider.loadEmbeddingModel(modelName);
return model;
}
} }
export default ModelRegistry; export default ModelRegistry;

View File

@@ -20,4 +20,15 @@ type MinimalProvider = {
embeddingModels: Model[]; embeddingModels: Model[];
}; };
export type { Model, ModelList, ProviderMetadata, MinimalProvider }; type ModelWithProvider = {
key: string;
providerId: string;
};
export type {
Model,
ModelList,
ProviderMetadata,
MinimalProvider,
ModelWithProvider,
};

View File

@@ -1,5 +0,0 @@
declare function computeDot(vectorA: number[], vectorB: number[]): number;
declare module 'compute-dot' {
export default computeDot;
}

View File

@@ -1,17 +1,7 @@
import dot from 'compute-dot';
import cosineSimilarity from 'compute-cosine-similarity'; import cosineSimilarity from 'compute-cosine-similarity';
import { getSimilarityMeasure } from '../config';
const computeSimilarity = (x: number[], y: number[]): number => { const computeSimilarity = (x: number[], y: number[]): number => {
const similarityMeasure = getSimilarityMeasure();
if (similarityMeasure === 'cosine') {
return cosineSimilarity(x, y) as number; return cosineSimilarity(x, y) as number;
} else if (similarityMeasure === 'dot') {
return dot(x, y);
}
throw new Error('Invalid similarity measure');
}; };
export default computeSimilarity; export default computeSimilarity;

View File

@@ -2709,6 +2709,15 @@ fraction.js@^4.3.7:
resolved "https://registry.yarnpkg.com/fraction.js/-/fraction.js-4.3.7.tgz#06ca0085157e42fda7f9e726e79fefc4068840f7" resolved "https://registry.yarnpkg.com/fraction.js/-/fraction.js-4.3.7.tgz#06ca0085157e42fda7f9e726e79fefc4068840f7"
integrity sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew== integrity sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==
framer-motion@^12.23.24:
version "12.23.24"
resolved "https://registry.yarnpkg.com/framer-motion/-/framer-motion-12.23.24.tgz#4895b67e880bd2b1089e61fbaa32ae802fc24b8c"
integrity sha512-HMi5HRoRCTou+3fb3h9oTLyJGBxHfW+HnNE25tAXOvVx/IvwMHK0cx7IR4a2ZU6sh3IX1Z+4ts32PcYBOqka8w==
dependencies:
motion-dom "^12.23.23"
motion-utils "^12.23.6"
tslib "^2.4.0"
fs-constants@^1.0.0: fs-constants@^1.0.0:
version "1.0.0" version "1.0.0"
resolved "https://registry.yarnpkg.com/fs-constants/-/fs-constants-1.0.0.tgz#6be0de9be998ce16af8afc24497b9ee9b7ccd9ad" resolved "https://registry.yarnpkg.com/fs-constants/-/fs-constants-1.0.0.tgz#6be0de9be998ce16af8afc24497b9ee9b7ccd9ad"
@@ -3674,6 +3683,18 @@ mkdirp-classic@^0.5.2, mkdirp-classic@^0.5.3:
resolved "https://registry.yarnpkg.com/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz#fa10c9115cc6d8865be221ba47ee9bed78601113" resolved "https://registry.yarnpkg.com/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz#fa10c9115cc6d8865be221ba47ee9bed78601113"
integrity sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A== integrity sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==
motion-dom@^12.23.23:
version "12.23.23"
resolved "https://registry.yarnpkg.com/motion-dom/-/motion-dom-12.23.23.tgz#8f874333ea1a04ee3a89eb928f518b463d589e0e"
integrity sha512-n5yolOs0TQQBRUFImrRfs/+6X4p3Q4n1dUEqt/H58Vx7OW6RF+foWEgmTVDhIWJIMXOuNNL0apKH2S16en9eiA==
dependencies:
motion-utils "^12.23.6"
motion-utils@^12.23.6:
version "12.23.6"
resolved "https://registry.yarnpkg.com/motion-utils/-/motion-utils-12.23.6.tgz#fafef80b4ea85122dd0d6c599a0c63d72881f312"
integrity sha512-eAWoPgr4eFEOFfg2WjIsMoqJTW6Z8MTUCgn/GZ3VRpClWBdnbjryiA3ZSNLyxCTmCQx4RmYX6jX1iWHbenUPNQ==
ms@2.1.2: ms@2.1.2:
version "2.1.2" version "2.1.2"
resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009" resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009"