feat(search-agent): use function calling

This commit is contained in:
ItzCrazyKns
2025-12-06 15:38:40 +05:30
parent 2d82cd65d9
commit 9afea48d31

View File

@@ -1,43 +1,28 @@
import z from 'zod'; import { ActionOutput, ResearcherInput, ResearcherOutput } from '../types';
import {
ActionConfig,
ActionOutput,
ResearcherInput,
ResearcherOutput,
} from '../types';
import { ActionRegistry } from './actions'; import { ActionRegistry } from './actions';
import { getResearcherPrompt } from '@/lib/prompts/search/researcher'; import { getResearcherPrompt } from '@/lib/prompts/search/researcher';
import SessionManager from '@/lib/session'; import SessionManager from '@/lib/session';
import { ReasoningResearchBlock } from '@/lib/types'; import { Message, ReasoningResearchBlock } from '@/lib/types';
import formatChatHistoryAsString from '@/lib/utils/formatHistory'; import formatChatHistoryAsString from '@/lib/utils/formatHistory';
import { ToolCall } from '@/lib/models/types';
class Researcher { class Researcher {
async research( async research(
session: SessionManager, session: SessionManager,
input: ResearcherInput, input: ResearcherInput,
): Promise<ResearcherOutput> { ): Promise<ResearcherOutput> {
let findings: string = '';
let actionOutput: ActionOutput[] = []; let actionOutput: ActionOutput[] = [];
let maxIteration = let maxIteration =
input.config.mode === 'speed' input.config.mode === 'speed'
? 1 ? 2
: input.config.mode === 'balanced' : input.config.mode === 'balanced'
? 3 ? 6
: 25; : 25;
const availableActions = ActionRegistry.getAvailableActions({ const availableTools = ActionRegistry.getAvailableActionTools({
classification: input.classification, classification: input.classification,
}); });
const schema = z.object({
reasoning: z
.string()
.describe('The reasoning behind choosing the next action.'),
action: z
.union(availableActions.map((a) => a.schema))
.describe('The action to be performed next.'),
});
const availableActionsDescription = const availableActionsDescription =
ActionRegistry.getAvailableActionsDescriptions({ ActionRegistry.getAvailableActionsDescriptions({
classification: input.classification, classification: input.classification,
@@ -53,6 +38,18 @@ class Researcher {
}, },
}); });
const agentMessageHistory: Message[] = [
{
role: 'user',
content: `
<conversation>
${formatChatHistoryAsString(input.chatHistory.slice(-10))}
User: ${input.followUp} (Standalone question: ${input.classification.standaloneFollowUp})
</conversation>
`,
},
];
for (let i = 0; i < maxIteration; i++) { for (let i = 0; i < maxIteration; i++) {
const researcherPrompt = getResearcherPrompt( const researcherPrompt = getResearcherPrompt(
availableActionsDescription, availableActionsDescription,
@@ -61,27 +58,15 @@ class Researcher {
maxIteration, maxIteration,
); );
const actionStream = input.config.llm.streamObject<typeof schema>({ const actionStream = input.config.llm.streamText({
messages: [ messages: [
{ {
role: 'system', role: 'system',
content: researcherPrompt, content: researcherPrompt,
}, },
{ ...agentMessageHistory,
role: 'user',
content: `
<conversation>
${formatChatHistoryAsString(input.chatHistory.slice(-10))}
User: ${input.followUp} (Standalone question: ${input.classification.standaloneFollowUp})
</conversation>
<previous_actions>
${findings}
</previous_actions>
`,
},
], ],
schema, tools: availableTools,
}); });
const block = session.getBlock(researchBlockId); const block = session.getBlock(researchBlockId);
@@ -89,43 +74,26 @@ class Researcher {
let reasoningEmitted = false; let reasoningEmitted = false;
let reasoningId = crypto.randomUUID(); let reasoningId = crypto.randomUUID();
let finalActionRes: any; let finalToolCalls: ToolCall[] = [];
for await (const partialRes of actionStream) { for await (const partialRes of actionStream) {
try { if (partialRes.toolCallChunk.length > 0) {
if ( partialRes.toolCallChunk.forEach((tc) => {
partialRes.reasoning && if (
!reasoningEmitted && tc.name === '___plan' &&
block && tc.arguments['plan'] &&
block.type === 'research' !reasoningEmitted &&
) { block &&
reasoningEmitted = true; block.type === 'research'
block.data.subSteps.push({ ) {
id: reasoningId, reasoningEmitted = true;
type: 'reasoning',
reasoning: partialRes.reasoning, block.data.subSteps.push({
}); id: reasoningId,
session.updateBlock(researchBlockId, [ type: 'reasoning',
{ reasoning: tc.arguments['plan'],
op: 'replace', });
path: '/data/subSteps',
value: block.data.subSteps,
},
]);
} else if (
partialRes.reasoning &&
reasoningEmitted &&
block &&
block.type === 'research'
) {
const subStepIndex = block.data.subSteps.findIndex(
(step: any) => step.id === reasoningId,
);
if (subStepIndex !== -1) {
const subStep = block.data.subSteps[
subStepIndex
] as ReasoningResearchBlock;
subStep.reasoning = partialRes.reasoning;
session.updateBlock(researchBlockId, [ session.updateBlock(researchBlockId, [
{ {
op: 'replace', op: 'replace',
@@ -133,77 +101,118 @@ class Researcher {
value: block.data.subSteps, value: block.data.subSteps,
}, },
]); ]);
} } else if (
} tc.name === '___plan' &&
tc.arguments['plan'] &&
reasoningEmitted &&
block &&
block.type === 'research'
) {
const subStepIndex = block.data.subSteps.findIndex(
(step: any) => step.id === reasoningId,
);
finalActionRes = partialRes; if (subStepIndex !== -1) {
} catch (e) { const subStep = block.data.subSteps[
// nothing subStepIndex
] as ReasoningResearchBlock;
subStep.reasoning = tc.arguments['plan'];
session.updateBlock(researchBlockId, [
{
op: 'replace',
path: '/data/subSteps',
value: block.data.subSteps,
},
]);
}
}
const existingIndex = finalToolCalls.findIndex(
(ftc) => ftc.id === tc.id,
);
if (existingIndex !== -1) {
finalToolCalls[existingIndex].arguments = tc.arguments;
} else {
finalToolCalls.push(tc);
}
});
} }
} }
if (finalActionRes.action.type === 'done') { if (finalToolCalls.length === 0) {
break; break;
} }
const actionConfig: ActionConfig = { if (finalToolCalls[finalToolCalls.length - 1].name === 'done') {
type: finalActionRes.action.type as string, break;
params: finalActionRes.action, }
};
const queries = actionConfig.params.queries || []; agentMessageHistory.push({
if (block && block.type === 'research') { role: 'assistant',
content: '',
tool_calls: finalToolCalls,
});
const searchCalls = finalToolCalls.filter(
(tc) =>
tc.name === 'web_search' ||
tc.name === 'academic_search' ||
tc.name === 'discussion_search',
);
if (searchCalls.length > 0 && block && block.type === 'research') {
block.data.subSteps.push({ block.data.subSteps.push({
id: crypto.randomUUID(), id: crypto.randomUUID(),
type: 'searching', type: 'searching',
searching: queries, searching: searchCalls.map((sc) => sc.arguments.queries).flat(),
}); });
session.updateBlock(researchBlockId, [ session.updateBlock(researchBlockId, [
{ op: 'replace', path: '/data/subSteps', value: block.data.subSteps }, {
op: 'replace',
path: '/data/subSteps',
value: block.data.subSteps,
},
]); ]);
} }
findings += `\n---\nIteration ${i + 1}:\n`; const actionResults = await ActionRegistry.executeAll(finalToolCalls, {
findings += 'Reasoning: ' + finalActionRes.reasoning + '\n'; llm: input.config.llm,
findings += `Executing Action: ${actionConfig.type} with params ${JSON.stringify(actionConfig.params)}\n`; embedding: input.config.embedding,
session: session,
});
const actionResult = await ActionRegistry.execute( actionOutput.push(...actionResults);
actionConfig.type,
actionConfig.params, actionResults.forEach((action, i) => {
{ agentMessageHistory.push({
llm: input.config.llm, role: 'tool',
embedding: input.config.embedding, id: finalToolCalls[i].id,
session: session, name: finalToolCalls[i].name,
}, content: JSON.stringify(action),
});
});
const searchResults = actionResults.filter(
(a) => a.type === 'search_results',
); );
actionOutput.push(actionResult); if (searchResults.length > 0 && block && block.type === 'research') {
block.data.subSteps.push({
id: crypto.randomUUID(),
type: 'reading',
reading: searchResults.flatMap((a) => a.results),
});
if (actionResult.type === 'search_results') { session.updateBlock(researchBlockId, [
if (block && block.type === 'research') { {
block.data.subSteps.push({ op: 'replace',
id: crypto.randomUUID(), path: '/data/subSteps',
type: 'reading', value: block.data.subSteps,
reading: actionResult.results, },
}); ]);
session.updateBlock(researchBlockId, [
{
op: 'replace',
path: '/data/subSteps',
value: block.data.subSteps,
},
]);
}
findings += actionResult.results
.map(
(r) =>
`Title: ${r.metadata.title}\nURL: ${r.metadata.url}\nContent: ${r.content}\n`,
)
.join('\n');
} }
findings += '\n---------\n';
} }
const searchResults = actionOutput.filter( const searchResults = actionOutput.filter(
@@ -212,12 +221,7 @@ class Researcher {
session.emit('data', { session.emit('data', {
type: 'sources', type: 'sources',
data: searchResults data: searchResults.flatMap((a) => a.results),
.flatMap((a) => a.results)
.map((r) => ({
content: r.content,
metadata: r.metadata,
})),
}); });
return { return {