from datetime import datetime
from decimal import Decimal, ROUND_HALF_UP

from .database import db
from .risk_manager import risk_manager
from .telegram_alert import send_telegram
from .trailing_manager import next_trailing_stop, short_pl_inr, short_profit_pct


def money(value: Decimal) -> Decimal:
    return value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)


class PaperTradeManager:
    def open_trade(self, tracker, entry_price: Decimal, now: datetime) -> int | None:
        slot = tracker.slot
        allowed, reason = risk_manager.can_open_trade(slot, tracker.symbol, tracker.pump_pct)
        if not allowed:
            db.log_trade_event(
                f"Trade blocked: {reason}",
                level="WARNING",
                pump_event_id=tracker.event_id,
                context={"symbol": tracker.symbol, "slot_id": slot["id"]},
            )
            if "limit" in reason.lower():
                send_telegram(reason)
            return None

        margin = Decimal(str(slot["max_amount_inr"]))
        exposure = margin * Decimal(str(slot["leverage"]))
        trade_id = db.execute(
            """
            INSERT INTO trades (
              pump_event_id, slot_id, slot_name, symbol, direction, mode, leverage,
              margin_amount_inr, effective_exposure_inr, pump_start_price, pump_high_price,
              pump_pct, entry_price, stop_loss_amount_inr, profit_rule, profit_pct,
              trailing_start_pct, trailing_distance_pct, max_trade_duration_seconds, best_price, entry_time
            ) VALUES (%s,%s,%s,%s,'SELL','paper',%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """,
            (
                tracker.event_id,
                slot["id"],
                slot["slot_name"],
                tracker.symbol,
                slot["leverage"],
                margin,
                exposure,
                tracker.start_price,
                tracker.high_price,
                tracker.pump_pct,
                entry_price,
                slot["fixed_stop_loss_inr"],
                slot["profit_rule"],
                slot["profit_pct"],
                slot["trailing_start_pct"],
                slot["trailing_distance_pct"],
                slot["max_trade_duration_seconds"],
                entry_price,
                now,
            ),
        )
        db.execute("UPDATE pump_events SET status = 'entered' WHERE id = %s", (tracker.event_id,))
        db.log_trade_event(
            "Paper SELL trade opened",
            trade_id=trade_id,
            pump_event_id=tracker.event_id,
            context={"entry_price": entry_price, "exposure_inr": exposure},
        )
        send_telegram(f"Paper SELL opened: {tracker.symbol} | {entry_price} | Slot {slot['slot_name']}")
        return trade_id

    def update_open_trades(self, prices: dict[str, Decimal], now: datetime) -> None:
        trades = db.fetch_all("SELECT * FROM trades WHERE status = 'open'")
        for trade in trades:
            price = prices.get(trade["symbol"])
            if price is None:
                continue
            self._update_trade(trade, price, now)

    def _update_trade(self, trade: dict, current_price: Decimal, now: datetime) -> None:
        entry = Decimal(str(trade["entry_price"]))
        exposure = Decimal(str(trade["effective_exposure_inr"]))
        pl = money(short_pl_inr(entry, current_price, exposure))
        profit_pct = short_profit_pct(entry, current_price)
        best_price = Decimal(str(trade["best_price"] or entry))
        trailing_stop = trade["trailing_stop_price"]
        trailing_stop_price = Decimal(str(trailing_stop)) if trailing_stop else None
        exit_reason = None

        if current_price < best_price:
            best_price = current_price

        if pl <= -Decimal(str(trade["stop_loss_amount_inr"])):
            exit_reason = "Fixed INR stop loss hit"

        if exit_reason is None and trade["profit_rule"] == "fixed":
            if profit_pct >= Decimal(str(trade["profit_pct"])):
                exit_reason = "Fixed profit hit"

        if exit_reason is None and trade["profit_rule"] == "trailing":
            start_pct = Decimal(str(trade["trailing_start_pct"]))
            distance_pct = Decimal(str(trade["trailing_distance_pct"]))
            if profit_pct >= start_pct:
                trailing_stop_price = next_trailing_stop(best_price, distance_pct)
            if trailing_stop_price is not None and current_price >= trailing_stop_price:
                exit_reason = "Trailing profit exit"

        max_duration = trade.get("max_trade_duration_seconds")
        if exit_reason is None and max_duration:
            elapsed = int((now - trade["entry_time"]).total_seconds())
            if elapsed >= int(max_duration):
                exit_reason = "Max trade duration reached"

        if exit_reason:
            self.close_trade(trade, current_price, now, exit_reason, pl)
            return

        db.execute(
            """
            UPDATE trades
            SET gross_pl_inr = %s, net_pl_inr = %s, best_price = %s, trailing_stop_price = %s
            WHERE id = %s
            """,
            (pl, pl, best_price, trailing_stop_price, trade["id"]),
        )

    def close_trade(
        self,
        trade: dict,
        exit_price: Decimal,
        now: datetime,
        exit_reason: str,
        pl: Decimal,
    ) -> None:
        holding = int((now - trade["entry_time"]).total_seconds())
        db.execute(
            """
            UPDATE trades
            SET status = 'closed',
                exit_price = %s,
                exit_time = %s,
                holding_duration_seconds = %s,
                exit_reason = %s,
                gross_pl_inr = %s,
                net_pl_inr = %s
            WHERE id = %s
            """,
            (exit_price, now, holding, exit_reason, pl, pl, trade["id"]),
        )
        db.log_trade_event(
            "Paper trade closed",
            trade_id=trade["id"],
            context={"exit_price": exit_price, "exit_reason": exit_reason, "net_pl_inr": pl},
        )
        self.rebuild_slot_performance(trade["slot_id"])
        send_telegram(f"Paper SELL closed: {trade['symbol']} | {exit_reason} | P/L INR {pl}")

    def rebuild_slot_performance(self, slot_id: int) -> None:
        row = db.fetch_one(
            """
            SELECT
              COUNT(*) total_trades,
              SUM(CASE WHEN net_pl_inr > 0 THEN 1 ELSE 0 END) wins,
              SUM(CASE WHEN net_pl_inr < 0 THEN 1 ELSE 0 END) losses,
              COALESCE(SUM(net_pl_inr),0) profit_loss,
              COALESCE(AVG(CASE WHEN net_pl_inr > 0 THEN net_pl_inr END),0) avg_profit,
              COALESCE(AVG(CASE WHEN net_pl_inr < 0 THEN net_pl_inr END),0) avg_loss
            FROM trades
            WHERE slot_id = %s AND status = 'closed'
            """,
            (slot_id,),
        )
        best = db.fetch_one(
            """
            SELECT symbol FROM trades
            WHERE slot_id = %s AND status = 'closed'
            GROUP BY symbol ORDER BY SUM(net_pl_inr) DESC LIMIT 1
            """,
            (slot_id,),
        )
        worst = db.fetch_one(
            """
            SELECT symbol FROM trades
            WHERE slot_id = %s AND status = 'closed'
            GROUP BY symbol ORDER BY SUM(net_pl_inr) ASC LIMIT 1
            """,
            (slot_id,),
        )
        db.execute(
            """
            INSERT INTO slot_performance
              (slot_id, total_trades, wins, losses, profit_loss_inr, average_profit_inr,
               average_loss_inr, best_coin, worst_coin)
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s)
            ON DUPLICATE KEY UPDATE
              total_trades = VALUES(total_trades),
              wins = VALUES(wins),
              losses = VALUES(losses),
              profit_loss_inr = VALUES(profit_loss_inr),
              average_profit_inr = VALUES(average_profit_inr),
              average_loss_inr = VALUES(average_loss_inr),
              best_coin = VALUES(best_coin),
              worst_coin = VALUES(worst_coin)
            """,
            (
                slot_id,
                row["total_trades"],
                row["wins"] or 0,
                row["losses"] or 0,
                row["profit_loss"],
                row["avg_profit"],
                row["avg_loss"],
                best["symbol"] if best else None,
                worst["symbol"] if worst else None,
            ),
        )


paper_trade_manager = PaperTradeManager()
