diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py index 1692a2d..7cdabfa 100644 --- a/openai_compatible_inference_bot.py +++ b/openai_compatible_inference_bot.py @@ -20,7 +20,8 @@ class OpenAICompatibleInferenceBot(InferenceBot): large_model_name: str | None = None, large_model_max_tokens: str | None = None, allowed_function_tags: list[str] | None = None, - system_prompt_path: str | None = None + system_prompt_path: str | None = None, + use_large_model: bool = False # New argument ): self.model_config = { "small_model_name": small_model_name, @@ -42,10 +43,16 @@ class OpenAICompatibleInferenceBot(InferenceBot): logging.info(log_msg) # Configure the actual model name and max_tokens for API calls - self._configure_model_and_tokens( - self.model_config["small_model_name"], - self.model_config["small_model_max_tokens"] - ) + if use_large_model: + self._configure_model_and_tokens( + self.model_config["large_model_name"], + self.model_config["large_model_max_tokens"] + ) + else: + self._configure_model_and_tokens( + self.model_config["small_model_name"], + self.model_config["small_model_max_tokens"] + ) @property def processing_status(self): """ @@ -238,7 +245,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): module_name = f'tools.{filename[:-3]}' try: module = importlib.import_module(module_name) - for name, obj in inspect.getmembers(module): + for name, obj in inspect.getmembers(module):\ if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: try: tools.append(obj()) # This instantiation might be an issue for tools needing config @@ -350,6 +357,7 @@ def main(): parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False) parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False) + parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument # Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate" # Parse command line arguments args = parser.parse_args() @@ -361,6 +369,7 @@ def main(): allowed_function_tags=args.tools if args.tools else None config_prepend = args.config if args.config else None messenger = args.messenger if args.messenger else None + use_large_model = args.use_large_model # Get the value of the new argument # Initialize model and max tokens based on the config prepend if config_prepend: @@ -379,7 +388,8 @@ def main(): large_model_name=large_model_name, large_model_max_tokens=large_model_max_tokens, system_prompt_path=system_prompt_path, - allowed_function_tags=allowed_function_tags + allowed_function_tags=allowed_function_tags, + use_large_model=use_large_model # Pass the new argument ) full_code_file = importlib.import_module(f'{messenger.lower()}_helper') messenger_helper_class_name = f"{messenger.capitalize()}Helper"