#!/usr/bin/env python3
"""Check main session health and log tiered results."""
from __future__ import annotations

import json
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo

SESSIONS_PATH = Path.home() / ".openclaw" / "agents" / "main" / "sessions" / "sessions.json"
WORKSPACE = Path("/home/isthekid/.openclaw/workspace")
MEMORY_DIR = WORKSPACE / "memory"

# Thresholds — tuned for 1M token context window (claude-sonnet-4-6)
# ⚪ OK:       < 300K tokens
# 🟡 Elevated: 300K–500K tokens
# 🟠 Warning:  500K–750K tokens
# 🔴 Critical: > 750K tokens
OK_TOKENS_MAX        = 300_000
ELEVATED_TOKENS_MAX  = 500_000
WARNING_TOKENS_MAX   = 750_000

LEVELS = {
    "OK": {
        "emoji": "⚪",
        "template": "⚪ Session health: OK — {tokens:,} tokens, {compactions} compactions. Log only.",
        "alert": False,
    },
    "ELEVATED": {
        "emoji": "🟡",
        "template": "🟡 Session health: Elevated — {tokens:,} tokens, {compactions} compactions. No action needed yet, just a heads up.",
        "alert": True,
    },
    "WARNING": {
        "emoji": "🟠",
        "template": "🟠 Session health: Warning — {tokens:,} tokens, {compactions} compactions. Consider resetting the session soon.",
        "alert": True,
    },
    "CRITICAL": {
        "emoji": "🔴",
        "template": "🔴 Session health: Critical — {tokens:,} tokens, {compactions} compactions. Session reset recommended now to avoid context window errors.",
        "alert": True,
    },
}


def load_metrics() -> tuple[int, int]:
    data = json.loads(SESSIONS_PATH.read_text())
    entry = data.get("agent:main:main", {})
    total_tokens = int(entry.get("totalTokens") or 0)
    compaction = entry.get("compactionCount")
    if compaction is None:
        compaction = entry.get("authProfileOverrideCompactionCount")
    if compaction is None:
        compaction = 0
    return total_tokens, int(compaction)


def determine_level(tokens: int, compactions: int) -> tuple[str, dict]:
    if tokens > WARNING_TOKENS_MAX:
        return "CRITICAL", LEVELS["CRITICAL"]
    if tokens > ELEVATED_TOKENS_MAX:
        return "WARNING", LEVELS["WARNING"]
    if tokens > OK_TOKENS_MAX:
        return "ELEVATED", LEVELS["ELEVATED"]
    return "OK", LEVELS["OK"]


def log_result(ts: datetime, emoji: str, level: str, tokens: int, compactions: int) -> None:
    MEMORY_DIR.mkdir(parents=True, exist_ok=True)
    memory_path = MEMORY_DIR / f"{ts:%Y-%m-%d}.md"
    line = (
        f"- {ts:%H:%M %Z}: {emoji} Session health check — level={level}, "
        f"tokens={tokens:,}, compactions={compactions}.\n"
    )
    with memory_path.open("a", encoding="utf-8") as f:
        f.write(line)


def main() -> None:
    ts = datetime.now(ZoneInfo("America/New_York"))
    tokens, compactions = load_metrics()
    level, details = determine_level(tokens, compactions)
    message = details["template"].format(tokens=tokens, compactions=compactions)
    emoji = details["emoji"]
    should_alert = details["alert"]

    log_result(ts, emoji, level, tokens, compactions)

    output = {
        "timestamp": ts.isoformat(),
        "level": level,
        "emoji": emoji,
        "totalTokens": tokens,
        "compactionCount": compactions,
        "message": message,
        "shouldAlert": should_alert,
    }
    print(json.dumps(output))


if __name__ == "__main__":
    main()
