diff --git a/src/lib/models/base/embedding.ts b/src/lib/models/base/embedding.ts new file mode 100644 index 0000000..35413ed --- /dev/null +++ b/src/lib/models/base/embedding.ts @@ -0,0 +1,7 @@ +abstract class BaseEmbedding { + constructor(protected config: CONFIG) {} + abstract embedText(texts: string[]): Promise; + abstract embedChunks(chunks: Chunk[]): Promise; +} + +export default BaseEmbedding; diff --git a/src/lib/models/base/llm.ts b/src/lib/models/base/llm.ts new file mode 100644 index 0000000..5d6f52d --- /dev/null +++ b/src/lib/models/base/llm.ts @@ -0,0 +1,26 @@ +import { + GenerateObjectInput, + GenerateObjectOutput, + GenerateOptions, + GenerateTextInput, + GenerateTextOutput, + StreamObjectOutput, + StreamTextOutput, +} from '../types'; + +abstract class BaseLLM { + constructor(protected config: CONFIG) {} + abstract withOptions(options: GenerateOptions): this; + abstract generateText(input: GenerateTextInput): Promise; + abstract streamText( + input: GenerateTextInput, + ): AsyncGenerator; + abstract generateObject( + input: GenerateObjectInput, + ): Promise>; + abstract streamObject( + input: GenerateObjectInput, + ): AsyncGenerator>; +} + +export default BaseLLM; diff --git a/src/lib/models/base/provider.ts b/src/lib/models/base/provider.ts new file mode 100644 index 0000000..950525e --- /dev/null +++ b/src/lib/models/base/provider.ts @@ -0,0 +1,47 @@ +import { Embeddings } from '@langchain/core/embeddings'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ModelList, ProviderMetadata } from '../types'; +import { UIConfigField } from '@/lib/config/types'; +import BaseLLM from './llm'; +import BaseEmbedding from './embedding'; + +abstract class BaseModelProvider { + constructor( + protected id: string, + protected name: string, + protected config: CONFIG, + ) {} + abstract getDefaultModels(): Promise; + abstract getModelList(): Promise; + abstract loadChatModel(modelName: string): Promise>; + abstract loadEmbeddingModel(modelName: string): Promise>; + static getProviderConfigFields(): UIConfigField[] { + throw new Error('Method not implemented.'); + } + static getProviderMetadata(): ProviderMetadata { + throw new Error('Method not Implemented.'); + } + static parseAndValidate(raw: any): any { + /* Static methods can't access class type parameters */ + throw new Error('Method not Implemented.'); + } +} + +export type ProviderConstructor = { + new (id: string, name: string, config: CONFIG): BaseModelProvider; + parseAndValidate(raw: any): CONFIG; + getProviderConfigFields: () => UIConfigField[]; + getProviderMetadata: () => ProviderMetadata; +}; + +export const createProviderInstance =

>( + Provider: P, + id: string, + name: string, + rawConfig: unknown, +): InstanceType

=> { + const cfg = Provider.parseAndValidate(rawConfig); + return new Provider(id, name, cfg) as InstanceType

; +}; + +export default BaseModelProvider;