""" Management command to classify departments using AI. Classifies each department's category (nursing, medical, non_medical, support_services) based on the department name, then sets staff.department_type to match. """ import json import logging from django.core.management.base import BaseCommand from django.db import transaction from apps.core.ai_service import AIService, AIServiceError from apps.organizations.models import Department, Hospital logger = logging.getLogger(__name__) VALID_CATEGORIES = ["nursing", "medical", "non_medical", "support_services"] SYSTEM_PROMPT = """You are a healthcare organization classifier. Given a list of department names, classify each one into exactly one of these categories: - "nursing": Nursing departments (e.g., nursing services, patient care nursing, ICU nursing, inpatient nursing) - "medical": Clinical/medical departments with physicians (e.g., cardiology, surgery, emergency, radiology, pediatrics, internal medicine, oncology, dermatology) - "non_medical": Administrative and non-clinical departments without direct patient care (e.g., HR, finance, IT, medical records, billing, reception, quality management, training) - "support_services": Support and ancillary services (e.g., housekeeping, maintenance, food services, security, logistics, sterilization, pharmacy supply, patient transport) IMPORTANT: Return ONLY a valid JSON object where keys are the department NUMBERS (as strings like "1", "2", etc.) and values are one of: "nursing", "medical", "non_medical", "support_services". Every department must be classified. Do not include any explanation, markdown formatting, or code blocks. Return ONLY the JSON object. Example: {"1": "medical", "2": "nursing", "3": "non_medical", "4": "support_services"}""" class Command(BaseCommand): help = "Classify departments into categories using AI and set staff.department_type accordingly" def add_arguments(self, parser): parser.add_argument( "--hospital-code", type=str, help="Target hospital code (default: all hospitals)", ) parser.add_argument( "--dry-run", action="store_true", help="Preview classifications without making changes", ) parser.add_argument( "--overwrite", action="store_true", help="Reclassify departments that already have a category", ) parser.add_argument( "--batch-size", type=int, default=50, help="Number of departments per AI call (default: 50)", ) def handle(self, *args, **options): hospital_code = options["hospital_code"] dry_run = options["dry_run"] overwrite = options["overwrite"] batch_size = options["batch_size"] self.stdout.write(f"\n{'=' * 60}") self.stdout.write("AI Department Classification Command") self.stdout.write(f"{'=' * 60}\n") hospitals = self._get_hospitals(hospital_code) if hospitals is None: return departments = self._get_departments(hospitals, overwrite) if not departments: self.stdout.write(self.style.WARNING("No departments to classify.")) return self.stdout.write(f" Hospitals: {hospitals.count()}") self.stdout.write(f" Departments to classify: {departments.count()}") self.stdout.write(f" Batch size: {batch_size}") self.stdout.write(f" Overwrite existing: {overwrite}") self.stdout.write(f" Dry run: {dry_run}\n") dept_list = list(departments) stats = {"classified": 0, "updated_staff": 0, "skipped": 0, "errors": 0} for i in range(0, len(dept_list), batch_size): batch = dept_list[i : i + batch_size] batch_num = (i // batch_size) + 1 total_batches = (len(dept_list) + batch_size - 1) // batch_size self.stdout.write( self.style.NOTICE(f"Batch {batch_num}/{total_batches} — {len(batch)} departments") ) classifications = self._classify_batch(batch) if not classifications: stats["errors"] += len(batch) continue for idx, dept in enumerate(batch, start=1): category = classifications.get(str(idx)) if not category or category not in VALID_CATEGORIES: self.stdout.write( self.style.WARNING( f" ⊘ Skipped '{dept.name}': " f"{'invalid category' if category else 'not in AI response'}" f" ({category})" ) ) stats["skipped"] += 1 continue self._apply_classification(dept, category, dry_run, stats) self._print_summary(stats, dry_run) def _get_hospitals(self, hospital_code): if hospital_code: hospitals = Hospital.objects.filter(code=hospital_code) if not hospitals.exists(): self.stdout.write( self.style.ERROR(f"Hospital with code '{hospital_code}' not found") ) return None else: hospitals = Hospital.objects.filter(status="active") if not hospitals.exists(): self.stdout.write(self.style.ERROR("No active hospitals found.")) return None return hospitals def _get_departments(self, hospitals, overwrite): qs = Department.objects.filter(hospital__in=hospitals).select_related("hospital") if not overwrite: qs = qs.filter(category="") return qs.order_by("hospital__name", "name") def _classify_batch(self, batch): dept_lines = [] for idx, dept in enumerate(batch, start=1): name = dept.name_en or dept.name if dept.name_ar: name += f" / {dept.name_ar}" dept_lines.append(f"{idx}. {name}") prompt = "Classify these departments:\n" + "\n".join(dept_lines) try: response = AIService.chat_completion( prompt=prompt, system_prompt=SYSTEM_PROMPT, temperature=0.1, max_tokens=2000, response_format="json_object", ) cleaned = response.strip() if cleaned.startswith("```"): cleaned = cleaned.split("\n", 1)[-1] if cleaned.endswith("```"): cleaned = cleaned.rsplit("```", 1)[0] cleaned = cleaned.strip() return json.loads(cleaned) except (AIServiceError, json.JSONDecodeError) as e: self.stdout.write(self.style.ERROR(f" AI classification failed: {e}")) return None def _apply_classification(self, dept, category, dry_run, stats): if dry_run: staff_count = dept.staff.filter(status="active").count() self.stdout.write( f" Would classify '{dept.name}' → {category}" f" ({staff_count} staff → department_type={category})" ) stats["classified"] += 1 stats["updated_staff"] += staff_count return try: with transaction.atomic(): dept.category = category dept.save(update_fields=["category", "updated_at"]) updated = dept.staff.filter(status="active").update(department_type=category) self.stdout.write( self.style.SUCCESS( f" ✓ '{dept.name}' → {category}" f" | {updated} staff → department_type={category}" ) ) stats["classified"] += 1 stats["updated_staff"] += updated except Exception as e: self.stdout.write( self.style.ERROR(f" ✗ Failed to update '{dept.name}': {e}") ) stats["errors"] += 1 def _print_summary(self, stats, dry_run): self.stdout.write(f"\n{'=' * 60}") self.stdout.write("Summary:") self.stdout.write(f" Departments classified: {stats['classified']}") self.stdout.write(f" Staff updated: {stats['updated_staff']}") self.stdout.write(f" Departments skipped: {stats['skipped']}") self.stdout.write(f" Errors: {stats['errors']}") self.stdout.write(f"{'=' * 60}\n") if dry_run: self.stdout.write(self.style.WARNING("DRY RUN: No changes were made\n")) else: self.stdout.write(self.style.SUCCESS("Classification completed!\n"))