haikal/haikalbot/haikal_agent.py
2025-06-13 01:58:40 +03:00

801 lines
30 KiB
Python

import asyncio
import sqlite3
import json
import re
import logging
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, asdict
from enum import Enum
import os
from functools import reduce
import operator
# Pydantic and AI imports
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
# Optional Django imports (if available)
try:
from django.apps import apps
from django.db import models, connection
from django.db.models import QuerySet, Q, F, Sum, Avg, Count, Max, Min
from django.core.exceptions import FieldDoesNotExist
from django.conf import settings
DJANGO_AVAILABLE = True
except ImportError:
DJANGO_AVAILABLE = False
# Optional database drivers
try:
import psycopg2
POSTGRESQL_AVAILABLE = True
except ImportError:
POSTGRESQL_AVAILABLE = False
try:
import pymysql
MYSQL_AVAILABLE = True
except ImportError:
MYSQL_AVAILABLE = False
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
class DatabaseConfig:
LLM_MODEL = settings.MODEL_ANALYZER_LLM_MODEL
LLM_BASE_URL = "http://localhost:11434/v1"
LLM_TEMPERATURE = 0.3
MAX_RESULTS = 1000
SUPPORTED_CHART_TYPES = ["bar", "line", "pie", "doughnut", "radar", "scatter"]
class DatabaseType(Enum):
SQLITE = "sqlite"
POSTGRESQL = "postgresql"
MYSQL = "mysql"
@dataclass
class DatabaseConnection:
db_type: DatabaseType
connection_string: str
database_name: Optional[str] = None
host: Optional[str] = None
port: Optional[int] = None
user: Optional[str] = None
password: Optional[str] = None
schema_info: Optional[Dict] = None
@dataclass
class QueryResult:
status: str
data: Union[List[Dict], Dict]
metadata: Dict[str, Any]
chart_data: Optional[Dict] = None
language: str = "en"
error: Optional[str] = None
def to_dict(self):
"""Convert to dictionary for JSON serialization."""
return asdict(self)
class DatabaseSchema(BaseModel):
tables: Dict[str, List[Dict[str, Any]]] = Field(
description="Database schema with table names as keys and column info as values"
)
relationships: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Foreign key relationships between tables"
)
class InsightRequest(BaseModel):
prompt: str = Field(description="Natural language query from user")
database_path: Optional[str] = Field(default=None, description="Path to database file (for SQLite)")
chart_type: Optional[str] = Field(default=None, description="Preferred chart type")
limit: Optional[int] = Field(default=1000, description="Maximum number of results")
language: Optional[str] = Field(default="auto", description="Response language preference")
use_django: Optional[bool] = Field(default=True, description="Use Django database if available")
class DatabaseInsightSystem:
def __init__(self, config: DatabaseConfig = None):
self.config = config or DatabaseConfig()
self.model = OpenAIModel(
model_name=self.config.LLM_MODEL,
provider=OpenAIProvider(base_url=self.config.LLM_BASE_URL)
)
self.db_connection = None
self._setup_agents()
def _setup_agents(self):
"""Initialize the AI agents for schema analysis and query generation."""
# Query generation and execution agent
self.query_agent = Agent(
self.model,
deps_type=DatabaseSchema,
output_type=str,
system_prompt="""You are an intelligent database query generator and analyst.
Given a natural language prompt and database schema, you must:
1. ANALYZE the user's request in English or Arabic
2. IDENTIFY relevant tables and columns from the schema
3. GENERATE appropriate SQL query or analysis approach
4. DETERMINE if aggregation, grouping, or joins are needed
5. SUGGEST appropriate visualization type
6. EXECUTE the query and provide insights
Response format should be JSON:
{
"analysis": "Brief analysis of the request",
"query_type": "select|aggregate|join|complex",
"sql_query": "Generated SQL query",
"chart_suggestion": "bar|line|pie|etc",
"expected_fields": ["field1", "field2"],
"language": "en|ar"
}
Handle both English and Arabic prompts. For Arabic text, respond in Arabic.
Focus on providing actionable insights, not just raw data."""
)
def _get_django_database_config(self) -> Optional[DatabaseConnection]:
"""Extract database configuration from Django settings."""
if not DJANGO_AVAILABLE:
return None
try:
# Get default database configuration
db_config = settings.DATABASES.get('default', {})
if not db_config:
logger.warning("No default database configuration found in Django settings")
return None
engine = db_config.get('ENGINE', '')
db_name = db_config.get('NAME', '')
host = db_config.get('HOST', 'localhost')
port = db_config.get('PORT', None)
user = db_config.get('USER', '')
password = db_config.get('PASSWORD', '')
# Determine database type from engine
if 'sqlite' in engine.lower():
db_type = DatabaseType.SQLITE
connection_string = db_name # For SQLite, NAME is the file path
elif 'postgresql' in engine.lower():
db_type = DatabaseType.POSTGRESQL
port = port or 5432
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
elif 'mysql' in engine.lower():
db_type = DatabaseType.MYSQL
port = port or 3306
connection_string = f"mysql://{user}:{password}@{host}:{port}/{db_name}"
else:
logger.warning(f"Unsupported database engine: {engine}")
return None
return DatabaseConnection(
db_type=db_type,
connection_string=connection_string,
database_name=db_name,
host=host,
port=port,
user=user,
password=password
)
except Exception as e:
logger.error(f"Failed to get Django database config: {e}")
return None
def analyze_database_schema_sync(self, request: InsightRequest) -> DatabaseSchema:
"""Synchronous wrapper for schema analysis."""
return asyncio.run(self.analyze_database_schema(request))
async def analyze_database_schema(self, request: InsightRequest) -> DatabaseSchema:
"""Extract and analyze database schema."""
try:
# Try Django first if available and requested
if request.use_django and DJANGO_AVAILABLE:
django_config = self._get_django_database_config()
if django_config:
self.db_connection = django_config
return await self._analyze_django_schema()
# Fallback to direct database connection
if request.database_path:
# Assume SQLite for direct file path
self.db_connection = DatabaseConnection(
db_type=DatabaseType.SQLITE,
connection_string=request.database_path
)
return await self._analyze_sqlite_schema(request.database_path)
raise ValueError("No database configuration available")
except Exception as e:
logger.error(f"Schema analysis failed: {e}")
raise
async def _analyze_sqlite_schema(self, db_path: str) -> DatabaseSchema:
"""Analyze SQLite database schema."""
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
schema_data = {}
relationships = []
for table in tables:
# Get column information
cursor.execute(f"PRAGMA table_info({table})")
columns = []
for col in cursor.fetchall():
columns.append({
"name": col[1],
"type": col[2],
"notnull": bool(col[3]),
"default_value": col[4],
"primary_key": bool(col[5])
})
schema_data[table] = columns
# Get foreign key relationships
cursor.execute(f"PRAGMA foreign_key_list({table})")
for fk in cursor.fetchall():
relationships.append({
"from_table": table,
"from_column": fk[3],
"to_table": fk[2],
"to_column": fk[4]
})
conn.close()
return DatabaseSchema(tables=schema_data, relationships=relationships)
except Exception as e:
logger.error(f"SQLite schema analysis failed: {e}")
raise
async def _analyze_django_schema(self) -> DatabaseSchema:
"""Analyze Django models schema."""
if not DJANGO_AVAILABLE:
raise ImportError("Django is not available")
schema_data = {}
relationships = []
for model in apps.get_models():
table_name = model._meta.db_table
columns = []
for field in model._meta.get_fields():
if not field.is_relation:
columns.append({
"name": field.name,
"type": field.get_internal_type(),
"notnull": not getattr(field, 'null', True),
"primary_key": getattr(field, 'primary_key', False)
})
else:
# Handle relationships
if hasattr(field, 'related_model') and field.related_model:
relationships.append({
"from_table": table_name,
"from_column": field.name,
"to_table": field.related_model._meta.db_table,
"relationship_type": field.get_internal_type()
})
schema_data[table_name] = columns
return DatabaseSchema(tables=schema_data, relationships=relationships)
async def _analyze_postgresql_schema(self, connection_string: str) -> DatabaseSchema:
"""Analyze PostgreSQL database schema."""
if not POSTGRESQL_AVAILABLE:
raise ImportError("psycopg2 is not available")
try:
import psycopg2
from psycopg2.extras import RealDictCursor
conn = psycopg2.connect(connection_string)
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Get table names
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
""")
tables = [row['table_name'] for row in cursor.fetchall()]
schema_data = {}
relationships = []
for table in tables:
# Get column information
cursor.execute("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
""", (table,))
columns = []
for col in cursor.fetchall():
columns.append({
"name": col['column_name'],
"type": col['data_type'],
"notnull": col['is_nullable'] == 'NO',
"default_value": col['column_default'],
"primary_key": False # Will be updated below
})
# Get primary key information
cursor.execute("""
SELECT column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND constraint_name LIKE '%_pkey'
""", (table,))
pk_columns = [row['column_name'] for row in cursor.fetchall()]
for col in columns:
if col['name'] in pk_columns:
col['primary_key'] = True
schema_data[table] = columns
# Get foreign key relationships
cursor.execute("""
SELECT kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = %s
""", (table,))
for fk in cursor.fetchall():
relationships.append({
"from_table": table,
"from_column": fk['column_name'],
"to_table": fk['foreign_table_name'],
"to_column": fk['foreign_column_name']
})
conn.close()
return DatabaseSchema(tables=schema_data, relationships=relationships)
except Exception as e:
logger.error(f"PostgreSQL schema analysis failed: {e}")
raise
async def _analyze_mysql_schema(self, connection_string: str) -> DatabaseSchema:
"""Analyze MySQL database schema."""
if not MYSQL_AVAILABLE:
raise ImportError("pymysql is not available")
try:
import pymysql
# Parse connection string to get connection parameters
# Format: mysql://user:password@host:port/database
import urllib.parse
parsed = urllib.parse.urlparse(connection_string)
conn = pymysql.connect(
host=parsed.hostname,
port=parsed.port or 3306,
user=parsed.username,
password=parsed.password,
database=parsed.path[1:], # Remove leading slash
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
# Get table names
cursor.execute("SHOW TABLES")
tables = [list(row.values())[0] for row in cursor.fetchall()]
schema_data = {}
relationships = []
for table in tables:
# Get column information
cursor.execute(f"DESCRIBE {table}")
columns = []
for col in cursor.fetchall():
columns.append({
"name": col['Field'],
"type": col['Type'],
"notnull": col['Null'] == 'NO',
"default_value": col['Default'],
"primary_key": col['Key'] == 'PRI'
})
schema_data[table] = columns
# Get foreign key relationships
cursor.execute(f"""
SELECT
COLUMN_NAME,
REFERENCED_TABLE_NAME,
REFERENCED_COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_NAME = '{table}'
AND REFERENCED_TABLE_NAME IS NOT NULL
""")
for fk in cursor.fetchall():
relationships.append({
"from_table": table,
"from_column": fk['COLUMN_NAME'],
"to_table": fk['REFERENCED_TABLE_NAME'],
"to_column": fk['REFERENCED_COLUMN_NAME']
})
conn.close()
return DatabaseSchema(tables=schema_data, relationships=relationships)
except Exception as e:
logger.error(f"MySQL schema analysis failed: {e}")
raise
def _detect_language(self, text: str) -> str:
"""Detect if text is Arabic or English."""
arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
return "ar" if len(arabic_chars) > len(text) * 0.3 else "en"
def _execute_query_sync(self, query: str) -> List[Dict]:
"""Synchronous wrapper for query execution."""
return asyncio.run(self._execute_query(query))
async def _execute_query(self, query: str) -> List[Dict]:
"""Execute query based on the current database connection."""
if not self.db_connection:
raise ValueError("No database connection established")
if self.db_connection.db_type == DatabaseType.SQLITE:
return await self._execute_sqlite_query(self.db_connection.connection_string, query)
# elif self.db_connection.db_type == DatabaseType.DJANGO and DJANGO_AVAILABLE:
# return await self._execute_django_query(query)
elif self.db_connection.db_type == DatabaseType.POSTGRESQL:
return await self._execute_postgresql_query(self.db_connection.connection_string, query)
elif self.db_connection.db_type == DatabaseType.MYSQL:
return await self._execute_mysql_query(self.db_connection.connection_string, query)
else:
raise ValueError(f"Unsupported database type: {self.db_connection.db_type}")
async def _execute_sqlite_query(self, db_path: str, query: str) -> List[Dict]:
"""Execute SQL query on SQLite database."""
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(query)
# Get column names
columns = [description[0] for description in cursor.description]
# Fetch results and convert to dictionaries
results = cursor.fetchall()
data = [dict(zip(columns, row)) for row in results]
conn.close()
return data
except Exception as e:
logger.error(f"SQLite query execution failed: {e}")
raise
async def _execute_django_query(self, query: str) -> List[Dict]:
"""Execute raw SQL query using Django's database connection."""
try:
from django.db import connection
with connection.cursor() as cursor:
cursor.execute(query)
columns = [col[0] for col in cursor.description]
results = cursor.fetchall()
data = [dict(zip(columns, row)) for row in results]
return data
except Exception as e:
logger.error(f"Django query execution failed: {e}")
raise
async def _execute_postgresql_query(self, connection_string: str, query: str) -> List[Dict]:
"""Execute SQL query on PostgreSQL database."""
try:
import psycopg2
from psycopg2.extras import RealDictCursor
conn = psycopg2.connect(connection_string)
cursor = conn.cursor(cursor_factory=RealDictCursor)
cursor.execute(query)
results = cursor.fetchall()
data = [dict(row) for row in results]
conn.close()
return data
except Exception as e:
logger.error(f"PostgreSQL query execution failed: {e}")
raise
async def _execute_mysql_query(self, connection_string: str, query: str) -> List[Dict]:
"""Execute SQL query on MySQL database."""
try:
import pymysql
import urllib.parse
parsed = urllib.parse.urlparse(connection_string)
conn = pymysql.connect(
host=parsed.hostname,
port=parsed.port or 3306,
user=parsed.username,
password=parsed.password,
database=parsed.path[1:],
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
conn.close()
return results
except Exception as e:
logger.error(f"MySQL query execution failed: {e}")
raise
def _prepare_chart_data(self, data: List[Dict], chart_type: str, fields: List[str]) -> Optional[Dict]:
"""Prepare data for chart visualization."""
if not data or not fields:
return None
chart_type = chart_type.lower()
if chart_type not in self.config.SUPPORTED_CHART_TYPES:
chart_type = "bar"
try:
# Extract labels and values
labels = []
datasets = []
if len(fields) >= 1:
labels = [str(item.get(fields[0], "")) for item in data]
if chart_type in ["pie", "doughnut"]:
# Single dataset for pie charts
values = []
for item in data:
if len(fields) > 1:
try:
value = float(item.get(fields[1], 0) or 0)
except (ValueError, TypeError):
value = 1
values.append(value)
else:
values.append(1)
return {
"type": chart_type,
"labels": labels,
"data": values,
"backgroundColor": [
f"rgba({50 + i * 30}, {100 + i * 25}, {200 + i * 20}, 0.7)"
for i in range(len(values))
]
}
else:
# Multiple datasets for other chart types
for i, field in enumerate(fields[1:], 1):
try:
dataset_values = []
for item in data:
try:
value = float(item.get(field, 0) or 0)
except (ValueError, TypeError):
value = 0
dataset_values.append(value)
datasets.append({
"label": field,
"data": dataset_values,
"backgroundColor": f"rgba({50 + i * 40}, {100 + i * 30}, 235, 0.6)",
"borderColor": f"rgba({50 + i * 40}, {100 + i * 30}, 235, 1.0)",
"borderWidth": 2
})
except Exception as e:
logger.warning(f"Error processing field {field}: {e}")
return {
"type": chart_type,
"labels": labels,
"datasets": datasets
}
except Exception as e:
logger.error(f"Chart preparation failed: {e}")
return None
def get_insights_sync(self, request: InsightRequest) -> Dict[str, Any]:
"""Synchronous wrapper for get_insights - for Django views."""
try:
result = asyncio.run(self.get_insights(request))
return result.to_dict()
except Exception as e:
logger.error(f"Synchronous insight generation failed: {e}")
return {
"status": "error",
"data": [],
"metadata": {},
"error": str(e),
"language": "en"
}
async def get_insights(self, request: InsightRequest) -> QueryResult:
"""Main method to get database insights from natural language prompt."""
try:
# Detect language
language = self._detect_language(request.prompt) if request.language == "auto" else request.language
# Analyze database schema
schema = await self.analyze_database_schema(request)
# Generate query plan using AI
query_response = await self.query_agent.run(
f"User prompt: {request.prompt}\nLanguage: {language}",
database_schema=schema
)
# Parse AI response
try:
query_plan = json.loads(query_response.output)
except json.JSONDecodeError:
# Fallback: extract SQL from response
sql_match = re.search(r'SELECT.*?;', query_response.output, re.IGNORECASE | re.DOTALL)
if sql_match:
query_plan = {
"sql_query": sql_match.group(0),
"chart_suggestion": "bar",
"expected_fields": [],
"language": language
}
else:
raise ValueError("Could not parse AI response")
# Execute query
sql_query = query_plan.get("sql_query", "")
if not sql_query:
raise ValueError("No SQL query generated")
data = await self._execute_query(sql_query)
# Prepare chart data
chart_data = None
chart_type = request.chart_type or query_plan.get("chart_suggestion", "bar")
expected_fields = query_plan.get("expected_fields", [])
if data and expected_fields:
chart_data = self._prepare_chart_data(data, chart_type, expected_fields)
elif data:
# Use first few fields if no specific fields suggested
available_fields = list(data[0].keys()) if data else []
chart_data = self._prepare_chart_data(data, chart_type, available_fields[:3])
# Prepare result
return QueryResult(
status="success",
data=data[:request.limit] if data else [],
metadata={
"total_count": len(data) if data else 0,
"query": sql_query,
"analysis": query_plan.get("analysis", ""),
"fields": expected_fields or (list(data[0].keys()) if data else []),
"database_type": self.db_connection.db_type.value if self.db_connection else "unknown"
},
chart_data=chart_data,
language=language
)
except Exception as e:
logger.error(f"Insight generation failed: {e}")
return QueryResult(
status="error",
data=[],
metadata={},
error=str(e),
language=language if 'language' in locals() else "en"
)
# # Static method for Django view compatibility
# @staticmethod
# def get_insights(django_request, prompt: str, **kwargs) -> Dict[str, Any]:
# """
# Static method compatible with your Django view.
# This method signature matches what your view is calling.
#
# Args:
# django_request: Django HttpRequest object (not used but kept for compatibility)
# prompt: Natural language query string
# **kwargs: Additional parameters
#
# Returns:
# Dictionary with query results
# """
# try:
# # Create system instance
# system = DatabaseInsightSystem()
#
# # Extract language from Django request if available
# language = "auto"
# if hasattr(django_request, 'LANGUAGE_CODE'):
# language = django_request.LANGUAGE_CODE
#
# # Create insight request
# insight_request = InsightRequest(
# prompt=prompt,
# language=language,
# use_django=True,
# **kwargs
# )
#
# # Get insights synchronously
# return system.get_insights_sync(insight_request)
#
# except Exception as e:
# logger.error(f"Static get_insights failed: {e}")
# return {
# "status": "error",
# "data": [],
# "metadata": {},
# "error": str(e),
# "language": language if 'language' in locals() else "en"
# }
# Convenience function for Django views (alternative approach)
def analyze_prompt_sync(prompt: str, **kwargs) -> Dict[str, Any]:
"""
Synchronous function to analyze a prompt and return insights.
Perfect for Django views.
Args:
prompt: Natural language query
**kwargs: Additional parameters for InsightRequest
Returns:
Dictionary with query results
"""
system = DatabaseInsightSystem()
request = InsightRequest(prompt=prompt, **kwargs)
return system.get_insights_sync(request)