| from dataclasses import dataclass |
| import json |
| from typing import List, Dict, Any, Optional |
| from openai import OpenAI |
| """ |
| EXAMPLE OUTPUT: |
| |
| **************************************** |
| RUNNING QUERY: What's the weather for Paris, TX in fahrenheit? |
| Step 1 |
| ---------------------------------------- |
| |
| Executing: get_geo_coordinates |
| Arguments: {'city': 'Paris', 'state': 'TX'} |
| Response: The coordinates for Paris, TX are: latitude 33.6609, longitude 95.5555 |
| |
| Step 2 |
| ---------------------------------------- |
| |
| Executing: get_current_weather |
| Arguments: {'latitude': [33.6609], 'longitude': [95.5555], 'unit': 'fahrenheit'} |
| Response: The weather is 85 degrees fahrenheit. It is partly cloudy, with highs in the 90's. |
| |
| Step 3 |
| ---------------------------------------- |
| Conversation Complete |
| |
| |
| **************************************** |
| RUNNING QUERY: Who won the most recent PGA? |
| Step 1 |
| ---------------------------------------- |
| |
| Executing: no_relevant_function |
| Arguments: {'user_query_span': 'Who won the most recent PGA?'} |
| Response: No relevant function for your request was found. We will stop here. |
| |
| Step 2 |
| ---------------------------------------- |
| Conversation Complete |
| """ |
|
|
| @dataclass |
| class WeatherConfig: |
| """Configuration for OpenAI and API settings""" |
| api_key: str = "" |
| api_base: str = "" |
| model: Optional[str] = None |
| max_steps: int = 5 |
|
|
| class WeatherTools: |
| """Collection of available tools/functions for the weather agent""" |
|
|
| @staticmethod |
| def get_current_weather(latitude: List[float], longitude: List[float], unit: str) -> str: |
| """Get weather for given coordinates""" |
| |
| return f"The weather is 85 degrees {unit}. It is partly cloudy, with highs in the 90's." |
|
|
| @staticmethod |
| def get_geo_coordinates(city: str, state: str) -> str: |
| """Get coordinates for a given city""" |
| coordinates = { |
| "Dallas": {"TX": (32.7767, -96.7970)}, |
| "San Francisco": {"CA": (37.7749, -122.4194)}, |
| "Paris": {"TX": (33.6609, 95.5555)} |
| } |
| lat, lon = coordinates.get(city, {}).get(state, (0, 0)) |
| |
| return f"The coordinates for {city}, {state} are: latitude {lat}, longitude {lon}" |
|
|
| @staticmethod |
| def no_relevant_function(user_query_span : str) -> str: |
| return "No relevant function for your request was found. We will stop here." |
|
|
| class ToolRegistry: |
| """Registry of available tools and their schemas""" |
|
|
| @property |
| def available_functions(self) -> Dict[str, callable]: |
| return { |
| "get_current_weather": WeatherTools.get_current_weather, |
| "get_geo_coordinates": WeatherTools.get_geo_coordinates, |
| "no_relevant_function" : WeatherTools.no_relevant_function, |
| } |
|
|
| @property |
| def tool_schemas(self) -> List[Dict[str, Any]]: |
| return [ |
| { |
| "type": "function", |
| "function": { |
| "name": "get_current_weather", |
| "description": "Get the current weather in a given location. Use exact coordinates.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "latitude": {"type": "array", "description": "The latitude for the city."}, |
| "longitude": {"type": "array", "description": "The longitude for the city."}, |
| "unit": { |
| "type": "string", |
| "description": "The unit to fetch the temperature in", |
| "enum": ["celsius", "fahrenheit"] |
| } |
| }, |
| "required": ["latitude", "longitude", "unit"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "get_geo_coordinates", |
| "description": "Get the latitude and longitude for a given city", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "city": {"type": "string", "description": "The city to find coordinates for"}, |
| "state": {"type": "string", "description": "The two-letter state abbreviation"} |
| }, |
| "required": ["city", "state"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function" : { |
| "name": "no_relevant_function", |
| "description": "Call this when no other provided function can be called to answer the user query.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "user_query_span": { |
| "type": "string", |
| "description": "The part of the user_query that cannot be answered by any other function calls." |
| } |
| }, |
| "required": ["user_query_span"] |
| } |
| } |
| } |
| ] |
|
|
| class WeatherAgent: |
| """Main agent class that handles the conversation and tool execution""" |
|
|
| def __init__(self, config: WeatherConfig): |
| self.config = config |
| self.client = OpenAI(api_key=config.api_key, base_url=config.api_base) |
| self.tools = ToolRegistry() |
| self.messages = [] |
|
|
| if not config.model: |
| models = self.client.models.list() |
| self.config.model = models.data[0].id |
|
|
| def _serialize_tool_call(self, tool_call) -> Dict[str, Any]: |
| """Convert tool call to serializable format""" |
| return { |
| "id": tool_call.id, |
| "type": tool_call.type, |
| "function": { |
| "name": tool_call.function.name, |
| "arguments": tool_call.function.arguments |
| } |
| } |
|
|
| def process_tool_calls(self, message) -> None: |
| """Process and execute tool calls from assistant""" |
| for tool_call in message.tool_calls: |
| function_name = tool_call.function.name |
| function_args = json.loads(tool_call.function.arguments) |
|
|
| print(f"\nExecuting: {function_name}") |
| print(f"Arguments: {function_args}") |
|
|
| function_response = self.tools.available_functions[function_name](**function_args) |
| print(f"Response: {function_response}") |
|
|
| self.messages.append({ |
| "role": "tool", |
| "content": json.dumps(function_response), |
| "tool_call_id": tool_call.id, |
| "name": function_name |
| }) |
|
|
| def run_conversation(self, initial_query: str) -> None: |
| """Run the main conversation loop""" |
| self.messages = [{"role": "user", "content": initial_query}] |
|
|
| print ("\n" * 5) |
| print ("*" * 40) |
| print (f"RUNNING QUERY: {initial_query}") |
|
|
| for step in range(self.config.max_steps): |
| print(f"\nStep {step + 1}") |
| print("-" * 40) |
|
|
| response = self.client.chat.completions.create( |
| messages=self.messages, |
| model=self.config.model, |
| tools=self.tools.tool_schemas, |
| temperature=0.0, |
| ) |
|
|
| message = response.choices[0].message |
|
|
| if not message.tool_calls: |
| print("Conversation Complete") |
| break |
|
|
| self.messages.append({ |
| "role": "assistant", |
| "content": json.dumps(message.content), |
| "tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls] |
| }) |
|
|
| self.process_tool_calls(message) |
|
|
| if step >= self.config.max_steps - 1: |
| print("Maximum steps reached") |
|
|
| def main(): |
| |
| config = WeatherConfig() |
| agent = WeatherAgent(config) |
| agent.run_conversation("What's the weather for Paris, TX in fahrenheit?") |
|
|
| |
| agent.run_conversation("Who won the most recent PGA?") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|