801 lines
30 KiB
Python
801 lines
30 KiB
Python
import asyncio
|
|
import sqlite3
|
|
import json
|
|
import re
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, Union
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
import os
|
|
from functools import reduce
|
|
import operator
|
|
|
|
# Pydantic and AI imports
|
|
from pydantic import BaseModel, Field
|
|
from pydantic_ai import Agent, RunContext
|
|
from pydantic_ai.models.openai import OpenAIModel
|
|
from pydantic_ai.providers.openai import OpenAIProvider
|
|
|
|
# Optional Django imports (if available)
|
|
try:
|
|
from django.apps import apps
|
|
from django.db import models, connection
|
|
from django.db.models import QuerySet, Q, F, Sum, Avg, Count, Max, Min
|
|
from django.core.exceptions import FieldDoesNotExist
|
|
from django.conf import settings
|
|
|
|
DJANGO_AVAILABLE = True
|
|
except ImportError:
|
|
DJANGO_AVAILABLE = False
|
|
|
|
# Optional database drivers
|
|
try:
|
|
import psycopg2
|
|
|
|
POSTGRESQL_AVAILABLE = True
|
|
except ImportError:
|
|
POSTGRESQL_AVAILABLE = False
|
|
|
|
try:
|
|
import pymysql
|
|
|
|
MYSQL_AVAILABLE = True
|
|
except ImportError:
|
|
MYSQL_AVAILABLE = False
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Configuration
|
|
class DatabaseConfig:
|
|
LLM_MODEL = settings.MODEL_ANALYZER_LLM_MODEL
|
|
LLM_BASE_URL = "http://localhost:11434/v1"
|
|
LLM_TEMPERATURE = 0.3
|
|
MAX_RESULTS = 1000
|
|
SUPPORTED_CHART_TYPES = ["bar", "line", "pie", "doughnut", "radar", "scatter"]
|
|
|
|
|
|
class DatabaseType(Enum):
|
|
SQLITE = "sqlite"
|
|
POSTGRESQL = "postgresql"
|
|
MYSQL = "mysql"
|
|
|
|
|
|
|
|
@dataclass
|
|
class DatabaseConnection:
|
|
db_type: DatabaseType
|
|
connection_string: str
|
|
database_name: Optional[str] = None
|
|
host: Optional[str] = None
|
|
port: Optional[int] = None
|
|
user: Optional[str] = None
|
|
password: Optional[str] = None
|
|
schema_info: Optional[Dict] = None
|
|
|
|
|
|
@dataclass
|
|
class QueryResult:
|
|
status: str
|
|
data: Union[List[Dict], Dict]
|
|
metadata: Dict[str, Any]
|
|
chart_data: Optional[Dict] = None
|
|
language: str = "en"
|
|
error: Optional[str] = None
|
|
|
|
def to_dict(self):
|
|
"""Convert to dictionary for JSON serialization."""
|
|
return asdict(self)
|
|
|
|
|
|
class DatabaseSchema(BaseModel):
|
|
tables: Dict[str, List[Dict[str, Any]]] = Field(
|
|
description="Database schema with table names as keys and column info as values"
|
|
)
|
|
relationships: Optional[List[Dict[str, Any]]] = Field(
|
|
default=None,
|
|
description="Foreign key relationships between tables"
|
|
)
|
|
|
|
|
|
class InsightRequest(BaseModel):
|
|
prompt: str = Field(description="Natural language query from user")
|
|
database_path: Optional[str] = Field(default=None, description="Path to database file (for SQLite)")
|
|
chart_type: Optional[str] = Field(default=None, description="Preferred chart type")
|
|
limit: Optional[int] = Field(default=1000, description="Maximum number of results")
|
|
language: Optional[str] = Field(default="auto", description="Response language preference")
|
|
use_django: Optional[bool] = Field(default=True, description="Use Django database if available")
|
|
|
|
|
|
class DatabaseInsightSystem:
|
|
def __init__(self, config: DatabaseConfig = None):
|
|
self.config = config or DatabaseConfig()
|
|
self.model = OpenAIModel(
|
|
model_name=self.config.LLM_MODEL,
|
|
provider=OpenAIProvider(base_url=self.config.LLM_BASE_URL)
|
|
)
|
|
self.db_connection = None
|
|
self._setup_agents()
|
|
|
|
def _setup_agents(self):
|
|
"""Initialize the AI agents for schema analysis and query generation."""
|
|
|
|
# Query generation and execution agent
|
|
self.query_agent = Agent(
|
|
self.model,
|
|
deps_type=DatabaseSchema,
|
|
output_type=str,
|
|
system_prompt="""You are an intelligent database query generator and analyst.
|
|
Given a natural language prompt and database schema, you must:
|
|
|
|
1. ANALYZE the user's request in English or Arabic
|
|
2. IDENTIFY relevant tables and columns from the schema
|
|
3. GENERATE appropriate SQL query or analysis approach
|
|
4. DETERMINE if aggregation, grouping, or joins are needed
|
|
5. SUGGEST appropriate visualization type
|
|
6. EXECUTE the query and provide insights
|
|
|
|
Response format should be JSON:
|
|
{
|
|
"analysis": "Brief analysis of the request",
|
|
"query_type": "select|aggregate|join|complex",
|
|
"sql_query": "Generated SQL query",
|
|
"chart_suggestion": "bar|line|pie|etc",
|
|
"expected_fields": ["field1", "field2"],
|
|
"language": "en|ar"
|
|
}
|
|
|
|
Handle both English and Arabic prompts. For Arabic text, respond in Arabic.
|
|
Focus on providing actionable insights, not just raw data."""
|
|
)
|
|
|
|
def _get_django_database_config(self) -> Optional[DatabaseConnection]:
|
|
"""Extract database configuration from Django settings."""
|
|
if not DJANGO_AVAILABLE:
|
|
return None
|
|
|
|
try:
|
|
# Get default database configuration
|
|
db_config = settings.DATABASES.get('default', {})
|
|
if not db_config:
|
|
logger.warning("No default database configuration found in Django settings")
|
|
return None
|
|
|
|
engine = db_config.get('ENGINE', '')
|
|
db_name = db_config.get('NAME', '')
|
|
host = db_config.get('HOST', 'localhost')
|
|
port = db_config.get('PORT', None)
|
|
user = db_config.get('USER', '')
|
|
password = db_config.get('PASSWORD', '')
|
|
|
|
# Determine database type from engine
|
|
if 'sqlite' in engine.lower():
|
|
db_type = DatabaseType.SQLITE
|
|
connection_string = db_name # For SQLite, NAME is the file path
|
|
elif 'postgresql' in engine.lower():
|
|
db_type = DatabaseType.POSTGRESQL
|
|
port = port or 5432
|
|
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
|
|
elif 'mysql' in engine.lower():
|
|
db_type = DatabaseType.MYSQL
|
|
port = port or 3306
|
|
connection_string = f"mysql://{user}:{password}@{host}:{port}/{db_name}"
|
|
else:
|
|
logger.warning(f"Unsupported database engine: {engine}")
|
|
return None
|
|
|
|
return DatabaseConnection(
|
|
db_type=db_type,
|
|
connection_string=connection_string,
|
|
database_name=db_name,
|
|
host=host,
|
|
port=port,
|
|
user=user,
|
|
password=password
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get Django database config: {e}")
|
|
return None
|
|
|
|
def analyze_database_schema_sync(self, request: InsightRequest) -> DatabaseSchema:
|
|
"""Synchronous wrapper for schema analysis."""
|
|
return asyncio.run(self.analyze_database_schema(request))
|
|
|
|
async def analyze_database_schema(self, request: InsightRequest) -> DatabaseSchema:
|
|
"""Extract and analyze database schema."""
|
|
try:
|
|
# Try Django first if available and requested
|
|
if request.use_django and DJANGO_AVAILABLE:
|
|
django_config = self._get_django_database_config()
|
|
if django_config:
|
|
self.db_connection = django_config
|
|
return await self._analyze_django_schema()
|
|
|
|
# Fallback to direct database connection
|
|
if request.database_path:
|
|
# Assume SQLite for direct file path
|
|
self.db_connection = DatabaseConnection(
|
|
db_type=DatabaseType.SQLITE,
|
|
connection_string=request.database_path
|
|
)
|
|
return await self._analyze_sqlite_schema(request.database_path)
|
|
|
|
raise ValueError("No database configuration available")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Schema analysis failed: {e}")
|
|
raise
|
|
|
|
async def _analyze_sqlite_schema(self, db_path: str) -> DatabaseSchema:
|
|
"""Analyze SQLite database schema."""
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Get table names
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
|
|
schema_data = {}
|
|
relationships = []
|
|
|
|
for table in tables:
|
|
# Get column information
|
|
cursor.execute(f"PRAGMA table_info({table})")
|
|
columns = []
|
|
for col in cursor.fetchall():
|
|
columns.append({
|
|
"name": col[1],
|
|
"type": col[2],
|
|
"notnull": bool(col[3]),
|
|
"default_value": col[4],
|
|
"primary_key": bool(col[5])
|
|
})
|
|
schema_data[table] = columns
|
|
|
|
# Get foreign key relationships
|
|
cursor.execute(f"PRAGMA foreign_key_list({table})")
|
|
for fk in cursor.fetchall():
|
|
relationships.append({
|
|
"from_table": table,
|
|
"from_column": fk[3],
|
|
"to_table": fk[2],
|
|
"to_column": fk[4]
|
|
})
|
|
|
|
conn.close()
|
|
return DatabaseSchema(tables=schema_data, relationships=relationships)
|
|
|
|
except Exception as e:
|
|
logger.error(f"SQLite schema analysis failed: {e}")
|
|
raise
|
|
|
|
async def _analyze_django_schema(self) -> DatabaseSchema:
|
|
"""Analyze Django models schema."""
|
|
if not DJANGO_AVAILABLE:
|
|
raise ImportError("Django is not available")
|
|
|
|
schema_data = {}
|
|
relationships = []
|
|
|
|
for model in apps.get_models():
|
|
table_name = model._meta.db_table
|
|
columns = []
|
|
|
|
for field in model._meta.get_fields():
|
|
if not field.is_relation:
|
|
columns.append({
|
|
"name": field.name,
|
|
"type": field.get_internal_type(),
|
|
"notnull": not getattr(field, 'null', True),
|
|
"primary_key": getattr(field, 'primary_key', False)
|
|
})
|
|
else:
|
|
# Handle relationships
|
|
if hasattr(field, 'related_model') and field.related_model:
|
|
relationships.append({
|
|
"from_table": table_name,
|
|
"from_column": field.name,
|
|
"to_table": field.related_model._meta.db_table,
|
|
"relationship_type": field.get_internal_type()
|
|
})
|
|
|
|
schema_data[table_name] = columns
|
|
|
|
return DatabaseSchema(tables=schema_data, relationships=relationships)
|
|
|
|
async def _analyze_postgresql_schema(self, connection_string: str) -> DatabaseSchema:
|
|
"""Analyze PostgreSQL database schema."""
|
|
if not POSTGRESQL_AVAILABLE:
|
|
raise ImportError("psycopg2 is not available")
|
|
|
|
try:
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
conn = psycopg2.connect(connection_string)
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Get table names
|
|
cursor.execute("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
""")
|
|
tables = [row['table_name'] for row in cursor.fetchall()]
|
|
|
|
schema_data = {}
|
|
relationships = []
|
|
|
|
for table in tables:
|
|
# Get column information
|
|
cursor.execute("""
|
|
SELECT column_name, data_type, is_nullable, column_default
|
|
FROM information_schema.columns
|
|
WHERE table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (table,))
|
|
|
|
columns = []
|
|
for col in cursor.fetchall():
|
|
columns.append({
|
|
"name": col['column_name'],
|
|
"type": col['data_type'],
|
|
"notnull": col['is_nullable'] == 'NO',
|
|
"default_value": col['column_default'],
|
|
"primary_key": False # Will be updated below
|
|
})
|
|
|
|
# Get primary key information
|
|
cursor.execute("""
|
|
SELECT column_name
|
|
FROM information_schema.key_column_usage
|
|
WHERE table_name = %s
|
|
AND constraint_name LIKE '%_pkey'
|
|
""", (table,))
|
|
|
|
pk_columns = [row['column_name'] for row in cursor.fetchall()]
|
|
for col in columns:
|
|
if col['name'] in pk_columns:
|
|
col['primary_key'] = True
|
|
|
|
schema_data[table] = columns
|
|
|
|
# Get foreign key relationships
|
|
cursor.execute("""
|
|
SELECT kcu.column_name,
|
|
ccu.table_name AS foreign_table_name,
|
|
ccu.column_name AS foreign_column_name
|
|
FROM information_schema.table_constraints AS tc
|
|
JOIN information_schema.key_column_usage AS kcu
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
JOIN information_schema.constraint_column_usage AS ccu
|
|
ON ccu.constraint_name = tc.constraint_name
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
AND tc.table_name = %s
|
|
""", (table,))
|
|
|
|
for fk in cursor.fetchall():
|
|
relationships.append({
|
|
"from_table": table,
|
|
"from_column": fk['column_name'],
|
|
"to_table": fk['foreign_table_name'],
|
|
"to_column": fk['foreign_column_name']
|
|
})
|
|
|
|
conn.close()
|
|
return DatabaseSchema(tables=schema_data, relationships=relationships)
|
|
|
|
except Exception as e:
|
|
logger.error(f"PostgreSQL schema analysis failed: {e}")
|
|
raise
|
|
|
|
async def _analyze_mysql_schema(self, connection_string: str) -> DatabaseSchema:
|
|
"""Analyze MySQL database schema."""
|
|
if not MYSQL_AVAILABLE:
|
|
raise ImportError("pymysql is not available")
|
|
|
|
try:
|
|
import pymysql
|
|
|
|
# Parse connection string to get connection parameters
|
|
# Format: mysql://user:password@host:port/database
|
|
import urllib.parse
|
|
parsed = urllib.parse.urlparse(connection_string)
|
|
|
|
conn = pymysql.connect(
|
|
host=parsed.hostname,
|
|
port=parsed.port or 3306,
|
|
user=parsed.username,
|
|
password=parsed.password,
|
|
database=parsed.path[1:], # Remove leading slash
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
)
|
|
|
|
cursor = conn.cursor()
|
|
|
|
# Get table names
|
|
cursor.execute("SHOW TABLES")
|
|
tables = [list(row.values())[0] for row in cursor.fetchall()]
|
|
|
|
schema_data = {}
|
|
relationships = []
|
|
|
|
for table in tables:
|
|
# Get column information
|
|
cursor.execute(f"DESCRIBE {table}")
|
|
columns = []
|
|
for col in cursor.fetchall():
|
|
columns.append({
|
|
"name": col['Field'],
|
|
"type": col['Type'],
|
|
"notnull": col['Null'] == 'NO',
|
|
"default_value": col['Default'],
|
|
"primary_key": col['Key'] == 'PRI'
|
|
})
|
|
|
|
schema_data[table] = columns
|
|
|
|
# Get foreign key relationships
|
|
cursor.execute(f"""
|
|
SELECT
|
|
COLUMN_NAME,
|
|
REFERENCED_TABLE_NAME,
|
|
REFERENCED_COLUMN_NAME
|
|
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
|
WHERE TABLE_NAME = '{table}'
|
|
AND REFERENCED_TABLE_NAME IS NOT NULL
|
|
""")
|
|
|
|
for fk in cursor.fetchall():
|
|
relationships.append({
|
|
"from_table": table,
|
|
"from_column": fk['COLUMN_NAME'],
|
|
"to_table": fk['REFERENCED_TABLE_NAME'],
|
|
"to_column": fk['REFERENCED_COLUMN_NAME']
|
|
})
|
|
|
|
conn.close()
|
|
return DatabaseSchema(tables=schema_data, relationships=relationships)
|
|
|
|
except Exception as e:
|
|
logger.error(f"MySQL schema analysis failed: {e}")
|
|
raise
|
|
|
|
def _detect_language(self, text: str) -> str:
|
|
"""Detect if text is Arabic or English."""
|
|
arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
|
|
return "ar" if len(arabic_chars) > len(text) * 0.3 else "en"
|
|
|
|
def _execute_query_sync(self, query: str) -> List[Dict]:
|
|
"""Synchronous wrapper for query execution."""
|
|
return asyncio.run(self._execute_query(query))
|
|
|
|
async def _execute_query(self, query: str) -> List[Dict]:
|
|
"""Execute query based on the current database connection."""
|
|
if not self.db_connection:
|
|
raise ValueError("No database connection established")
|
|
|
|
if self.db_connection.db_type == DatabaseType.SQLITE:
|
|
return await self._execute_sqlite_query(self.db_connection.connection_string, query)
|
|
# elif self.db_connection.db_type == DatabaseType.DJANGO and DJANGO_AVAILABLE:
|
|
# return await self._execute_django_query(query)
|
|
elif self.db_connection.db_type == DatabaseType.POSTGRESQL:
|
|
return await self._execute_postgresql_query(self.db_connection.connection_string, query)
|
|
elif self.db_connection.db_type == DatabaseType.MYSQL:
|
|
return await self._execute_mysql_query(self.db_connection.connection_string, query)
|
|
else:
|
|
raise ValueError(f"Unsupported database type: {self.db_connection.db_type}")
|
|
|
|
async def _execute_sqlite_query(self, db_path: str, query: str) -> List[Dict]:
|
|
"""Execute SQL query on SQLite database."""
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute(query)
|
|
|
|
# Get column names
|
|
columns = [description[0] for description in cursor.description]
|
|
|
|
# Fetch results and convert to dictionaries
|
|
results = cursor.fetchall()
|
|
data = [dict(zip(columns, row)) for row in results]
|
|
|
|
conn.close()
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"SQLite query execution failed: {e}")
|
|
raise
|
|
|
|
async def _execute_django_query(self, query: str) -> List[Dict]:
|
|
"""Execute raw SQL query using Django's database connection."""
|
|
try:
|
|
from django.db import connection
|
|
|
|
with connection.cursor() as cursor:
|
|
cursor.execute(query)
|
|
columns = [col[0] for col in cursor.description]
|
|
results = cursor.fetchall()
|
|
data = [dict(zip(columns, row)) for row in results]
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Django query execution failed: {e}")
|
|
raise
|
|
|
|
async def _execute_postgresql_query(self, connection_string: str, query: str) -> List[Dict]:
|
|
"""Execute SQL query on PostgreSQL database."""
|
|
try:
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
conn = psycopg2.connect(connection_string)
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
cursor.execute(query)
|
|
|
|
results = cursor.fetchall()
|
|
data = [dict(row) for row in results]
|
|
|
|
conn.close()
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"PostgreSQL query execution failed: {e}")
|
|
raise
|
|
|
|
async def _execute_mysql_query(self, connection_string: str, query: str) -> List[Dict]:
|
|
"""Execute SQL query on MySQL database."""
|
|
try:
|
|
import pymysql
|
|
import urllib.parse
|
|
|
|
parsed = urllib.parse.urlparse(connection_string)
|
|
|
|
conn = pymysql.connect(
|
|
host=parsed.hostname,
|
|
port=parsed.port or 3306,
|
|
user=parsed.username,
|
|
password=parsed.password,
|
|
database=parsed.path[1:],
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
)
|
|
|
|
cursor = conn.cursor()
|
|
cursor.execute(query)
|
|
results = cursor.fetchall()
|
|
|
|
conn.close()
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"MySQL query execution failed: {e}")
|
|
raise
|
|
|
|
def _prepare_chart_data(self, data: List[Dict], chart_type: str, fields: List[str]) -> Optional[Dict]:
|
|
"""Prepare data for chart visualization."""
|
|
if not data or not fields:
|
|
return None
|
|
|
|
chart_type = chart_type.lower()
|
|
if chart_type not in self.config.SUPPORTED_CHART_TYPES:
|
|
chart_type = "bar"
|
|
|
|
try:
|
|
# Extract labels and values
|
|
labels = []
|
|
datasets = []
|
|
|
|
if len(fields) >= 1:
|
|
labels = [str(item.get(fields[0], "")) for item in data]
|
|
|
|
if chart_type in ["pie", "doughnut"]:
|
|
# Single dataset for pie charts
|
|
values = []
|
|
for item in data:
|
|
if len(fields) > 1:
|
|
try:
|
|
value = float(item.get(fields[1], 0) or 0)
|
|
except (ValueError, TypeError):
|
|
value = 1
|
|
values.append(value)
|
|
else:
|
|
values.append(1)
|
|
|
|
return {
|
|
"type": chart_type,
|
|
"labels": labels,
|
|
"data": values,
|
|
"backgroundColor": [
|
|
f"rgba({50 + i * 30}, {100 + i * 25}, {200 + i * 20}, 0.7)"
|
|
for i in range(len(values))
|
|
]
|
|
}
|
|
else:
|
|
# Multiple datasets for other chart types
|
|
for i, field in enumerate(fields[1:], 1):
|
|
try:
|
|
dataset_values = []
|
|
for item in data:
|
|
try:
|
|
value = float(item.get(field, 0) or 0)
|
|
except (ValueError, TypeError):
|
|
value = 0
|
|
dataset_values.append(value)
|
|
|
|
datasets.append({
|
|
"label": field,
|
|
"data": dataset_values,
|
|
"backgroundColor": f"rgba({50 + i * 40}, {100 + i * 30}, 235, 0.6)",
|
|
"borderColor": f"rgba({50 + i * 40}, {100 + i * 30}, 235, 1.0)",
|
|
"borderWidth": 2
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Error processing field {field}: {e}")
|
|
|
|
return {
|
|
"type": chart_type,
|
|
"labels": labels,
|
|
"datasets": datasets
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chart preparation failed: {e}")
|
|
return None
|
|
|
|
def get_insights_sync(self, request: InsightRequest) -> Dict[str, Any]:
|
|
"""Synchronous wrapper for get_insights - for Django views."""
|
|
try:
|
|
result = asyncio.run(self.get_insights(request))
|
|
return result.to_dict()
|
|
except Exception as e:
|
|
logger.error(f"Synchronous insight generation failed: {e}")
|
|
return {
|
|
"status": "error",
|
|
"data": [],
|
|
"metadata": {},
|
|
"error": str(e),
|
|
"language": "en"
|
|
}
|
|
|
|
async def get_insights(self, request: InsightRequest) -> QueryResult:
|
|
"""Main method to get database insights from natural language prompt."""
|
|
try:
|
|
# Detect language
|
|
language = self._detect_language(request.prompt) if request.language == "auto" else request.language
|
|
|
|
# Analyze database schema
|
|
schema = await self.analyze_database_schema(request)
|
|
|
|
# Generate query plan using AI
|
|
query_response = await self.query_agent.run(
|
|
f"User prompt: {request.prompt}\nLanguage: {language}",
|
|
database_schema=schema
|
|
)
|
|
|
|
# Parse AI response
|
|
try:
|
|
query_plan = json.loads(query_response.output)
|
|
except json.JSONDecodeError:
|
|
# Fallback: extract SQL from response
|
|
sql_match = re.search(r'SELECT.*?;', query_response.output, re.IGNORECASE | re.DOTALL)
|
|
if sql_match:
|
|
query_plan = {
|
|
"sql_query": sql_match.group(0),
|
|
"chart_suggestion": "bar",
|
|
"expected_fields": [],
|
|
"language": language
|
|
}
|
|
else:
|
|
raise ValueError("Could not parse AI response")
|
|
|
|
# Execute query
|
|
sql_query = query_plan.get("sql_query", "")
|
|
if not sql_query:
|
|
raise ValueError("No SQL query generated")
|
|
|
|
data = await self._execute_query(sql_query)
|
|
|
|
# Prepare chart data
|
|
chart_data = None
|
|
chart_type = request.chart_type or query_plan.get("chart_suggestion", "bar")
|
|
expected_fields = query_plan.get("expected_fields", [])
|
|
|
|
if data and expected_fields:
|
|
chart_data = self._prepare_chart_data(data, chart_type, expected_fields)
|
|
elif data:
|
|
# Use first few fields if no specific fields suggested
|
|
available_fields = list(data[0].keys()) if data else []
|
|
chart_data = self._prepare_chart_data(data, chart_type, available_fields[:3])
|
|
|
|
# Prepare result
|
|
return QueryResult(
|
|
status="success",
|
|
data=data[:request.limit] if data else [],
|
|
metadata={
|
|
"total_count": len(data) if data else 0,
|
|
"query": sql_query,
|
|
"analysis": query_plan.get("analysis", ""),
|
|
"fields": expected_fields or (list(data[0].keys()) if data else []),
|
|
"database_type": self.db_connection.db_type.value if self.db_connection else "unknown"
|
|
},
|
|
chart_data=chart_data,
|
|
language=language
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Insight generation failed: {e}")
|
|
return QueryResult(
|
|
status="error",
|
|
data=[],
|
|
metadata={},
|
|
error=str(e),
|
|
language=language if 'language' in locals() else "en"
|
|
)
|
|
|
|
# # Static method for Django view compatibility
|
|
# @staticmethod
|
|
# def get_insights(django_request, prompt: str, **kwargs) -> Dict[str, Any]:
|
|
# """
|
|
# Static method compatible with your Django view.
|
|
# This method signature matches what your view is calling.
|
|
#
|
|
# Args:
|
|
# django_request: Django HttpRequest object (not used but kept for compatibility)
|
|
# prompt: Natural language query string
|
|
# **kwargs: Additional parameters
|
|
#
|
|
# Returns:
|
|
# Dictionary with query results
|
|
# """
|
|
# try:
|
|
# # Create system instance
|
|
# system = DatabaseInsightSystem()
|
|
#
|
|
# # Extract language from Django request if available
|
|
# language = "auto"
|
|
# if hasattr(django_request, 'LANGUAGE_CODE'):
|
|
# language = django_request.LANGUAGE_CODE
|
|
#
|
|
# # Create insight request
|
|
# insight_request = InsightRequest(
|
|
# prompt=prompt,
|
|
# language=language,
|
|
# use_django=True,
|
|
# **kwargs
|
|
# )
|
|
#
|
|
# # Get insights synchronously
|
|
# return system.get_insights_sync(insight_request)
|
|
#
|
|
# except Exception as e:
|
|
# logger.error(f"Static get_insights failed: {e}")
|
|
# return {
|
|
# "status": "error",
|
|
# "data": [],
|
|
# "metadata": {},
|
|
# "error": str(e),
|
|
# "language": language if 'language' in locals() else "en"
|
|
# }
|
|
|
|
|
|
# Convenience function for Django views (alternative approach)
|
|
def analyze_prompt_sync(prompt: str, **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Synchronous function to analyze a prompt and return insights.
|
|
Perfect for Django views.
|
|
|
|
Args:
|
|
prompt: Natural language query
|
|
**kwargs: Additional parameters for InsightRequest
|
|
|
|
Returns:
|
|
Dictionary with query results
|
|
"""
|
|
system = DatabaseInsightSystem()
|
|
request = InsightRequest(prompt=prompt, **kwargs)
|
|
return system.get_insights_sync(request) |