diff --git a/src/lib/models/base/llm.ts b/src/lib/models/base/llm.ts index e701fa9..fb24af4 100644 --- a/src/lib/models/base/llm.ts +++ b/src/lib/models/base/llm.ts @@ -8,7 +8,6 @@ import { abstract class BaseLLM { constructor(protected config: CONFIG) {} - abstract withOptions(options: GenerateOptions): this; abstract generateText(input: GenerateTextInput): Promise; abstract streamText( input: GenerateTextInput, diff --git a/src/lib/models/providers/ollama/ollamaLLM.ts b/src/lib/models/providers/ollama/ollamaLLM.ts index 548818e..491dfcd 100644 --- a/src/lib/models/providers/ollama/ollamaLLM.ts +++ b/src/lib/models/providers/ollama/ollamaLLM.ts @@ -35,28 +35,24 @@ class OllamaLLM extends BaseLLM { }); } - withOptions(options: GenerateOptions) { - this.config.options = { - ...this.config.options, - ...options, - }; - return this; - } - async generateText(input: GenerateTextInput): Promise { - this.withOptions(input.options || {}); - const res = await this.ollamaClient.chat({ model: this.config.model, messages: input.messages, options: { - top_p: this.config.options?.topP, - temperature: this.config.options?.temperature, - num_predict: this.config.options?.maxTokens, + top_p: input.options?.topP ?? this.config.options?.topP, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 0.7, + num_predict: input.options?.maxTokens ?? this.config.options?.maxTokens, num_ctx: 32000, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, - stop: this.config.options?.stopSequences, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? + this.config.options?.presencePenalty, + stop: + input.options?.stopSequences ?? this.config.options?.stopSequences, }, }); @@ -71,20 +67,24 @@ class OllamaLLM extends BaseLLM { async *streamText( input: GenerateTextInput, ): AsyncGenerator { - this.withOptions(input.options || {}); - const stream = await this.ollamaClient.chat({ model: this.config.model, messages: input.messages, stream: true, options: { - top_p: this.config.options?.topP, - temperature: this.config.options?.temperature, + top_p: input.options?.topP ?? this.config.options?.topP, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 0.7, num_ctx: 32000, - num_predict: this.config.options?.maxTokens, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, - stop: this.config.options?.stopSequences, + num_predict: input.options?.maxTokens ?? this.config.options?.maxTokens, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? + this.config.options?.presencePenalty, + stop: + input.options?.stopSequences ?? this.config.options?.stopSequences, }, }); @@ -100,8 +100,6 @@ class OllamaLLM extends BaseLLM { } async generateObject(input: GenerateObjectInput): Promise { - this.withOptions(input.options || {}); - const response = await this.ollamaClient.chat({ model: this.config.model, messages: input.messages, @@ -110,12 +108,18 @@ class OllamaLLM extends BaseLLM { ? { think: false } : {}), options: { - top_p: this.config.options?.topP, - temperature: 0.7, - num_predict: this.config.options?.maxTokens, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, - stop: this.config.options?.stopSequences, + top_p: input.options?.topP ?? this.config.options?.topP, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 0.7, + num_predict: input.options?.maxTokens ?? this.config.options?.maxTokens, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? + this.config.options?.presencePenalty, + stop: + input.options?.stopSequences ?? this.config.options?.stopSequences, }, }); @@ -129,8 +133,6 @@ class OllamaLLM extends BaseLLM { async *streamObject(input: GenerateObjectInput): AsyncGenerator { let recievedObj: string = ''; - this.withOptions(input.options || {}); - const stream = await this.ollamaClient.chat({ model: this.config.model, messages: input.messages, @@ -140,12 +142,18 @@ class OllamaLLM extends BaseLLM { ? { think: false } : {}), options: { - top_p: this.config.options?.topP, - temperature: 0.7, - num_predict: this.config.options?.maxTokens, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, - stop: this.config.options?.stopSequences, + top_p: input.options?.topP ?? this.config.options?.topP, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 0.7, + num_predict: input.options?.maxTokens ?? this.config.options?.maxTokens, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? + this.config.options?.presencePenalty, + stop: + input.options?.stopSequences ?? this.config.options?.stopSequences, }, }); diff --git a/src/lib/models/providers/openai/openaiLLM.ts b/src/lib/models/providers/openai/openaiLLM.ts index 95594e6..22f23d4 100644 --- a/src/lib/models/providers/openai/openaiLLM.ts +++ b/src/lib/models/providers/openai/openaiLLM.ts @@ -29,27 +29,21 @@ class OpenAILLM extends BaseLLM { }); } - withOptions(options: GenerateOptions) { - this.config.options = { - ...this.config.options, - ...options, - }; - - return this; - } - async generateText(input: GenerateTextInput): Promise { - this.withOptions(input.options || {}); - const response = await this.openAIClient.chat.completions.create({ model: this.config.model, messages: input.messages, - temperature: this.config.options?.temperature || 1.0, - top_p: this.config.options?.topP, - max_completion_tokens: this.config.options?.maxTokens, - stop: this.config.options?.stopSequences, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 1.0, + top_p: input.options?.topP ?? this.config.options?.topP, + max_completion_tokens: + input.options?.maxTokens ?? this.config.options?.maxTokens, + stop: input.options?.stopSequences ?? this.config.options?.stopSequences, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? this.config.options?.presencePenalty, }); if (response.choices && response.choices.length > 0) { @@ -67,17 +61,20 @@ class OpenAILLM extends BaseLLM { async *streamText( input: GenerateTextInput, ): AsyncGenerator { - this.withOptions(input.options || {}); - const stream = await this.openAIClient.chat.completions.create({ model: this.config.model, messages: input.messages, - temperature: this.config.options?.temperature || 1.0, - top_p: this.config.options?.topP, - max_completion_tokens: this.config.options?.maxTokens, - stop: this.config.options?.stopSequences, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 1.0, + top_p: input.options?.topP ?? this.config.options?.topP, + max_completion_tokens: + input.options?.maxTokens ?? this.config.options?.maxTokens, + stop: input.options?.stopSequences ?? this.config.options?.stopSequences, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? this.config.options?.presencePenalty, stream: true, }); @@ -95,17 +92,20 @@ class OpenAILLM extends BaseLLM { } async generateObject(input: GenerateObjectInput): Promise { - this.withOptions(input.options || {}); - const response = await this.openAIClient.chat.completions.parse({ messages: input.messages, model: this.config.model, - temperature: this.config.options?.temperature || 1.0, - top_p: this.config.options?.topP, - max_completion_tokens: this.config.options?.maxTokens, - stop: this.config.options?.stopSequences, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 1.0, + top_p: input.options?.topP ?? this.config.options?.topP, + max_completion_tokens: + input.options?.maxTokens ?? this.config.options?.maxTokens, + stop: input.options?.stopSequences ?? this.config.options?.stopSequences, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? this.config.options?.presencePenalty, response_format: zodResponseFormat(input.schema, 'object'), }); @@ -123,17 +123,20 @@ class OpenAILLM extends BaseLLM { async *streamObject(input: GenerateObjectInput): AsyncGenerator { let recievedObj: string = ''; - this.withOptions(input.options || {}); - const stream = this.openAIClient.responses.stream({ model: this.config.model, input: input.messages, - temperature: this.config.options?.temperature || 1.0, - top_p: this.config.options?.topP, - max_completion_tokens: this.config.options?.maxTokens, - stop: this.config.options?.stopSequences, - frequency_penalty: this.config.options?.frequencyPenalty, - presence_penalty: this.config.options?.presencePenalty, + temperature: + input.options?.temperature ?? this.config.options?.temperature ?? 1.0, + top_p: input.options?.topP ?? this.config.options?.topP, + max_completion_tokens: + input.options?.maxTokens ?? this.config.options?.maxTokens, + stop: input.options?.stopSequences ?? this.config.options?.stopSequences, + frequency_penalty: + input.options?.frequencyPenalty ?? + this.config.options?.frequencyPenalty, + presence_penalty: + input.options?.presencePenalty ?? this.config.options?.presencePenalty, text: { format: zodTextFormat(input.schema, 'object'), },