feat(providers): implement custom classes

This commit is contained in:
ItzCrazyKns
2025-11-18 14:39:04 +05:30
parent 5272c7fd3e
commit 4bcbdad6cb
8 changed files with 417 additions and 88 deletions

View File

@@ -1,45 +0,0 @@
import { Embeddings } from '@langchain/core/embeddings';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Model, ModelList, ProviderMetadata } from '../types';
import { UIConfigField } from '@/lib/config/types';
abstract class BaseModelProvider<CONFIG> {
constructor(
protected id: string,
protected name: string,
protected config: CONFIG,
) {}
abstract getDefaultModels(): Promise<ModelList>;
abstract getModelList(): Promise<ModelList>;
abstract loadChatModel(modelName: string): Promise<BaseChatModel>;
abstract loadEmbeddingModel(modelName: string): Promise<Embeddings>;
static getProviderConfigFields(): UIConfigField[] {
throw new Error('Method not implemented.');
}
static getProviderMetadata(): ProviderMetadata {
throw new Error('Method not Implemented.');
}
static parseAndValidate(raw: any): any {
/* Static methods can't access class type parameters */
throw new Error('Method not Implemented.');
}
}
export type ProviderConstructor<CONFIG> = {
new (id: string, name: string, config: CONFIG): BaseModelProvider<CONFIG>;
parseAndValidate(raw: any): CONFIG;
getProviderConfigFields: () => UIConfigField[];
getProviderMetadata: () => ProviderMetadata;
};
export const createProviderInstance = <P extends ProviderConstructor<any>>(
Provider: P,
id: string,
name: string,
rawConfig: unknown,
): InstanceType<P> => {
const cfg = Provider.parseAndValidate(rawConfig);
return new Provider(id, name, cfg) as InstanceType<P>;
};
export default BaseModelProvider;

View File

@@ -1,27 +1,11 @@
import { ModelProviderUISection } from '@/lib/config/types';
import { ProviderConstructor } from './baseProvider';
import { ProviderConstructor } from '../base/provider';
import OpenAIProvider from './openai';
import OllamaProvider from './ollama';
import TransformersProvider from './transformers';
import AnthropicProvider from './anthropic';
import GeminiProvider from './gemini';
import GroqProvider from './groq';
import DeepSeekProvider from './deepseek';
import LMStudioProvider from './lmstudio';
import LemonadeProvider from './lemonade';
import AimlProvider from '@/lib/models/providers/aiml';
export const providers: Record<string, ProviderConstructor<any>> = {
openai: OpenAIProvider,
ollama: OllamaProvider,
transformers: TransformersProvider,
anthropic: AnthropicProvider,
gemini: GeminiProvider,
groq: GroqProvider,
deepseek: DeepSeekProvider,
aiml: AimlProvider,
lmstudio: LMStudioProvider,
lemonade: LemonadeProvider,
};
export const getModelProvidersUIConfigSection =

View File

@@ -1,10 +1,11 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Model, ModelList, ProviderMetadata } from '../types';
import BaseModelProvider from './baseProvider';
import { ChatOllama, OllamaEmbeddings } from '@langchain/ollama';
import { Embeddings } from '@langchain/core/embeddings';
import { UIConfigField } from '@/lib/config/types';
import { getConfiguredModelProviderById } from '@/lib/config/serverRegistry';
import BaseModelProvider from '../../base/provider';
import { Model, ModelList, ProviderMetadata } from '../../types';
import BaseLLM from '../../base/llm';
import BaseEmbedding from '../../base/embedding';
import OllamaLLM from './ollamaLLM';
import OllamaEmbedding from './ollamaEmbedding';
interface OllamaConfig {
baseURL: string;
@@ -76,7 +77,7 @@ class OllamaProvider extends BaseModelProvider<OllamaConfig> {
};
}
async loadChatModel(key: string): Promise<BaseChatModel> {
async loadChatModel(key: string): Promise<BaseLLM<any>> {
const modelList = await this.getModelList();
const exists = modelList.chat.find((m) => m.key === key);
@@ -87,14 +88,13 @@ class OllamaProvider extends BaseModelProvider<OllamaConfig> {
);
}
return new ChatOllama({
temperature: 0.7,
return new OllamaLLM({
baseURL: this.config.baseURL,
model: key,
baseUrl: this.config.baseURL,
});
}
async loadEmbeddingModel(key: string): Promise<Embeddings> {
async loadEmbeddingModel(key: string): Promise<BaseEmbedding<any>> {
const modelList = await this.getModelList();
const exists = modelList.embedding.find((m) => m.key === key);
@@ -104,9 +104,9 @@ class OllamaProvider extends BaseModelProvider<OllamaConfig> {
);
}
return new OllamaEmbeddings({
return new OllamaEmbedding({
model: key,
baseUrl: this.config.baseURL,
baseURL: this.config.baseURL,
});
}

View File

@@ -0,0 +1,39 @@
import { Ollama } from 'ollama';
import BaseEmbedding from '../../base/embedding';
type OllamaConfig = {
model: string;
baseURL?: string;
};
class OllamaEmbedding extends BaseEmbedding<OllamaConfig> {
ollamaClient: Ollama;
constructor(protected config: OllamaConfig) {
super(config);
this.ollamaClient = new Ollama({
host: this.config.baseURL || 'http://localhost:11434',
});
}
async embedText(texts: string[]): Promise<number[][]> {
const response = await this.ollamaClient.embed({
input: texts,
model: this.config.model,
});
return response.embeddings;
}
async embedChunks(chunks: Chunk[]): Promise<number[][]> {
const response = await this.ollamaClient.embed({
input: chunks.map((c) => c.content),
model: this.config.model,
});
return response.embeddings;
}
}
export default OllamaEmbedding;

View File

@@ -0,0 +1,149 @@
import z from 'zod';
import BaseLLM from '../../base/llm';
import {
GenerateObjectInput,
GenerateOptions,
GenerateTextInput,
GenerateTextOutput,
StreamTextOutput,
} from '../../types';
import { Ollama } from 'ollama';
import { parse } from 'partial-json';
type OllamaConfig = {
baseURL: string;
model: string;
options?: GenerateOptions;
};
class OllamaLLM extends BaseLLM<OllamaConfig> {
ollamaClient: Ollama;
constructor(protected config: OllamaConfig) {
super(config);
this.ollamaClient = new Ollama({
host: this.config.baseURL || 'http://localhost:11434',
});
}
withOptions(options: GenerateOptions) {
this.config.options = {
...this.config.options,
...options,
};
return this;
}
async generateText(input: GenerateTextInput): Promise<GenerateTextOutput> {
this.withOptions(input.options || {});
const res = await this.ollamaClient.chat({
model: this.config.model,
messages: input.messages,
options: {
top_p: this.config.options?.topP,
temperature: this.config.options?.temperature,
num_predict: this.config.options?.maxTokens,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
stop: this.config.options?.stopSequences,
},
});
return {
content: res.message.content,
additionalInfo: {
reasoning: res.message.thinking,
},
};
}
async *streamText(
input: GenerateTextInput,
): AsyncGenerator<StreamTextOutput> {
this.withOptions(input.options || {});
const stream = await this.ollamaClient.chat({
model: this.config.model,
messages: input.messages,
stream: true,
options: {
top_p: this.config.options?.topP,
temperature: this.config.options?.temperature,
num_predict: this.config.options?.maxTokens,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
stop: this.config.options?.stopSequences,
},
});
for await (const chunk of stream) {
yield {
contentChunk: chunk.message.content,
done: chunk.done,
additionalInfo: {
reasoning: chunk.message.thinking,
},
};
}
}
async generateObject<T>(input: GenerateObjectInput): Promise<T> {
this.withOptions(input.options || {});
const response = await this.ollamaClient.chat({
model: this.config.model,
messages: input.messages,
format: z.toJSONSchema(input.schema),
options: {
top_p: this.config.options?.topP,
temperature: this.config.options?.temperature,
num_predict: this.config.options?.maxTokens,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
stop: this.config.options?.stopSequences,
},
});
try {
return input.schema.parse(JSON.parse(response.message.content)) as T;
} catch (err) {
throw new Error(`Error parsing response from Ollama: ${err}`);
}
}
async *streamObject<T>(input: GenerateObjectInput): AsyncGenerator<T> {
let recievedObj: string = '';
this.withOptions(input.options || {});
const stream = await this.ollamaClient.chat({
model: this.config.model,
messages: input.messages,
format: z.toJSONSchema(input.schema),
stream: true,
options: {
top_p: this.config.options?.topP,
temperature: this.config.options?.temperature,
num_predict: this.config.options?.maxTokens,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
stop: this.config.options?.stopSequences,
},
});
for await (const chunk of stream) {
recievedObj += chunk.message.content;
try {
yield parse(recievedObj) as T;
} catch (err) {
console.log('Error parsing partial object from Ollama:', err);
yield {} as T;
}
}
}
}
export default OllamaLLM;

View File

@@ -1,10 +1,13 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Model, ModelList, ProviderMetadata } from '../types';
import BaseModelProvider from './baseProvider';
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { Embeddings } from '@langchain/core/embeddings';
import { UIConfigField } from '@/lib/config/types';
import { getConfiguredModelProviderById } from '@/lib/config/serverRegistry';
import { Model, ModelList, ProviderMetadata } from '../../types';
import OpenAIEmbedding from './openaiEmbedding';
import BaseEmbedding from '../../base/embedding';
import BaseModelProvider from '../../base/provider';
import BaseLLM from '../../base/llm';
import OpenAILLM from './openaiLLM';
interface OpenAIConfig {
apiKey: string;
@@ -145,7 +148,7 @@ class OpenAIProvider extends BaseModelProvider<OpenAIConfig> {
};
}
async loadChatModel(key: string): Promise<BaseChatModel> {
async loadChatModel(key: string): Promise<BaseLLM<any>> {
const modelList = await this.getModelList();
const exists = modelList.chat.find((m) => m.key === key);
@@ -156,17 +159,14 @@ class OpenAIProvider extends BaseModelProvider<OpenAIConfig> {
);
}
return new ChatOpenAI({
return new OpenAILLM({
apiKey: this.config.apiKey,
temperature: 0.7,
model: key,
configuration: {
baseURL: this.config.baseURL,
},
baseURL: this.config.baseURL,
});
}
async loadEmbeddingModel(key: string): Promise<Embeddings> {
async loadEmbeddingModel(key: string): Promise<BaseEmbedding<any>> {
const modelList = await this.getModelList();
const exists = modelList.embedding.find((m) => m.key === key);
@@ -176,12 +176,10 @@ class OpenAIProvider extends BaseModelProvider<OpenAIConfig> {
);
}
return new OpenAIEmbeddings({
return new OpenAIEmbedding({
apiKey: this.config.apiKey,
model: key,
configuration: {
baseURL: this.config.baseURL,
},
baseURL: this.config.baseURL,
});
}

View File

@@ -0,0 +1,41 @@
import OpenAI from 'openai';
import BaseEmbedding from '../../base/embedding';
type OpenAIConfig = {
apiKey: string;
model: string;
baseURL?: string;
};
class OpenAIEmbedding extends BaseEmbedding<OpenAIConfig> {
openAIClient: OpenAI;
constructor(protected config: OpenAIConfig) {
super(config);
this.openAIClient = new OpenAI({
apiKey: config.apiKey,
baseURL: config.baseURL,
});
}
async embedText(texts: string[]): Promise<number[][]> {
const response = await this.openAIClient.embeddings.create({
model: this.config.model,
input: texts,
});
return response.data.map((embedding) => embedding.embedding);
}
async embedChunks(chunks: Chunk[]): Promise<number[][]> {
const response = await this.openAIClient.embeddings.create({
model: this.config.model,
input: chunks.map((c) => c.content),
});
return response.data.map((embedding) => embedding.embedding);
}
}
export default OpenAIEmbedding;

View File

@@ -0,0 +1,163 @@
import OpenAI from 'openai';
import BaseLLM from '../../base/llm';
import { zodTextFormat, zodResponseFormat } from 'openai/helpers/zod';
import {
GenerateObjectInput,
GenerateOptions,
GenerateTextInput,
GenerateTextOutput,
StreamTextOutput,
} from '../../types';
import { parse } from 'partial-json';
type OpenAIConfig = {
apiKey: string;
model: string;
baseURL?: string;
options?: GenerateOptions;
};
class OpenAILLM extends BaseLLM<OpenAIConfig> {
openAIClient: OpenAI;
constructor(protected config: OpenAIConfig) {
super(config);
this.openAIClient = new OpenAI({
apiKey: this.config.apiKey,
baseURL: this.config.baseURL || 'https://api.openai.com/v1',
});
}
withOptions(options: GenerateOptions) {
this.config.options = {
...this.config.options,
...options,
};
return this;
}
async generateText(input: GenerateTextInput): Promise<GenerateTextOutput> {
this.withOptions(input.options || {});
const response = await this.openAIClient.chat.completions.create({
model: this.config.model,
messages: input.messages,
temperature: this.config.options?.temperature || 1.0,
top_p: this.config.options?.topP,
max_completion_tokens: this.config.options?.maxTokens,
stop: this.config.options?.stopSequences,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
});
if (response.choices && response.choices.length > 0) {
return {
content: response.choices[0].message.content!,
additionalInfo: {
finishReason: response.choices[0].finish_reason,
},
};
}
throw new Error('No response from OpenAI');
}
async *streamText(
input: GenerateTextInput,
): AsyncGenerator<StreamTextOutput> {
this.withOptions(input.options || {});
const stream = await this.openAIClient.chat.completions.create({
model: this.config.model,
messages: input.messages,
temperature: this.config.options?.temperature || 1.0,
top_p: this.config.options?.topP,
max_completion_tokens: this.config.options?.maxTokens,
stop: this.config.options?.stopSequences,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
stream: true,
});
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
yield {
contentChunk: chunk.choices[0].delta.content || '',
done: chunk.choices[0].finish_reason !== null,
additionalInfo: {
finishReason: chunk.choices[0].finish_reason,
},
};
}
}
}
async generateObject<T>(input: GenerateObjectInput): Promise<T> {
this.withOptions(input.options || {});
const response = await this.openAIClient.chat.completions.parse({
messages: input.messages,
model: this.config.model,
temperature: this.config.options?.temperature || 1.0,
top_p: this.config.options?.topP,
max_completion_tokens: this.config.options?.maxTokens,
stop: this.config.options?.stopSequences,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
response_format: zodResponseFormat(input.schema, 'object'),
});
if (response.choices && response.choices.length > 0) {
try {
return input.schema.parse(response.choices[0].message.parsed) as T;
} catch (err) {
throw new Error(`Error parsing response from OpenAI: ${err}`);
}
}
throw new Error('No response from OpenAI');
}
async *streamObject<T>(input: GenerateObjectInput): AsyncGenerator<T> {
let recievedObj: string = '';
this.withOptions(input.options || {});
const stream = this.openAIClient.responses.stream({
model: this.config.model,
input: input.messages,
temperature: this.config.options?.temperature || 1.0,
top_p: this.config.options?.topP,
max_completion_tokens: this.config.options?.maxTokens,
stop: this.config.options?.stopSequences,
frequency_penalty: this.config.options?.frequencyPenalty,
presence_penalty: this.config.options?.presencePenalty,
text: {
format: zodTextFormat(input.schema, 'object'),
},
});
for await (const chunk of stream) {
if (chunk.type === 'response.output_text.delta' && chunk.delta) {
recievedObj += chunk.delta;
try {
yield parse(recievedObj) as T;
} catch (err) {
console.log('Error parsing partial object from OpenAI:', err);
yield {} as T;
}
} else if (chunk.type === 'response.output_text.done' && chunk.text) {
try {
yield parse(chunk.text) as T;
} catch (err) {
throw new Error(`Error parsing response from OpenAI: ${err}`);
}
}
}
}
}
export default OpenAILLM;