mirror of
https://github.com/ItzCrazyKns/Perplexica.git
synced 2025-12-14 15:48:15 +00:00
feat(ollama-llm): process ollama messages with tool calls
This commit is contained in:
@@ -7,8 +7,10 @@ import {
|
|||||||
GenerateTextOutput,
|
GenerateTextOutput,
|
||||||
StreamTextOutput,
|
StreamTextOutput,
|
||||||
} from '../../types';
|
} from '../../types';
|
||||||
import { Ollama, Tool as OllamaTool } from 'ollama';
|
import { Ollama, Tool as OllamaTool, Message as OllamaMessage } from 'ollama';
|
||||||
import { parse } from 'partial-json';
|
import { parse } from 'partial-json';
|
||||||
|
import crypto from 'crypto';
|
||||||
|
import { Message } from '@/lib/types';
|
||||||
|
|
||||||
type OllamaConfig = {
|
type OllamaConfig = {
|
||||||
baseURL: string;
|
baseURL: string;
|
||||||
@@ -35,6 +37,33 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
convertToOllamaMessages(messages: Message[]): OllamaMessage[] {
|
||||||
|
return messages.map((msg) => {
|
||||||
|
if (msg.role === 'tool') {
|
||||||
|
return {
|
||||||
|
role: 'tool',
|
||||||
|
tool_name: msg.name,
|
||||||
|
content: msg.content,
|
||||||
|
} as OllamaMessage;
|
||||||
|
} else if (msg.role === 'assistant') {
|
||||||
|
return {
|
||||||
|
role: 'assistant',
|
||||||
|
content: msg.content,
|
||||||
|
tool_calls:
|
||||||
|
msg.tool_calls?.map((tc, i) => ({
|
||||||
|
function: {
|
||||||
|
index: i,
|
||||||
|
name: tc.name,
|
||||||
|
arguments: tc.arguments,
|
||||||
|
},
|
||||||
|
})) || [],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
async generateText(input: GenerateTextInput): Promise<GenerateTextOutput> {
|
async generateText(input: GenerateTextInput): Promise<GenerateTextOutput> {
|
||||||
const ollamaTools: OllamaTool[] = [];
|
const ollamaTools: OllamaTool[] = [];
|
||||||
|
|
||||||
@@ -51,8 +80,11 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
|
|
||||||
const res = await this.ollamaClient.chat({
|
const res = await this.ollamaClient.chat({
|
||||||
model: this.config.model,
|
model: this.config.model,
|
||||||
messages: input.messages,
|
messages: this.convertToOllamaMessages(input.messages),
|
||||||
tools: ollamaTools.length > 0 ? ollamaTools : undefined,
|
tools: ollamaTools.length > 0 ? ollamaTools : undefined,
|
||||||
|
...(reasoningModels.find((m) => this.config.model.includes(m))
|
||||||
|
? { think: false }
|
||||||
|
: {}),
|
||||||
options: {
|
options: {
|
||||||
top_p: input.options?.topP ?? this.config.options?.topP,
|
top_p: input.options?.topP ?? this.config.options?.topP,
|
||||||
temperature:
|
temperature:
|
||||||
@@ -74,6 +106,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
content: res.message.content,
|
content: res.message.content,
|
||||||
toolCalls:
|
toolCalls:
|
||||||
res.message.tool_calls?.map((tc) => ({
|
res.message.tool_calls?.map((tc) => ({
|
||||||
|
id: crypto.randomUUID(),
|
||||||
name: tc.function.name,
|
name: tc.function.name,
|
||||||
arguments: tc.function.arguments,
|
arguments: tc.function.arguments,
|
||||||
})) || [],
|
})) || [],
|
||||||
@@ -101,8 +134,11 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
|
|
||||||
const stream = await this.ollamaClient.chat({
|
const stream = await this.ollamaClient.chat({
|
||||||
model: this.config.model,
|
model: this.config.model,
|
||||||
messages: input.messages,
|
messages: this.convertToOllamaMessages(input.messages),
|
||||||
stream: true,
|
stream: true,
|
||||||
|
...(reasoningModels.find((m) => this.config.model.includes(m))
|
||||||
|
? { think: false }
|
||||||
|
: {}),
|
||||||
tools: ollamaTools.length > 0 ? ollamaTools : undefined,
|
tools: ollamaTools.length > 0 ? ollamaTools : undefined,
|
||||||
options: {
|
options: {
|
||||||
top_p: input.options?.topP ?? this.config.options?.topP,
|
top_p: input.options?.topP ?? this.config.options?.topP,
|
||||||
@@ -126,6 +162,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
contentChunk: chunk.message.content,
|
contentChunk: chunk.message.content,
|
||||||
toolCallChunk:
|
toolCallChunk:
|
||||||
chunk.message.tool_calls?.map((tc) => ({
|
chunk.message.tool_calls?.map((tc) => ({
|
||||||
|
id: crypto.randomUUID(),
|
||||||
name: tc.function.name,
|
name: tc.function.name,
|
||||||
arguments: tc.function.arguments,
|
arguments: tc.function.arguments,
|
||||||
})) || [],
|
})) || [],
|
||||||
@@ -140,7 +177,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
async generateObject<T>(input: GenerateObjectInput): Promise<T> {
|
async generateObject<T>(input: GenerateObjectInput): Promise<T> {
|
||||||
const response = await this.ollamaClient.chat({
|
const response = await this.ollamaClient.chat({
|
||||||
model: this.config.model,
|
model: this.config.model,
|
||||||
messages: input.messages,
|
messages: this.convertToOllamaMessages(input.messages),
|
||||||
format: z.toJSONSchema(input.schema),
|
format: z.toJSONSchema(input.schema),
|
||||||
...(reasoningModels.find((m) => this.config.model.includes(m))
|
...(reasoningModels.find((m) => this.config.model.includes(m))
|
||||||
? { think: false }
|
? { think: false }
|
||||||
@@ -173,7 +210,7 @@ class OllamaLLM extends BaseLLM<OllamaConfig> {
|
|||||||
|
|
||||||
const stream = await this.ollamaClient.chat({
|
const stream = await this.ollamaClient.chat({
|
||||||
model: this.config.model,
|
model: this.config.model,
|
||||||
messages: input.messages,
|
messages: this.convertToOllamaMessages(input.messages),
|
||||||
format: z.toJSONSchema(input.schema),
|
format: z.toJSONSchema(input.schema),
|
||||||
stream: true,
|
stream: true,
|
||||||
...(reasoningModels.find((m) => this.config.model.includes(m))
|
...(reasoningModels.find((m) => this.config.model.includes(m))
|
||||||
|
|||||||
Reference in New Issue
Block a user