import numpy as np
from docx import Document
from docx.shared import RGBColor
from docx.enum.text import WD_COLOR_INDEX
from sklearn.metrics.pairwise import cosine_similarity

from app.ml.model_loader import get_model_components, load_model
from app.services.phrase_extractor import generate_ngrams
from app.ml.inference import detect_tortured_phrases

SIMILARITY_THRESHOLD = 0.80

def process_document(uploaded_file, output_path: str):
    # -----------------------------
    # Load model & expected phrases
    # -----------------------------
    load_model()
    model, expected_phrases, expected_embeddings = get_model_components()

    # -----------------------------
    # Read uploaded document
    # -----------------------------
    document = Document(uploaded_file.file)

    # -----------------------------
    # Extract full text and detect phrases
    # -----------------------------
    full_text = "\n".join([para.text for para in document.paragraphs])
    candidate_phrases = generate_ngrams(full_text)
    detections = detect_tortured_phrases(candidate_phrases)

    # Create a set of detected phrase texts for quick lookup
    detected_phrases = {det["phrase"].lower(): det for det in detections}

    # -----------------------------
    # New annotated document
    # -----------------------------
    annotated_doc = Document()

    # -----------------------------
    # Process each paragraph
    # -----------------------------
    for para in document.paragraphs:
        text = para.text

        if not text.strip():
            annotated_doc.add_paragraph("")
            continue

        new_para = annotated_doc.add_paragraph()

        # Find all detected phrases in this paragraph
        para_detections = []
        for det in detections:
            phrase = det["phrase"]
            start = 0
            while True:
                idx = text.lower().find(phrase.lower(), start)
                if idx == -1:
                    break
                para_detections.append({
                    "start": idx,
                    "end": idx + len(phrase),
                    "phrase": phrase,
                    "matched_with": det["matched_with"],
                    "confidence": det["confidence"]
                })
                start = idx + 1

        # Sort by start position
        para_detections.sort(key=lambda x: x["start"])

        # Remove overlapping detections (keep longer ones)
        filtered_detections = []
        for det in para_detections:
            overlaps = False
            for kept in filtered_detections:
                if not (det["end"] <= kept["start"] or det["start"] >= kept["end"]):
                    overlaps = True
                    break
            if not overlaps:
                filtered_detections.append(det)

        # Build paragraph with highlighted phrases
        if not filtered_detections:
            new_para.add_run(text)
        else:
            current_pos = 0
            annotations = []

            for det in filtered_detections:
                # Add text before the detection
                if det["start"] > current_pos:
                    run = new_para.add_run(text[current_pos:det["start"]])

                # Add highlighted detected phrase with yellow background
                run = new_para.add_run(text[det["start"]:det["end"]])
                run.font.highlight_color = WD_COLOR_INDEX.YELLOW
                run.bold = True

                annotations.append(f"'{det['phrase']}' -> '{det['matched_with']}' ({det['confidence']:.2f})")

                current_pos = det["end"]

            # Add remaining text
            if current_pos < len(text):
                new_para.add_run(text[current_pos:])

            # Add annotation at end of paragraph with red background
            if annotations:
                ann_run = new_para.add_run(f"\n[Detected: {'; '.join(annotations)}]")
                ann_run.font.highlight_color = WD_COLOR_INDEX.RED
                ann_run.font.color.rgb = RGBColor(0, 0, 0)  # black text for better contrast on red

    # -----------------------------
    # Save annotated document
    # -----------------------------
    annotated_doc.save(output_path)
