223 lines
8.6 KiB
Python
223 lines
8.6 KiB
Python
"""
|
|
Management command to classify departments using AI.
|
|
|
|
Classifies each department's category (nursing, medical, non_medical, support_services)
|
|
based on the department name, then sets staff.department_type to match.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
|
|
from django.core.management.base import BaseCommand
|
|
from django.db import transaction
|
|
|
|
from apps.core.ai_service import AIService, AIServiceError
|
|
from apps.organizations.models import Department, Hospital
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VALID_CATEGORIES = ["nursing", "medical", "non_medical", "support_services"]
|
|
|
|
SYSTEM_PROMPT = """You are a healthcare organization classifier. Given a list of department names, classify each one into exactly one of these categories:
|
|
|
|
- "nursing": Nursing departments (e.g., nursing services, patient care nursing, ICU nursing, inpatient nursing)
|
|
- "medical": Clinical/medical departments with physicians (e.g., cardiology, surgery, emergency, radiology, pediatrics, internal medicine, oncology, dermatology)
|
|
- "non_medical": Administrative and non-clinical departments without direct patient care (e.g., HR, finance, IT, medical records, billing, reception, quality management, training)
|
|
- "support_services": Support and ancillary services (e.g., housekeeping, maintenance, food services, security, logistics, sterilization, pharmacy supply, patient transport)
|
|
|
|
IMPORTANT: Return ONLY a valid JSON object where keys are the department NUMBERS (as strings like "1", "2", etc.) and values are one of: "nursing", "medical", "non_medical", "support_services".
|
|
Every department must be classified. Do not include any explanation, markdown formatting, or code blocks. Return ONLY the JSON object.
|
|
Example: {"1": "medical", "2": "nursing", "3": "non_medical", "4": "support_services"}"""
|
|
|
|
|
|
class Command(BaseCommand):
|
|
help = "Classify departments into categories using AI and set staff.department_type accordingly"
|
|
|
|
def add_arguments(self, parser):
|
|
parser.add_argument(
|
|
"--hospital-code",
|
|
type=str,
|
|
help="Target hospital code (default: all hospitals)",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Preview classifications without making changes",
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite",
|
|
action="store_true",
|
|
help="Reclassify departments that already have a category",
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=50,
|
|
help="Number of departments per AI call (default: 50)",
|
|
)
|
|
|
|
def handle(self, *args, **options):
|
|
hospital_code = options["hospital_code"]
|
|
dry_run = options["dry_run"]
|
|
overwrite = options["overwrite"]
|
|
batch_size = options["batch_size"]
|
|
|
|
self.stdout.write(f"\n{'=' * 60}")
|
|
self.stdout.write("AI Department Classification Command")
|
|
self.stdout.write(f"{'=' * 60}\n")
|
|
|
|
hospitals = self._get_hospitals(hospital_code)
|
|
if hospitals is None:
|
|
return
|
|
|
|
departments = self._get_departments(hospitals, overwrite)
|
|
if not departments:
|
|
self.stdout.write(self.style.WARNING("No departments to classify."))
|
|
return
|
|
|
|
self.stdout.write(f" Hospitals: {hospitals.count()}")
|
|
self.stdout.write(f" Departments to classify: {departments.count()}")
|
|
self.stdout.write(f" Batch size: {batch_size}")
|
|
self.stdout.write(f" Overwrite existing: {overwrite}")
|
|
self.stdout.write(f" Dry run: {dry_run}\n")
|
|
|
|
dept_list = list(departments)
|
|
stats = {"classified": 0, "updated_staff": 0, "skipped": 0, "errors": 0}
|
|
|
|
for i in range(0, len(dept_list), batch_size):
|
|
batch = dept_list[i : i + batch_size]
|
|
batch_num = (i // batch_size) + 1
|
|
total_batches = (len(dept_list) + batch_size - 1) // batch_size
|
|
|
|
self.stdout.write(
|
|
self.style.NOTICE(f"Batch {batch_num}/{total_batches} — {len(batch)} departments")
|
|
)
|
|
|
|
classifications = self._classify_batch(batch)
|
|
|
|
if not classifications:
|
|
stats["errors"] += len(batch)
|
|
continue
|
|
|
|
for idx, dept in enumerate(batch, start=1):
|
|
category = classifications.get(str(idx))
|
|
|
|
if not category or category not in VALID_CATEGORIES:
|
|
self.stdout.write(
|
|
self.style.WARNING(
|
|
f" ⊘ Skipped '{dept.name}': "
|
|
f"{'invalid category' if category else 'not in AI response'}"
|
|
f" ({category})"
|
|
)
|
|
)
|
|
stats["skipped"] += 1
|
|
continue
|
|
|
|
self._apply_classification(dept, category, dry_run, stats)
|
|
|
|
self._print_summary(stats, dry_run)
|
|
|
|
def _get_hospitals(self, hospital_code):
|
|
if hospital_code:
|
|
hospitals = Hospital.objects.filter(code=hospital_code)
|
|
if not hospitals.exists():
|
|
self.stdout.write(
|
|
self.style.ERROR(f"Hospital with code '{hospital_code}' not found")
|
|
)
|
|
return None
|
|
else:
|
|
hospitals = Hospital.objects.filter(status="active")
|
|
|
|
if not hospitals.exists():
|
|
self.stdout.write(self.style.ERROR("No active hospitals found."))
|
|
return None
|
|
|
|
return hospitals
|
|
|
|
def _get_departments(self, hospitals, overwrite):
|
|
qs = Department.objects.filter(hospital__in=hospitals).select_related("hospital")
|
|
|
|
if not overwrite:
|
|
qs = qs.filter(category="")
|
|
|
|
return qs.order_by("hospital__name", "name")
|
|
|
|
def _classify_batch(self, batch):
|
|
dept_lines = []
|
|
for idx, dept in enumerate(batch, start=1):
|
|
name = dept.name_en or dept.name
|
|
if dept.name_ar:
|
|
name += f" / {dept.name_ar}"
|
|
dept_lines.append(f"{idx}. {name}")
|
|
|
|
prompt = "Classify these departments:\n" + "\n".join(dept_lines)
|
|
|
|
try:
|
|
response = AIService.chat_completion(
|
|
prompt=prompt,
|
|
system_prompt=SYSTEM_PROMPT,
|
|
temperature=0.1,
|
|
max_tokens=2000,
|
|
response_format="json_object",
|
|
)
|
|
|
|
cleaned = response.strip()
|
|
if cleaned.startswith("```"):
|
|
cleaned = cleaned.split("\n", 1)[-1]
|
|
if cleaned.endswith("```"):
|
|
cleaned = cleaned.rsplit("```", 1)[0]
|
|
cleaned = cleaned.strip()
|
|
|
|
return json.loads(cleaned)
|
|
|
|
except (AIServiceError, json.JSONDecodeError) as e:
|
|
self.stdout.write(self.style.ERROR(f" AI classification failed: {e}"))
|
|
return None
|
|
|
|
def _apply_classification(self, dept, category, dry_run, stats):
|
|
if dry_run:
|
|
staff_count = dept.staff.filter(status="active").count()
|
|
self.stdout.write(
|
|
f" Would classify '{dept.name}' → {category}"
|
|
f" ({staff_count} staff → department_type={category})"
|
|
)
|
|
stats["classified"] += 1
|
|
stats["updated_staff"] += staff_count
|
|
return
|
|
|
|
try:
|
|
with transaction.atomic():
|
|
dept.category = category
|
|
dept.save(update_fields=["category", "updated_at"])
|
|
|
|
updated = dept.staff.filter(status="active").update(department_type=category)
|
|
|
|
self.stdout.write(
|
|
self.style.SUCCESS(
|
|
f" ✓ '{dept.name}' → {category}"
|
|
f" | {updated} staff → department_type={category}"
|
|
)
|
|
)
|
|
stats["classified"] += 1
|
|
stats["updated_staff"] += updated
|
|
|
|
except Exception as e:
|
|
self.stdout.write(
|
|
self.style.ERROR(f" ✗ Failed to update '{dept.name}': {e}")
|
|
)
|
|
stats["errors"] += 1
|
|
|
|
def _print_summary(self, stats, dry_run):
|
|
self.stdout.write(f"\n{'=' * 60}")
|
|
self.stdout.write("Summary:")
|
|
self.stdout.write(f" Departments classified: {stats['classified']}")
|
|
self.stdout.write(f" Staff updated: {stats['updated_staff']}")
|
|
self.stdout.write(f" Departments skipped: {stats['skipped']}")
|
|
self.stdout.write(f" Errors: {stats['errors']}")
|
|
self.stdout.write(f"{'=' * 60}\n")
|
|
|
|
if dry_run:
|
|
self.stdout.write(self.style.WARNING("DRY RUN: No changes were made\n"))
|
|
else:
|
|
self.stdout.write(self.style.SUCCESS("Classification completed!\n"))
|