@@ -363,6 +363,10 @@ def get_answer(llm: AzureChatOpenAI,
363
363
364
364
class SearchInput (BaseModel ):
365
365
query : str = Field (description = "should be a search query" )
366
+ return_direct : bool = Field (
367
+ description = "Whether or the result of this should be returned directly to the user without you seeing what it is" ,
368
+ default = False ,
369
+ )
366
370
367
371
class GetDocSearchResults_Tool (BaseTool ):
368
372
name = "docsearch"
@@ -375,7 +379,7 @@ class GetDocSearchResults_Tool(BaseTool):
375
379
sas_token : str = ""
376
380
377
381
def _run (
378
- self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None
382
+ self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None
379
383
) -> str :
380
384
381
385
retriever = CustomAzureSearchRetriever (indexes = self .indexes , topK = self .k , reranker_threshold = self .reranker_th ,
@@ -385,7 +389,7 @@ def _run(
385
389
return results
386
390
387
391
async def _arun (
388
- self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None
392
+ self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None
389
393
) -> str :
390
394
"""Use the tool asynchronously."""
391
395
@@ -424,15 +428,15 @@ def __init__(self, **data):
424
428
self .agent_executor = AgentExecutor (agent = agent , tools = tools , verbose = self .verbose , callback_manager = self .callbacks , handle_parsing_errors = True )
425
429
426
430
427
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
431
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
428
432
try :
429
433
result = self .agent_executor .invoke ({"question" : query })
430
434
return result ['output' ]
431
435
except Exception as e :
432
436
print (e )
433
437
return str (e ) # Return an empty string or some error indicator
434
438
435
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
439
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
436
440
try :
437
441
result = await self .agent_executor .ainvoke ({"question" : query })
438
442
return result ['output' ]
@@ -465,7 +469,7 @@ def __init__(self, **data):
465
469
callback_manager = self .callbacks ,
466
470
)
467
471
468
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
472
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
469
473
try :
470
474
# Use the initialized agent_executor to invoke the query
471
475
result = self .agent_executor .invoke (query )
@@ -474,7 +478,7 @@ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = No
474
478
print (e )
475
479
return str (e ) # Return an error indicator
476
480
477
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
481
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
478
482
# Note: Implementation assumes the agent_executor and its methods support async operations
479
483
try :
480
484
# Use the initialized agent_executor to asynchronously invoke the query
@@ -528,7 +532,7 @@ def get_db_config(self):
528
532
'query' : {'driver' : 'ODBC Driver 17 for SQL Server' }
529
533
}
530
534
531
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
535
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
532
536
try :
533
537
# Use the initialized agent_executor to invoke the query
534
538
result = self .agent_executor .invoke (query )
@@ -537,7 +541,7 @@ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = No
537
541
print (e )
538
542
return str (e ) # Return an error indicator
539
543
540
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
544
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
541
545
# Note: Implementation assumes the agent_executor and its methods support async operations
542
546
try :
543
547
# Use the initialized agent_executor to asynchronously invoke the query
@@ -567,15 +571,15 @@ def __init__(self, **data):
567
571
output_parser = StrOutputParser ()
568
572
self .chatgpt_chain = CHATGPT_PROMPT | self .llm | output_parser
569
573
570
- def _run (self , query : str ) -> str :
574
+ def _run (self , query : str , return_direct = False , run_manager : Optional [ CallbackManagerForToolRun ] = None ) -> str :
571
575
try :
572
576
response = self .chatgpt_chain .invoke ({"question" : query })
573
577
return response
574
578
except Exception as e :
575
579
print (e )
576
580
return str (e ) # Return an error indicator
577
581
578
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
582
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
579
583
"""Implement the tool to be used asynchronously."""
580
584
try :
581
585
response = await self .chatgpt_chain .ainvoke ({"question" : query })
@@ -595,14 +599,14 @@ class GetBingSearchResults_Tool(BaseTool):
595
599
596
600
k : int = 5
597
601
598
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
602
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
599
603
bing = BingSearchAPIWrapper (k = self .k )
600
604
try :
601
605
return bing .results (query ,num_results = self .k )
602
606
except :
603
607
return "No Results Found"
604
608
605
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
609
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
606
610
bing = BingSearchAPIWrapper (k = self .k )
607
611
loop = asyncio .get_event_loop ()
608
612
try :
@@ -635,7 +639,9 @@ def __init__(self, **data):
635
639
description = "useful to fetch the content of a url"
636
640
)
637
641
638
- tools = [GetBingSearchResults_Tool (k = self .k ), web_fetch_tool ]
642
+ tools = [GetBingSearchResults_Tool (k = self .k )]
643
+ # tools = [GetBingSearchResults_Tool(k=self.k), web_fetch_tool] # Uncomment if using GPT-4
644
+
639
645
agent = create_openai_tools_agent (self .llm , tools , BINGSEARCH_PROMPT )
640
646
641
647
self .agent_executor = AgentExecutor (agent = agent , tools = tools ,
@@ -656,15 +662,15 @@ def fetch_web_page(self, url: str) -> str:
656
662
response = requests .get (url , headers = HEADERS )
657
663
return self .parse_html (response .content )
658
664
659
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
665
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
660
666
try :
661
667
response = self .agent_executor .invoke ({"question" : query })
662
668
return response ['output' ]
663
669
except Exception as e :
664
670
print (e )
665
671
return str (e ) # Return an error indicator
666
672
667
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
673
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
668
674
"""Implements the tool to be used asynchronously."""
669
675
try :
670
676
response = await self .agent_executor .ainvoke ({"question" : query })
@@ -701,7 +707,7 @@ def __init__(self, **data):
701
707
limit_to_domains = self .limit_to_domains
702
708
)
703
709
704
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
710
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
705
711
try :
706
712
# Optionally sleep to avoid possible TPM rate limits
707
713
sleep (2 )
@@ -711,7 +717,7 @@ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = No
711
717
712
718
return response
713
719
714
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
720
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
715
721
"""Use the tool asynchronously."""
716
722
loop = asyncio .get_event_loop ()
717
723
try :
@@ -757,7 +763,7 @@ def __init__(self, **data):
757
763
return_intermediate_steps = True ,
758
764
callback_manager = self .callbacks )
759
765
760
- def _run (self , query : str , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
766
+ def _run (self , query : str , return_direct = False , run_manager : Optional [CallbackManagerForToolRun ] = None ) -> str :
761
767
try :
762
768
# Use the initialized agent_executor to invoke the query
763
769
response = self .agent_executor .invoke ({"question" :query })
@@ -766,7 +772,7 @@ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = No
766
772
print (e )
767
773
return str (e ) # Return an error indicator
768
774
769
- async def _arun (self , query : str , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
775
+ async def _arun (self , query : str , return_direct = False , run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ) -> str :
770
776
# Note: Implementation assumes the agent_executor and its methods support async operations
771
777
try :
772
778
# Use the initialized agent_executor to asynchronously invoke the query
0 commit comments