diff --git a/src/lib/agents/search/researcher/actions/registry.ts b/src/lib/agents/search/researcher/actions/registry.ts index 3a8eda6..8e0530c 100644 --- a/src/lib/agents/search/researcher/actions/registry.ts +++ b/src/lib/agents/search/researcher/actions/registry.ts @@ -20,6 +20,7 @@ class ActionRegistry { static getAvailableActions(config: { classification: ClassifierOutput; + fileIds: string[]; mode: SearchAgentConfig['mode']; }): ResearchAction[] { return Array.from( @@ -29,6 +30,7 @@ class ActionRegistry { static getAvailableActionTools(config: { classification: ClassifierOutput; + fileIds: string[]; mode: SearchAgentConfig['mode']; }): Tool[] { const availableActions = this.getAvailableActions(config); @@ -42,19 +44,26 @@ class ActionRegistry { static getAvailableActionsDescriptions(config: { classification: ClassifierOutput; + fileIds: string[]; mode: SearchAgentConfig['mode']; }): string { const availableActions = this.getAvailableActions(config); return availableActions - .map((action) => `\n${action.getDescription({ mode: config.mode })}\n`) + .map( + (action) => + `\n${action.getDescription({ mode: config.mode })}\n`, + ) .join('\n\n'); } static async execute( name: string, params: any, - additionalConfig: AdditionalConfig & { researchBlockId: string }, + additionalConfig: AdditionalConfig & { + researchBlockId: string; + fileIds: string[]; + }, ) { const action = this.actions.get(name); @@ -67,7 +76,10 @@ class ActionRegistry { static async executeAll( actions: ToolCall[], - additionalConfig: AdditionalConfig & { researchBlockId: string }, + additionalConfig: AdditionalConfig & { + researchBlockId: string; + fileIds: string[]; + }, ): Promise { const results: ActionOutput[] = []; diff --git a/src/lib/agents/search/types.ts b/src/lib/agents/search/types.ts index f1ae862..0733de3 100644 --- a/src/lib/agents/search/types.ts +++ b/src/lib/agents/search/types.ts @@ -8,6 +8,7 @@ export type SearchSources = 'web' | 'discussions' | 'academic'; export type SearchAgentConfig = { sources: SearchSources[]; + fileIds: string[]; llm: BaseLLM; embedding: BaseEmbedding; mode: 'speed' | 'balanced' | 'quality'; @@ -102,11 +103,16 @@ export interface ResearchAction< schema: z.ZodObject; getToolDescription: (config: { mode: SearchAgentConfig['mode'] }) => string; getDescription: (config: { mode: SearchAgentConfig['mode'] }) => string; - enabled: (config: { classification: ClassifierOutput, mode: SearchAgentConfig['mode'] }) => boolean; + enabled: (config: { + classification: ClassifierOutput; + fileIds: string[]; + mode: SearchAgentConfig['mode']; + }) => boolean; execute: ( params: z.infer, additionalConfig: AdditionalConfig & { researchBlockId: string; + fileIds: string[]; }, ) => Promise; }