From 86a43086cce9f3ce35be745de2ec444a740d2c69 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Wed, 24 Dec 2025 13:55:56 +0530 Subject: [PATCH] feat(providers): add transformers --- .../models/providers/transformers/index.ts | 88 +++++++++++++++++++ .../transformers/transformerEmbedding.ts | 43 +++++++++ 2 files changed, 131 insertions(+) create mode 100644 src/lib/models/providers/transformers/index.ts create mode 100644 src/lib/models/providers/transformers/transformerEmbedding.ts diff --git a/src/lib/models/providers/transformers/index.ts b/src/lib/models/providers/transformers/index.ts new file mode 100644 index 0000000..e60e94f --- /dev/null +++ b/src/lib/models/providers/transformers/index.ts @@ -0,0 +1,88 @@ +import { UIConfigField } from '@/lib/config/types'; +import { getConfiguredModelProviderById } from '@/lib/config/serverRegistry'; +import { Model, ModelList, ProviderMetadata } from '../../types'; +import BaseModelProvider from '../../base/provider'; +import BaseLLM from '../../base/llm'; +import BaseEmbedding from '../../base/embedding'; +import TransformerEmbedding from './transformerEmbedding'; + +interface TransformersConfig {} + +const defaultEmbeddingModels: Model[] = [ + { + name: 'all-MiniLM-L6-v2', + key: 'Xenova/all-MiniLM-L6-v2', + }, + { + name: 'mxbai-embed-large-v1', + key: 'mixedbread-ai/mxbai-embed-large-v1', + }, + { + name: 'nomic-embed-text-v1', + key: 'Xenova/nomic-embed-text-v1', + }, +]; + +const providerConfigFields: UIConfigField[] = []; + +class TransformersProvider extends BaseModelProvider { + constructor(id: string, name: string, config: TransformersConfig) { + super(id, name, config); + } + + async getDefaultModels(): Promise { + return { + embedding: [...defaultEmbeddingModels], + chat: [], + }; + } + + async getModelList(): Promise { + const defaultModels = await this.getDefaultModels(); + const configProvider = getConfiguredModelProviderById(this.id)!; + + return { + embedding: [ + ...defaultModels.embedding, + ...configProvider.embeddingModels, + ], + chat: [], + }; + } + + async loadChatModel(key: string): Promise> { + throw new Error('Transformers Provider does not support chat models.'); + } + + async loadEmbeddingModel(key: string): Promise> { + const modelList = await this.getModelList(); + const exists = modelList.embedding.find((m) => m.key === key); + + if (!exists) { + throw new Error( + 'Error Loading OpenAI Embedding Model. Invalid Model Selected.', + ); + } + + return new TransformerEmbedding({ + model: key, + }); + } + + static parseAndValidate(raw: any): TransformersConfig { + return {}; + } + + static getProviderConfigFields(): UIConfigField[] { + return providerConfigFields; + } + + static getProviderMetadata(): ProviderMetadata { + return { + key: 'transformers', + name: 'Transformers', + }; + } +} + +export default TransformersProvider; diff --git a/src/lib/models/providers/transformers/transformerEmbedding.ts b/src/lib/models/providers/transformers/transformerEmbedding.ts new file mode 100644 index 0000000..b3f43f0 --- /dev/null +++ b/src/lib/models/providers/transformers/transformerEmbedding.ts @@ -0,0 +1,43 @@ +import { Chunk } from '@/lib/types'; +import BaseEmbedding from '../../base/embedding'; +import { FeatureExtractionPipeline, pipeline } from '@huggingface/transformers'; + +type TransformerConfig = { + model: string; +}; + +class TransformerEmbedding extends BaseEmbedding { + private pipelinePromise: Promise | null = null; + + constructor(protected config: TransformerConfig) { + super(config); + } + + async embedText(texts: string[]): Promise { + return this.embed(texts); + } + + async embedChunks(chunks: Chunk[]): Promise { + return this.embed(chunks.map((c) => c.content)); + } + + async embed(texts: string[]): Promise { + if (!this.pipelinePromise) { + this.pipelinePromise = (async () => { + const transformers = await import('@huggingface/transformers'); + return (await transformers.pipeline( + 'feature-extraction', + this.config.model, + )) as unknown as FeatureExtractionPipeline; + })(); + } + + const pipeline = await this.pipelinePromise; + + const output = await pipeline(texts, { pooling: 'mean', normalize: true }); + + return output.tolist() as number[][]; + } +} + +export default TransformerEmbedding;