207 lines
8.1 KiB
Python
207 lines
8.1 KiB
Python
import os
|
|
import django
|
|
|
|
|
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "car_inventory.settings")
|
|
django.setup()
|
|
|
|
from django.test import TestCase, RequestFactory
|
|
from django.contrib.auth.models import User
|
|
from django.http import JsonResponse
|
|
import json
|
|
|
|
from unittest.mock import patch, MagicMock
|
|
from haikalbot.views import ModelAnalystView
|
|
from haikalbot.models import AnalysisCache
|
|
|
|
|
|
class ModelAnalystViewTest(TestCase):
|
|
def setUp(self):
|
|
self.factory = RequestFactory()
|
|
self.user = User.objects.create_user(
|
|
username="testuser", email="test@example.com", password="testpass"
|
|
)
|
|
self.superuser = User.objects.create_superuser(
|
|
username="admin", email="admin@example.com", password="adminpass"
|
|
)
|
|
self.view = ModelAnalystView()
|
|
|
|
def test_post_without_prompt(self):
|
|
"""Test that the view returns an error when no prompt is provided."""
|
|
request = self.factory.post(
|
|
"/analyze/", data=json.dumps({}), content_type="application/json"
|
|
)
|
|
request.user = self.user
|
|
|
|
response = self.view.post(request)
|
|
|
|
self.assertEqual(response.status_code, 400)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content["status"], "error")
|
|
self.assertEqual(content["message"], "Prompt is required")
|
|
|
|
def test_post_with_invalid_json(self):
|
|
"""Test that the view handles invalid JSON properly."""
|
|
request = self.factory.post(
|
|
"/analyze/", data="invalid json", content_type="application/json"
|
|
)
|
|
request.user = self.user
|
|
|
|
response = self.view.post(request)
|
|
|
|
self.assertEqual(response.status_code, 400)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content["status"], "error")
|
|
self.assertEqual(content["message"], "Invalid JSON in request body")
|
|
|
|
@patch("ai_analyst.views.ModelAnalystView._process_prompt")
|
|
@patch("ai_analyst.views.ModelAnalystView._check_permissions")
|
|
@patch("ai_analyst.views.ModelAnalystView._generate_hash")
|
|
@patch("ai_analyst.views.ModelAnalystView._get_cached_result")
|
|
@patch("ai_analyst.views.ModelAnalystView._cache_result")
|
|
def test_post_with_valid_prompt(
|
|
self,
|
|
mock_cache_result,
|
|
mock_get_cached,
|
|
mock_generate_hash,
|
|
mock_check_permissions,
|
|
mock_process_prompt,
|
|
):
|
|
"""Test that the view processes a valid prompt correctly."""
|
|
# Setup mocks
|
|
mock_check_permissions.return_value = True
|
|
mock_generate_hash.return_value = "test_hash"
|
|
mock_get_cached.return_value = None
|
|
mock_process_prompt.return_value = {
|
|
"status": "success",
|
|
"insights": [{"type": "test_insight"}],
|
|
}
|
|
|
|
# Create request
|
|
request = self.factory.post(
|
|
"/analyze/",
|
|
data=json.dumps({"prompt": "How many cars do we have?", "dealer_id": 1}),
|
|
content_type="application/json",
|
|
)
|
|
request.user = self.user
|
|
|
|
# Call view
|
|
response = self.view.post(request)
|
|
|
|
# Assertions
|
|
self.assertEqual(response.status_code, 200)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content["status"], "success")
|
|
self.assertEqual(len(content["insights"]), 1)
|
|
|
|
# Verify function calls
|
|
mock_check_permissions.assert_called_once_with(self.user, 1)
|
|
mock_generate_hash.assert_called_once_with("How many cars do we have?", 1)
|
|
mock_get_cached.assert_called_once_with("test_hash", self.user, 1)
|
|
mock_process_prompt.assert_called_once_with(
|
|
"How many cars do we have?", self.user, 1
|
|
)
|
|
mock_cache_result.assert_called_once()
|
|
|
|
@patch("ai_analyst.views.ModelAnalystView._get_cached_result")
|
|
@patch("ai_analyst.views.ModelAnalystView._check_permissions")
|
|
@patch("ai_analyst.views.ModelAnalystView._generate_hash")
|
|
def test_post_with_cached_result(
|
|
self, mock_generate_hash, mock_check_permissions, mock_get_cached
|
|
):
|
|
"""Test that the view returns cached results when available."""
|
|
# Setup mocks
|
|
mock_check_permissions.return_value = True
|
|
mock_generate_hash.return_value = "test_hash"
|
|
mock_get_cached.return_value = {
|
|
"status": "success",
|
|
"insights": [{"type": "cached_insight"}],
|
|
"cached": True,
|
|
}
|
|
|
|
# Create request
|
|
request = self.factory.post(
|
|
"/analyze/",
|
|
data=json.dumps({"prompt": "How many cars do we have?", "dealer_id": 1}),
|
|
content_type="application/json",
|
|
)
|
|
request.user = self.user
|
|
|
|
# Call view
|
|
response = self.view.post(request)
|
|
|
|
# Assertions
|
|
self.assertEqual(response.status_code, 200)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content["status"], "success")
|
|
self.assertEqual(content["cached"], True)
|
|
|
|
# Verify function calls
|
|
mock_check_permissions.assert_called_once_with(self.user, 1)
|
|
mock_generate_hash.assert_called_once_with("How many cars do we have?", 1)
|
|
mock_get_cached.assert_called_once_with("test_hash", self.user, 1)
|
|
|
|
def test_check_permissions_superuser(self):
|
|
"""Test that superusers have permission to access any dealer data."""
|
|
result = self.view._check_permissions(self.superuser, 1)
|
|
self.assertTrue(result)
|
|
|
|
result = self.view._check_permissions(self.superuser, None)
|
|
self.assertTrue(result)
|
|
|
|
def test_analyze_prompt_count(self):
|
|
"""Test that the prompt analyzer correctly identifies count queries."""
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"How many cars do we have?"
|
|
)
|
|
self.assertEqual(analysis_type, "count")
|
|
self.assertEqual(target_models, ["Car"])
|
|
self.assertEqual(query_params, {})
|
|
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"Count the number of users with active status"
|
|
)
|
|
self.assertEqual(analysis_type, "count")
|
|
self.assertEqual(target_models, ["User"])
|
|
self.assertTrue("active" in query_params or "status" in query_params)
|
|
|
|
def test_analyze_prompt_relationship(self):
|
|
"""Test that the prompt analyzer correctly identifies relationship queries."""
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"Show relationship between User and Profile"
|
|
)
|
|
self.assertEqual(analysis_type, "relationship")
|
|
self.assertTrue("User" in target_models and "Profile" in target_models)
|
|
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"What is the User to Order relationship?"
|
|
)
|
|
self.assertEqual(analysis_type, "relationship")
|
|
self.assertTrue("User" in target_models and "Order" in target_models)
|
|
|
|
def test_analyze_prompt_statistics(self):
|
|
"""Test that the prompt analyzer correctly identifies statistics queries."""
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"What is the average price of cars?"
|
|
)
|
|
self.assertEqual(analysis_type, "statistics")
|
|
self.assertEqual(target_models, ["Car"])
|
|
self.assertEqual(query_params["field"], "price")
|
|
self.assertEqual(query_params["operation"], "average")
|
|
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"Show maximum age of users"
|
|
)
|
|
self.assertEqual(analysis_type, "statistics")
|
|
self.assertEqual(target_models, ["User"])
|
|
self.assertEqual(query_params["field"], "age")
|
|
self.assertEqual(query_params["operation"], "maximum")
|
|
|
|
def test_normalize_model_name(self):
|
|
"""Test that model names are correctly normalized."""
|
|
self.assertEqual(self.view._normalize_model_name("users"), "User")
|
|
self.assertEqual(self.view._normalize_model_name("car"), "Car")
|
|
self.assertEqual(
|
|
self.view._normalize_model_name("orderItems"), "OrderItem"
|
|
) # This would actually need more logic to handle camelCase
|