haikal/test_ollama.py
2025-06-22 13:25:54 +03:00

207 lines
8.1 KiB
Python

import os
import django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "car_inventory.settings")
django.setup()
from django.test import TestCase, RequestFactory
from django.contrib.auth.models import User
from django.http import JsonResponse
import json
from unittest.mock import patch, MagicMock
from haikalbot.views import ModelAnalystView
from haikalbot.models import AnalysisCache
class ModelAnalystViewTest(TestCase):
def setUp(self):
self.factory = RequestFactory()
self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="testpass"
)
self.superuser = User.objects.create_superuser(
username="admin", email="admin@example.com", password="adminpass"
)
self.view = ModelAnalystView()
def test_post_without_prompt(self):
"""Test that the view returns an error when no prompt is provided."""
request = self.factory.post(
"/analyze/", data=json.dumps({}), content_type="application/json"
)
request.user = self.user
response = self.view.post(request)
self.assertEqual(response.status_code, 400)
content = json.loads(response.content)
self.assertEqual(content["status"], "error")
self.assertEqual(content["message"], "Prompt is required")
def test_post_with_invalid_json(self):
"""Test that the view handles invalid JSON properly."""
request = self.factory.post(
"/analyze/", data="invalid json", content_type="application/json"
)
request.user = self.user
response = self.view.post(request)
self.assertEqual(response.status_code, 400)
content = json.loads(response.content)
self.assertEqual(content["status"], "error")
self.assertEqual(content["message"], "Invalid JSON in request body")
@patch("ai_analyst.views.ModelAnalystView._process_prompt")
@patch("ai_analyst.views.ModelAnalystView._check_permissions")
@patch("ai_analyst.views.ModelAnalystView._generate_hash")
@patch("ai_analyst.views.ModelAnalystView._get_cached_result")
@patch("ai_analyst.views.ModelAnalystView._cache_result")
def test_post_with_valid_prompt(
self,
mock_cache_result,
mock_get_cached,
mock_generate_hash,
mock_check_permissions,
mock_process_prompt,
):
"""Test that the view processes a valid prompt correctly."""
# Setup mocks
mock_check_permissions.return_value = True
mock_generate_hash.return_value = "test_hash"
mock_get_cached.return_value = None
mock_process_prompt.return_value = {
"status": "success",
"insights": [{"type": "test_insight"}],
}
# Create request
request = self.factory.post(
"/analyze/",
data=json.dumps({"prompt": "How many cars do we have?", "dealer_id": 1}),
content_type="application/json",
)
request.user = self.user
# Call view
response = self.view.post(request)
# Assertions
self.assertEqual(response.status_code, 200)
content = json.loads(response.content)
self.assertEqual(content["status"], "success")
self.assertEqual(len(content["insights"]), 1)
# Verify function calls
mock_check_permissions.assert_called_once_with(self.user, 1)
mock_generate_hash.assert_called_once_with("How many cars do we have?", 1)
mock_get_cached.assert_called_once_with("test_hash", self.user, 1)
mock_process_prompt.assert_called_once_with(
"How many cars do we have?", self.user, 1
)
mock_cache_result.assert_called_once()
@patch("ai_analyst.views.ModelAnalystView._get_cached_result")
@patch("ai_analyst.views.ModelAnalystView._check_permissions")
@patch("ai_analyst.views.ModelAnalystView._generate_hash")
def test_post_with_cached_result(
self, mock_generate_hash, mock_check_permissions, mock_get_cached
):
"""Test that the view returns cached results when available."""
# Setup mocks
mock_check_permissions.return_value = True
mock_generate_hash.return_value = "test_hash"
mock_get_cached.return_value = {
"status": "success",
"insights": [{"type": "cached_insight"}],
"cached": True,
}
# Create request
request = self.factory.post(
"/analyze/",
data=json.dumps({"prompt": "How many cars do we have?", "dealer_id": 1}),
content_type="application/json",
)
request.user = self.user
# Call view
response = self.view.post(request)
# Assertions
self.assertEqual(response.status_code, 200)
content = json.loads(response.content)
self.assertEqual(content["status"], "success")
self.assertEqual(content["cached"], True)
# Verify function calls
mock_check_permissions.assert_called_once_with(self.user, 1)
mock_generate_hash.assert_called_once_with("How many cars do we have?", 1)
mock_get_cached.assert_called_once_with("test_hash", self.user, 1)
def test_check_permissions_superuser(self):
"""Test that superusers have permission to access any dealer data."""
result = self.view._check_permissions(self.superuser, 1)
self.assertTrue(result)
result = self.view._check_permissions(self.superuser, None)
self.assertTrue(result)
def test_analyze_prompt_count(self):
"""Test that the prompt analyzer correctly identifies count queries."""
analysis_type, target_models, query_params = self.view._analyze_prompt(
"How many cars do we have?"
)
self.assertEqual(analysis_type, "count")
self.assertEqual(target_models, ["Car"])
self.assertEqual(query_params, {})
analysis_type, target_models, query_params = self.view._analyze_prompt(
"Count the number of users with active status"
)
self.assertEqual(analysis_type, "count")
self.assertEqual(target_models, ["User"])
self.assertTrue("active" in query_params or "status" in query_params)
def test_analyze_prompt_relationship(self):
"""Test that the prompt analyzer correctly identifies relationship queries."""
analysis_type, target_models, query_params = self.view._analyze_prompt(
"Show relationship between User and Profile"
)
self.assertEqual(analysis_type, "relationship")
self.assertTrue("User" in target_models and "Profile" in target_models)
analysis_type, target_models, query_params = self.view._analyze_prompt(
"What is the User to Order relationship?"
)
self.assertEqual(analysis_type, "relationship")
self.assertTrue("User" in target_models and "Order" in target_models)
def test_analyze_prompt_statistics(self):
"""Test that the prompt analyzer correctly identifies statistics queries."""
analysis_type, target_models, query_params = self.view._analyze_prompt(
"What is the average price of cars?"
)
self.assertEqual(analysis_type, "statistics")
self.assertEqual(target_models, ["Car"])
self.assertEqual(query_params["field"], "price")
self.assertEqual(query_params["operation"], "average")
analysis_type, target_models, query_params = self.view._analyze_prompt(
"Show maximum age of users"
)
self.assertEqual(analysis_type, "statistics")
self.assertEqual(target_models, ["User"])
self.assertEqual(query_params["field"], "age")
self.assertEqual(query_params["operation"], "maximum")
def test_normalize_model_name(self):
"""Test that model names are correctly normalized."""
self.assertEqual(self.view._normalize_model_name("users"), "User")
self.assertEqual(self.view._normalize_model_name("car"), "Car")
self.assertEqual(
self.view._normalize_model_name("orderItems"), "OrderItem"
) # This would actually need more logic to handle camelCase