haikal/test_ollama.py
Marwan Alwali 250e0aa7bb update
2025-05-26 15:17:10 +03:00

195 lines
8.0 KiB
Python

import os
import django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "car_inventory.settings")
django.setup()
from django.test import TestCase, RequestFactory
from django.contrib.auth.models import User
from django.http import JsonResponse
import json
from unittest.mock import patch, MagicMock
from haikalbot.views import ModelAnalystView
from haikalbot.models import AnalysisCache
class ModelAnalystViewTest(TestCase):
def setUp(self):
self.factory = RequestFactory()
self.user = User.objects.create_user(
username='testuser', email='test@example.com', password='testpass'
)
self.superuser = User.objects.create_superuser(
username='admin', email='admin@example.com', password='adminpass'
)
self.view = ModelAnalystView()
def test_post_without_prompt(self):
"""Test that the view returns an error when no prompt is provided."""
request = self.factory.post(
'/analyze/',
data=json.dumps({}),
content_type='application/json'
)
request.user = self.user
response = self.view.post(request)
self.assertEqual(response.status_code, 400)
content = json.loads(response.content)
self.assertEqual(content['status'], 'error')
self.assertEqual(content['message'], 'Prompt is required')
def test_post_with_invalid_json(self):
"""Test that the view handles invalid JSON properly."""
request = self.factory.post(
'/analyze/',
data='invalid json',
content_type='application/json'
)
request.user = self.user
response = self.view.post(request)
self.assertEqual(response.status_code, 400)
content = json.loads(response.content)
self.assertEqual(content['status'], 'error')
self.assertEqual(content['message'], 'Invalid JSON in request body')
@patch('ai_analyst.views.ModelAnalystView._process_prompt')
@patch('ai_analyst.views.ModelAnalystView._check_permissions')
@patch('ai_analyst.views.ModelAnalystView._generate_hash')
@patch('ai_analyst.views.ModelAnalystView._get_cached_result')
@patch('ai_analyst.views.ModelAnalystView._cache_result')
def test_post_with_valid_prompt(self, mock_cache_result, mock_get_cached,
mock_generate_hash, mock_check_permissions,
mock_process_prompt):
"""Test that the view processes a valid prompt correctly."""
# Setup mocks
mock_check_permissions.return_value = True
mock_generate_hash.return_value = 'test_hash'
mock_get_cached.return_value = None
mock_process_prompt.return_value = {
'status': 'success',
'insights': [{'type': 'test_insight'}]
}
# Create request
request = self.factory.post(
'/analyze/',
data=json.dumps({'prompt': 'How many cars do we have?', 'dealer_id': 1}),
content_type='application/json'
)
request.user = self.user
# Call view
response = self.view.post(request)
# Assertions
self.assertEqual(response.status_code, 200)
content = json.loads(response.content)
self.assertEqual(content['status'], 'success')
self.assertEqual(len(content['insights']), 1)
# Verify function calls
mock_check_permissions.assert_called_once_with(self.user, 1)
mock_generate_hash.assert_called_once_with('How many cars do we have?', 1)
mock_get_cached.assert_called_once_with('test_hash', self.user, 1)
mock_process_prompt.assert_called_once_with('How many cars do we have?', self.user, 1)
mock_cache_result.assert_called_once()
@patch('ai_analyst.views.ModelAnalystView._get_cached_result')
@patch('ai_analyst.views.ModelAnalystView._check_permissions')
@patch('ai_analyst.views.ModelAnalystView._generate_hash')
def test_post_with_cached_result(self, mock_generate_hash, mock_check_permissions, mock_get_cached):
"""Test that the view returns cached results when available."""
# Setup mocks
mock_check_permissions.return_value = True
mock_generate_hash.return_value = 'test_hash'
mock_get_cached.return_value = {
'status': 'success',
'insights': [{'type': 'cached_insight'}],
'cached': True
}
# Create request
request = self.factory.post(
'/analyze/',
data=json.dumps({'prompt': 'How many cars do we have?', 'dealer_id': 1}),
content_type='application/json'
)
request.user = self.user
# Call view
response = self.view.post(request)
# Assertions
self.assertEqual(response.status_code, 200)
content = json.loads(response.content)
self.assertEqual(content['status'], 'success')
self.assertEqual(content['cached'], True)
# Verify function calls
mock_check_permissions.assert_called_once_with(self.user, 1)
mock_generate_hash.assert_called_once_with('How many cars do we have?', 1)
mock_get_cached.assert_called_once_with('test_hash', self.user, 1)
def test_check_permissions_superuser(self):
"""Test that superusers have permission to access any dealer data."""
result = self.view._check_permissions(self.superuser, 1)
self.assertTrue(result)
result = self.view._check_permissions(self.superuser, None)
self.assertTrue(result)
def test_analyze_prompt_count(self):
"""Test that the prompt analyzer correctly identifies count queries."""
analysis_type, target_models, query_params = self.view._analyze_prompt("How many cars do we have?")
self.assertEqual(analysis_type, 'count')
self.assertEqual(target_models, ['Car'])
self.assertEqual(query_params, {})
analysis_type, target_models, query_params = self.view._analyze_prompt(
"Count the number of users with active status")
self.assertEqual(analysis_type, 'count')
self.assertEqual(target_models, ['User'])
self.assertTrue('active' in query_params or 'status' in query_params)
def test_analyze_prompt_relationship(self):
"""Test that the prompt analyzer correctly identifies relationship queries."""
analysis_type, target_models, query_params = self.view._analyze_prompt(
"Show relationship between User and Profile")
self.assertEqual(analysis_type, 'relationship')
self.assertTrue('User' in target_models and 'Profile' in target_models)
analysis_type, target_models, query_params = self.view._analyze_prompt(
"What is the User to Order relationship?")
self.assertEqual(analysis_type, 'relationship')
self.assertTrue('User' in target_models and 'Order' in target_models)
def test_analyze_prompt_statistics(self):
"""Test that the prompt analyzer correctly identifies statistics queries."""
analysis_type, target_models, query_params = self.view._analyze_prompt("What is the average price of cars?")
self.assertEqual(analysis_type, 'statistics')
self.assertEqual(target_models, ['Car'])
self.assertEqual(query_params['field'], 'price')
self.assertEqual(query_params['operation'], 'average')
analysis_type, target_models, query_params = self.view._analyze_prompt("Show maximum age of users")
self.assertEqual(analysis_type, 'statistics')
self.assertEqual(target_models, ['User'])
self.assertEqual(query_params['field'], 'age')
self.assertEqual(query_params['operation'], 'maximum')
def test_normalize_model_name(self):
"""Test that model names are correctly normalized."""
self.assertEqual(self.view._normalize_model_name('users'), 'User')
self.assertEqual(self.view._normalize_model_name('car'), 'Car')
self.assertEqual(self.view._normalize_model_name('orderItems'),
'OrderItem') # This would actually need more logic to handle camelCase