diff --git a/src/lib/models/providers/transformers/transformerEmbedding.ts b/src/lib/models/providers/transformers/transformerEmbedding.ts index b3f43f0..2675f9c 100644 --- a/src/lib/models/providers/transformers/transformerEmbedding.ts +++ b/src/lib/models/providers/transformers/transformerEmbedding.ts @@ -1,6 +1,6 @@ import { Chunk } from '@/lib/types'; import BaseEmbedding from '../../base/embedding'; -import { FeatureExtractionPipeline, pipeline } from '@huggingface/transformers'; +import { FeatureExtractionPipeline } from '@huggingface/transformers'; type TransformerConfig = { model: string; @@ -21,21 +21,19 @@ class TransformerEmbedding extends BaseEmbedding { return this.embed(chunks.map((c) => c.content)); } - async embed(texts: string[]): Promise { + private async embed(texts: string[]) { if (!this.pipelinePromise) { this.pipelinePromise = (async () => { - const transformers = await import('@huggingface/transformers'); - return (await transformers.pipeline( - 'feature-extraction', - this.config.model, - )) as unknown as FeatureExtractionPipeline; + const { pipeline } = await import('@huggingface/transformers'); + const result = await pipeline('feature-extraction', this.config.model, { + dtype: 'fp32', + }); + return result as FeatureExtractionPipeline; })(); } - const pipeline = await this.pipelinePromise; - - const output = await pipeline(texts, { pooling: 'mean', normalize: true }); - + const pipe = await this.pipelinePromise; + const output = await pipe(texts, { pooling: 'mean', normalize: true }); return output.tolist() as number[][]; } }