diff --git a/src/lib/models/providers/ollama/ollamaLLM.ts b/src/lib/models/providers/ollama/ollamaLLM.ts index 491dfcd..c0028a6 100644 --- a/src/lib/models/providers/ollama/ollamaLLM.ts +++ b/src/lib/models/providers/ollama/ollamaLLM.ts @@ -7,7 +7,7 @@ import { GenerateTextOutput, StreamTextOutput, } from '../../types'; -import { Ollama } from 'ollama'; +import { Ollama, Tool as OllamaTool } from 'ollama'; import { parse } from 'partial-json'; type OllamaConfig = { @@ -36,9 +36,23 @@ class OllamaLLM extends BaseLLM { } async generateText(input: GenerateTextInput): Promise { + const ollamaTools: OllamaTool[] = []; + + input.tools?.forEach((tool) => { + ollamaTools.push({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: z.toJSONSchema(tool.schema).properties, + }, + }); + }); + const res = await this.ollamaClient.chat({ model: this.config.model, messages: input.messages, + tools: ollamaTools.length > 0 ? ollamaTools : undefined, options: { top_p: input.options?.topP ?? this.config.options?.topP, temperature: @@ -58,6 +72,11 @@ class OllamaLLM extends BaseLLM { return { content: res.message.content, + toolCalls: + res.message.tool_calls?.map((tc) => ({ + name: tc.function.name, + arguments: tc.function.arguments, + })) || [], additionalInfo: { reasoning: res.message.thinking, }, @@ -67,10 +86,24 @@ class OllamaLLM extends BaseLLM { async *streamText( input: GenerateTextInput, ): AsyncGenerator { + const ollamaTools: OllamaTool[] = []; + + input.tools?.forEach((tool) => { + ollamaTools.push({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: z.toJSONSchema(tool.schema) as any, + }, + }); + }); + const stream = await this.ollamaClient.chat({ model: this.config.model, messages: input.messages, stream: true, + tools: ollamaTools.length > 0 ? ollamaTools : undefined, options: { top_p: input.options?.topP ?? this.config.options?.topP, temperature: @@ -91,6 +124,11 @@ class OllamaLLM extends BaseLLM { for await (const chunk of stream) { yield { contentChunk: chunk.message.content, + toolCallChunk: + chunk.message.tool_calls?.map((tc) => ({ + name: tc.function.name, + arguments: tc.function.arguments, + })) || [], done: chunk.done, additionalInfo: { reasoning: chunk.message.thinking,