Added tests

This commit is contained in:
2024-08-17 09:28:17 -05:00
parent 20ecbe4efa
commit 7ba4838522
5 changed files with 434 additions and 11 deletions
+58 -6
View File
@@ -1,6 +1,8 @@
import os import os
import importlib import importlib
import inspect import inspect
import tempfile
import base64
from telegram import Update from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
from openai import OpenAI from openai import OpenAI
@@ -12,12 +14,17 @@ load_dotenv()
client = OpenAI() client = OpenAI()
GPT_4O = "gpt-4o"
GPT_4O_MINI = "gpt-4o-mini"
# Set up Telegram bot # Set up Telegram bot
TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
# Dictionary to store conversation history for each user # Dictionary to store conversation history for each user
conversation_history = {} conversation_history = {}
# Dictionary to store the last image file for each user
user_images = {}
# Load tools # Load tools
tools = [] tools = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools') tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
@@ -35,13 +42,32 @@ for tool in tools:
functions.extend(tool.get_functions()) functions.extend(tool.get_functions())
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today?") await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.")
async def clear(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def clear(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
user_id = update.effective_user.id user_id = update.effective_user.id
if user_id in conversation_history: if user_id in conversation_history:
del conversation_history[user_id] del conversation_history[user_id]
await update.message.reply_text("Conversation history cleared. Let's start fresh!") if user_id in user_images:
os.remove(user_images[user_id])
del user_images[user_id]
await update.message.reply_text("Conversation history and image cleared. Let's start fresh!")
async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
user_id = update.effective_user.id
# Get the largest available photo
photo = max(update.message.photo, key=lambda x: x.file_size)
# Download the photo
photo_file = await context.bot.get_file(photo.file_id)
# Create a temporary file to store the image
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
await photo_file.download_to_drive(custom_path=temp_file.name)
user_images[user_id] = temp_file.name
await update.message.reply_text("I've received your image. What would you like to know about it?")
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
try: try:
@@ -58,9 +84,34 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
# Prepare messages for OpenAI API # Prepare messages for OpenAI API
messages = [{"role": "system", "content": "You are a helpful assistant."}] + conversation_history[user_id] messages = [{"role": "system", "content": "You are a helpful assistant."}] + conversation_history[user_id]
# Call OpenAI API for inference # Check if there's an image to process
if user_id in user_images:
with open(user_images[user_id], "rb") as image_file:
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=GPT_4O_MINI,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": user_message},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64.b64encode(image_file.read()).decode('utf-8')}"
}
},
],
}
],
max_tokens=2048,
)
# Remove the temporary image file
os.remove(user_images[user_id])
del user_images[user_id]
else:
# Call OpenAI API for inference (text-only)
response = client.chat.completions.create(
model=GPT_4O,
messages=messages, messages=messages,
functions=functions, functions=functions,
function_call="auto" function_call="auto"
@@ -69,7 +120,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
# Extract the assistant's reply # Extract the assistant's reply
assistant_message = response.choices[0].message assistant_message = response.choices[0].message
if assistant_message.function_call: if hasattr(assistant_message, 'function_call') and assistant_message.function_call:
# Execute the function # Execute the function
function_name = assistant_message.function_call.name function_name = assistant_message.function_call.name
function_args = assistant_message.function_call.arguments function_args = assistant_message.function_call.arguments
@@ -83,7 +134,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
}) })
# Call API again to get the final response # Call API again to get the final response
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=GPT_4O,
messages=messages messages=messages
) )
assistant_reply = response.choices[0].message.content assistant_reply = response.choices[0].message.content
@@ -112,6 +163,7 @@ def main() -> None:
# Add handlers # Add handlers
application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("start", start))
application.add_handler(CommandHandler("clear", clear)) application.add_handler(CommandHandler("clear", clear))
application.add_handler(MessageHandler(filters.PHOTO, handle_image))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
# Start the Bot # Start the Bot
+81
View File
@@ -0,0 +1,81 @@
# tests/test_github_tool.py
import unittest
from unittest.mock import patch, MagicMock
from tools.github_tool import GitHubTool
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.github_tool = GitHubTool()
def test_get_functions(self):
functions = self.github_tool.get_functions()
self.assertEqual(len(functions), 4)
function_names = [f["name"] for f in functions]
expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"]
self.assertListEqual(function_names, expected_names)
@patch('tools.github_tool.requests.get')
def test_read_file(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"content": "file content"}
mock_get.return_value = mock_response
result = self.github_tool.execute("read_file", path="test.txt")
self.assertEqual(result, "file content")
mock_get.assert_called_once()
@patch('tools.github_tool.requests.get')
@patch('tools.github_tool.requests.post')
def test_create_branch(self, mock_post, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {"object": {"sha": "test_sha"}}
mock_get.return_value = mock_get_response
mock_post_response = MagicMock()
mock_post_response.status_code = 201
mock_post.return_value = mock_post_response
result = self.github_tool.execute("create_branch", branch_name="test-branch")
self.assertEqual(result, "Branch 'test-branch' created successfully")
mock_get.assert_called_once()
mock_post.assert_called_once()
@patch('tools.github_tool.requests.put')
def test_commit_file(self, mock_put):
mock_response = MagicMock()
mock_response.status_code = 200
mock_put.return_value = mock_response
result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "File committed successfully to branch 'test-branch'")
mock_put.assert_called_once()
def test_commit_file_to_main(self):
result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "Cannot commit directly to main branch")
@patch('tools.github_tool.requests.post')
def test_create_pull_request(self, mock_post):
mock_response = MagicMock()
mock_response.status_code = 201
mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"}
mock_post.return_value = mock_response
result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch")
self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1")
mock_post.assert_called_once()
def test_unknown_function(self):
result = self.github_tool.execute("unknown_function")
self.assertEqual(result, "Unknown function: unknown_function")
if __name__ == '__main__':
unittest.main()
View File
+96
View File
@@ -0,0 +1,96 @@
# GitHub Integration Tool
The GitHub Integration Tool provides a simple interface to interact with the GitHub repository for the Cyclop project. This tool allows reading files, creating branches, committing changes, and creating pull requests.
## Functions
The tool provides the following functions:
1. `read_file`: Read a file from the repository
2. `create_branch`: Create a new branch in the repository
3. `commit_file`: Commit a file to a branch (not main)
4. `create_pull_request`: Create a pull request
## Usage
To use this tool, you need to have the `GITHUB_TOKEN` environment variable set with your GitHub personal access token.
### Read File
Reads the content of a file from the repository.
Parameters:
- `path`: The path to the file in the repository
Example:
```python
result = github_tool.execute("read_file", path="README.md")
print(result) # Prints the content of README.md
```
### Create Branch
Creates a new branch in the repository.
Parameters:
- `branch_name`: Name of the new branch
- `base_branch` (optional): Name of the base branch (default is "main")
Example:
```python
result = github_tool.execute("create_branch", branch_name="feature-branch")
print(result) # Prints a success message if the branch was created
```
### Commit File
Commits a file to a specified branch (not main).
Parameters:
- `branch_name`: Name of the branch to commit to
- `file_path`: Path to the file in the repository
- `content`: Content of the file
- `commit_message`: Commit message
Example:
```python
result = github_tool.execute(
"commit_file",
branch_name="feature-branch",
file_path="docs/NEW_FEATURE.md",
content="# New Feature\n\nThis document describes the new feature.",
commit_message="Add documentation for new feature"
)
print(result) # Prints a success message if the file was committed
```
### Create Pull Request
Creates a pull request from one branch to another.
Parameters:
- `title`: Title of the pull request
- `body`: Body of the pull request
- `head`: The name of the branch where your changes are implemented
- `base` (optional): The name of the branch you want the changes pulled into (default is "main")
Example:
```python
result = github_tool.execute(
"create_pull_request",
title="Add new feature documentation",
body="This PR adds documentation for the new feature.",
head="feature-branch"
)
print(result) # Prints the URL of the created pull request
```
## Error Handling
If an error occurs during the execution of any function, an error message will be returned instead of the expected result. Always check the returned value to ensure the operation was successful.
## Notes
- This tool uses the GitHub API v3.
- Make sure your GitHub token has the necessary permissions to perform these operations.
- Committing directly to the main branch is not allowed for safety reasons.
+194
View File
@@ -0,0 +1,194 @@
# tools/github_tool.py
from .base_tool import BaseTool
import requests
import os
class GitHubTool(BaseTool):
def __init__(self):
self.base_url = "https://api.github.com"
self.token = os.environ.get("GITHUB_TOKEN")
self.headers = {
"Authorization": f"token {self.token}",
"Accept": "application/vnd.github.v3+json"
}
self.repo = os.environ.get("GITHUB_REPOSITORY")
def get_functions(self):
return [
{
"name": "read_file",
"description": "Read a file from the repository",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file in the repository"
}
},
"required": ["path"]
}
},
{
"name": "create_branch",
"description": "Create a new branch in the repository",
"parameters": {
"type": "object",
"properties": {
"branch_name": {
"type": "string",
"description": "Name of the new branch"
},
"base_branch": {
"type": "string",
"description": "Name of the base branch",
"default": "main"
}
},
"required": ["branch_name"]
}
},
{
"name": "commit_file",
"description": "Commit a file to a branch (not main)",
"parameters": {
"type": "object",
"properties": {
"branch_name": {
"type": "string",
"description": "Name of the branch to commit to"
},
"file_path": {
"type": "string",
"description": "Path to the file in the repository"
},
"content": {
"type": "string",
"description": "Content of the file"
},
"commit_message": {
"type": "string",
"description": "Commit message"
}
},
"required": ["branch_name", "file_path", "content", "commit_message"]
}
},
{
"name": "create_pull_request",
"description": "Create a pull request",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title of the pull request"
},
"body": {
"type": "string",
"description": "Body of the pull request"
},
"head": {
"type": "string",
"description": "The name of the branch where your changes are implemented"
},
"base": {
"type": "string",
"description": "The name of the branch you want the changes pulled into",
"default": "main"
}
},
"required": ["title", "body", "head"]
}
}
]
def execute(self, function_name, **kwargs):
if function_name == "read_file":
return self._read_file(kwargs["path"])
elif function_name == "create_branch":
return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main"))
elif function_name == "commit_file":
return self._commit_file(kwargs["branch_name"], kwargs["file_path"], kwargs["content"], kwargs["commit_message"])
elif function_name == "push_branch":
return self._push_branch(kwargs["branch_name"])
elif function_name == "create_pull_request":
return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs["head"], kwargs.get("base", "main"))
else:
return f"Unknown function: {function_name}"
def _read_file(self, path):
url = f"{self.base_url}/repos/{self.repo}/contents/{path}"
response = requests.get(url, headers=self.headers)
if response.status_code == 200:
content = response.json()["content"]
return content
else:
return f"Error reading file: {response.status_code}"
def _create_branch(self, branch_name, base_branch):
url = f"{self.base_url}/repos/{self.repo}/git/refs"
response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers)
if response.status_code != 200:
return f"Error getting base branch: {response.status_code}"
sha = response.json()["object"]["sha"]
data = {
"ref": f"refs/heads/{branch_name}",
"sha": sha
}
response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201:
return f"Branch '{branch_name}' created successfully"
else:
return f"Error creating branch: {response.status_code}"
def _commit_file(self, branch_name, file_path, content, commit_message):
if branch_name == "main":
return "Cannot commit directly to main branch"
url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}"
data = {
"message": commit_message,
"content": content,
"branch": branch_name
}
response = requests.put(url, headers=self.headers, json=data)
if response.status_code in [200, 201]:
return f"File committed successfully to branch '{branch_name}'"
else:
return f"Error committing file: {response.status_code}"
def _create_pull_request(self, title, body, head, base):
url = f"{self.base_url}/repos/{self.repo}/pulls"
data = {
"title": title,
"body": body,
"head": head,
"base": base
}
response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201:
return f"Pull request created successfully: {response.json()['html_url']}"
else:
return f"Error creating pull request: {response.status_code}\nResponse: {response.text}"
def _push_branch(self, branch_name):
url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch_name}"
response = requests.get(url, headers=self.headers)
if response.status_code != 200:
return f"Error getting branch information: {response.status_code}"
sha = response.json()["object"]["sha"]
push_url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch_name}"
data = {
"sha": sha,
"force": True
}
response = requests.patch(push_url, headers=self.headers, json=data)
if response.status_code == 200:
return f"Branch '{branch_name}' pushed successfully"
else:
return f"Error pushing branch: {response.status_code}\nResponse: {response.text}"