import os import django os.environ.setdefault("DJANGO_SETTINGS_MODULE", "car_inventory.settings") django.setup() from django.test import TestCase, RequestFactory from django.contrib.auth.models import User from django.http import JsonResponse import json from unittest.mock import patch, MagicMock from haikalbot.views import ModelAnalystView from haikalbot.models import AnalysisCache class ModelAnalystViewTest(TestCase): def setUp(self): self.factory = RequestFactory() self.user = User.objects.create_user( username='testuser', email='test@example.com', password='testpass' ) self.superuser = User.objects.create_superuser( username='admin', email='admin@example.com', password='adminpass' ) self.view = ModelAnalystView() def test_post_without_prompt(self): """Test that the view returns an error when no prompt is provided.""" request = self.factory.post( '/analyze/', data=json.dumps({}), content_type='application/json' ) request.user = self.user response = self.view.post(request) self.assertEqual(response.status_code, 400) content = json.loads(response.content) self.assertEqual(content['status'], 'error') self.assertEqual(content['message'], 'Prompt is required') def test_post_with_invalid_json(self): """Test that the view handles invalid JSON properly.""" request = self.factory.post( '/analyze/', data='invalid json', content_type='application/json' ) request.user = self.user response = self.view.post(request) self.assertEqual(response.status_code, 400) content = json.loads(response.content) self.assertEqual(content['status'], 'error') self.assertEqual(content['message'], 'Invalid JSON in request body') @patch('ai_analyst.views.ModelAnalystView._process_prompt') @patch('ai_analyst.views.ModelAnalystView._check_permissions') @patch('ai_analyst.views.ModelAnalystView._generate_hash') @patch('ai_analyst.views.ModelAnalystView._get_cached_result') @patch('ai_analyst.views.ModelAnalystView._cache_result') def test_post_with_valid_prompt(self, mock_cache_result, mock_get_cached, mock_generate_hash, mock_check_permissions, mock_process_prompt): """Test that the view processes a valid prompt correctly.""" # Setup mocks mock_check_permissions.return_value = True mock_generate_hash.return_value = 'test_hash' mock_get_cached.return_value = None mock_process_prompt.return_value = { 'status': 'success', 'insights': [{'type': 'test_insight'}] } # Create request request = self.factory.post( '/analyze/', data=json.dumps({'prompt': 'How many cars do we have?', 'dealer_id': 1}), content_type='application/json' ) request.user = self.user # Call view response = self.view.post(request) # Assertions self.assertEqual(response.status_code, 200) content = json.loads(response.content) self.assertEqual(content['status'], 'success') self.assertEqual(len(content['insights']), 1) # Verify function calls mock_check_permissions.assert_called_once_with(self.user, 1) mock_generate_hash.assert_called_once_with('How many cars do we have?', 1) mock_get_cached.assert_called_once_with('test_hash', self.user, 1) mock_process_prompt.assert_called_once_with('How many cars do we have?', self.user, 1) mock_cache_result.assert_called_once() @patch('ai_analyst.views.ModelAnalystView._get_cached_result') @patch('ai_analyst.views.ModelAnalystView._check_permissions') @patch('ai_analyst.views.ModelAnalystView._generate_hash') def test_post_with_cached_result(self, mock_generate_hash, mock_check_permissions, mock_get_cached): """Test that the view returns cached results when available.""" # Setup mocks mock_check_permissions.return_value = True mock_generate_hash.return_value = 'test_hash' mock_get_cached.return_value = { 'status': 'success', 'insights': [{'type': 'cached_insight'}], 'cached': True } # Create request request = self.factory.post( '/analyze/', data=json.dumps({'prompt': 'How many cars do we have?', 'dealer_id': 1}), content_type='application/json' ) request.user = self.user # Call view response = self.view.post(request) # Assertions self.assertEqual(response.status_code, 200) content = json.loads(response.content) self.assertEqual(content['status'], 'success') self.assertEqual(content['cached'], True) # Verify function calls mock_check_permissions.assert_called_once_with(self.user, 1) mock_generate_hash.assert_called_once_with('How many cars do we have?', 1) mock_get_cached.assert_called_once_with('test_hash', self.user, 1) def test_check_permissions_superuser(self): """Test that superusers have permission to access any dealer data.""" result = self.view._check_permissions(self.superuser, 1) self.assertTrue(result) result = self.view._check_permissions(self.superuser, None) self.assertTrue(result) def test_analyze_prompt_count(self): """Test that the prompt analyzer correctly identifies count queries.""" analysis_type, target_models, query_params = self.view._analyze_prompt("How many cars do we have?") self.assertEqual(analysis_type, 'count') self.assertEqual(target_models, ['Car']) self.assertEqual(query_params, {}) analysis_type, target_models, query_params = self.view._analyze_prompt( "Count the number of users with active status") self.assertEqual(analysis_type, 'count') self.assertEqual(target_models, ['User']) self.assertTrue('active' in query_params or 'status' in query_params) def test_analyze_prompt_relationship(self): """Test that the prompt analyzer correctly identifies relationship queries.""" analysis_type, target_models, query_params = self.view._analyze_prompt( "Show relationship between User and Profile") self.assertEqual(analysis_type, 'relationship') self.assertTrue('User' in target_models and 'Profile' in target_models) analysis_type, target_models, query_params = self.view._analyze_prompt( "What is the User to Order relationship?") self.assertEqual(analysis_type, 'relationship') self.assertTrue('User' in target_models and 'Order' in target_models) def test_analyze_prompt_statistics(self): """Test that the prompt analyzer correctly identifies statistics queries.""" analysis_type, target_models, query_params = self.view._analyze_prompt("What is the average price of cars?") self.assertEqual(analysis_type, 'statistics') self.assertEqual(target_models, ['Car']) self.assertEqual(query_params['field'], 'price') self.assertEqual(query_params['operation'], 'average') analysis_type, target_models, query_params = self.view._analyze_prompt("Show maximum age of users") self.assertEqual(analysis_type, 'statistics') self.assertEqual(target_models, ['User']) self.assertEqual(query_params['field'], 'age') self.assertEqual(query_params['operation'], 'maximum') def test_normalize_model_name(self): """Test that model names are correctly normalized.""" self.assertEqual(self.view._normalize_model_name('users'), 'User') self.assertEqual(self.view._normalize_model_name('car'), 'Car') self.assertEqual(self.view._normalize_model_name('orderItems'), 'OrderItem') # This would actually need more logic to handle camelCase