from datetime import datetime
from decimal import Decimal

from .database import db
from .highest_price_tracker import highest_price_tracker
from .paper_trade_manager import paper_trade_manager


class PumpReversalStrategy:
    def on_prices(self, prices: dict[str, Decimal], slots: list[dict], now: datetime) -> None:
        highest_price_tracker.expire_old(now)
        self._store_ticks(prices, now)
        paper_trade_manager.update_open_trades(prices, now)

        for symbol, current_price in prices.items():
            candidates = []
            for slot in slots:
                if not self._symbol_supports_leverage(symbol, int(slot["leverage"])):
                    continue
                pump = self._pump_for_slot(symbol, current_price, slot)
                if pump:
                    candidates.append((slot, pump[0], pump[1]))

            candidates.sort(key=lambda item: (-Decimal(str(item[0]["price_rise_pct"])), int(item[0]["priority"])))
            for slot, low_price, pump_pct in candidates:
                tracker = highest_price_tracker.add_or_update(
                    slot, symbol, low_price, current_price, pump_pct, now
                )
                fall_pct = ((tracker.high_price - current_price) / tracker.high_price) * Decimal("100")
                if fall_pct >= Decimal(str(slot["price_fall_pct"])):
                    trade_id = paper_trade_manager.open_trade(tracker, current_price, now)
                    if trade_id:
                        highest_price_tracker.remove(tracker)

    def _store_ticks(self, prices: dict[str, Decimal], now: datetime) -> None:
        rows = [(symbol, price, now) for symbol, price in prices.items()]
        db.execute_many(
            "INSERT INTO price_ticks (symbol, price, tick_time) VALUES (%s,%s,%s)",
            rows,
        )
        db.execute(
            """
            DELETE FROM price_ticks
            WHERE tick_time < DATE_SUB(NOW(), INTERVAL 12 HOUR)
            """
        )

    def _symbol_supports_leverage(self, symbol: str, leverage: int) -> bool:
        row = db.fetch_one("SELECT max_leverage FROM symbols WHERE symbol = %s", (symbol,))
        return bool(row and int(row["max_leverage"]) >= leverage)

    def _pump_for_slot(
        self, symbol: str, current_price: Decimal, slot: dict
    ) -> tuple[Decimal, Decimal] | None:
        row = db.fetch_one(
            """
            SELECT MIN(price) AS low_price
            FROM price_ticks
            WHERE symbol = %s
              AND tick_time >= DATE_SUB(NOW(), INTERVAL %s SECOND)
            """,
            (symbol, int(slot["pump_time_seconds"])),
        )
        if not row or row["low_price"] is None:
            return None
        low_price = Decimal(str(row["low_price"]))
        if low_price <= 0:
            return None
        pump_pct = ((current_price - low_price) / low_price) * Decimal("100")
        if pump_pct >= Decimal(str(slot["price_rise_pct"])):
            return low_price, pump_pct
        return None


strategy = PumpReversalStrategy()

