866 lines
29 KiB
Python
866 lines
29 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Dict, Optional, Any, Union
|
|
from django.apps import apps
|
|
from django.db import models
|
|
from django.db.models import QuerySet, Q, F, Value, CharField, Sum, Avg, Count, Max, Min
|
|
from django.db.models.functions import Concat, Cast
|
|
from django.core.exceptions import FieldDoesNotExist
|
|
from django.core.serializers import serialize
|
|
from django.conf import settings
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
import json
|
|
import re
|
|
import logging
|
|
from functools import reduce
|
|
import operator
|
|
|
|
from sqlalchemy.orm import relationship
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration settings
|
|
LLM_MODEL = getattr(settings, "MODEL_ANALYZER_LLM_MODEL", "qwen3:8b")
|
|
LLM_TEMPERATURE = getattr(settings, "MODEL_ANALYZER_LLM_TEMPERATURE", 0.3)
|
|
LLM_MAX_TOKENS = getattr(settings, "MODEL_ANALYZER_LLM_MAX_TOKENS", 2048)
|
|
CACHE_TIMEOUT = getattr(settings, "MODEL_ANALYZER_CACHE_TIMEOUT", 3600)
|
|
|
|
system_instruction = """
|
|
You are a specialized AI agent designed to analyze Django models and extract relevant information based on user input in Arabic or English. You must:
|
|
|
|
1. Model Analysis:
|
|
- Parse the user's natural language prompt to understand the analysis requirements
|
|
- Identify the relevant Django model(s) from the provided model structure
|
|
- Extract only the fields needed for the specific analysis
|
|
- Handle both direct fields and relationship fields appropriately
|
|
|
|
2. Field Selection:
|
|
- Determine relevant fields based on:
|
|
* Analysis type (count, average, sum, etc.)
|
|
* Explicit field mentions in the prompt
|
|
* Related fields needed for joins
|
|
* Common fields for the requested analysis type
|
|
|
|
3. Return Structure:
|
|
Return a JSON response with:
|
|
{
|
|
"status": "success",
|
|
"analysis_requirements": {
|
|
"app_label": "<django_app_name>",
|
|
"model_name": "<model_name>",
|
|
"fields": ["field1", "field2", ...],
|
|
"relationships": [{"field": "related_field", "type": "relation_type", "to": "related_model"}]
|
|
},
|
|
"language": "<ar|en>"
|
|
}
|
|
|
|
4. Analysis Types:
|
|
- COUNT queries: Return id field
|
|
- AGGREGATE queries (avg, sum): Return numeric fields
|
|
- DATE queries: Return date/timestamp fields
|
|
- RELATIONSHIP queries: Return foreign key and related fields
|
|
- TEXT queries: Return relevant text fields
|
|
|
|
5. Special Considerations:
|
|
- Handle both Arabic and English inputs
|
|
- Consider model relationships for joined queries
|
|
- Include only fields necessary for the requested analysis
|
|
- Support filtering and grouping requirements
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class FieldAnalysis:
|
|
name: str
|
|
field_type: str
|
|
is_required: bool
|
|
is_relation: bool
|
|
related_model: Optional[str] = None
|
|
analysis_relevance: float = 0.0
|
|
|
|
|
|
@dataclass
|
|
class ModelAnalysis:
|
|
app_label: str
|
|
model_name: str
|
|
relevant_fields: List[FieldAnalysis]
|
|
relationships: List[Dict[str, str]]
|
|
confidence_score: float
|
|
|
|
|
|
class DjangoModelAnalyzer:
|
|
def __init__(self):
|
|
self.analysis_patterns = {
|
|
"count": {
|
|
"patterns": [r"\b(count|number|how many)\b"],
|
|
"fields": ["id"],
|
|
"weight": 1.0,
|
|
},
|
|
"aggregate": {
|
|
"patterns": [r"\b(average|avg|mean|sum|total)\b"],
|
|
"fields": ["price", "amount", "value", "cost", "quantity"],
|
|
"weight": 0.8,
|
|
},
|
|
"temporal": {
|
|
"patterns": [r"\b(date|time|when|period)\b"],
|
|
"fields": ["created_at", "updated_at", "date", "timestamp"],
|
|
"weight": 0.7,
|
|
},
|
|
}
|
|
|
|
def analyze_prompt(self, prompt: str, model_structure: List) -> ModelAnalysis:
|
|
# Initialize LLM
|
|
llm = ChatOllama(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
|
|
|
|
# Get model analysis from LLM
|
|
messages = [
|
|
SystemMessage(content=system_instruction),
|
|
HumanMessage(content=prompt),
|
|
]
|
|
|
|
try:
|
|
response = llm.invoke(messages)
|
|
|
|
if (
|
|
not response
|
|
or not hasattr(response, "content")
|
|
or response.content is None
|
|
):
|
|
raise ValueError("Empty response from LLM")
|
|
|
|
analysis_requirements = self._parse_llm_response(response.content)
|
|
except Exception as e:
|
|
logger.error(f"Error in LLM analysis: {e}")
|
|
analysis_requirements = self._pattern_based_analysis(
|
|
prompt, model_structure
|
|
)
|
|
|
|
return self._enhance_analysis(analysis_requirements, model_structure)
|
|
|
|
def _parse_llm_response(self, response: str) -> Dict:
|
|
try:
|
|
json_match = re.search(r"({.*})", response.replace("\n", " "), re.DOTALL)
|
|
if json_match:
|
|
return json.loads(json_match.group(1))
|
|
return {}
|
|
except Exception as e:
|
|
logger.error(f"Error parsing LLM response: {e}")
|
|
return {}
|
|
|
|
def _pattern_based_analysis(self, prompt: str, model_structure: List) -> Dict:
|
|
analysis_type = None
|
|
relevant_fields = []
|
|
|
|
for analysis_name, config in self.analysis_patterns.items():
|
|
for pattern in config["patterns"]:
|
|
if re.search(pattern, prompt.lower()):
|
|
relevant_fields.extend(config["fields"])
|
|
analysis_type = analysis_name
|
|
break
|
|
if analysis_type:
|
|
break
|
|
|
|
return {
|
|
"analysis_type": analysis_type or "basic",
|
|
"fields": list(set(relevant_fields)) or ["id", "name"],
|
|
}
|
|
|
|
def _enhance_analysis(
|
|
self, requirements: Dict, model_structure: List
|
|
) -> ModelAnalysis:
|
|
app_label = requirements.get("analysis_requirements", {}).get("app_label")
|
|
model_name = requirements.get("analysis_requirements", {}).get("model_name")
|
|
fields = requirements.get("analysis_requirements", {}).get("fields") or []
|
|
|
|
if not isinstance(fields, list):
|
|
raise ValueError(f"Invalid fields in analysis requirements: {fields}")
|
|
|
|
try:
|
|
model = apps.get_model(app_label, model_name)
|
|
except LookupError as e:
|
|
logger.error(f"Model lookup error: {e}")
|
|
return None
|
|
|
|
relevant_fields = []
|
|
relationships = []
|
|
|
|
for field_name in fields:
|
|
try:
|
|
field = model._meta.get_field(field_name)
|
|
field_analysis = FieldAnalysis(
|
|
name=field_name,
|
|
field_type=field.get_internal_type(),
|
|
is_required=not field.null if hasattr(field, "null") else True,
|
|
is_relation=field.is_relation,
|
|
related_model=field.related_model.__name__
|
|
if field.is_relation
|
|
and hasattr(field, "related_model")
|
|
and field.related_model
|
|
else None,
|
|
)
|
|
|
|
field_analysis.analysis_relevance = self._calculate_field_relevance(
|
|
field_analysis, requirements.get("analysis_type", "basic")
|
|
)
|
|
|
|
relevant_fields.append(field_analysis)
|
|
|
|
if field.is_relation:
|
|
relationships.append(
|
|
{
|
|
"field": field_name,
|
|
"type": field.get_internal_type(),
|
|
"to": field.related_model.__name__
|
|
if hasattr(field, "related_model") and field.related_model
|
|
else "",
|
|
}
|
|
)
|
|
|
|
except FieldDoesNotExist:
|
|
logger.warning(f"Field {field_name} not found in {model_name}")
|
|
|
|
return ModelAnalysis(
|
|
app_label=app_label,
|
|
model_name=model_name,
|
|
relevant_fields=sorted(
|
|
relevant_fields, key=lambda x: x.analysis_relevance, reverse=True
|
|
),
|
|
relationships=relationships,
|
|
confidence_score=self._calculate_confidence_score(relevant_fields),
|
|
)
|
|
|
|
def _calculate_field_relevance(
|
|
self, field: FieldAnalysis, analysis_type: str
|
|
) -> float:
|
|
base_score = 0.5
|
|
if analysis_type in self.analysis_patterns:
|
|
if field.name in self.analysis_patterns[analysis_type]["fields"]:
|
|
base_score += self.analysis_patterns[analysis_type]["weight"]
|
|
if field.is_required:
|
|
base_score += 0.2
|
|
if field.is_relation:
|
|
base_score += 0.1
|
|
return min(base_score, 1.0)
|
|
|
|
def _calculate_confidence_score(self, fields: List[FieldAnalysis]) -> float:
|
|
if not fields:
|
|
return 0.0
|
|
return sum(field.analysis_relevance for field in fields) / len(fields)
|
|
|
|
|
|
def get_all_model_structures(filtered_apps: Optional[List[str]] = None) -> List[Dict]:
|
|
"""
|
|
Retrieve structure information for all Django models, optionally filtered by app names.
|
|
|
|
Args:
|
|
filtered_apps: Optional list of app names to filter models by
|
|
|
|
Returns:
|
|
List of dictionaries containing model structure information
|
|
"""
|
|
structures = []
|
|
for model in apps.get_models():
|
|
app_label = model._meta.app_label
|
|
if filtered_apps and app_label not in filtered_apps:
|
|
continue
|
|
|
|
fields = {}
|
|
relationships = []
|
|
|
|
for field in model._meta.get_fields():
|
|
if field.is_relation:
|
|
# Get related model name safely
|
|
related_model_name = None
|
|
if hasattr(field, "related_model") and field.related_model:
|
|
related_model_name = field.related_model.__name__
|
|
elif hasattr(field, "model") and field.model:
|
|
related_model_name = field.model.__name__
|
|
|
|
if (
|
|
related_model_name
|
|
): # Only add relationship if we have a valid related model
|
|
relationships.append(
|
|
{
|
|
"field": field.name,
|
|
"type": field.get_internal_type(),
|
|
"to": related_model_name,
|
|
}
|
|
)
|
|
else:
|
|
fields[field.name] = field.get_internal_type()
|
|
|
|
structures.append(
|
|
{
|
|
"app_label": app_label,
|
|
"model_name": model.__name__,
|
|
"fields": fields,
|
|
"relationships": relationships,
|
|
}
|
|
)
|
|
|
|
return structures
|
|
|
|
|
|
def apply_joins(queryset: QuerySet, joins: List[Dict[str, str]]) -> QuerySet:
|
|
"""
|
|
Apply joins to the queryset based on the provided join specifications.
|
|
|
|
Args:
|
|
queryset: The base queryset to apply joins to
|
|
joins: List of join specifications with path and type
|
|
|
|
Returns:
|
|
Updated queryset with joins applied
|
|
"""
|
|
if not joins:
|
|
return queryset
|
|
|
|
for join in joins:
|
|
path = join.get("path")
|
|
join_type = join.get("type", "LEFT").upper()
|
|
|
|
if not path:
|
|
continue
|
|
|
|
try:
|
|
if join_type == "LEFT":
|
|
queryset = queryset.select_related(path)
|
|
else:
|
|
queryset = queryset.prefetch_related(path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to apply join for {path}: {e}")
|
|
|
|
return queryset
|
|
|
|
|
|
def apply_filters(queryset: QuerySet, filters: Dict[str, Any]) -> QuerySet:
|
|
"""
|
|
Apply filters to queryset with advanced filter operations.
|
|
|
|
Args:
|
|
queryset: The base queryset to apply filters to
|
|
filters: Dictionary of field:value pairs or complex filter operations
|
|
|
|
Returns:
|
|
Filtered queryset
|
|
"""
|
|
if not filters:
|
|
return queryset
|
|
|
|
q_objects = []
|
|
|
|
for key, value in filters.items():
|
|
if isinstance(value, dict):
|
|
# Handle complex filters
|
|
operation = value.get("operation", "exact")
|
|
filter_value = value.get("value")
|
|
|
|
if not filter_value and operation != "isnull":
|
|
continue
|
|
|
|
if operation == "contains":
|
|
q_objects.append(Q(**{f"{key}__icontains": filter_value}))
|
|
elif operation == "in":
|
|
if isinstance(filter_value, list) and filter_value:
|
|
q_objects.append(Q(**{f"{key}__in": filter_value}))
|
|
elif operation in [
|
|
"gt",
|
|
"gte",
|
|
"lt",
|
|
"lte",
|
|
"exact",
|
|
"iexact",
|
|
"startswith",
|
|
"endswith",
|
|
]:
|
|
q_objects.append(Q(**{f"{key}__{operation}": filter_value}))
|
|
elif (
|
|
operation == "between"
|
|
and isinstance(filter_value, list)
|
|
and len(filter_value) >= 2
|
|
):
|
|
q_objects.append(
|
|
Q(
|
|
**{
|
|
f"{key}__gte": filter_value[0],
|
|
f"{key}__lte": filter_value[1],
|
|
}
|
|
)
|
|
)
|
|
elif operation == "isnull":
|
|
q_objects.append(Q(**{f"{key}__isnull": bool(filter_value)}))
|
|
else:
|
|
# Simple exact match
|
|
q_objects.append(Q(**{key: value}))
|
|
|
|
if not q_objects:
|
|
return queryset
|
|
|
|
return queryset.filter(reduce(operator.and_, q_objects))
|
|
|
|
|
|
def process_aggregation(
|
|
queryset: QuerySet,
|
|
aggregation: str,
|
|
fields: List[str],
|
|
group_by: Optional[List[str]] = None,
|
|
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
|
"""
|
|
Process aggregation queries with support for grouping.
|
|
|
|
Args:
|
|
queryset: The base queryset to aggregate
|
|
aggregation: Aggregation type (sum, avg, count, max, min)
|
|
fields: Fields to aggregate
|
|
group_by: Optional fields to group by
|
|
|
|
Returns:
|
|
Dictionary of aggregation results or list of grouped results
|
|
"""
|
|
if not fields:
|
|
return {"error": "No fields specified for aggregation"}
|
|
|
|
agg_func_map = {"sum": Sum, "avg": Avg, "count": Count, "max": Max, "min": Min}
|
|
|
|
agg_func = agg_func_map.get(aggregation.lower())
|
|
if not agg_func:
|
|
return {"error": f"Unsupported aggregation: {aggregation}"}
|
|
|
|
try:
|
|
if group_by:
|
|
# Create aggregation dictionary for valid fields
|
|
agg_dict = {}
|
|
for field in fields:
|
|
if field not in group_by:
|
|
agg_dict[f"{aggregation}_{field}"] = agg_func(field)
|
|
|
|
if not agg_dict:
|
|
return {
|
|
"error": "No valid fields for aggregation after excluding group_by fields"
|
|
}
|
|
|
|
# Apply group_by and aggregation
|
|
return list(queryset.values(*group_by).annotate(**agg_dict))
|
|
else:
|
|
# Simple aggregation without grouping
|
|
return queryset.aggregate(
|
|
**{f"{aggregation}_{field}": agg_func(field) for field in fields}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Aggregation error: {e}")
|
|
return {"error": f"Aggregation failed: {str(e)}"}
|
|
|
|
|
|
def prepare_chart_data(
|
|
data: List[Dict], fields: List[str], chart_type: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Prepare data for chart visualization.
|
|
|
|
Args:
|
|
data: List of data dictionaries
|
|
fields: Fields to include in the chart
|
|
chart_type: Type of chart (pie, bar, line)
|
|
|
|
Returns:
|
|
Dictionary with chart configuration
|
|
"""
|
|
if not data or not fields or len(fields) < 1 or not chart_type:
|
|
return None
|
|
|
|
# Validate chart type
|
|
chart_type = chart_type.lower()
|
|
if chart_type not in ["pie", "bar", "line", "doughnut", "radar", "scatter"]:
|
|
chart_type = "bar" # Default to bar chart for unsupported types
|
|
|
|
try:
|
|
# For aggregation results that come as a dictionary
|
|
if isinstance(data, dict):
|
|
# Convert single dict to list for chart processing
|
|
labels = list(data.keys())
|
|
values = list(data.values())
|
|
|
|
return {
|
|
"type": chart_type,
|
|
"labels": [str(label).replace(f"{fields[0]}_", "") for label in labels],
|
|
"data": [
|
|
float(value) if isinstance(value, (int, float)) else 0
|
|
for value in values
|
|
],
|
|
"backgroundColor": [
|
|
"rgba(54, 162, 235, 0.6)",
|
|
"rgba(255, 99, 132, 0.6)",
|
|
"rgba(255, 206, 86, 0.6)",
|
|
"rgba(75, 192, 192, 0.6)",
|
|
"rgba(153, 102, 255, 0.6)",
|
|
"rgba(255, 159, 64, 0.6)",
|
|
],
|
|
}
|
|
|
|
# For regular query results as list of dictionaries
|
|
# Create labels from first field values
|
|
labels = [str(item.get(fields[0], "")) for item in data]
|
|
|
|
if chart_type == "pie" or chart_type == "doughnut":
|
|
# For pie charts, we need just one data series
|
|
data_values = []
|
|
for item in data:
|
|
# Use second field for values if available, otherwise use 1
|
|
if len(fields) > 1:
|
|
try:
|
|
value = float(item.get(fields[1], 0))
|
|
except (ValueError, TypeError):
|
|
value = 0
|
|
data_values.append(value)
|
|
else:
|
|
data_values.append(1) # Default count if no value field
|
|
|
|
return {
|
|
"type": chart_type,
|
|
"labels": labels,
|
|
"data": data_values,
|
|
"backgroundColor": [
|
|
"rgba(54, 162, 235, 0.6)",
|
|
"rgba(255, 99, 132, 0.6)",
|
|
"rgba(255, 206, 86, 0.6)",
|
|
"rgba(75, 192, 192, 0.6)",
|
|
"rgba(153, 102, 255, 0.6)",
|
|
"rgba(255, 159, 64, 0.6)",
|
|
]
|
|
* (len(data_values) // 6 + 1), # Repeat colors as needed
|
|
}
|
|
else:
|
|
# For other charts, create dataset for each field after the first
|
|
datasets = []
|
|
for i, field in enumerate(fields[1:], 1):
|
|
try:
|
|
dataset = {
|
|
"label": field,
|
|
"data": [float(item.get(field, 0) or 0) for item in data],
|
|
"backgroundColor": f"rgba({50 + i * 50}, {100 + i * 40}, 235, 0.6)",
|
|
"borderColor": f"rgba({50 + i * 50}, {100 + i * 40}, 235, 1.0)",
|
|
"borderWidth": 1,
|
|
}
|
|
datasets.append(dataset)
|
|
except (ValueError, TypeError) as e:
|
|
logger.warning(f"Error processing field {field} for chart: {e}")
|
|
|
|
return {"type": chart_type, "labels": labels, "datasets": datasets}
|
|
except Exception as e:
|
|
logger.error(f"Error preparing chart data: {e}")
|
|
return None
|
|
|
|
|
|
def query_django_model(parsed: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Execute Django model queries based on parsed analysis requirements.
|
|
|
|
Args:
|
|
parsed: Dictionary containing query parameters:
|
|
- app_label: Django app label
|
|
- model: Model name
|
|
- fields: List of fields to query
|
|
- filters: Query filters
|
|
- aggregation: Aggregation type
|
|
- chart: Chart type for visualization
|
|
- joins: List of joins to apply
|
|
- group_by: Fields to group by
|
|
- order_by: Fields to order by
|
|
- limit: Maximum number of results
|
|
|
|
Returns:
|
|
Dictionary with query results
|
|
"""
|
|
try:
|
|
# Extract parameters with defaults
|
|
app_label = parsed.get("app_label")
|
|
model_name = parsed.get("model_name")
|
|
fields = parsed.get("fields", [])
|
|
filters = parsed.get("filters", {})
|
|
aggregation = parsed.get("aggregation")
|
|
chart = parsed.get("chart")
|
|
joins = parsed.get("joins", [])
|
|
group_by = parsed.get("group_by", [])
|
|
order_by = parsed.get("order_by", [])
|
|
limit = int(parsed.get("limit", 1000))
|
|
language = parsed.get("language", "en")
|
|
|
|
# Validate required parameters
|
|
if not app_label or not model_name:
|
|
return {
|
|
"status": "error",
|
|
"error": "app_label and model are required",
|
|
"language": language,
|
|
}
|
|
|
|
# Get model class
|
|
try:
|
|
model = apps.get_model(app_label=app_label, model_name=model_name)
|
|
except LookupError:
|
|
return {
|
|
"status": "error",
|
|
"error": f"Model '{model_name}' not found in app '{app_label}'",
|
|
"language": language,
|
|
}
|
|
|
|
# Validate fields against model
|
|
if fields:
|
|
model_fields = [f.name for f in model._meta.fields]
|
|
invalid_fields = [f for f in fields if f not in model_fields]
|
|
if invalid_fields:
|
|
logger.warning(f"Invalid fields requested: {invalid_fields}")
|
|
fields = [f for f in fields if f in model_fields]
|
|
|
|
# Build queryset
|
|
queryset = model.objects.all()
|
|
|
|
# Apply joins
|
|
queryset = apply_joins(queryset, joins)
|
|
|
|
# Apply filters
|
|
if filters:
|
|
try:
|
|
queryset = apply_filters(queryset, filters)
|
|
except Exception as e:
|
|
logger.error(f"Error applying filters: {e}")
|
|
return {
|
|
"status": "error",
|
|
"error": f"Invalid filters: {str(e)}",
|
|
"language": language,
|
|
}
|
|
|
|
# Handle aggregations
|
|
if aggregation:
|
|
result = process_aggregation(queryset, aggregation, fields, group_by)
|
|
|
|
if isinstance(result, dict) and "error" in result:
|
|
return {
|
|
"status": "error",
|
|
"error": result["error"],
|
|
"language": language,
|
|
}
|
|
|
|
chart_data = None
|
|
if chart:
|
|
chart_data = prepare_chart_data(result, fields, chart)
|
|
|
|
return {
|
|
"status": "success",
|
|
"data": result,
|
|
"chart": chart_data,
|
|
"language": language,
|
|
}
|
|
|
|
# Handle regular queries
|
|
try:
|
|
# Apply field selection
|
|
if fields:
|
|
queryset = queryset.values(*fields)
|
|
|
|
# Apply ordering
|
|
if order_by:
|
|
queryset = queryset.order_by(*order_by)
|
|
|
|
# Apply limit (with safety check)
|
|
if limit <= 0:
|
|
limit = 1000
|
|
queryset = queryset[:limit]
|
|
|
|
# Convert queryset to list
|
|
data = list(queryset)
|
|
|
|
# Prepare chart data if needed
|
|
chart_data = None
|
|
if chart and data and fields:
|
|
chart_data = prepare_chart_data(data, fields, chart)
|
|
|
|
return {
|
|
"status": "success",
|
|
"data": data,
|
|
"count": len(data),
|
|
"chart": chart_data,
|
|
"metadata": {
|
|
"total_count": len(data),
|
|
"fields": fields,
|
|
"model": model_name,
|
|
"app": app_label,
|
|
},
|
|
"language": language,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing query: {e}")
|
|
return {
|
|
"status": "error",
|
|
"error": f"Query execution failed: {str(e)}",
|
|
"language": language,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in query_django_model: {e}")
|
|
return {
|
|
"status": "error",
|
|
"error": f"Unexpected error: {str(e)}",
|
|
"language": parsed.get("language", "en"),
|
|
}
|
|
|
|
|
|
def determine_aggregation_type(
|
|
prompt: str, fields: List[FieldAnalysis]
|
|
) -> Optional[str]:
|
|
"""
|
|
Determine the appropriate aggregation type based on the prompt and fields.
|
|
|
|
Args:
|
|
prompt: User prompt text
|
|
fields: List of field analysis objects
|
|
|
|
Returns:
|
|
Aggregation type or None
|
|
"""
|
|
if any(
|
|
pattern in prompt.lower()
|
|
for pattern in ["average", "avg", "mean", "معدل", "متوسط"]
|
|
):
|
|
return "avg"
|
|
elif any(
|
|
pattern in prompt.lower() for pattern in ["sum", "total", "مجموع", "إجمالي"]
|
|
):
|
|
return "sum"
|
|
elif any(
|
|
pattern in prompt.lower()
|
|
for pattern in ["count", "number", "how many", "عدد", "كم"]
|
|
):
|
|
return "count"
|
|
elif any(
|
|
pattern in prompt.lower()
|
|
for pattern in ["maximum", "max", "highest", "أقصى", "أعلى"]
|
|
):
|
|
return "max"
|
|
elif any(
|
|
pattern in prompt.lower()
|
|
for pattern in ["minimum", "min", "lowest", "أدنى", "أقل"]
|
|
):
|
|
return "min"
|
|
|
|
# Check field types for numeric fields to determine default aggregation
|
|
numeric_fields = [
|
|
field
|
|
for field in fields
|
|
if field.field_type in ["DecimalField", "FloatField", "IntegerField"]
|
|
]
|
|
if numeric_fields:
|
|
return "sum" # Default to sum for numeric fields
|
|
|
|
return None
|
|
|
|
|
|
def determine_chart_type(prompt: str, fields: List[FieldAnalysis]) -> Optional[str]:
|
|
"""
|
|
Determine the appropriate chart type based on the prompt and fields.
|
|
|
|
Args:
|
|
prompt: User prompt text
|
|
fields: List of field analysis objects
|
|
|
|
Returns:
|
|
Chart type or None
|
|
"""
|
|
# Check for explicit chart type mentions in prompt
|
|
if any(
|
|
term in prompt.lower()
|
|
for term in ["line chart", "time series", "trend", "رسم خطي", "اتجاه"]
|
|
):
|
|
return "line"
|
|
elif any(
|
|
term in prompt.lower()
|
|
for term in ["bar chart", "histogram", "column", "رسم شريطي", "أعمدة"]
|
|
):
|
|
return "bar"
|
|
elif any(
|
|
term in prompt.lower()
|
|
for term in ["pie chart", "circle chart", "رسم دائري", "فطيرة"]
|
|
):
|
|
return "pie"
|
|
elif any(term in prompt.lower() for term in ["doughnut", "دونات"]):
|
|
return "doughnut"
|
|
elif any(term in prompt.lower() for term in ["radar", "spider", "رادار"]):
|
|
return "radar"
|
|
|
|
# Determine chart type based on field types and count
|
|
date_fields = [
|
|
field for field in fields if field.field_type in ["DateField", "DateTimeField"]
|
|
]
|
|
numeric_fields = [
|
|
field
|
|
for field in fields
|
|
if field.field_type in ["DecimalField", "FloatField", "IntegerField"]
|
|
]
|
|
|
|
if date_fields and numeric_fields:
|
|
return "line" # Time series data
|
|
elif len(fields) == 2 and len(numeric_fields) >= 1:
|
|
return "bar" # Category and value
|
|
elif len(fields) == 1 or (len(fields) == 2 and len(numeric_fields) == 1):
|
|
return "pie" # Single dimension data
|
|
elif len(fields) > 2:
|
|
return "bar" # Multi-dimensional data
|
|
|
|
# Default
|
|
return "bar"
|
|
|
|
|
|
def analyze_prompt(prompt: str) -> Dict[str, Any]:
|
|
"""
|
|
Analyze a natural language prompt and execute the appropriate Django model query.
|
|
|
|
Args:
|
|
prompt: Natural language prompt from user
|
|
|
|
Returns:
|
|
Dictionary with query results
|
|
"""
|
|
# Detect language
|
|
language = "ar" if bool(re.search(r"[\u0600-\u06FF]", prompt)) else "en"
|
|
filtered_apps = ["inventory"]
|
|
try:
|
|
analyzer = DjangoModelAnalyzer()
|
|
model_structure = get_all_model_structures(filtered_apps=filtered_apps)
|
|
print(model_structure)
|
|
analysis = analyzer.analyze_prompt(prompt, model_structure)
|
|
print(analysis)
|
|
|
|
if not analysis or not analysis.app_label or not analysis.model_name:
|
|
return {
|
|
"status": "error",
|
|
"message": "تعذر العثور على النموذج المطلوب"
|
|
if language == "ar"
|
|
else "Missing model information",
|
|
"language": language,
|
|
}
|
|
|
|
query_params = {
|
|
"app_label": analysis.app_label,
|
|
"model_name": analysis.model_name,
|
|
"fields": [field.name for field in analysis.relevant_fields],
|
|
"joins": [
|
|
{"path": rel["field"], "type": rel["type"]}
|
|
for rel in analysis.relationships
|
|
],
|
|
"filters": {},
|
|
"aggregation": determine_aggregation_type(prompt, analysis.relevant_fields),
|
|
"chart": determine_chart_type(prompt, analysis.relevant_fields),
|
|
"language": language,
|
|
}
|
|
|
|
return query_django_model(query_params)
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing prompt: {e}")
|
|
return {
|
|
"status": "error",
|
|
"error": "حدث خطأ أثناء تحليل الاستعلام"
|
|
if language == "ar"
|
|
else f"Error analyzing prompt: {str(e)}",
|
|
"language": language,
|
|
}
|