From 046f1595286f92f8a9245e29fb36224092658886 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:52:40 +0530 Subject: [PATCH] feat(widgets): use new classifier, implement new widget executor, delete registry --- src/lib/agents/search/index.ts | 18 ++-- .../search/researcher/actions/webSearch.ts | 2 +- src/lib/agents/search/researcher/index.ts | 4 +- src/lib/agents/search/types.ts | 25 ++--- .../search/widgets/calculationWidget.ts | 96 +++++++++---------- src/lib/agents/search/widgets/executor.ts | 36 +++++++ src/lib/agents/search/widgets/index.ts | 10 +- src/lib/agents/search/widgets/registry.ts | 65 ------------- src/lib/agents/search/widgets/stockWidget.ts | 91 +++++++++--------- .../agents/search/widgets/weatherWidget.ts | 79 +++++++++------ 10 files changed, 204 insertions(+), 222 deletions(-) create mode 100644 src/lib/agents/search/widgets/executor.ts delete mode 100644 src/lib/agents/search/widgets/registry.ts diff --git a/src/lib/agents/search/index.ts b/src/lib/agents/search/index.ts index ab12bfe..bc6ff5b 100644 --- a/src/lib/agents/search/index.ts +++ b/src/lib/agents/search/index.ts @@ -1,26 +1,24 @@ import { ResearcherOutput, SearchAgentInput } from './types'; import SessionManager from '@/lib/session'; -import Classifier from './classifier'; -import { WidgetRegistry } from './widgets'; +import { classify } from './classifier'; import Researcher from './researcher'; import { getWriterPrompt } from '@/lib/prompts/search/writer'; -import fs from 'fs'; +import { WidgetExecutor } from './widgets'; class SearchAgent { async searchAsync(session: SessionManager, input: SearchAgentInput) { - const classifier = new Classifier(); - - const classification = await classifier.classify({ + const classification = await classify({ chatHistory: input.chatHistory, enabledSources: input.config.sources, query: input.followUp, llm: input.config.llm, }); - const widgetPromise = WidgetRegistry.executeAll(classification.widgets, { + const widgetPromise = WidgetExecutor.executeAll({ + classification, + chatHistory: input.chatHistory, + followUp: input.followUp, llm: input.config.llm, - embedding: input.config.embedding, - session: session, }).then((widgetOutputs) => { widgetOutputs.forEach((o) => { session.emitBlock({ @@ -37,7 +35,7 @@ class SearchAgent { let searchPromise: Promise | null = null; - if (!classification.skipSearch) { + if (!classification.classification.skipSearch) { const researcher = new Researcher(); searchPromise = researcher.research(session, { chatHistory: input.chatHistory, diff --git a/src/lib/agents/search/researcher/actions/webSearch.ts b/src/lib/agents/search/researcher/actions/webSearch.ts index e5ffdd3..17f5e61 100644 --- a/src/lib/agents/search/researcher/actions/webSearch.ts +++ b/src/lib/agents/search/researcher/actions/webSearch.ts @@ -26,7 +26,7 @@ const webSearchAction: ResearchAction = { name: 'web_search', description: actionDescription, schema: actionSchema, - enabled: (config) => config.classification.intents.includes('web_search'), + enabled: (config) => true, execute: async (input, _) => { let results: Chunk[] = []; diff --git a/src/lib/agents/search/researcher/index.ts b/src/lib/agents/search/researcher/index.ts index 3eb9478..9ca1fed 100644 --- a/src/lib/agents/search/researcher/index.ts +++ b/src/lib/agents/search/researcher/index.ts @@ -61,9 +61,7 @@ class Researcher { maxIteration, ); - const actionStream = input.config.llm.streamObject< - z.infer - >({ + const actionStream = input.config.llm.streamObject({ messages: [ { role: 'system', diff --git a/src/lib/agents/search/types.ts b/src/lib/agents/search/types.ts index c30d4f7..589fa2d 100644 --- a/src/lib/agents/search/types.ts +++ b/src/lib/agents/search/types.ts @@ -19,26 +19,17 @@ export type SearchAgentInput = { config: SearchAgentConfig; }; -export interface Intent { - name: string; - description: string; - requiresSearch: boolean; - enabled: (config: { sources: SearchSources[] }) => boolean; -} - -export type Widget = z.ZodObject> = { - name: string; - description: string; - schema: TSchema; - execute: ( - params: z.infer, - additionalConfig: AdditionalConfig, - ) => Promise; +export type WidgetInput = { + chatHistory: ChatTurnMessage[]; + followUp: string; + classification: ClassifierOutput; + llm: BaseLLM; }; -export type WidgetConfig = { +export type Widget = { type: string; - params: Record; + shouldExecute: (classification: ClassifierOutput) => boolean; + execute: (input: WidgetInput) => Promise; }; export type WidgetOutput = { diff --git a/src/lib/agents/search/widgets/calculationWidget.ts b/src/lib/agents/search/widgets/calculationWidget.ts index 1c1ba51..0026741 100644 --- a/src/lib/agents/search/widgets/calculationWidget.ts +++ b/src/lib/agents/search/widgets/calculationWidget.ts @@ -1,66 +1,66 @@ import z from 'zod'; import { Widget } from '../types'; -import { evaluate as mathEval } from 'mathjs'; +import formatChatHistoryAsString from '@/lib/utils/formatHistory'; +import { exp, evaluate as mathEval } from 'mathjs'; const schema = z.object({ - type: z.literal('calculation'), expression: z .string() - .describe( - "A valid mathematical expression to be evaluated (e.g., '2 + 2', '3 * (4 + 5)').", - ), + .describe('Mathematical expression to calculate or evaluate.'), + notPresent: z + .boolean() + .describe('Whether there is any need for the calculation widget.'), }); -const calculationWidget: Widget = { - name: 'calculation', - description: `Performs mathematical calculations and evaluates mathematical expressions. Supports arithmetic operations, algebraic equations, functions, and complex mathematical computations. +const system = ` + +Assistant is a calculation expression extractor. You will recieve a user follow up and a conversation history. +Your task is to determine if there is a mathematical expression that needs to be calculated or evaluated. If there is, extract the expression and return it. If there is no need for any calculation, set notPresent to true. + -**What it provides:** -- Evaluates mathematical expressions and returns computed results -- Handles basic arithmetic (+, -, *, /) -- Supports functions (sqrt, sin, cos, log, etc.) -- Can process complex expressions with parentheses and order of operations + +Make sure that the extracted expression is valid and can be used to calculate the result with Math JS library (https://mathjs.org/). If the expression is not valid, set notPresent to true. +If you feel like you cannot extract a valid expression, set notPresent to true. + -**When to use:** -- User asks to calculate, compute, or evaluate a mathematical expression -- Questions like "what is X", "calculate Y", "how much is Z" where X/Y/Z are math expressions -- Any request involving numbers and mathematical operations - -**Example call:** + +You must respond in the following JSON format without any extra text, explanations or filler sentences: { - "type": "calculation", - "expression": "25% of 480" + "expression": string, + "notPresent": boolean } + +`; -{ - "type": "calculation", - "expression": "sqrt(144) + 5 * 2" -} - -**Important:** The expression must be valid mathematical syntax that can be evaluated by mathjs. Format percentages as "0.25 * 480" or "25% of 480". Do not include currency symbols, units, or non-mathematical text in the expression.`, - schema: schema, - execute: async (params, _) => { - try { - const result = mathEval(params.expression); - - return { - type: 'calculation_result', - llmContext: `The result of the expression "${params.expression}" is ${result}.`, - data: { - expression: params.expression, - result: result, +const calculationWidget: Widget = { + type: 'calculationWidget', + shouldExecute: (classification) => + classification.classification.showCalculationWidget, + execute: async (input) => { + const output = await input.llm.generateObject({ + messages: [ + { + role: 'system', + content: system, }, - }; - } catch (error) { - return { - type: 'calculation_result', - llmContext: 'Failed to evaluate mathematical expression.', - data: { - expression: params.expression, - result: `Error evaluating expression: ${error}`, + { + role: 'user', + content: `\n${formatChatHistoryAsString(input.chatHistory)}\n\n\n${input.followUp}\n`, }, - }; - } + ], + schema, + }); + + const result = mathEval(output.expression); + + return { + type: 'calculation_result', + llmContext: `The result of the calculation for the expression "${output.expression}" is: ${result}`, + data: { + expression: output.expression, + result, + }, + }; }, }; diff --git a/src/lib/agents/search/widgets/executor.ts b/src/lib/agents/search/widgets/executor.ts new file mode 100644 index 0000000..89f1830 --- /dev/null +++ b/src/lib/agents/search/widgets/executor.ts @@ -0,0 +1,36 @@ +import { Widget, WidgetInput, WidgetOutput } from '../types'; + +class WidgetExecutor { + static widgets = new Map(); + + static register(widget: Widget) { + this.widgets.set(widget.type, widget); + } + + static getWidget(type: string): Widget | undefined { + return this.widgets.get(type); + } + + static async executeAll(input: WidgetInput): Promise { + const results: WidgetOutput[] = []; + + await Promise.all( + Array.from(this.widgets.values()).map(async (widget) => { + try { + if (widget.shouldExecute(input.classification)) { + const output = await widget.execute(input); + if (output) { + results.push(output); + } + } + } catch (e) { + console.log(`Error executing widget ${widget.type}:`, e); + } + }), + ); + + return results; + } +} + +export default WidgetExecutor; diff --git a/src/lib/agents/search/widgets/index.ts b/src/lib/agents/search/widgets/index.ts index ff18d40..9958b0d 100644 --- a/src/lib/agents/search/widgets/index.ts +++ b/src/lib/agents/search/widgets/index.ts @@ -1,10 +1,10 @@ import calculationWidget from './calculationWidget'; -import WidgetRegistry from './registry'; +import WidgetExecutor from './executor'; import weatherWidget from './weatherWidget'; import stockWidget from './stockWidget'; -WidgetRegistry.register(weatherWidget); -WidgetRegistry.register(calculationWidget); -WidgetRegistry.register(stockWidget); +WidgetExecutor.register(weatherWidget); +WidgetExecutor.register(calculationWidget); +WidgetExecutor.register(stockWidget); -export { WidgetRegistry }; +export { WidgetExecutor }; diff --git a/src/lib/agents/search/widgets/registry.ts b/src/lib/agents/search/widgets/registry.ts deleted file mode 100644 index d8ceaba..0000000 --- a/src/lib/agents/search/widgets/registry.ts +++ /dev/null @@ -1,65 +0,0 @@ -import { - AdditionalConfig, - SearchAgentConfig, - Widget, - WidgetConfig, - WidgetOutput, -} from '../types'; - -class WidgetRegistry { - private static widgets = new Map(); - - static register(widget: Widget) { - this.widgets.set(widget.name, widget); - } - - static get(name: string): Widget | undefined { - return this.widgets.get(name); - } - - static getAll(): Widget[] { - return Array.from(this.widgets.values()); - } - - static getDescriptions(): string { - return Array.from(this.widgets.values()) - .map((widget) => `${widget.name}: ${widget.description}`) - .join('\n\n'); - } - - static async execute( - name: string, - params: any, - config: AdditionalConfig, - ): Promise { - const widget = this.get(name); - - if (!widget) { - throw new Error(`Widget with name ${name} not found`); - } - - return widget.execute(params, config); - } - - static async executeAll( - widgets: WidgetConfig[], - additionalConfig: AdditionalConfig, - ): Promise { - const results: WidgetOutput[] = []; - - await Promise.all( - widgets.map(async (widgetConfig) => { - const output = await this.execute( - widgetConfig.type, - widgetConfig.params, - additionalConfig, - ); - results.push(output); - }), - ); - - return results; - } -} - -export default WidgetRegistry; diff --git a/src/lib/agents/search/widgets/stockWidget.ts b/src/lib/agents/search/widgets/stockWidget.ts index d728460..4ac2059 100644 --- a/src/lib/agents/search/widgets/stockWidget.ts +++ b/src/lib/agents/search/widgets/stockWidget.ts @@ -1,13 +1,13 @@ import z from 'zod'; import { Widget } from '../types'; import YahooFinance from 'yahoo-finance2'; +import formatChatHistoryAsString from '@/lib/utils/formatHistory'; const yf = new YahooFinance({ suppressNotices: ['yahooSurvey'], }); const schema = z.object({ - type: z.literal('stock'), name: z .string() .describe( @@ -19,60 +19,59 @@ const schema = z.object({ .describe( "Optional array of up to 3 stock names to compare against the base name (e.g., ['Microsoft', 'GOOGL', 'Meta']). Charts will show percentage change comparison.", ), + notPresent: z + .boolean() + .describe('Whether there is no need for the stock widget.'), }); -const stockWidget: Widget = { - name: 'stock', - description: `Provides comprehensive real-time stock market data and financial information for any publicly traded company. Returns detailed quote data, market status, trading metrics, and company fundamentals. +const systemPrompt = ` + +You are a stock ticker/name extractor. You will receive a user follow up and a conversation history. +Your task is to determine if the user is asking about stock information and extract the stock name(s) they want data for. + -You can set skipSearch to true if the stock widget can fully answer the user's query without needing additional web search. + +- If the user is asking about a stock, extract the primary stock name or ticker. +- If the user wants to compare stocks, extract up to 3 comparison stock names in comparisonNames. +- You can use either stock names (e.g., "Nvidia", "Apple") or tickers (e.g., "NVDA", "AAPL"). +- If you cannot determine a valid stock or the query is not stock-related, set notPresent to true. +- If no comparison is needed, set comparisonNames to an empty array. + -**What it provides:** -- **Real-time Price Data**: Current price, previous close, open price, day's range (high/low) -- **Market Status**: Whether market is currently open or closed, trading sessions -- **Trading Metrics**: Volume, average volume, bid/ask prices and sizes -- **Performance**: Price changes (absolute and percentage), 52-week high/low range -- **Valuation**: Market capitalization, P/E ratio, earnings per share (EPS) -- **Dividends**: Dividend rate, dividend yield, ex-dividend date -- **Company Info**: Full company name, exchange, currency, sector/industry (when available) -- **Advanced Metrics**: Beta, trailing/forward P/E, book value, price-to-book ratio -- **Charts Data**: Historical price movements for visualization -- **Comparison**: Compare up to 3 stocks side-by-side with percentage-based performance visualization - -**When to use:** -- User asks about a stock price ("What's AAPL stock price?", "How is Tesla doing?") -- Questions about company market performance ("Is Microsoft up or down today?") -- Requests for stock market data, trading info, or company valuation -- Queries about dividends, P/E ratio, market cap, or other financial metrics -- Any stock/equity-related question for a specific company -- Stock comparisons ("Compare AAPL vs MSFT", "How is TSLA doing vs RIVN and LCID?") - -**Example calls:** + +You must respond in the following JSON format without any extra text, explanations or filler sentences: { - "type": "stock", - "name": "AAPL" + "name": string, + "comparisonNames": string[], + "notPresent": boolean } + +`; -{ - "type": "stock", - "name": "TSLA", - "comparisonNames": ["RIVN", "LCID"] -} +const stockWidget: Widget = { + type: 'stockWidget', + shouldExecute: (classification) => + classification.classification.showStockWidget, + execute: async (input) => { + const output = await input.llm.generateObject({ + messages: [ + { + role: 'system', + content: systemPrompt, + }, + { + role: 'user', + content: `\n${formatChatHistoryAsString(input.chatHistory)}\n\n\n${input.followUp}\n`, + }, + ], + schema, + }); -{ - "type": "stock", - "name": "Google", - "comparisonNames": ["Microsoft", "Meta", "Amazon"] -} + if (output.notPresent) { + return; + } -**Important:** -- You can use both tickers and names (prefer name when you're not aware of the ticker). -- For companies with multiple share classes, use the most common one. -- The widget works for stocks listed on major exchanges (NYSE, NASDAQ, etc.) -- Returns comprehensive data; the UI will display relevant metrics based on availability -- Market data may be delayed by 15-20 minutes for free data sources during trading hours`, - schema: schema, - execute: async (params, _) => { + const params = output; try { const name = params.name; diff --git a/src/lib/agents/search/widgets/weatherWidget.ts b/src/lib/agents/search/widgets/weatherWidget.ts index 2c6d7ab..4739324 100644 --- a/src/lib/agents/search/widgets/weatherWidget.ts +++ b/src/lib/agents/search/widgets/weatherWidget.ts @@ -1,8 +1,8 @@ import z from 'zod'; import { Widget } from '../types'; +import formatChatHistoryAsString from '@/lib/utils/formatHistory'; -const WeatherWidgetSchema = z.object({ - type: z.literal('weather'), +const schema = z.object({ location: z .string() .describe( @@ -18,38 +18,63 @@ const WeatherWidgetSchema = z.object({ .describe( 'Longitude coordinate in decimal degrees (e.g., -74.0060). Only use when location name is empty.', ), + notPresent: z + .boolean() + .describe('Whether there is no need for the weather widget.'), }); -const weatherWidget: Widget = { - name: 'weather', - description: `Provides comprehensive current weather information and forecasts for any location worldwide. Returns real-time weather data including temperature, conditions, humidity, wind, and multi-day forecasts. +const systemPrompt = ` + +You are a location extractor for weather queries. You will receive a user follow up and a conversation history. +Your task is to determine if the user is asking about weather and extract the location they want weather for. + -You can set skipSearch to true if the weather widget can fully answer the user's query without needing additional web search. + +- If the user is asking about weather, extract the location name OR coordinates (never both). +- If using location name, set lat and lon to 0. +- If using coordinates, set location to empty string. +- If you cannot determine a valid location or the query is not weather-related, set notPresent to true. +- Location should be specific (city, state/region, country) for best results. +- You have to give the location so that it can be used to fetch weather data, it cannot be left empty unless notPresent is true. +- Make sure to infer short forms of location names (e.g., "NYC" -> "New York City", "LA" -> "Los Angeles"). + -**What it provides:** -- Current weather conditions (temperature, feels-like, humidity, precipitation) -- Wind speed, direction, and gusts -- Weather codes/conditions (clear, cloudy, rainy, etc.) -- Hourly forecast for next 24 hours -- Daily forecast for next 7 days (high/low temps, precipitation probability) -- Timezone information - -**When to use:** -- User asks about weather in a location ("weather in X", "is it raining in Y") -- Questions about temperature, conditions, or forecast -- Any weather-related query for a specific place - -**Example call:** + +You must respond in the following JSON format without any extra text, explanations or filler sentences: { - "type": "weather", - "location": "San Francisco, CA, USA", - "lat": 0, - "lon": 0 + "location": string, + "lat": number, + "lon": number, + "notPresent": boolean } + +`; + +const weatherWidget: Widget = { + type: 'weatherWidget', + shouldExecute: (classification) => + classification.classification.showWeatherWidget, + execute: async (input) => { + const output = await input.llm.generateObject({ + messages: [ + { + role: 'system', + content: systemPrompt, + }, + { + role: 'user', + content: `\n${formatChatHistoryAsString(input.chatHistory)}\n\n\n${input.followUp}\n`, + }, + ], + schema, + }); + + if (output.notPresent) { + return; + } + + const params = output; -**Important:** Provide EITHER a location name OR latitude/longitude coordinates, never both. If using location name, set lat/lon to 0. Location should be specific (city, state/region, country) for best results.`, - schema: WeatherWidgetSchema, - execute: async (params, _) => { try { if ( params.location === '' &&