from decimal import Decimal

from .database import db


class RiskManager:
    def __init__(self) -> None:
        self.settings = db.setting_map()

    def refresh(self) -> None:
        self.settings = db.setting_map()

    def max_open_slots_per_coin(self) -> int:
        return int(self.settings.get("max_open_slots_per_coin", "1"))

    def global_exposure_limit(self) -> Decimal:
        return Decimal(str(self.settings.get("global_max_capital_exposure", "20000")))

    def daily_loss_limit(self) -> Decimal:
        return Decimal(str(self.settings.get("daily_loss_limit", "500")))

    def max_total_open_trades(self) -> int:
        return int(self.settings.get("max_total_open_trades", "3"))

    def current_open_margin(self) -> Decimal:
        row = db.fetch_one(
            "SELECT COALESCE(SUM(margin_amount_inr),0) AS total FROM trades WHERE status = 'open'"
        )
        return Decimal(str(row["total"] or 0))

    def todays_loss(self) -> Decimal:
        row = db.fetch_one(
            """
            SELECT COALESCE(SUM(net_pl_inr),0) AS total
            FROM trades
            WHERE DATE(entry_time) = CURDATE() AND status = 'closed' AND net_pl_inr < 0
            """
        )
        return abs(Decimal(str(row["total"] or 0)))

    def open_trade_count(self) -> int:
        row = db.fetch_one("SELECT COUNT(*) AS total FROM trades WHERE status = 'open'")
        return int(row["total"] or 0)

    def open_slots_for_symbol(self, symbol: str) -> list[dict]:
        return db.fetch_all(
            "SELECT slot_id, pump_pct FROM trades WHERE status = 'open' AND symbol = %s",
            (symbol,),
        )

    def can_open_trade(self, slot: dict, symbol: str, pump_pct: Decimal) -> tuple[bool, str]:
        self.refresh()
        if self.todays_loss() >= self.daily_loss_limit():
            return False, "Daily loss limit reached"
        if self.open_trade_count() >= self.max_total_open_trades():
            return False, "Max total open trades reached"

        open_for_symbol = self.open_slots_for_symbol(symbol)
        if any(int(row["slot_id"]) == int(slot["id"]) for row in open_for_symbol):
            return False, "Duplicate same-slot trade on same coin blocked"
        if len(open_for_symbol) >= self.max_open_slots_per_coin():
            return False, "Max open slots per coin reached"
        if open_for_symbol and self.max_open_slots_per_coin() < 2:
            return False, "Only one open slot per coin allowed"
        if open_for_symbol:
            best_existing = max(Decimal(str(row["pump_pct"])) for row in open_for_symbol)
            if pump_pct < best_existing:
                return False, "Second slot pump logic is weaker than existing open trade"

        next_margin = Decimal(str(slot["max_amount_inr"]))
        if self.current_open_margin() + next_margin > self.global_exposure_limit():
            return False, "Global exposure limit reached"
        return True, "OK"


risk_manager = RiskManager()

