feat(providers): add optimization modes

This commit is contained in:
ItzCrazyKns
2024-10-11 10:35:59 +05:30
parent 877735b852
commit 7cce853618
9 changed files with 294 additions and 88 deletions

View File

@ -118,7 +118,6 @@ const createBasicAcademicSearchRetrieverChain = (llm: BaseChatModel) => {
engines: [
'arxiv',
'google scholar',
'internetarchivescholar',
'pubmed',
],
});
@ -143,6 +142,7 @@ const createBasicAcademicSearchRetrieverChain = (llm: BaseChatModel) => {
const createBasicAcademicSearchAnsweringChain = (
llm: BaseChatModel,
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => {
const basicAcademicSearchRetrieverChain =
createBasicAcademicSearchRetrieverChain(llm);
@ -168,26 +168,33 @@ const createBasicAcademicSearchAnsweringChain = (
(doc) => doc.pageContent && doc.pageContent.length > 0,
);
const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(docsWithContent.map((doc) => doc.pageContent)),
embeddings.embedQuery(query),
]);
if (optimizationMode === 'speed') {
return docsWithContent.slice(0, 15);
} else if (optimizationMode === 'balanced') {
console.log('Balanced mode');
const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(
docsWithContent.map((doc) => doc.pageContent),
),
embeddings.embedQuery(query),
]);
const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding);
const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding);
return {
index: i,
similarity: sim,
};
});
return {
index: i,
similarity: sim,
};
});
const sortedDocs = similarity
.sort((a, b) => b.similarity - a.similarity)
.slice(0, 15)
.map((sim) => docsWithContent[sim.index]);
const sortedDocs = similarity
.sort((a, b) => b.similarity - a.similarity)
.slice(0, 15)
.map((sim) => docsWithContent[sim.index]);
return sortedDocs;
return sortedDocs;
}
};
return RunnableSequence.from([
@ -224,12 +231,17 @@ const basicAcademicSearch = (
history: BaseMessage[],
llm: BaseChatModel,
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => {
const emitter = new eventEmitter();
try {
const basicAcademicSearchAnsweringChain =
createBasicAcademicSearchAnsweringChain(llm, embeddings);
createBasicAcademicSearchAnsweringChain(
llm,
embeddings,
optimizationMode,
);
const stream = basicAcademicSearchAnsweringChain.streamEvents(
{
@ -258,8 +270,15 @@ const handleAcademicSearch = (
history: BaseMessage[],
llm: BaseChatModel,
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
) => {
const emitter = basicAcademicSearch(message, history, llm, embeddings);
const emitter = basicAcademicSearch(
message,
history,
llm,
embeddings,
optimizationMode,
);
return emitter;
};