diff --git a/src/lib/models/providers/openai/openaiLLM.ts b/src/lib/models/providers/openai/openaiLLM.ts index 22f23d4..e1dc531 100644 --- a/src/lib/models/providers/openai/openaiLLM.ts +++ b/src/lib/models/providers/openai/openaiLLM.ts @@ -7,8 +7,16 @@ import { GenerateTextInput, GenerateTextOutput, StreamTextOutput, + ToolCall, } from '../../types'; import { parse } from 'partial-json'; +import z from 'zod'; +import { + ChatCompletionMessageParam, + ChatCompletionTool, + ChatCompletionToolMessageParam, +} from 'openai/resources/index.mjs'; +import { Message } from '@/lib/types'; type OpenAIConfig = { apiKey: string; @@ -29,10 +37,38 @@ class OpenAILLM extends BaseLLM { }); } + convertToOpenAIMessages(messages: Message[]): ChatCompletionMessageParam[] { + return messages.map((msg) => { + if (msg.role === 'tool') { + return { + role: 'tool', + tool_call_id: msg.id, + content: msg.content, + } as ChatCompletionToolMessageParam; + } + + return msg; + }); + } + async generateText(input: GenerateTextInput): Promise { + const openaiTools: ChatCompletionTool[] = []; + + input.tools?.forEach((tool) => { + openaiTools.push({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: z.toJSONSchema(tool.schema), + }, + }); + }); + const response = await this.openAIClient.chat.completions.create({ model: this.config.model, - messages: input.messages, + tools: openaiTools.length > 0 ? openaiTools : undefined, + messages: this.convertToOpenAIMessages(input.messages), temperature: input.options?.temperature ?? this.config.options?.temperature ?? 1.0, top_p: input.options?.topP ?? this.config.options?.topP, @@ -49,6 +85,18 @@ class OpenAILLM extends BaseLLM { if (response.choices && response.choices.length > 0) { return { content: response.choices[0].message.content!, + toolCalls: + response.choices[0].message.tool_calls + ?.map((tc) => { + if (tc.type === 'function') { + return { + name: tc.function.name, + id: tc.id, + arguments: JSON.parse(tc.function.arguments), + }; + } + }) + .filter((tc) => tc !== undefined) || [], additionalInfo: { finishReason: response.choices[0].finish_reason, }, @@ -61,9 +109,23 @@ class OpenAILLM extends BaseLLM { async *streamText( input: GenerateTextInput, ): AsyncGenerator { + const openaiTools: ChatCompletionTool[] = []; + + input.tools?.forEach((tool) => { + openaiTools.push({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: z.toJSONSchema(tool.schema), + }, + }); + }); + const stream = await this.openAIClient.chat.completions.create({ model: this.config.model, - messages: input.messages, + messages: this.convertToOpenAIMessages(input.messages), + tools: openaiTools.length > 0 ? openaiTools : undefined, temperature: input.options?.temperature ?? this.config.options?.temperature ?? 1.0, top_p: input.options?.topP ?? this.config.options?.topP, @@ -78,10 +140,33 @@ class OpenAILLM extends BaseLLM { stream: true, }); + let recievedToolCalls: { name: string; id: string; arguments: string }[] = + []; + for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { + const toolCalls = chunk.choices[0].delta.tool_calls; yield { contentChunk: chunk.choices[0].delta.content || '', + toolCallChunk: + toolCalls?.map((tc) => { + if (tc.type === 'function') { + const call = { + name: tc.function?.name!, + id: tc.id!, + arguments: tc.function?.arguments || '', + }; + recievedToolCalls.push(call); + return { ...call, arguments: parse(call.arguments || '{}') }; + } else { + const existingCall = recievedToolCalls[tc.index]; + existingCall.arguments += tc.function?.arguments || ''; + return { + ...existingCall, + arguments: parse(existingCall.arguments), + }; + } + }) || [], done: chunk.choices[0].finish_reason !== null, additionalInfo: { finishReason: chunk.choices[0].finish_reason,