"""
binance_framework.py – A general framework for accessing Binance data.
Place this file in your project; other data types can be supported by importing it and extending it via inheritance.
"""

import time, json, threading, os, math, queue
from typing import List, Any, Optional
from abc import ABC, abstractmethod
from ..configs import DDB,BINANCE_BASE_CONFIG
import numpy as np
import dolphindb as ddb


# ========================= Secure JSON serialization =========================
def normalize_scalar(x: Any) -> Any:
    """Convert NumPy scalar/time types and floating-point NaN/Inf values into JSON-friendly types."""
    if x is None:
        return None
    if isinstance(x, np.datetime64):
        return int(np.datetime64(x, 'ms').astype('datetime64[ms]').astype(np.int64))
    if isinstance(x, np.integer):
        return int(x)
    if isinstance(x, np.floating):
        fx = float(x)
        if math.isnan(fx) or math.isinf(fx):
            return None
        return fx
    if isinstance(x, np.bool_):
        return bool(x)
    if isinstance(x, float):
        if math.isnan(x) or math.isinf(x):
            return None
        return x
    if isinstance(x, (int, bool, str)):
        return x
    return x

def normalize_row(row: Any) -> Any:
    """Recursively convert lists, dicts, ndarrays, and similar objects into structures that can be serialized with json.dumps."""
    if row is None:
        return None
    if isinstance(row, (list, tuple)):
        return [normalize_row(v) for v in row]
    if isinstance(row, dict):
        return {k: normalize_row(v) for k, v in row.items()}
    if isinstance(row, np.ndarray):
        return normalize_row(row.tolist())
    return normalize_scalar(row)


# ========================= General IOThread =========================
class IOThread(threading.Thread):
    """
    Unified control over the single write path for “writing to MTW / persisting locally / replaying backfill”, executed serially to naturally avoid out-of-order writes.
    Three-state state machine:
        live: DDB is healthy; consume data from the real-time queue and write to MTW (switch to offline on failure)
        offline: DDB is unavailable; flush the real-time queue to local storage as much as possible; after the cooldown period, probe DDB using the first local record
        replay: only write cached local data back to MTW (strictly do not consume the real-time queue); switch back to live after a successful clearing
    """

    # Initialize the thread and immediately start health probing.
    def __init__(self, config):
        super().__init__(daemon=True)
        self.config = config
        self.path = config.BUFFER_FILE
        self.mode = 'live'
        self.next_probe_ts = 0.0

    # A function to initialize database and table creation scripts, used to ensure stream tables can be rebuilt after DDB outages or failures.
    def ensure_table_exists(self):
        s = ddb.session()
        s.connect(
            self.config.dolphindb_address, 
            self.config.dolphindb_port,
            self.config.dolphindb_user, 
            self.config.dolphindb_password
        )
        try:
            s.run(self.config.get_create_table_script())
            s.close()
        except:
            pass

    def build_mtw(self) -> bool:
        """Attempt to rebuild MTW; return True on success and False on failure."""
        try:
            self.ensure_table_exists()
            self.config.writer = ddb.MultithreadedTableWriter(
                self.config.dolphindb_address, 
                self.config.dolphindb_port, 
                self.config.dolphindb_user, 
                self.config.dolphindb_password,
                dbPath="", 
                tableName=self.config.tableName,
                batchSize=10000, 
                throttle=1, 
                threadCount=1, 
                reconnect=False
            )
            return True
        except Exception as e:
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Rebuilding MTW failed: {e}")
            self.config.writer = None
            return False

    def save_unwritten_to_local(self):
        """
        Key timing: Once writing to MTW fails, first flush MTW’s internal unwrittenData to local storage,
        then write the current row/queue real-time data, ensuring that data entering MTW first is also written to local storage first.
        """
        if self.config.writer is None:
            return
        # Extract the unwritten data from the MTW queue.
        try:
            unwritten = self.config.writer.getUnwrittenData()
        except Exception as e:
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] getUnwrittenData failed: {e}")
            self.config.writer = None
            return

        if not unwritten:
            return

        with self.config.file_lock, open(self.path, "a", encoding="utf-8") as f:
            for row in unwritten:
                f.write(json.dumps(normalize_row(row), ensure_ascii=False) + "\n")
            f.flush()
        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Successfully transferred {len(unwritten)} unwritten MTW rows to local storage.")

    # Append a batch of data to the local file, saving any rows that fail to write.
    def _append_rows_to_local(self, rows: List[List[Any]]):
        if not rows:
            return
        with self.config.file_lock, open(self.path, "a", encoding="utf-8") as f:
            for row in rows:
                f.write(json.dumps(normalize_row(row), ensure_ascii=False) + "\n")
            f.flush()

    # Batch-transfer the data from the queue to local storage.
    def _save_queue_to_local(self, max_n: int = 50000):
        moved = 0
        buf: List[List[Any]] = []
        while moved < max_n:
            try:
                row = self.config.realtime_q.get_nowait()
            except queue.Empty:
                break
            buf.append(row)
            moved += 1
        if buf:
            self._append_rows_to_local(buf)

    #  Write a single record to DDB.
    def _insert_one(self, row: List[Any]):
        if self.config.writer is None and not self.build_mtw():
            raise RuntimeError("MTW unavailable")
        res = self.config.writer.insert(*row)
        if hasattr(res, "hasError") and res.hasError():
            raise RuntimeError(res.errorInfo)

    # Health probe – use the first line of the local file to test if DDB has recovered.
    def _probe_from_local_first_line(self) -> bool:
        if not os.path.exists(self.path):
            return False
        if self.config.writer is None and not self.build_mtw():
            return False

        with self.config.file_lock:
            with open(self.path, "r", encoding="utf-8", newline="\n") as f:
                line = f.readline()
                if not line or not line.endswith("\n"):
                    return False
                try:
                    row = json.loads(line)
                    self._insert_one(row)
                    print("The first local test write succeeded.")
                    return True
                except Exception:
                    self.config.writer = None
                    return False
    # Backfill all local data to DDB.
    def _replay_all_local(self) -> bool:
        if not os.path.exists(self.path):
            return True
        if self.config.writer is None and not self.build_mtw():
            return False
        total = 0
        try:
            with self.config.file_lock:
                src = open(self.path, "r", encoding="utf-8", newline="\n")
                while True:
                    # Read in batches (each batch of READ_BATCH_SIZE rows).
                    batch_lines = []
                    for _ in range(self.config.READ_BATCH_SIZE):
                        line = src.readline()
                        if not line:
                            break
                        if not line.endswith("\n"):
                            break
                        batch_lines.append(line)
                    if not batch_lines:
                        break

                    # Write to DDB row by row.
                    i = 0
                    try:
                        for i, line in enumerate(batch_lines):
                            row = json.loads(line)
                            self._insert_one(row)
                        total += len(batch_lines)
                        continue
                    except Exception as e:
                        # On write failure, write the unprocessed data back to a temporary file.
                        print("An issue occurred during backfill.")
                        tmp_path = self.path + ".tmp"
                        try:
                            uw = self.config.writer.getUnwrittenData()
                        except:
                            print("Failed to get unwritten data during backfill.")
                            self.config.writer = None
                        # Write to a temporary file.
                        with open(tmp_path, "w", encoding="utf-8") as tmp:
                            # 1. First write the unwritten MTW data.
                            for row in uw:
                                tmp.write(json.dumps(normalize_row(row), ensure_ascii=False) + "\n")
                            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Backfill interrupted: {e}. UnwrittenData has been written back to local storage.")
                            # 2. Then write the unprocessed rows from the current batch.
                            tmp.writelines(batch_lines[i:])
                            # 3. Finally, write all remaining rows.
                            for rest in src:
                                tmp.write(rest)
                        src.close()
                        os.replace(tmp_path, self.path)
                        self.config.writer = None
                        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Backfill interrupted: {e}. Unprocessed data has been written back to local storage.")
                        return False
                    
            # All succeeded; clear the file.
            src.close()
            with self.config.file_lock, open(self.path, "w", encoding="utf-8"):
                pass
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Backfill completed; a total of {total} rows were written, and the cache has been cleared.")
            return True

        except Exception as e:
            try:
                src.close()
            except:
                pass
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Backfill failed: {e} (cache retained for retry later).")
            self.config.writer = None
            return False

    def run(self):
        while True:
            # If local cache exists, prioritize “health probe/backfill” first."
            if os.path.exists(self.path) and os.path.getsize(self.path) > 0:
                if self.mode != 'replay':
                    now = time.time()
                    # During the cooldown period: flush the queue to local storage and wait for the next health probe.
                    if self.mode == 'offline' and now < self.next_probe_ts:
                        self._save_queue_to_local()
                        time.sleep(0.1)
                        continue

                    # At health probe time: use the first local row for a test write.
                    ok = self._probe_from_local_first_line()
                    if ok:
                        self.mode = 'replay'
                        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Health check succeeded; enter backfill mode.")
                    else:
                        self.mode = 'offline'
                        self.next_probe_ts = time.time() + self.config.PROBE_COOLDOWN_SECS
                        self._save_queue_to_local()
                        time.sleep(0.2)
                        continue

                # Backfill mode: strictly do not consume the real-time queue.
                if self.mode == 'replay':
                    success = self._replay_all_local()
                    if success:
                        self.mode = 'live'  # Backfill cleared; switch back to live mode.
                    else:
                        self.mode = 'offline'
                        self.next_probe_ts = time.time() + self.config.PROBE_COOLDOWN_SECS
                        self._save_queue_to_local()
                        time.sleep(0.2)
                        continue

            # No local data or already cleared; live mode writes real-time data.
            if self.mode == 'live':
                try:
                    row = self.config.realtime_q.get(timeout=self.config.LIVE_GET_TIMEOUT)
                except queue.Empty:
                    time.sleep(0.05)
                    continue

                try:
                    self._insert_one(row)
                except Exception as e:
                    print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Real-time write failed: {e}. Switching to offline and transferring data.")
                    self.save_unwritten_to_local()          # 1. First, persist the MTW cache.
                    self._append_rows_to_local([row])       # 2. Then, save the current row.
                    self._save_queue_to_local()             # 3. Finally, flush the queue.
                    self.config.writer = None
                    self.mode = 'offline'
                    self.next_probe_ts = time.time() + self.config.PROBE_COOLDOWN_SECS
                    time.sleep(0.2)
                    continue

            elif self.mode == 'offline':
                self._save_queue_to_local()
                time.sleep(0.1)
                continue


# ========================= Basic config class =========================
class BinanceBaseConfig(ABC):
    """
    Base class for all Binance data ingestion config.
    In use, inherit from this class and implement the abstract methods.
    """
    
    # DolphinDB connection config
    dolphindb_address = DDB["HOST"]
    dolphindb_port = DDB["PORT"]
    dolphindb_user = DDB["USER"]
    dolphindb_password = DDB["PWD"]
    
    # Proxy config
    proxy_address = BINANCE_BASE_CONFIG["PROXY"]
    
    # Table name and cache file (must be set by subclasses).
    tableName = ""
    BUFFER_FILE = ""
    
    # Behavior parameters (adjustable as needed).
    TIMEOUT = BINANCE_BASE_CONFIG["TIMEOUT"]
    PROBE_COOLDOWN_SECS = BINANCE_BASE_CONFIG["PROBE_COOLDOWN_SECS"]
    READ_BATCH_SIZE = BINANCE_BASE_CONFIG["READ_BATCH_SIZE"]
    LIVE_GET_TIMEOUT = BINANCE_BASE_CONFIG["LIVE_GET_TIMEOUT"]
    
    def __init__(self):
        # Create a shared object
        self.realtime_q = queue.Queue(maxsize=1_000_000)
        self.file_lock = threading.Lock()
        self.writer = None
        self.last_received_time = time.time()
        
        # Ensure the cache directory exists.
        if self.BUFFER_FILE:
            os.makedirs(os.path.dirname(self.BUFFER_FILE) or '.', exist_ok=True)
    
    @abstractmethod
    def get_create_table_script(self) -> str:
        """Return the DolphinDB table creation script."""
        pass
    
    @abstractmethod
    def create_message_handler(self):
        """Create a message handler function."""
        pass
    
    @abstractmethod
    def start_client_and_subscribe(self):
        """Start the client and subscribe; return the client object."""
        pass
    
    def start_all(self):
        """Start the complete data pipeline."""
        # Initial MTW construction (failure is okay; IOThread will automatically retry).
        io_thread = IOThread(self)
        if not io_thread.build_mtw():
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Initial MTW creation failed; it will be automatically retried in the write thread.")
        
        # Start the single write channel.
        io_thread.start()
        
        # 启动客户端Start the client and subscribe.并订阅
        self.client = self.start_client_and_subscribe()
        monitor_thread = threading.Thread(
            target=self.monitor_timeout,  
            daemon=True
        )
        monitor_thread.start()        
        return self.client
    
    def monitor_timeout(self):
        """WebSocket timeout monitoring"""
        while True:
            try:
                time.sleep(self.TIMEOUT)
                if time.time() - self.last_received_time > self.TIMEOUT:
                    print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] WebSocket timeout, attempting to reconnect…")
                    try:
                        self.client.stop()
                    except:
                        pass
                    time.sleep(1)
                    # Rebuild the WebSocket client.
                    self.client = self.start_client_and_subscribe()
                    self.last_received_time = time.time()
                    print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] WebSocket reconnect")
            except Exception as e:
                print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] WebSocket monitoring exception:{e}")
                time.sleep(self.TIMEOUT)

    def quick_exit(self):
        """Quick exit"""
        print(f"\n[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Stopping...")
        
        # Stop key components.件
        if hasattr(self, 'client') and self.client:
            try:
                self.client.stop()
            except:
                pass
        
        # Give it some time to flush the buffer.
        time.sleep(1.0)
        
        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Program has exited.")
        os._exit(0)                