haikal/haikalbot/ai_agent.py
2025-06-22 13:25:54 +03:00

866 lines
29 KiB
Python

from dataclasses import dataclass
from typing import List, Dict, Optional, Any, Union
from django.apps import apps
from django.db import models
from django.db.models import QuerySet, Q, F, Value, CharField, Sum, Avg, Count, Max, Min
from django.db.models.functions import Concat, Cast
from django.core.exceptions import FieldDoesNotExist
from django.core.serializers import serialize
from django.conf import settings
from langchain_ollama import ChatOllama
from langchain_core.messages import SystemMessage, HumanMessage
import json
import re
import logging
from functools import reduce
import operator
from sqlalchemy.orm import relationship
logger = logging.getLogger(__name__)
# Configuration settings
LLM_MODEL = getattr(settings, "MODEL_ANALYZER_LLM_MODEL", "qwen3:8b")
LLM_TEMPERATURE = getattr(settings, "MODEL_ANALYZER_LLM_TEMPERATURE", 0.3)
LLM_MAX_TOKENS = getattr(settings, "MODEL_ANALYZER_LLM_MAX_TOKENS", 2048)
CACHE_TIMEOUT = getattr(settings, "MODEL_ANALYZER_CACHE_TIMEOUT", 3600)
system_instruction = """
You are a specialized AI agent designed to analyze Django models and extract relevant information based on user input in Arabic or English. You must:
1. Model Analysis:
- Parse the user's natural language prompt to understand the analysis requirements
- Identify the relevant Django model(s) from the provided model structure
- Extract only the fields needed for the specific analysis
- Handle both direct fields and relationship fields appropriately
2. Field Selection:
- Determine relevant fields based on:
* Analysis type (count, average, sum, etc.)
* Explicit field mentions in the prompt
* Related fields needed for joins
* Common fields for the requested analysis type
3. Return Structure:
Return a JSON response with:
{
"status": "success",
"analysis_requirements": {
"app_label": "<django_app_name>",
"model_name": "<model_name>",
"fields": ["field1", "field2", ...],
"relationships": [{"field": "related_field", "type": "relation_type", "to": "related_model"}]
},
"language": "<ar|en>"
}
4. Analysis Types:
- COUNT queries: Return id field
- AGGREGATE queries (avg, sum): Return numeric fields
- DATE queries: Return date/timestamp fields
- RELATIONSHIP queries: Return foreign key and related fields
- TEXT queries: Return relevant text fields
5. Special Considerations:
- Handle both Arabic and English inputs
- Consider model relationships for joined queries
- Include only fields necessary for the requested analysis
- Support filtering and grouping requirements
"""
@dataclass
class FieldAnalysis:
name: str
field_type: str
is_required: bool
is_relation: bool
related_model: Optional[str] = None
analysis_relevance: float = 0.0
@dataclass
class ModelAnalysis:
app_label: str
model_name: str
relevant_fields: List[FieldAnalysis]
relationships: List[Dict[str, str]]
confidence_score: float
class DjangoModelAnalyzer:
def __init__(self):
self.analysis_patterns = {
"count": {
"patterns": [r"\b(count|number|how many)\b"],
"fields": ["id"],
"weight": 1.0,
},
"aggregate": {
"patterns": [r"\b(average|avg|mean|sum|total)\b"],
"fields": ["price", "amount", "value", "cost", "quantity"],
"weight": 0.8,
},
"temporal": {
"patterns": [r"\b(date|time|when|period)\b"],
"fields": ["created_at", "updated_at", "date", "timestamp"],
"weight": 0.7,
},
}
def analyze_prompt(self, prompt: str, model_structure: List) -> ModelAnalysis:
# Initialize LLM
llm = ChatOllama(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
# Get model analysis from LLM
messages = [
SystemMessage(content=system_instruction),
HumanMessage(content=prompt),
]
try:
response = llm.invoke(messages)
if (
not response
or not hasattr(response, "content")
or response.content is None
):
raise ValueError("Empty response from LLM")
analysis_requirements = self._parse_llm_response(response.content)
except Exception as e:
logger.error(f"Error in LLM analysis: {e}")
analysis_requirements = self._pattern_based_analysis(
prompt, model_structure
)
return self._enhance_analysis(analysis_requirements, model_structure)
def _parse_llm_response(self, response: str) -> Dict:
try:
json_match = re.search(r"({.*})", response.replace("\n", " "), re.DOTALL)
if json_match:
return json.loads(json_match.group(1))
return {}
except Exception as e:
logger.error(f"Error parsing LLM response: {e}")
return {}
def _pattern_based_analysis(self, prompt: str, model_structure: List) -> Dict:
analysis_type = None
relevant_fields = []
for analysis_name, config in self.analysis_patterns.items():
for pattern in config["patterns"]:
if re.search(pattern, prompt.lower()):
relevant_fields.extend(config["fields"])
analysis_type = analysis_name
break
if analysis_type:
break
return {
"analysis_type": analysis_type or "basic",
"fields": list(set(relevant_fields)) or ["id", "name"],
}
def _enhance_analysis(
self, requirements: Dict, model_structure: List
) -> ModelAnalysis:
app_label = requirements.get("analysis_requirements", {}).get("app_label")
model_name = requirements.get("analysis_requirements", {}).get("model_name")
fields = requirements.get("analysis_requirements", {}).get("fields") or []
if not isinstance(fields, list):
raise ValueError(f"Invalid fields in analysis requirements: {fields}")
try:
model = apps.get_model(app_label, model_name)
except LookupError as e:
logger.error(f"Model lookup error: {e}")
return None
relevant_fields = []
relationships = []
for field_name in fields:
try:
field = model._meta.get_field(field_name)
field_analysis = FieldAnalysis(
name=field_name,
field_type=field.get_internal_type(),
is_required=not field.null if hasattr(field, "null") else True,
is_relation=field.is_relation,
related_model=field.related_model.__name__
if field.is_relation
and hasattr(field, "related_model")
and field.related_model
else None,
)
field_analysis.analysis_relevance = self._calculate_field_relevance(
field_analysis, requirements.get("analysis_type", "basic")
)
relevant_fields.append(field_analysis)
if field.is_relation:
relationships.append(
{
"field": field_name,
"type": field.get_internal_type(),
"to": field.related_model.__name__
if hasattr(field, "related_model") and field.related_model
else "",
}
)
except FieldDoesNotExist:
logger.warning(f"Field {field_name} not found in {model_name}")
return ModelAnalysis(
app_label=app_label,
model_name=model_name,
relevant_fields=sorted(
relevant_fields, key=lambda x: x.analysis_relevance, reverse=True
),
relationships=relationships,
confidence_score=self._calculate_confidence_score(relevant_fields),
)
def _calculate_field_relevance(
self, field: FieldAnalysis, analysis_type: str
) -> float:
base_score = 0.5
if analysis_type in self.analysis_patterns:
if field.name in self.analysis_patterns[analysis_type]["fields"]:
base_score += self.analysis_patterns[analysis_type]["weight"]
if field.is_required:
base_score += 0.2
if field.is_relation:
base_score += 0.1
return min(base_score, 1.0)
def _calculate_confidence_score(self, fields: List[FieldAnalysis]) -> float:
if not fields:
return 0.0
return sum(field.analysis_relevance for field in fields) / len(fields)
def get_all_model_structures(filtered_apps: Optional[List[str]] = None) -> List[Dict]:
"""
Retrieve structure information for all Django models, optionally filtered by app names.
Args:
filtered_apps: Optional list of app names to filter models by
Returns:
List of dictionaries containing model structure information
"""
structures = []
for model in apps.get_models():
app_label = model._meta.app_label
if filtered_apps and app_label not in filtered_apps:
continue
fields = {}
relationships = []
for field in model._meta.get_fields():
if field.is_relation:
# Get related model name safely
related_model_name = None
if hasattr(field, "related_model") and field.related_model:
related_model_name = field.related_model.__name__
elif hasattr(field, "model") and field.model:
related_model_name = field.model.__name__
if (
related_model_name
): # Only add relationship if we have a valid related model
relationships.append(
{
"field": field.name,
"type": field.get_internal_type(),
"to": related_model_name,
}
)
else:
fields[field.name] = field.get_internal_type()
structures.append(
{
"app_label": app_label,
"model_name": model.__name__,
"fields": fields,
"relationships": relationships,
}
)
return structures
def apply_joins(queryset: QuerySet, joins: List[Dict[str, str]]) -> QuerySet:
"""
Apply joins to the queryset based on the provided join specifications.
Args:
queryset: The base queryset to apply joins to
joins: List of join specifications with path and type
Returns:
Updated queryset with joins applied
"""
if not joins:
return queryset
for join in joins:
path = join.get("path")
join_type = join.get("type", "LEFT").upper()
if not path:
continue
try:
if join_type == "LEFT":
queryset = queryset.select_related(path)
else:
queryset = queryset.prefetch_related(path)
except Exception as e:
logger.warning(f"Failed to apply join for {path}: {e}")
return queryset
def apply_filters(queryset: QuerySet, filters: Dict[str, Any]) -> QuerySet:
"""
Apply filters to queryset with advanced filter operations.
Args:
queryset: The base queryset to apply filters to
filters: Dictionary of field:value pairs or complex filter operations
Returns:
Filtered queryset
"""
if not filters:
return queryset
q_objects = []
for key, value in filters.items():
if isinstance(value, dict):
# Handle complex filters
operation = value.get("operation", "exact")
filter_value = value.get("value")
if not filter_value and operation != "isnull":
continue
if operation == "contains":
q_objects.append(Q(**{f"{key}__icontains": filter_value}))
elif operation == "in":
if isinstance(filter_value, list) and filter_value:
q_objects.append(Q(**{f"{key}__in": filter_value}))
elif operation in [
"gt",
"gte",
"lt",
"lte",
"exact",
"iexact",
"startswith",
"endswith",
]:
q_objects.append(Q(**{f"{key}__{operation}": filter_value}))
elif (
operation == "between"
and isinstance(filter_value, list)
and len(filter_value) >= 2
):
q_objects.append(
Q(
**{
f"{key}__gte": filter_value[0],
f"{key}__lte": filter_value[1],
}
)
)
elif operation == "isnull":
q_objects.append(Q(**{f"{key}__isnull": bool(filter_value)}))
else:
# Simple exact match
q_objects.append(Q(**{key: value}))
if not q_objects:
return queryset
return queryset.filter(reduce(operator.and_, q_objects))
def process_aggregation(
queryset: QuerySet,
aggregation: str,
fields: List[str],
group_by: Optional[List[str]] = None,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Process aggregation queries with support for grouping.
Args:
queryset: The base queryset to aggregate
aggregation: Aggregation type (sum, avg, count, max, min)
fields: Fields to aggregate
group_by: Optional fields to group by
Returns:
Dictionary of aggregation results or list of grouped results
"""
if not fields:
return {"error": "No fields specified for aggregation"}
agg_func_map = {"sum": Sum, "avg": Avg, "count": Count, "max": Max, "min": Min}
agg_func = agg_func_map.get(aggregation.lower())
if not agg_func:
return {"error": f"Unsupported aggregation: {aggregation}"}
try:
if group_by:
# Create aggregation dictionary for valid fields
agg_dict = {}
for field in fields:
if field not in group_by:
agg_dict[f"{aggregation}_{field}"] = agg_func(field)
if not agg_dict:
return {
"error": "No valid fields for aggregation after excluding group_by fields"
}
# Apply group_by and aggregation
return list(queryset.values(*group_by).annotate(**agg_dict))
else:
# Simple aggregation without grouping
return queryset.aggregate(
**{f"{aggregation}_{field}": agg_func(field) for field in fields}
)
except Exception as e:
logger.error(f"Aggregation error: {e}")
return {"error": f"Aggregation failed: {str(e)}"}
def prepare_chart_data(
data: List[Dict], fields: List[str], chart_type: str
) -> Optional[Dict[str, Any]]:
"""
Prepare data for chart visualization.
Args:
data: List of data dictionaries
fields: Fields to include in the chart
chart_type: Type of chart (pie, bar, line)
Returns:
Dictionary with chart configuration
"""
if not data or not fields or len(fields) < 1 or not chart_type:
return None
# Validate chart type
chart_type = chart_type.lower()
if chart_type not in ["pie", "bar", "line", "doughnut", "radar", "scatter"]:
chart_type = "bar" # Default to bar chart for unsupported types
try:
# For aggregation results that come as a dictionary
if isinstance(data, dict):
# Convert single dict to list for chart processing
labels = list(data.keys())
values = list(data.values())
return {
"type": chart_type,
"labels": [str(label).replace(f"{fields[0]}_", "") for label in labels],
"data": [
float(value) if isinstance(value, (int, float)) else 0
for value in values
],
"backgroundColor": [
"rgba(54, 162, 235, 0.6)",
"rgba(255, 99, 132, 0.6)",
"rgba(255, 206, 86, 0.6)",
"rgba(75, 192, 192, 0.6)",
"rgba(153, 102, 255, 0.6)",
"rgba(255, 159, 64, 0.6)",
],
}
# For regular query results as list of dictionaries
# Create labels from first field values
labels = [str(item.get(fields[0], "")) for item in data]
if chart_type == "pie" or chart_type == "doughnut":
# For pie charts, we need just one data series
data_values = []
for item in data:
# Use second field for values if available, otherwise use 1
if len(fields) > 1:
try:
value = float(item.get(fields[1], 0))
except (ValueError, TypeError):
value = 0
data_values.append(value)
else:
data_values.append(1) # Default count if no value field
return {
"type": chart_type,
"labels": labels,
"data": data_values,
"backgroundColor": [
"rgba(54, 162, 235, 0.6)",
"rgba(255, 99, 132, 0.6)",
"rgba(255, 206, 86, 0.6)",
"rgba(75, 192, 192, 0.6)",
"rgba(153, 102, 255, 0.6)",
"rgba(255, 159, 64, 0.6)",
]
* (len(data_values) // 6 + 1), # Repeat colors as needed
}
else:
# For other charts, create dataset for each field after the first
datasets = []
for i, field in enumerate(fields[1:], 1):
try:
dataset = {
"label": field,
"data": [float(item.get(field, 0) or 0) for item in data],
"backgroundColor": f"rgba({50 + i * 50}, {100 + i * 40}, 235, 0.6)",
"borderColor": f"rgba({50 + i * 50}, {100 + i * 40}, 235, 1.0)",
"borderWidth": 1,
}
datasets.append(dataset)
except (ValueError, TypeError) as e:
logger.warning(f"Error processing field {field} for chart: {e}")
return {"type": chart_type, "labels": labels, "datasets": datasets}
except Exception as e:
logger.error(f"Error preparing chart data: {e}")
return None
def query_django_model(parsed: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute Django model queries based on parsed analysis requirements.
Args:
parsed: Dictionary containing query parameters:
- app_label: Django app label
- model: Model name
- fields: List of fields to query
- filters: Query filters
- aggregation: Aggregation type
- chart: Chart type for visualization
- joins: List of joins to apply
- group_by: Fields to group by
- order_by: Fields to order by
- limit: Maximum number of results
Returns:
Dictionary with query results
"""
try:
# Extract parameters with defaults
app_label = parsed.get("app_label")
model_name = parsed.get("model_name")
fields = parsed.get("fields", [])
filters = parsed.get("filters", {})
aggregation = parsed.get("aggregation")
chart = parsed.get("chart")
joins = parsed.get("joins", [])
group_by = parsed.get("group_by", [])
order_by = parsed.get("order_by", [])
limit = int(parsed.get("limit", 1000))
language = parsed.get("language", "en")
# Validate required parameters
if not app_label or not model_name:
return {
"status": "error",
"error": "app_label and model are required",
"language": language,
}
# Get model class
try:
model = apps.get_model(app_label=app_label, model_name=model_name)
except LookupError:
return {
"status": "error",
"error": f"Model '{model_name}' not found in app '{app_label}'",
"language": language,
}
# Validate fields against model
if fields:
model_fields = [f.name for f in model._meta.fields]
invalid_fields = [f for f in fields if f not in model_fields]
if invalid_fields:
logger.warning(f"Invalid fields requested: {invalid_fields}")
fields = [f for f in fields if f in model_fields]
# Build queryset
queryset = model.objects.all()
# Apply joins
queryset = apply_joins(queryset, joins)
# Apply filters
if filters:
try:
queryset = apply_filters(queryset, filters)
except Exception as e:
logger.error(f"Error applying filters: {e}")
return {
"status": "error",
"error": f"Invalid filters: {str(e)}",
"language": language,
}
# Handle aggregations
if aggregation:
result = process_aggregation(queryset, aggregation, fields, group_by)
if isinstance(result, dict) and "error" in result:
return {
"status": "error",
"error": result["error"],
"language": language,
}
chart_data = None
if chart:
chart_data = prepare_chart_data(result, fields, chart)
return {
"status": "success",
"data": result,
"chart": chart_data,
"language": language,
}
# Handle regular queries
try:
# Apply field selection
if fields:
queryset = queryset.values(*fields)
# Apply ordering
if order_by:
queryset = queryset.order_by(*order_by)
# Apply limit (with safety check)
if limit <= 0:
limit = 1000
queryset = queryset[:limit]
# Convert queryset to list
data = list(queryset)
# Prepare chart data if needed
chart_data = None
if chart and data and fields:
chart_data = prepare_chart_data(data, fields, chart)
return {
"status": "success",
"data": data,
"count": len(data),
"chart": chart_data,
"metadata": {
"total_count": len(data),
"fields": fields,
"model": model_name,
"app": app_label,
},
"language": language,
}
except Exception as e:
logger.error(f"Error executing query: {e}")
return {
"status": "error",
"error": f"Query execution failed: {str(e)}",
"language": language,
}
except Exception as e:
logger.error(f"Unexpected error in query_django_model: {e}")
return {
"status": "error",
"error": f"Unexpected error: {str(e)}",
"language": parsed.get("language", "en"),
}
def determine_aggregation_type(
prompt: str, fields: List[FieldAnalysis]
) -> Optional[str]:
"""
Determine the appropriate aggregation type based on the prompt and fields.
Args:
prompt: User prompt text
fields: List of field analysis objects
Returns:
Aggregation type or None
"""
if any(
pattern in prompt.lower()
for pattern in ["average", "avg", "mean", "معدل", "متوسط"]
):
return "avg"
elif any(
pattern in prompt.lower() for pattern in ["sum", "total", "مجموع", "إجمالي"]
):
return "sum"
elif any(
pattern in prompt.lower()
for pattern in ["count", "number", "how many", "عدد", "كم"]
):
return "count"
elif any(
pattern in prompt.lower()
for pattern in ["maximum", "max", "highest", "أقصى", "أعلى"]
):
return "max"
elif any(
pattern in prompt.lower()
for pattern in ["minimum", "min", "lowest", "أدنى", "أقل"]
):
return "min"
# Check field types for numeric fields to determine default aggregation
numeric_fields = [
field
for field in fields
if field.field_type in ["DecimalField", "FloatField", "IntegerField"]
]
if numeric_fields:
return "sum" # Default to sum for numeric fields
return None
def determine_chart_type(prompt: str, fields: List[FieldAnalysis]) -> Optional[str]:
"""
Determine the appropriate chart type based on the prompt and fields.
Args:
prompt: User prompt text
fields: List of field analysis objects
Returns:
Chart type or None
"""
# Check for explicit chart type mentions in prompt
if any(
term in prompt.lower()
for term in ["line chart", "time series", "trend", "رسم خطي", "اتجاه"]
):
return "line"
elif any(
term in prompt.lower()
for term in ["bar chart", "histogram", "column", "رسم شريطي", "أعمدة"]
):
return "bar"
elif any(
term in prompt.lower()
for term in ["pie chart", "circle chart", "رسم دائري", "فطيرة"]
):
return "pie"
elif any(term in prompt.lower() for term in ["doughnut", "دونات"]):
return "doughnut"
elif any(term in prompt.lower() for term in ["radar", "spider", "رادار"]):
return "radar"
# Determine chart type based on field types and count
date_fields = [
field for field in fields if field.field_type in ["DateField", "DateTimeField"]
]
numeric_fields = [
field
for field in fields
if field.field_type in ["DecimalField", "FloatField", "IntegerField"]
]
if date_fields and numeric_fields:
return "line" # Time series data
elif len(fields) == 2 and len(numeric_fields) >= 1:
return "bar" # Category and value
elif len(fields) == 1 or (len(fields) == 2 and len(numeric_fields) == 1):
return "pie" # Single dimension data
elif len(fields) > 2:
return "bar" # Multi-dimensional data
# Default
return "bar"
def analyze_prompt(prompt: str) -> Dict[str, Any]:
"""
Analyze a natural language prompt and execute the appropriate Django model query.
Args:
prompt: Natural language prompt from user
Returns:
Dictionary with query results
"""
# Detect language
language = "ar" if bool(re.search(r"[\u0600-\u06FF]", prompt)) else "en"
filtered_apps = ["inventory"]
try:
analyzer = DjangoModelAnalyzer()
model_structure = get_all_model_structures(filtered_apps=filtered_apps)
print(model_structure)
analysis = analyzer.analyze_prompt(prompt, model_structure)
print(analysis)
if not analysis or not analysis.app_label or not analysis.model_name:
return {
"status": "error",
"message": "تعذر العثور على النموذج المطلوب"
if language == "ar"
else "Missing model information",
"language": language,
}
query_params = {
"app_label": analysis.app_label,
"model_name": analysis.model_name,
"fields": [field.name for field in analysis.relevant_fields],
"joins": [
{"path": rel["field"], "type": rel["type"]}
for rel in analysis.relationships
],
"filters": {},
"aggregation": determine_aggregation_type(prompt, analysis.relevant_fields),
"chart": determine_chart_type(prompt, analysis.relevant_fields),
"language": language,
}
return query_django_model(query_params)
except Exception as e:
logger.error(f"Error analyzing prompt: {e}")
return {
"status": "error",
"error": "حدث خطأ أثناء تحليل الاستعلام"
if language == "ar"
else f"Error analyzing prompt: {str(e)}",
"language": language,
}