import asyncio import sqlite3 import json from typing import List, Dict from pydantic import BaseModel, Field from pydantic_ai import Agent, RunContext from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider import os import logfire logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() # Define the OpenAI model (replace with your actual model if needed) model = OpenAIModel( model_name="qwen2.5:14b", # Or your preferred model provider=OpenAIProvider(base_url='http://localhost:11434/v1') # Or your provider ) class DatabaseSchema(BaseModel): tables: Dict[str, List[Dict[str, str]]] = Field( description="A dictionary where keys are table names and values are lists of column dictionaries (name, type)") # Agent to get the database schema schema_agent = Agent( model, deps_type=str, output_type=str, system_prompt="""You are a helpful assistant that extracts the schema of a SQLite database. When the user provides a database path, use the get_database_schema to retrieve the schema. Your ONLY response should be the raw JSON string representing the database schema. Do not include any other text. The JSON should be a dictionary where keys are table names, and values are lists of column dictionaries. Each column dictionary should include 'name', 'type', 'notnull', 'dflt_value', and 'pk' keys. If there is an error, return a JSON string containing an "error" key with a list of error messages.""" ) @schema_agent.tool async def get_database_schema(ctx: RunContext[str], db_path: str) -> str: """Retrieves the schema of the SQLite database and returns it as a JSON string.""" print(f"Database path: {db_path}") try: conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = [row[0] for row in cursor.fetchall()] print(tables) schema = {} for table in tables: cursor.execute(f"PRAGMA table_info({table})") columns = [ { "name": col[1], "type": col[2], "notnull": col[3], "dflt_value": col[4], "pk": col[5], } for col in cursor.fetchall() ] schema[table] = columns print(schema) conn.close() return json.dumps(schema) except sqlite3.Error as e: error_json = json.dumps({"error": [str(e)]}) return error_json except Exception as e: error_json = json.dumps({"error": [str(e)]}) return error_json # Agent to generate and execute SQL queries sql_agent = Agent( model, deps_type=DatabaseSchema, output_type=str, system_prompt="""You are a highly precise SQL query generator for a SQLite database. You are given the EXACT database schema, which is a dictionary where keys are table names and values are lists of column dictionaries (with 'name' and 'type'). Your ABSOLUTE priority is to generate SQL queries that ONLY use the table and column names exactly as they appear in this schema to answer the user's question. Follow these strict steps: 1. **Analyze User Question:** Understand the user's request. 2. **Match Schema EXACTLY:** Identify the specific table(s) and column(s) in the provided schema whose names EXACTLY match the entities and information requested in the user's question. 3. **Generate STRICT SQL:** Construct a valid SQL query that selects the identified column(s) from the identified table(s). You MUST use the exact names from the schema. Do not use aliases or make any assumptions about naming conventions. Aim for the simplest possible query. 4. **Execute Query:** Use the execute_sql_query to run your generated SQL. 5. **Return interactive Answer as if you are a sports person:** Provide a direct and simple answer to the user's question based on the query results. 6. **No Results:** If the query returns empty list, respond with: 'No matching entries found.' 7. **Error Handling:** If there's any error in generating or executing the SQL, return a JSON string with an "error" key and a list of error messages. """ ) # Example: # Schema: {'Country': [{'name': 'id', 'type': 'INTEGER'}, {'name': 'name', 'type': 'TEXT'}]} # User Question: "What are the country names?" # Generated SQL: SELECT name FROM Country; # Expected Answer: The countries are Belgium, England, France, ... @sql_agent.tool async def execute_sql_query(ctx: RunContext[DatabaseSchema], query: str) -> str: """Executes the SQL query and returns a simple string answer.""" db_path = os.path.join(os.getcwd(), 'db.sqlite3') print(query) try: conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute(query) results = cursor.fetchall() columns = [description[0] for description in cursor.description] rows = [dict(zip(columns, row)) for row in results] conn.close() print(rows) return rows except Exception as e: print(e) async def main(): db_path = os.path.join(os.getcwd(), 'db.sqlite3') print(f"Database path: {db_path}") user_question = "how many cars do we have in the inventory" # 1. Get the database schema schema_result = await schema_agent.run(db_path) print("Schema Agent Response:", schema_result) print("Schema Agent Output:", schema_result.output) if "error" in schema_result.output: print(f"Error getting schema: {schema_result.output}") return try: schema_data = json.loads(schema_result.output) database_schema = DatabaseSchema(tables=schema_data) print("Parsed Database Schema:", database_schema) # 2. Use the schema to answer the user question sql_response = await sql_agent.run(user_question, database_schema=database_schema.tables) print("SQL Agent Response:", sql_response) print("SQL Agent Output:", sql_response.output) if "error" in sql_response.output: print(f"Error executing SQL: {sql_response.output}") except json.JSONDecodeError: print(f"Error: Could not parse schema agent response as JSON: {schema_result.output}") if __name__ == "__main__": asyncio.run(main())