195 lines
8.0 KiB
Python
195 lines
8.0 KiB
Python
import os
|
|
import django
|
|
|
|
|
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "car_inventory.settings")
|
|
django.setup()
|
|
|
|
from django.test import TestCase, RequestFactory
|
|
from django.contrib.auth.models import User
|
|
from django.http import JsonResponse
|
|
import json
|
|
|
|
from unittest.mock import patch, MagicMock
|
|
from haikalbot.views import ModelAnalystView
|
|
from haikalbot.models import AnalysisCache
|
|
|
|
|
|
|
|
|
|
|
|
class ModelAnalystViewTest(TestCase):
|
|
def setUp(self):
|
|
self.factory = RequestFactory()
|
|
self.user = User.objects.create_user(
|
|
username='testuser', email='test@example.com', password='testpass'
|
|
)
|
|
self.superuser = User.objects.create_superuser(
|
|
username='admin', email='admin@example.com', password='adminpass'
|
|
)
|
|
self.view = ModelAnalystView()
|
|
|
|
def test_post_without_prompt(self):
|
|
"""Test that the view returns an error when no prompt is provided."""
|
|
request = self.factory.post(
|
|
'/analyze/',
|
|
data=json.dumps({}),
|
|
content_type='application/json'
|
|
)
|
|
request.user = self.user
|
|
|
|
response = self.view.post(request)
|
|
|
|
self.assertEqual(response.status_code, 400)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content['status'], 'error')
|
|
self.assertEqual(content['message'], 'Prompt is required')
|
|
|
|
def test_post_with_invalid_json(self):
|
|
"""Test that the view handles invalid JSON properly."""
|
|
request = self.factory.post(
|
|
'/analyze/',
|
|
data='invalid json',
|
|
content_type='application/json'
|
|
)
|
|
request.user = self.user
|
|
|
|
response = self.view.post(request)
|
|
|
|
self.assertEqual(response.status_code, 400)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content['status'], 'error')
|
|
self.assertEqual(content['message'], 'Invalid JSON in request body')
|
|
|
|
@patch('ai_analyst.views.ModelAnalystView._process_prompt')
|
|
@patch('ai_analyst.views.ModelAnalystView._check_permissions')
|
|
@patch('ai_analyst.views.ModelAnalystView._generate_hash')
|
|
@patch('ai_analyst.views.ModelAnalystView._get_cached_result')
|
|
@patch('ai_analyst.views.ModelAnalystView._cache_result')
|
|
def test_post_with_valid_prompt(self, mock_cache_result, mock_get_cached,
|
|
mock_generate_hash, mock_check_permissions,
|
|
mock_process_prompt):
|
|
"""Test that the view processes a valid prompt correctly."""
|
|
# Setup mocks
|
|
mock_check_permissions.return_value = True
|
|
mock_generate_hash.return_value = 'test_hash'
|
|
mock_get_cached.return_value = None
|
|
mock_process_prompt.return_value = {
|
|
'status': 'success',
|
|
'insights': [{'type': 'test_insight'}]
|
|
}
|
|
|
|
# Create request
|
|
request = self.factory.post(
|
|
'/analyze/',
|
|
data=json.dumps({'prompt': 'How many cars do we have?', 'dealer_id': 1}),
|
|
content_type='application/json'
|
|
)
|
|
request.user = self.user
|
|
|
|
# Call view
|
|
response = self.view.post(request)
|
|
|
|
# Assertions
|
|
self.assertEqual(response.status_code, 200)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content['status'], 'success')
|
|
self.assertEqual(len(content['insights']), 1)
|
|
|
|
# Verify function calls
|
|
mock_check_permissions.assert_called_once_with(self.user, 1)
|
|
mock_generate_hash.assert_called_once_with('How many cars do we have?', 1)
|
|
mock_get_cached.assert_called_once_with('test_hash', self.user, 1)
|
|
mock_process_prompt.assert_called_once_with('How many cars do we have?', self.user, 1)
|
|
mock_cache_result.assert_called_once()
|
|
|
|
@patch('ai_analyst.views.ModelAnalystView._get_cached_result')
|
|
@patch('ai_analyst.views.ModelAnalystView._check_permissions')
|
|
@patch('ai_analyst.views.ModelAnalystView._generate_hash')
|
|
def test_post_with_cached_result(self, mock_generate_hash, mock_check_permissions, mock_get_cached):
|
|
"""Test that the view returns cached results when available."""
|
|
# Setup mocks
|
|
mock_check_permissions.return_value = True
|
|
mock_generate_hash.return_value = 'test_hash'
|
|
mock_get_cached.return_value = {
|
|
'status': 'success',
|
|
'insights': [{'type': 'cached_insight'}],
|
|
'cached': True
|
|
}
|
|
|
|
# Create request
|
|
request = self.factory.post(
|
|
'/analyze/',
|
|
data=json.dumps({'prompt': 'How many cars do we have?', 'dealer_id': 1}),
|
|
content_type='application/json'
|
|
)
|
|
request.user = self.user
|
|
|
|
# Call view
|
|
response = self.view.post(request)
|
|
|
|
# Assertions
|
|
self.assertEqual(response.status_code, 200)
|
|
content = json.loads(response.content)
|
|
self.assertEqual(content['status'], 'success')
|
|
self.assertEqual(content['cached'], True)
|
|
|
|
# Verify function calls
|
|
mock_check_permissions.assert_called_once_with(self.user, 1)
|
|
mock_generate_hash.assert_called_once_with('How many cars do we have?', 1)
|
|
mock_get_cached.assert_called_once_with('test_hash', self.user, 1)
|
|
|
|
def test_check_permissions_superuser(self):
|
|
"""Test that superusers have permission to access any dealer data."""
|
|
result = self.view._check_permissions(self.superuser, 1)
|
|
self.assertTrue(result)
|
|
|
|
result = self.view._check_permissions(self.superuser, None)
|
|
self.assertTrue(result)
|
|
|
|
def test_analyze_prompt_count(self):
|
|
"""Test that the prompt analyzer correctly identifies count queries."""
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt("How many cars do we have?")
|
|
self.assertEqual(analysis_type, 'count')
|
|
self.assertEqual(target_models, ['Car'])
|
|
self.assertEqual(query_params, {})
|
|
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"Count the number of users with active status")
|
|
self.assertEqual(analysis_type, 'count')
|
|
self.assertEqual(target_models, ['User'])
|
|
self.assertTrue('active' in query_params or 'status' in query_params)
|
|
|
|
def test_analyze_prompt_relationship(self):
|
|
"""Test that the prompt analyzer correctly identifies relationship queries."""
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"Show relationship between User and Profile")
|
|
self.assertEqual(analysis_type, 'relationship')
|
|
self.assertTrue('User' in target_models and 'Profile' in target_models)
|
|
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt(
|
|
"What is the User to Order relationship?")
|
|
self.assertEqual(analysis_type, 'relationship')
|
|
self.assertTrue('User' in target_models and 'Order' in target_models)
|
|
|
|
def test_analyze_prompt_statistics(self):
|
|
"""Test that the prompt analyzer correctly identifies statistics queries."""
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt("What is the average price of cars?")
|
|
self.assertEqual(analysis_type, 'statistics')
|
|
self.assertEqual(target_models, ['Car'])
|
|
self.assertEqual(query_params['field'], 'price')
|
|
self.assertEqual(query_params['operation'], 'average')
|
|
|
|
analysis_type, target_models, query_params = self.view._analyze_prompt("Show maximum age of users")
|
|
self.assertEqual(analysis_type, 'statistics')
|
|
self.assertEqual(target_models, ['User'])
|
|
self.assertEqual(query_params['field'], 'age')
|
|
self.assertEqual(query_params['operation'], 'maximum')
|
|
|
|
def test_normalize_model_name(self):
|
|
"""Test that model names are correctly normalized."""
|
|
self.assertEqual(self.view._normalize_model_name('users'), 'User')
|
|
self.assertEqual(self.view._normalize_model_name('car'), 'Car')
|
|
self.assertEqual(self.view._normalize_model_name('orderItems'),
|
|
'OrderItem') # This would actually need more logic to handle camelCase
|