import os import django os.environ.setdefault("DJANGO_SETTINGS_MODULE", "car_inventory.settings") django.setup() from django.test import TestCase, RequestFactory from django.contrib.auth.models import User from django.http import JsonResponse import json from unittest.mock import patch, MagicMock from haikalbot.views import ModelAnalystView from haikalbot.models import AnalysisCache class ModelAnalystViewTest(TestCase): def setUp(self): self.factory = RequestFactory() self.user = User.objects.create_user( username="testuser", email="test@example.com", password="testpass" ) self.superuser = User.objects.create_superuser( username="admin", email="admin@example.com", password="adminpass" ) self.view = ModelAnalystView() def test_post_without_prompt(self): """Test that the view returns an error when no prompt is provided.""" request = self.factory.post( "/analyze/", data=json.dumps({}), content_type="application/json" ) request.user = self.user response = self.view.post(request) self.assertEqual(response.status_code, 400) content = json.loads(response.content) self.assertEqual(content["status"], "error") self.assertEqual(content["message"], "Prompt is required") def test_post_with_invalid_json(self): """Test that the view handles invalid JSON properly.""" request = self.factory.post( "/analyze/", data="invalid json", content_type="application/json" ) request.user = self.user response = self.view.post(request) self.assertEqual(response.status_code, 400) content = json.loads(response.content) self.assertEqual(content["status"], "error") self.assertEqual(content["message"], "Invalid JSON in request body") @patch("ai_analyst.views.ModelAnalystView._process_prompt") @patch("ai_analyst.views.ModelAnalystView._check_permissions") @patch("ai_analyst.views.ModelAnalystView._generate_hash") @patch("ai_analyst.views.ModelAnalystView._get_cached_result") @patch("ai_analyst.views.ModelAnalystView._cache_result") def test_post_with_valid_prompt( self, mock_cache_result, mock_get_cached, mock_generate_hash, mock_check_permissions, mock_process_prompt, ): """Test that the view processes a valid prompt correctly.""" # Setup mocks mock_check_permissions.return_value = True mock_generate_hash.return_value = "test_hash" mock_get_cached.return_value = None mock_process_prompt.return_value = { "status": "success", "insights": [{"type": "test_insight"}], } # Create request request = self.factory.post( "/analyze/", data=json.dumps({"prompt": "How many cars do we have?", "dealer_id": 1}), content_type="application/json", ) request.user = self.user # Call view response = self.view.post(request) # Assertions self.assertEqual(response.status_code, 200) content = json.loads(response.content) self.assertEqual(content["status"], "success") self.assertEqual(len(content["insights"]), 1) # Verify function calls mock_check_permissions.assert_called_once_with(self.user, 1) mock_generate_hash.assert_called_once_with("How many cars do we have?", 1) mock_get_cached.assert_called_once_with("test_hash", self.user, 1) mock_process_prompt.assert_called_once_with( "How many cars do we have?", self.user, 1 ) mock_cache_result.assert_called_once() @patch("ai_analyst.views.ModelAnalystView._get_cached_result") @patch("ai_analyst.views.ModelAnalystView._check_permissions") @patch("ai_analyst.views.ModelAnalystView._generate_hash") def test_post_with_cached_result( self, mock_generate_hash, mock_check_permissions, mock_get_cached ): """Test that the view returns cached results when available.""" # Setup mocks mock_check_permissions.return_value = True mock_generate_hash.return_value = "test_hash" mock_get_cached.return_value = { "status": "success", "insights": [{"type": "cached_insight"}], "cached": True, } # Create request request = self.factory.post( "/analyze/", data=json.dumps({"prompt": "How many cars do we have?", "dealer_id": 1}), content_type="application/json", ) request.user = self.user # Call view response = self.view.post(request) # Assertions self.assertEqual(response.status_code, 200) content = json.loads(response.content) self.assertEqual(content["status"], "success") self.assertEqual(content["cached"], True) # Verify function calls mock_check_permissions.assert_called_once_with(self.user, 1) mock_generate_hash.assert_called_once_with("How many cars do we have?", 1) mock_get_cached.assert_called_once_with("test_hash", self.user, 1) def test_check_permissions_superuser(self): """Test that superusers have permission to access any dealer data.""" result = self.view._check_permissions(self.superuser, 1) self.assertTrue(result) result = self.view._check_permissions(self.superuser, None) self.assertTrue(result) def test_analyze_prompt_count(self): """Test that the prompt analyzer correctly identifies count queries.""" analysis_type, target_models, query_params = self.view._analyze_prompt( "How many cars do we have?" ) self.assertEqual(analysis_type, "count") self.assertEqual(target_models, ["Car"]) self.assertEqual(query_params, {}) analysis_type, target_models, query_params = self.view._analyze_prompt( "Count the number of users with active status" ) self.assertEqual(analysis_type, "count") self.assertEqual(target_models, ["User"]) self.assertTrue("active" in query_params or "status" in query_params) def test_analyze_prompt_relationship(self): """Test that the prompt analyzer correctly identifies relationship queries.""" analysis_type, target_models, query_params = self.view._analyze_prompt( "Show relationship between User and Profile" ) self.assertEqual(analysis_type, "relationship") self.assertTrue("User" in target_models and "Profile" in target_models) analysis_type, target_models, query_params = self.view._analyze_prompt( "What is the User to Order relationship?" ) self.assertEqual(analysis_type, "relationship") self.assertTrue("User" in target_models and "Order" in target_models) def test_analyze_prompt_statistics(self): """Test that the prompt analyzer correctly identifies statistics queries.""" analysis_type, target_models, query_params = self.view._analyze_prompt( "What is the average price of cars?" ) self.assertEqual(analysis_type, "statistics") self.assertEqual(target_models, ["Car"]) self.assertEqual(query_params["field"], "price") self.assertEqual(query_params["operation"], "average") analysis_type, target_models, query_params = self.view._analyze_prompt( "Show maximum age of users" ) self.assertEqual(analysis_type, "statistics") self.assertEqual(target_models, ["User"]) self.assertEqual(query_params["field"], "age") self.assertEqual(query_params["operation"], "maximum") def test_normalize_model_name(self): """Test that model names are correctly normalized.""" self.assertEqual(self.view._normalize_model_name("users"), "User") self.assertEqual(self.view._normalize_model_name("car"), "Car") self.assertEqual( self.view._normalize_model_name("orderItems"), "OrderItem" ) # This would actually need more logic to handle camelCase