Add ability to switch models

This commit is contained in:
2024-08-18 07:35:52 -05:00
parent 8dff69e960
commit db009105c7
2 changed files with 27 additions and 17 deletions
+1 -1
View File
@@ -2,7 +2,7 @@ Imagine you're a savvy developer with a trusty toolkit, working in harmony with
As you navigate the repository, keep in mind the following principles:
Practicality: When updating files, consider that you're writing them in their entirety to disk. DO NOT omit code in your output.
Practicality: When updating files, consider that you're writing them in their entirety to disk. DO NOT omit code, especially when sending to a function or tool.
Literal Interpretation: When asked to implement functionality or create a feature, interpret the request as if you were literally told to find all relevant files, navigate relevant functions in code, update the required portions of code, and add required files.
Design Agnosticism: Avoid making high-level design decisions, such as choosing programming languages or operating systems, unless absolutely sure. If unsure, ask before proceeding.
Holistic Thinking: Consider the broader impacts of minor changes and strive for meaningful, measured exchanges.
+26 -16
View File
@@ -19,22 +19,18 @@ client = OpenAI()
GPT_4O = "gpt-4o"
GPT_4O_MINI = "gpt-4o-mini"
class StringFilter(logging.Filter):
def __init__(self, strings_to_filter):
super().__init__()
self.strings_to_filter = strings_to_filter
model_max_tokens = {
GPT_4O: 4096,
GPT_4O_MINI: 16384
}
def filter(self, record):
return not any(s in record.getMessage() for s in self.strings_to_filter)
strings_to_filter = ['unwanted_string_1', 'unwanted_string_2'] # Change these to the specific strings you want to filter out
use_smart_model = True
# Set up logging to console and file
logging.basicConfig(level=logging.INFO, handlers=[
logging.basicConfig(level=logging.WARNING, handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/output.log', mode='a')
])
logging.getLogger().addFilter(StringFilter(strings_to_filter))
# Set up Telegram bot
TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
@@ -139,14 +135,14 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
del user_images[user_id]
else:
# Call OpenAI API for inference (text-only)
response = get_chat_response(client, messages, 4096, GPT_4O)
response = get_chat_response(client, messages, GPT_4O if use_smart_model else GPT_4O_MINI)
# Extract the assistant's reply
assistant_message = response.choices[0].message
toolUseCount = 0
if hasattr(assistant_message, 'function_call') and assistant_message.function_call:
while hasattr(assistant_message, 'function_call') and assistant_message.function_call and toolUseCount < 50: # Todo: put amount in env
tool_response = call_tool(assistant_message.function_call, messages)
tool_response = call_tool(assistant_message.function_call)
conversation_history[user_id].append({"role": "function", "name": assistant_message.function_call.name, "content": json.dumps(tool_response)})
messages.append({
@@ -156,7 +152,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
})
# Call API again to get the final response
assistant_message = get_chat_response(client, messages, 4096, GPT_4O).choices[0].message
assistant_message = get_chat_response(client, messages, GPT_4O if use_smart_model else GPT_4O_MINI).choices[0].message
if not hasattr(assistant_message, 'function_call') or not assistant_message.function_call:
assistant_reply = assistant_message.content
conversation_history[user_id].append({"role": "assistant", "content": assistant_reply})
@@ -178,7 +174,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
logging.error(f"An error occurred: {str(e)}")
await update.message.reply_text("Sorry, an error occurred while processing your request.")
def call_tool(function_call, messages):
def call_tool(function_call):
# Execute the function
function_name = function_call.name
function_args = function_call.arguments
@@ -186,16 +182,28 @@ def call_tool(function_call, messages):
if function_name in [f["name"] for f in tool.get_functions()]:
return tool.execute(function_name, **eval(function_args))
def get_chat_response(client, messages, max_tokens, model):
def get_chat_response(client, messages, model):
response = client.chat.completions.create(
model=model,
messages=messages,
functions=functions,
function_call="auto",
max_tokens=max_tokens
max_tokens=model_max_tokens[model]
)
return response
def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
global use_smart_model
use_smart_model = not use_smart_model
model = GPT_4O if use_smart_model else GPT_4O_MINI
logging.info(f"Switched to model: {model}")
update.message.reply_text(f"Switched to model: {model}")
async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
model = GPT_4O if use_smart_model else GPT_4O_MINI
await update.message.reply_text(f"Currently using model model: {model}")
pass
def main() -> None:
# Create the Application and pass it your bot's token
application = Application.builder().token(TELEGRAM_BOT_TOKEN).build()
@@ -203,6 +211,8 @@ def main() -> None:
# Add handlers
application.add_handler(CommandHandler("start", start))
application.add_handler(CommandHandler("clear", clear))
application.add_handler(CommandHandler("switch", switch))
application.add_handler(CommandHandler("status", status))
application.add_handler(MessageHandler(filters.PHOTO, handle_image))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))