diff --git a/src/lib/models/providers/gemini/geminiEmbedding.ts b/src/lib/models/providers/gemini/geminiEmbedding.ts new file mode 100644 index 0000000..0054853 --- /dev/null +++ b/src/lib/models/providers/gemini/geminiEmbedding.ts @@ -0,0 +1,5 @@ +import OpenAIEmbedding from '../openai/openaiEmbedding'; + +class GeminiEmbedding extends OpenAIEmbedding {} + +export default GeminiEmbedding; diff --git a/src/lib/models/providers/gemini/geminiLLM.ts b/src/lib/models/providers/gemini/geminiLLM.ts new file mode 100644 index 0000000..0a0d4ff --- /dev/null +++ b/src/lib/models/providers/gemini/geminiLLM.ts @@ -0,0 +1,5 @@ +import OpenAILLM from '../openai/openaiLLM'; + +class GeminiLLM extends OpenAILLM {} + +export default GeminiLLM; diff --git a/src/lib/models/providers/gemini/index.ts b/src/lib/models/providers/gemini/index.ts new file mode 100644 index 0000000..2eb92cd --- /dev/null +++ b/src/lib/models/providers/gemini/index.ts @@ -0,0 +1,144 @@ +import { UIConfigField } from '@/lib/config/types'; +import { getConfiguredModelProviderById } from '@/lib/config/serverRegistry'; +import { Model, ModelList, ProviderMetadata } from '../../types'; +import GeminiEmbedding from './geminiEmbedding'; +import BaseEmbedding from '../../base/embedding'; +import BaseModelProvider from '../../base/provider'; +import BaseLLM from '../../base/llm'; +import GeminiLLM from './geminiLLM'; + +interface GeminiConfig { + apiKey: string; +} + +const providerConfigFields: UIConfigField[] = [ + { + type: 'password', + name: 'API Key', + key: 'apiKey', + description: 'Your Gemini API key', + required: true, + placeholder: 'Gemini API Key', + env: 'GEMINI_API_KEY', + scope: 'server', + }, +]; + +class GeminiProvider extends BaseModelProvider { + constructor(id: string, name: string, config: GeminiConfig) { + super(id, name, config); + } + + async getDefaultModels(): Promise { + const res = await fetch( + `https://generativelanguage.googleapis.com/v1beta/models?key=${this.config.apiKey}`, + { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + + const data = await res.json(); + + let defaultEmbeddingModels: Model[] = []; + let defaultChatModels: Model[] = []; + + data.models.forEach((m: any) => { + if ( + m.supportedGenerationMethods.some( + (genMethod: string) => + genMethod === 'embedText' || genMethod === 'embedContent', + ) + ) { + defaultEmbeddingModels.push({ + key: m.name, + name: m.displayName, + }); + } else if (m.supportedGenerationMethods.includes('generateContent')) { + defaultChatModels.push({ + key: m.name, + name: m.displayName, + }); + } + }); + + return { + embedding: defaultEmbeddingModels, + chat: defaultChatModels, + }; + } + + async getModelList(): Promise { + const defaultModels = await this.getDefaultModels(); + const configProvider = getConfiguredModelProviderById(this.id)!; + + return { + embedding: [ + ...defaultModels.embedding, + ...configProvider.embeddingModels, + ], + chat: [...defaultModels.chat, ...configProvider.chatModels], + }; + } + + async loadChatModel(key: string): Promise> { + const modelList = await this.getModelList(); + + const exists = modelList.chat.find((m) => m.key === key); + + if (!exists) { + throw new Error( + 'Error Loading Gemini Chat Model. Invalid Model Selected', + ); + } + + return new GeminiLLM({ + apiKey: this.config.apiKey, + model: key, + baseURL: 'https://generativelanguage.googleapis.com/v1beta/openai', + }); + } + + async loadEmbeddingModel(key: string): Promise> { + const modelList = await this.getModelList(); + const exists = modelList.embedding.find((m) => m.key === key); + + if (!exists) { + throw new Error( + 'Error Loading Gemini Embedding Model. Invalid Model Selected.', + ); + } + + return new GeminiEmbedding({ + apiKey: this.config.apiKey, + model: key, + baseURL: 'https://generativelanguage.googleapis.com/v1beta/openai', + }); + } + + static parseAndValidate(raw: any): GeminiConfig { + if (!raw || typeof raw !== 'object') + throw new Error('Invalid config provided. Expected object'); + if (!raw.apiKey) + throw new Error('Invalid config provided. API key must be provided'); + + return { + apiKey: String(raw.apiKey), + }; + } + + static getProviderConfigFields(): UIConfigField[] { + return providerConfigFields; + } + + static getProviderMetadata(): ProviderMetadata { + return { + key: 'gemini', + name: 'Gemini', + }; + } +} + +export default GeminiProvider;