This commit is contained in:
2025-06-05 18:07:42 -05:00
+12 -2
View File
@@ -20,7 +20,8 @@ class OpenAICompatibleInferenceBot(InferenceBot):
large_model_name: str | None = None, large_model_name: str | None = None,
large_model_max_tokens: str | None = None, large_model_max_tokens: str | None = None,
allowed_function_tags: list[str] | None = None, allowed_function_tags: list[str] | None = None,
system_prompt_path: str | None = None system_prompt_path: str | None = None,
use_large_model: bool = False # New argument
): ):
self.model_config = { self.model_config = {
"small_model_name": small_model_name, "small_model_name": small_model_name,
@@ -42,6 +43,12 @@ class OpenAICompatibleInferenceBot(InferenceBot):
logging.info(log_msg) logging.info(log_msg)
# Configure the actual model name and max_tokens for API calls # Configure the actual model name and max_tokens for API calls
if use_large_model:
self._configure_model_and_tokens(
self.model_config["large_model_name"],
self.model_config["large_model_max_tokens"]
)
else:
self._configure_model_and_tokens( self._configure_model_and_tokens(
self.model_config["small_model_name"], self.model_config["small_model_name"],
self.model_config["small_model_max_tokens"] self.model_config["small_model_max_tokens"]
@@ -350,6 +357,7 @@ def main():
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False) parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False) parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate" # Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments # Parse command line arguments
args = parser.parse_args() args = parser.parse_args()
@@ -361,6 +369,7 @@ def main():
allowed_function_tags=args.tools if args.tools else None allowed_function_tags=args.tools if args.tools else None
config_prepend = args.config if args.config else None config_prepend = args.config if args.config else None
messenger = args.messenger if args.messenger else None messenger = args.messenger if args.messenger else None
use_large_model = args.use_large_model # Get the value of the new argument
# Initialize model and max tokens based on the config prepend # Initialize model and max tokens based on the config prepend
if config_prepend: if config_prepend:
@@ -379,7 +388,8 @@ def main():
large_model_name=large_model_name, large_model_name=large_model_name,
large_model_max_tokens=large_model_max_tokens, large_model_max_tokens=large_model_max_tokens,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
allowed_function_tags=allowed_function_tags allowed_function_tags=allowed_function_tags,
use_large_model=use_large_model # Pass the new argument
) )
full_code_file = importlib.import_module(f'{messenger.lower()}_helper') full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
messenger_helper_class_name = f"{messenger.capitalize()}Helper" messenger_helper_class_name = f"{messenger.capitalize()}Helper"