diff --git a/telegram_inference_bot.py b/telegram_inference_bot.py index e94dd34..2818191 100644 --- a/telegram_inference_bot.py +++ b/telegram_inference_bot.py @@ -1,6 +1,8 @@ import os import importlib import inspect +import tempfile +import base64 from telegram import Update from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from openai import OpenAI @@ -12,12 +14,17 @@ load_dotenv() client = OpenAI() +GPT_4O = "gpt-4o" +GPT_4O_MINI = "gpt-4o-mini" # Set up Telegram bot TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') # Dictionary to store conversation history for each user conversation_history = {} +# Dictionary to store the last image file for each user +user_images = {} + # Load tools tools = [] tools_dir = os.path.join(os.path.dirname(__file__), 'tools') @@ -35,13 +42,32 @@ for tool in tools: functions.extend(tool.get_functions()) async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today?") + await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.") async def clear(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: user_id = update.effective_user.id if user_id in conversation_history: del conversation_history[user_id] - await update.message.reply_text("Conversation history cleared. Let's start fresh!") + if user_id in user_images: + os.remove(user_images[user_id]) + del user_images[user_id] + await update.message.reply_text("Conversation history and image cleared. Let's start fresh!") + +async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + user_id = update.effective_user.id + + # Get the largest available photo + photo = max(update.message.photo, key=lambda x: x.file_size) + + # Download the photo + photo_file = await context.bot.get_file(photo.file_id) + + # Create a temporary file to store the image + with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: + await photo_file.download_to_drive(custom_path=temp_file.name) + user_images[user_id] = temp_file.name + + await update.message.reply_text("I've received your image. What would you like to know about it?") async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: try: @@ -58,18 +84,43 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> # Prepare messages for OpenAI API messages = [{"role": "system", "content": "You are a helpful assistant."}] + conversation_history[user_id] - # Call OpenAI API for inference - response = client.chat.completions.create( - model="gpt-4", - messages=messages, - functions=functions, - function_call="auto" - ) + # Check if there's an image to process + if user_id in user_images: + with open(user_images[user_id], "rb") as image_file: + response = client.chat.completions.create( + model=GPT_4O_MINI, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": user_message}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64.b64encode(image_file.read()).decode('utf-8')}" + } + }, + ], + } + ], + max_tokens=2048, + ) + # Remove the temporary image file + os.remove(user_images[user_id]) + del user_images[user_id] + else: + # Call OpenAI API for inference (text-only) + response = client.chat.completions.create( + model=GPT_4O, + messages=messages, + functions=functions, + function_call="auto" + ) # Extract the assistant's reply assistant_message = response.choices[0].message - if assistant_message.function_call: + if hasattr(assistant_message, 'function_call') and assistant_message.function_call: # Execute the function function_name = assistant_message.function_call.name function_args = assistant_message.function_call.arguments @@ -83,7 +134,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> }) # Call API again to get the final response response = client.chat.completions.create( - model="gpt-4", + model=GPT_4O, messages=messages ) assistant_reply = response.choices[0].message.content @@ -112,6 +163,7 @@ def main() -> None: # Add handlers application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("clear", clear)) + application.add_handler(MessageHandler(filters.PHOTO, handle_image)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) # Start the Bot diff --git a/tests/test_github_tool.py b/tests/test_github_tool.py new file mode 100644 index 0000000..a012ed1 --- /dev/null +++ b/tests/test_github_tool.py @@ -0,0 +1,81 @@ +# tests/test_github_tool.py + +import unittest +from unittest.mock import patch, MagicMock +from tools.github_tool import GitHubTool + +class TestGitHubTool(unittest.TestCase): + + def setUp(self): + self.github_tool = GitHubTool() + + def test_get_functions(self): + functions = self.github_tool.get_functions() + self.assertEqual(len(functions), 4) + function_names = [f["name"] for f in functions] + expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"] + self.assertListEqual(function_names, expected_names) + + @patch('tools.github_tool.requests.get') + def test_read_file(self, mock_get): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"content": "file content"} + mock_get.return_value = mock_response + + result = self.github_tool.execute("read_file", path="test.txt") + self.assertEqual(result, "file content") + + mock_get.assert_called_once() + + @patch('tools.github_tool.requests.get') + @patch('tools.github_tool.requests.post') + def test_create_branch(self, mock_post, mock_get): + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = {"object": {"sha": "test_sha"}} + mock_get.return_value = mock_get_response + + mock_post_response = MagicMock() + mock_post_response.status_code = 201 + mock_post.return_value = mock_post_response + + result = self.github_tool.execute("create_branch", branch_name="test-branch") + self.assertEqual(result, "Branch 'test-branch' created successfully") + + mock_get.assert_called_once() + mock_post.assert_called_once() + + @patch('tools.github_tool.requests.put') + def test_commit_file(self, mock_put): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_put.return_value = mock_response + + result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit") + self.assertEqual(result, "File committed successfully to branch 'test-branch'") + + mock_put.assert_called_once() + + def test_commit_file_to_main(self): + result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit") + self.assertEqual(result, "Cannot commit directly to main branch") + + @patch('tools.github_tool.requests.post') + def test_create_pull_request(self, mock_post): + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"} + mock_post.return_value = mock_response + + result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch") + self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1") + + mock_post.assert_called_once() + + def test_unknown_function(self): + result = self.github_tool.execute("unknown_function") + self.assertEqual(result, "Unknown function: unknown_function") + +if __name__ == '__main__': + unittest.main() diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 0000000..e69de29 diff --git a/tools/github_tool.md b/tools/github_tool.md new file mode 100644 index 0000000..e3548a2 --- /dev/null +++ b/tools/github_tool.md @@ -0,0 +1,96 @@ +# GitHub Integration Tool + +The GitHub Integration Tool provides a simple interface to interact with the GitHub repository for the Cyclop project. This tool allows reading files, creating branches, committing changes, and creating pull requests. + +## Functions + +The tool provides the following functions: + +1. `read_file`: Read a file from the repository +2. `create_branch`: Create a new branch in the repository +3. `commit_file`: Commit a file to a branch (not main) +4. `create_pull_request`: Create a pull request + +## Usage + +To use this tool, you need to have the `GITHUB_TOKEN` environment variable set with your GitHub personal access token. + +### Read File + +Reads the content of a file from the repository. + +Parameters: +- `path`: The path to the file in the repository + +Example: +``‍`python +result = github_tool.execute("read_file", path="README.md") +print(result) # Prints the content of README.md +``‍` + +### Create Branch + +Creates a new branch in the repository. + +Parameters: +- `branch_name`: Name of the new branch +- `base_branch` (optional): Name of the base branch (default is "main") + +Example: +``‍`python +result = github_tool.execute("create_branch", branch_name="feature-branch") +print(result) # Prints a success message if the branch was created +``‍` + +### Commit File + +Commits a file to a specified branch (not main). + +Parameters: +- `branch_name`: Name of the branch to commit to +- `file_path`: Path to the file in the repository +- `content`: Content of the file +- `commit_message`: Commit message + +Example: +``‍`python +result = github_tool.execute( + "commit_file", + branch_name="feature-branch", + file_path="docs/NEW_FEATURE.md", + content="# New Feature\n\nThis document describes the new feature.", + commit_message="Add documentation for new feature" +) +print(result) # Prints a success message if the file was committed +``‍` + +### Create Pull Request + +Creates a pull request from one branch to another. + +Parameters: +- `title`: Title of the pull request +- `body`: Body of the pull request +- `head`: The name of the branch where your changes are implemented +- `base` (optional): The name of the branch you want the changes pulled into (default is "main") + +Example: +``‍`python +result = github_tool.execute( + "create_pull_request", + title="Add new feature documentation", + body="This PR adds documentation for the new feature.", + head="feature-branch" +) +print(result) # Prints the URL of the created pull request +``‍` + +## Error Handling + +If an error occurs during the execution of any function, an error message will be returned instead of the expected result. Always check the returned value to ensure the operation was successful. + +## Notes + +- This tool uses the GitHub API v3. +- Make sure your GitHub token has the necessary permissions to perform these operations. +- Committing directly to the main branch is not allowed for safety reasons. diff --git a/tools/github_tool.py b/tools/github_tool.py new file mode 100644 index 0000000..6e9c081 --- /dev/null +++ b/tools/github_tool.py @@ -0,0 +1,194 @@ +# tools/github_tool.py +from .base_tool import BaseTool +import requests +import os + +class GitHubTool(BaseTool): + def __init__(self): + self.base_url = "https://api.github.com" + self.token = os.environ.get("GITHUB_TOKEN") + self.headers = { + "Authorization": f"token {self.token}", + "Accept": "application/vnd.github.v3+json" + } + self.repo = os.environ.get("GITHUB_REPOSITORY") + + def get_functions(self): + return [ + { + "name": "read_file", + "description": "Read a file from the repository", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file in the repository" + } + }, + "required": ["path"] + } + }, + { + "name": "create_branch", + "description": "Create a new branch in the repository", + "parameters": { + "type": "object", + "properties": { + "branch_name": { + "type": "string", + "description": "Name of the new branch" + }, + "base_branch": { + "type": "string", + "description": "Name of the base branch", + "default": "main" + } + }, + "required": ["branch_name"] + } + }, + { + "name": "commit_file", + "description": "Commit a file to a branch (not main)", + "parameters": { + "type": "object", + "properties": { + "branch_name": { + "type": "string", + "description": "Name of the branch to commit to" + }, + "file_path": { + "type": "string", + "description": "Path to the file in the repository" + }, + "content": { + "type": "string", + "description": "Content of the file" + }, + "commit_message": { + "type": "string", + "description": "Commit message" + } + }, + "required": ["branch_name", "file_path", "content", "commit_message"] + } + }, + { + "name": "create_pull_request", + "description": "Create a pull request", + "parameters": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the pull request" + }, + "body": { + "type": "string", + "description": "Body of the pull request" + }, + "head": { + "type": "string", + "description": "The name of the branch where your changes are implemented" + }, + "base": { + "type": "string", + "description": "The name of the branch you want the changes pulled into", + "default": "main" + } + }, + "required": ["title", "body", "head"] + } + } + ] + + def execute(self, function_name, **kwargs): + if function_name == "read_file": + return self._read_file(kwargs["path"]) + elif function_name == "create_branch": + return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main")) + elif function_name == "commit_file": + return self._commit_file(kwargs["branch_name"], kwargs["file_path"], kwargs["content"], kwargs["commit_message"]) + elif function_name == "push_branch": + return self._push_branch(kwargs["branch_name"]) + elif function_name == "create_pull_request": + return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs["head"], kwargs.get("base", "main")) + else: + return f"Unknown function: {function_name}" + + + def _read_file(self, path): + url = f"{self.base_url}/repos/{self.repo}/contents/{path}" + response = requests.get(url, headers=self.headers) + if response.status_code == 200: + content = response.json()["content"] + return content + else: + return f"Error reading file: {response.status_code}" + + def _create_branch(self, branch_name, base_branch): + url = f"{self.base_url}/repos/{self.repo}/git/refs" + response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers) + if response.status_code != 200: + return f"Error getting base branch: {response.status_code}" + + sha = response.json()["object"]["sha"] + data = { + "ref": f"refs/heads/{branch_name}", + "sha": sha + } + response = requests.post(url, headers=self.headers, json=data) + if response.status_code == 201: + return f"Branch '{branch_name}' created successfully" + else: + return f"Error creating branch: {response.status_code}" + + def _commit_file(self, branch_name, file_path, content, commit_message): + if branch_name == "main": + return "Cannot commit directly to main branch" + + url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}" + data = { + "message": commit_message, + "content": content, + "branch": branch_name + } + response = requests.put(url, headers=self.headers, json=data) + if response.status_code in [200, 201]: + return f"File committed successfully to branch '{branch_name}'" + else: + return f"Error committing file: {response.status_code}" + + def _create_pull_request(self, title, body, head, base): + url = f"{self.base_url}/repos/{self.repo}/pulls" + data = { + "title": title, + "body": body, + "head": head, + "base": base + } + response = requests.post(url, headers=self.headers, json=data) + if response.status_code == 201: + return f"Pull request created successfully: {response.json()['html_url']}" + else: + return f"Error creating pull request: {response.status_code}\nResponse: {response.text}" + + def _push_branch(self, branch_name): + url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch_name}" + response = requests.get(url, headers=self.headers) + if response.status_code != 200: + return f"Error getting branch information: {response.status_code}" + + sha = response.json()["object"]["sha"] + + push_url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch_name}" + data = { + "sha": sha, + "force": True + } + response = requests.patch(push_url, headers=self.headers, json=data) + if response.status_code == 200: + return f"Branch '{branch_name}' pushed successfully" + else: + return f"Error pushing branch: {response.status_code}\nResponse: {response.text}"