feat(ollama-llm): process ollama messages with tool calls

This commit is contained in:
ItzCrazyKns
2025-12-06 15:31:35 +05:30
parent 3c524b0f98
commit 4c4c1d1930

View File

@@ -7,8 +7,10 @@ import {
GenerateTextOutput, GenerateTextOutput,
StreamTextOutput, StreamTextOutput,
} from '../../types'; } from '../../types';
import { Ollama, Tool as OllamaTool } from 'ollama'; import { Ollama, Tool as OllamaTool, Message as OllamaMessage } from 'ollama';
import { parse } from 'partial-json'; import { parse } from 'partial-json';
import crypto from 'crypto';
import { Message } from '@/lib/types';
type OllamaConfig = { type OllamaConfig = {
baseURL: string; baseURL: string;
@@ -35,6 +37,33 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
}); });
} }
convertToOllamaMessages(messages: Message[]): OllamaMessage[] {
return messages.map((msg) => {
if (msg.role === 'tool') {
return {
role: 'tool',
tool_name: msg.name,
content: msg.content,
} as OllamaMessage;
} else if (msg.role === 'assistant') {
return {
role: 'assistant',
content: msg.content,
tool_calls:
msg.tool_calls?.map((tc, i) => ({
function: {
index: i,
name: tc.name,
arguments: tc.arguments,
},
})) || [],
};
}
return msg;
});
}
async generateText(input: GenerateTextInput): Promise<GenerateTextOutput> { async generateText(input: GenerateTextInput): Promise<GenerateTextOutput> {
const ollamaTools: OllamaTool[] = []; const ollamaTools: OllamaTool[] = [];
@@ -51,8 +80,11 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
const res = await this.ollamaClient.chat({ const res = await this.ollamaClient.chat({
model: this.config.model, model: this.config.model,
messages: input.messages, messages: this.convertToOllamaMessages(input.messages),
tools: ollamaTools.length > 0 ? ollamaTools : undefined, tools: ollamaTools.length > 0 ? ollamaTools : undefined,
...(reasoningModels.find((m) => this.config.model.includes(m))
? { think: false }
: {}),
options: { options: {
top_p: input.options?.topP ?? this.config.options?.topP, top_p: input.options?.topP ?? this.config.options?.topP,
temperature: temperature:
@@ -74,6 +106,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
content: res.message.content, content: res.message.content,
toolCalls: toolCalls:
res.message.tool_calls?.map((tc) => ({ res.message.tool_calls?.map((tc) => ({
id: crypto.randomUUID(),
name: tc.function.name, name: tc.function.name,
arguments: tc.function.arguments, arguments: tc.function.arguments,
})) || [], })) || [],
@@ -101,8 +134,11 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
const stream = await this.ollamaClient.chat({ const stream = await this.ollamaClient.chat({
model: this.config.model, model: this.config.model,
messages: input.messages, messages: this.convertToOllamaMessages(input.messages),
stream: true, stream: true,
...(reasoningModels.find((m) => this.config.model.includes(m))
? { think: false }
: {}),
tools: ollamaTools.length > 0 ? ollamaTools : undefined, tools: ollamaTools.length > 0 ? ollamaTools : undefined,
options: { options: {
top_p: input.options?.topP ?? this.config.options?.topP, top_p: input.options?.topP ?? this.config.options?.topP,
@@ -126,6 +162,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
contentChunk: chunk.message.content, contentChunk: chunk.message.content,
toolCallChunk: toolCallChunk:
chunk.message.tool_calls?.map((tc) => ({ chunk.message.tool_calls?.map((tc) => ({
id: crypto.randomUUID(),
name: tc.function.name, name: tc.function.name,
arguments: tc.function.arguments, arguments: tc.function.arguments,
})) || [], })) || [],
@@ -140,7 +177,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
async generateObject<T>(input: GenerateObjectInput): Promise<T> { async generateObject<T>(input: GenerateObjectInput): Promise<T> {
const response = await this.ollamaClient.chat({ const response = await this.ollamaClient.chat({
model: this.config.model, model: this.config.model,
messages: input.messages, messages: this.convertToOllamaMessages(input.messages),
format: z.toJSONSchema(input.schema), format: z.toJSONSchema(input.schema),
...(reasoningModels.find((m) => this.config.model.includes(m)) ...(reasoningModels.find((m) => this.config.model.includes(m))
? { think: false } ? { think: false }
@@ -173,7 +210,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
const stream = await this.ollamaClient.chat({ const stream = await this.ollamaClient.chat({
model: this.config.model, model: this.config.model,
messages: input.messages, messages: this.convertToOllamaMessages(input.messages),
format: z.toJSONSchema(input.schema), format: z.toJSONSchema(input.schema),
stream: true, stream: true,
...(reasoningModels.find((m) => this.config.model.includes(m)) ...(reasoningModels.find((m) => this.config.model.includes(m))