From 4bcbdad6cb3346faabc1ac3b1e0439ac0f00d10d Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:39:04 +0530 Subject: [PATCH] feat(providers): implement custom classes --- src/lib/models/providers/baseProvider.ts | 45 ----- src/lib/models/providers/index.ts | 18 +- .../providers/{ollama.ts => ollama/index.ts} | 24 +-- .../providers/ollama/ollamaEmbedding.ts | 39 +++++ src/lib/models/providers/ollama/ollamaLLM.ts | 149 ++++++++++++++++ .../providers/{openai.ts => openai/index.ts} | 26 ++- .../providers/openai/openaiEmbedding.ts | 41 +++++ src/lib/models/providers/openai/openaiLLM.ts | 163 ++++++++++++++++++ 8 files changed, 417 insertions(+), 88 deletions(-) delete mode 100644 src/lib/models/providers/baseProvider.ts rename src/lib/models/providers/{ollama.ts => ollama/index.ts} (83%) create mode 100644 src/lib/models/providers/ollama/ollamaEmbedding.ts create mode 100644 src/lib/models/providers/ollama/ollamaLLM.ts rename src/lib/models/providers/{openai.ts => openai/index.ts} (87%) create mode 100644 src/lib/models/providers/openai/openaiEmbedding.ts create mode 100644 src/lib/models/providers/openai/openaiLLM.ts diff --git a/src/lib/models/providers/baseProvider.ts b/src/lib/models/providers/baseProvider.ts deleted file mode 100644 index 980a2b2..0000000 --- a/src/lib/models/providers/baseProvider.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { Embeddings } from '@langchain/core/embeddings'; -import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { Model, ModelList, ProviderMetadata } from '../types'; -import { UIConfigField } from '@/lib/config/types'; - -abstract class BaseModelProvider { - constructor( - protected id: string, - protected name: string, - protected config: CONFIG, - ) {} - abstract getDefaultModels(): Promise; - abstract getModelList(): Promise; - abstract loadChatModel(modelName: string): Promise; - abstract loadEmbeddingModel(modelName: string): Promise; - static getProviderConfigFields(): UIConfigField[] { - throw new Error('Method not implemented.'); - } - static getProviderMetadata(): ProviderMetadata { - throw new Error('Method not Implemented.'); - } - static parseAndValidate(raw: any): any { - /* Static methods can't access class type parameters */ - throw new Error('Method not Implemented.'); - } -} - -export type ProviderConstructor = { - new (id: string, name: string, config: CONFIG): BaseModelProvider; - parseAndValidate(raw: any): CONFIG; - getProviderConfigFields: () => UIConfigField[]; - getProviderMetadata: () => ProviderMetadata; -}; - -export const createProviderInstance =

>( - Provider: P, - id: string, - name: string, - rawConfig: unknown, -): InstanceType

=> { - const cfg = Provider.parseAndValidate(rawConfig); - return new Provider(id, name, cfg) as InstanceType

; -}; - -export default BaseModelProvider; diff --git a/src/lib/models/providers/index.ts b/src/lib/models/providers/index.ts index addca61..6e508e1 100644 --- a/src/lib/models/providers/index.ts +++ b/src/lib/models/providers/index.ts @@ -1,27 +1,11 @@ import { ModelProviderUISection } from '@/lib/config/types'; -import { ProviderConstructor } from './baseProvider'; +import { ProviderConstructor } from '../base/provider'; import OpenAIProvider from './openai'; import OllamaProvider from './ollama'; -import TransformersProvider from './transformers'; -import AnthropicProvider from './anthropic'; -import GeminiProvider from './gemini'; -import GroqProvider from './groq'; -import DeepSeekProvider from './deepseek'; -import LMStudioProvider from './lmstudio'; -import LemonadeProvider from './lemonade'; -import AimlProvider from '@/lib/models/providers/aiml'; export const providers: Record> = { openai: OpenAIProvider, ollama: OllamaProvider, - transformers: TransformersProvider, - anthropic: AnthropicProvider, - gemini: GeminiProvider, - groq: GroqProvider, - deepseek: DeepSeekProvider, - aiml: AimlProvider, - lmstudio: LMStudioProvider, - lemonade: LemonadeProvider, }; export const getModelProvidersUIConfigSection = diff --git a/src/lib/models/providers/ollama.ts b/src/lib/models/providers/ollama/index.ts similarity index 83% rename from src/lib/models/providers/ollama.ts rename to src/lib/models/providers/ollama/index.ts index 9ae5899..762c2bf 100644 --- a/src/lib/models/providers/ollama.ts +++ b/src/lib/models/providers/ollama/index.ts @@ -1,10 +1,11 @@ -import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { Model, ModelList, ProviderMetadata } from '../types'; -import BaseModelProvider from './baseProvider'; -import { ChatOllama, OllamaEmbeddings } from '@langchain/ollama'; -import { Embeddings } from '@langchain/core/embeddings'; import { UIConfigField } from '@/lib/config/types'; import { getConfiguredModelProviderById } from '@/lib/config/serverRegistry'; +import BaseModelProvider from '../../base/provider'; +import { Model, ModelList, ProviderMetadata } from '../../types'; +import BaseLLM from '../../base/llm'; +import BaseEmbedding from '../../base/embedding'; +import OllamaLLM from './ollamaLLM'; +import OllamaEmbedding from './ollamaEmbedding'; interface OllamaConfig { baseURL: string; @@ -76,7 +77,7 @@ class OllamaProvider extends BaseModelProvider { }; } - async loadChatModel(key: string): Promise { + async loadChatModel(key: string): Promise> { const modelList = await this.getModelList(); const exists = modelList.chat.find((m) => m.key === key); @@ -87,14 +88,13 @@ class OllamaProvider extends BaseModelProvider { ); } - return new ChatOllama({ - temperature: 0.7, + return new OllamaLLM({ + baseURL: this.config.baseURL, model: key, - baseUrl: this.config.baseURL, }); } - async loadEmbeddingModel(key: string): Promise { + async loadEmbeddingModel(key: string): Promise> { const modelList = await this.getModelList(); const exists = modelList.embedding.find((m) => m.key === key); @@ -104,9 +104,9 @@ class OllamaProvider extends BaseModelProvider { ); } - return new OllamaEmbeddings({ + return new OllamaEmbedding({ model: key, - baseUrl: this.config.baseURL, + baseURL: this.config.baseURL, }); } diff --git a/src/lib/models/providers/ollama/ollamaEmbedding.ts b/src/lib/models/providers/ollama/ollamaEmbedding.ts new file mode 100644 index 0000000..0fd306a --- /dev/null +++ b/src/lib/models/providers/ollama/ollamaEmbedding.ts @@ -0,0 +1,39 @@ +import { Ollama } from 'ollama'; +import BaseEmbedding from '../../base/embedding'; + +type OllamaConfig = { + model: string; + baseURL?: string; +}; + +class OllamaEmbedding extends BaseEmbedding { + ollamaClient: Ollama; + + constructor(protected config: OllamaConfig) { + super(config); + + this.ollamaClient = new Ollama({ + host: this.config.baseURL || 'http://localhost:11434', + }); + } + + async embedText(texts: string[]): Promise { + const response = await this.ollamaClient.embed({ + input: texts, + model: this.config.model, + }); + + return response.embeddings; + } + + async embedChunks(chunks: Chunk[]): Promise { + const response = await this.ollamaClient.embed({ + input: chunks.map((c) => c.content), + model: this.config.model, + }); + + return response.embeddings; + } +} + +export default OllamaEmbedding; diff --git a/src/lib/models/providers/ollama/ollamaLLM.ts b/src/lib/models/providers/ollama/ollamaLLM.ts new file mode 100644 index 0000000..fd12b77 --- /dev/null +++ b/src/lib/models/providers/ollama/ollamaLLM.ts @@ -0,0 +1,149 @@ +import z from 'zod'; +import BaseLLM from '../../base/llm'; +import { + GenerateObjectInput, + GenerateOptions, + GenerateTextInput, + GenerateTextOutput, + StreamTextOutput, +} from '../../types'; +import { Ollama } from 'ollama'; +import { parse } from 'partial-json'; + +type OllamaConfig = { + baseURL: string; + model: string; + options?: GenerateOptions; +}; + +class OllamaLLM extends BaseLLM { + ollamaClient: Ollama; + + constructor(protected config: OllamaConfig) { + super(config); + + this.ollamaClient = new Ollama({ + host: this.config.baseURL || 'http://localhost:11434', + }); + } + + withOptions(options: GenerateOptions) { + this.config.options = { + ...this.config.options, + ...options, + }; + return this; + } + + async generateText(input: GenerateTextInput): Promise { + this.withOptions(input.options || {}); + + const res = await this.ollamaClient.chat({ + model: this.config.model, + messages: input.messages, + options: { + top_p: this.config.options?.topP, + temperature: this.config.options?.temperature, + num_predict: this.config.options?.maxTokens, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + stop: this.config.options?.stopSequences, + }, + }); + + return { + content: res.message.content, + additionalInfo: { + reasoning: res.message.thinking, + }, + }; + } + + async *streamText( + input: GenerateTextInput, + ): AsyncGenerator { + this.withOptions(input.options || {}); + + const stream = await this.ollamaClient.chat({ + model: this.config.model, + messages: input.messages, + stream: true, + options: { + top_p: this.config.options?.topP, + temperature: this.config.options?.temperature, + num_predict: this.config.options?.maxTokens, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + stop: this.config.options?.stopSequences, + }, + }); + + for await (const chunk of stream) { + yield { + contentChunk: chunk.message.content, + done: chunk.done, + additionalInfo: { + reasoning: chunk.message.thinking, + }, + }; + } + } + + async generateObject(input: GenerateObjectInput): Promise { + this.withOptions(input.options || {}); + + const response = await this.ollamaClient.chat({ + model: this.config.model, + messages: input.messages, + format: z.toJSONSchema(input.schema), + options: { + top_p: this.config.options?.topP, + temperature: this.config.options?.temperature, + num_predict: this.config.options?.maxTokens, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + stop: this.config.options?.stopSequences, + }, + }); + + try { + return input.schema.parse(JSON.parse(response.message.content)) as T; + } catch (err) { + throw new Error(`Error parsing response from Ollama: ${err}`); + } + } + + async *streamObject(input: GenerateObjectInput): AsyncGenerator { + let recievedObj: string = ''; + + this.withOptions(input.options || {}); + + const stream = await this.ollamaClient.chat({ + model: this.config.model, + messages: input.messages, + format: z.toJSONSchema(input.schema), + stream: true, + options: { + top_p: this.config.options?.topP, + temperature: this.config.options?.temperature, + num_predict: this.config.options?.maxTokens, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + stop: this.config.options?.stopSequences, + }, + }); + + for await (const chunk of stream) { + recievedObj += chunk.message.content; + + try { + yield parse(recievedObj) as T; + } catch (err) { + console.log('Error parsing partial object from Ollama:', err); + yield {} as T; + } + } + } +} + +export default OllamaLLM; diff --git a/src/lib/models/providers/openai.ts b/src/lib/models/providers/openai/index.ts similarity index 87% rename from src/lib/models/providers/openai.ts rename to src/lib/models/providers/openai/index.ts index 6055b34..8b5eacb 100644 --- a/src/lib/models/providers/openai.ts +++ b/src/lib/models/providers/openai/index.ts @@ -1,10 +1,13 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { Model, ModelList, ProviderMetadata } from '../types'; -import BaseModelProvider from './baseProvider'; -import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; import { Embeddings } from '@langchain/core/embeddings'; import { UIConfigField } from '@/lib/config/types'; import { getConfiguredModelProviderById } from '@/lib/config/serverRegistry'; +import { Model, ModelList, ProviderMetadata } from '../../types'; +import OpenAIEmbedding from './openaiEmbedding'; +import BaseEmbedding from '../../base/embedding'; +import BaseModelProvider from '../../base/provider'; +import BaseLLM from '../../base/llm'; +import OpenAILLM from './openaiLLM'; interface OpenAIConfig { apiKey: string; @@ -145,7 +148,7 @@ class OpenAIProvider extends BaseModelProvider { }; } - async loadChatModel(key: string): Promise { + async loadChatModel(key: string): Promise> { const modelList = await this.getModelList(); const exists = modelList.chat.find((m) => m.key === key); @@ -156,17 +159,14 @@ class OpenAIProvider extends BaseModelProvider { ); } - return new ChatOpenAI({ + return new OpenAILLM({ apiKey: this.config.apiKey, - temperature: 0.7, model: key, - configuration: { - baseURL: this.config.baseURL, - }, + baseURL: this.config.baseURL, }); } - async loadEmbeddingModel(key: string): Promise { + async loadEmbeddingModel(key: string): Promise> { const modelList = await this.getModelList(); const exists = modelList.embedding.find((m) => m.key === key); @@ -176,12 +176,10 @@ class OpenAIProvider extends BaseModelProvider { ); } - return new OpenAIEmbeddings({ + return new OpenAIEmbedding({ apiKey: this.config.apiKey, model: key, - configuration: { - baseURL: this.config.baseURL, - }, + baseURL: this.config.baseURL, }); } diff --git a/src/lib/models/providers/openai/openaiEmbedding.ts b/src/lib/models/providers/openai/openaiEmbedding.ts new file mode 100644 index 0000000..ea15680 --- /dev/null +++ b/src/lib/models/providers/openai/openaiEmbedding.ts @@ -0,0 +1,41 @@ +import OpenAI from 'openai'; +import BaseEmbedding from '../../base/embedding'; + +type OpenAIConfig = { + apiKey: string; + model: string; + baseURL?: string; +}; + +class OpenAIEmbedding extends BaseEmbedding { + openAIClient: OpenAI; + + constructor(protected config: OpenAIConfig) { + super(config); + + this.openAIClient = new OpenAI({ + apiKey: config.apiKey, + baseURL: config.baseURL, + }); + } + + async embedText(texts: string[]): Promise { + const response = await this.openAIClient.embeddings.create({ + model: this.config.model, + input: texts, + }); + + return response.data.map((embedding) => embedding.embedding); + } + + async embedChunks(chunks: Chunk[]): Promise { + const response = await this.openAIClient.embeddings.create({ + model: this.config.model, + input: chunks.map((c) => c.content), + }); + + return response.data.map((embedding) => embedding.embedding); + } +} + +export default OpenAIEmbedding; diff --git a/src/lib/models/providers/openai/openaiLLM.ts b/src/lib/models/providers/openai/openaiLLM.ts new file mode 100644 index 0000000..95594e6 --- /dev/null +++ b/src/lib/models/providers/openai/openaiLLM.ts @@ -0,0 +1,163 @@ +import OpenAI from 'openai'; +import BaseLLM from '../../base/llm'; +import { zodTextFormat, zodResponseFormat } from 'openai/helpers/zod'; +import { + GenerateObjectInput, + GenerateOptions, + GenerateTextInput, + GenerateTextOutput, + StreamTextOutput, +} from '../../types'; +import { parse } from 'partial-json'; + +type OpenAIConfig = { + apiKey: string; + model: string; + baseURL?: string; + options?: GenerateOptions; +}; + +class OpenAILLM extends BaseLLM { + openAIClient: OpenAI; + + constructor(protected config: OpenAIConfig) { + super(config); + + this.openAIClient = new OpenAI({ + apiKey: this.config.apiKey, + baseURL: this.config.baseURL || 'https://api.openai.com/v1', + }); + } + + withOptions(options: GenerateOptions) { + this.config.options = { + ...this.config.options, + ...options, + }; + + return this; + } + + async generateText(input: GenerateTextInput): Promise { + this.withOptions(input.options || {}); + + const response = await this.openAIClient.chat.completions.create({ + model: this.config.model, + messages: input.messages, + temperature: this.config.options?.temperature || 1.0, + top_p: this.config.options?.topP, + max_completion_tokens: this.config.options?.maxTokens, + stop: this.config.options?.stopSequences, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + }); + + if (response.choices && response.choices.length > 0) { + return { + content: response.choices[0].message.content!, + additionalInfo: { + finishReason: response.choices[0].finish_reason, + }, + }; + } + + throw new Error('No response from OpenAI'); + } + + async *streamText( + input: GenerateTextInput, + ): AsyncGenerator { + this.withOptions(input.options || {}); + + const stream = await this.openAIClient.chat.completions.create({ + model: this.config.model, + messages: input.messages, + temperature: this.config.options?.temperature || 1.0, + top_p: this.config.options?.topP, + max_completion_tokens: this.config.options?.maxTokens, + stop: this.config.options?.stopSequences, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + stream: true, + }); + + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + yield { + contentChunk: chunk.choices[0].delta.content || '', + done: chunk.choices[0].finish_reason !== null, + additionalInfo: { + finishReason: chunk.choices[0].finish_reason, + }, + }; + } + } + } + + async generateObject(input: GenerateObjectInput): Promise { + this.withOptions(input.options || {}); + + const response = await this.openAIClient.chat.completions.parse({ + messages: input.messages, + model: this.config.model, + temperature: this.config.options?.temperature || 1.0, + top_p: this.config.options?.topP, + max_completion_tokens: this.config.options?.maxTokens, + stop: this.config.options?.stopSequences, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + response_format: zodResponseFormat(input.schema, 'object'), + }); + + if (response.choices && response.choices.length > 0) { + try { + return input.schema.parse(response.choices[0].message.parsed) as T; + } catch (err) { + throw new Error(`Error parsing response from OpenAI: ${err}`); + } + } + + throw new Error('No response from OpenAI'); + } + + async *streamObject(input: GenerateObjectInput): AsyncGenerator { + let recievedObj: string = ''; + + this.withOptions(input.options || {}); + + const stream = this.openAIClient.responses.stream({ + model: this.config.model, + input: input.messages, + temperature: this.config.options?.temperature || 1.0, + top_p: this.config.options?.topP, + max_completion_tokens: this.config.options?.maxTokens, + stop: this.config.options?.stopSequences, + frequency_penalty: this.config.options?.frequencyPenalty, + presence_penalty: this.config.options?.presencePenalty, + text: { + format: zodTextFormat(input.schema, 'object'), + }, + }); + + for await (const chunk of stream) { + if (chunk.type === 'response.output_text.delta' && chunk.delta) { + recievedObj += chunk.delta; + + try { + yield parse(recievedObj) as T; + } catch (err) { + console.log('Error parsing partial object from OpenAI:', err); + yield {} as T; + } + } else if (chunk.type === 'response.output_text.done' && chunk.text) { + try { + yield parse(chunk.text) as T; + } catch (err) { + throw new Error(`Error parsing response from OpenAI: ${err}`); + } + } + } + } +} + +export default OpenAILLM;