import json, time, os, asyncio
import numpy as np
from okx_framework import OKXBaseConfig

class contractInfoConfig(OKXBaseConfig):
    """contract info data configuration"""
    TIMEOUT=3600

    # 1. Modify table name and buffer file
    tableName = "Cryptocurrency_contractInfoST"
    BUFFER_FILE = "./OKX_contractInfo_fail_buffer.jsonl"
    
    # 2. Modify table creation script
    def get_create_table_script(self) -> str:
        return '''
dbName = "dfs://CryptocurrencyDay"
tbName = "contractInfo"
streamtbName = "Cryptocurrency_contractInfoST"
colNames = 
`eventTime`collectionTime`symbolSource`symbol`contractType`contractDirection`deliveryDatetime`onboardDatetime`contractStatus`notionalBracket`floorNotional`capNotional`maintenanceRatio`auxiliaryNumber`minLeverage`maxLeverage`settleCurrency`tickSize`lotSize`minSize`contractValue`contractMultiplier`contractValueCurrency
colTypes = [TIMESTAMP, TIMESTAMP, SYMBOL, SYMBOL,STRING,STRING,TIMESTAMP, TIMESTAMP, STRING, INT, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, STRING, DOUBLE, DOUBLE, DOUBLE, DOUBLE,  DOUBLE, STRING]  

if(!existsDatabase(dbName)){
    db = database(dbName,RANGE,2010.01M+(0..20)*60)
}else{ db=database(dbName)}
if(!existsTable(dbName,tbName)){
    createPartitionedTable(db,table(1:0,colNames,colTypes),tbName,`eventTime) 
}
enableTableShareAndPersistence(table=keyedStreamTable(`symbolSource`symbol`eventTime, 10000:0, colNames, colTypes), tableName=streamtbName, cacheSize=100000, retentionMinutes=2880)
go
contractInfoTb = loadTable(dbName, tbName)
subscribeTable(tableName=streamtbName, actionName="insertDB", offset=-2, handler=contractInfoTb, msgAsTable=true, batchSize=10000, throttle=1, persistOffset=true)
        '''


    # 3. Modify subscription parameters
    def get_subscription_args(self, inst_ids: list) -> list:
        return [{"channel": "instruments", "instType": "SWAP"},{"channel": "instruments", "instType": "SPOT"}]
    
    def _bj(self, ts):
        """OKX millisecond timestamp (string/int) → Beijing time in milliseconds; return None if empty"""
        if ts in (None, "", 0):
            return None
        return int(ts) 
    
    def _norm_symbol_okx(self,inst_id: str):
        """
        OKX -> normalize to BTCUSDT
        - SPOT: 'BTC-USDT' -> 'BTCUSDT'
        - SWAP: 'BTC-USDT-SWAP' -> 'BTCUSDT'
        """
        parts = inst_id.split("-")
        if len(parts) >= 2:
            # Remove hyphens and drop the last 'SWAP' segment (if present)
            if parts[-1] == "SWAP":
                parts = parts[:-1]
            return "".join(parts)
        return inst_id


    def _symbol_source(self,inst_type: str):
        if inst_type == "SWAP":
            return "OKX-Futures"
        elif inst_type == "SPOT":
            return "OKX-Spot"
        else:
            return f"OKX-{inst_type or 'Unknown'}"

    def _contract_type_from_instType(self, inst_type: str):
        """Normalize contractType: SWAP → perpetual, SPOT → spot"""
        if inst_type == "SWAP":
            return "PERPETUAL"
        if inst_type == "SPOT":
            return "SPOT"
        return None

    # 4. Modify message handling
    def handle_message(self, message: str):
        d = json.loads(message)
        if "data" not in d or not d["data"]:
            return
        
        now_bj = int(time.time() * 1000) 
        arg_inst_type = d.get("arg", {}).get("instType")  # SWAP / SPOT

        for j in d["data"]:
            # Basic fields
            inst_type = j.get("instType") or arg_inst_type
            symbol_source = self._symbol_source(j.get("instType") or inst_type)
            symbol        = self._norm_symbol_okx(j["instId"])

            # Time fields
            event_time     = now_bj  # No server-side event timestamp
            collectionTime = now_bj
            delivery_dt    = self._bj(j.get("expTime"))        # Usually empty for perpetuals
            onboard_dt     = self._bj(j.get("listTime") or j.get("contTdSwTime"))

            # Status
            contract_status = (j.get("state") or "").upper()          # live / suspend / expired / preopen / test
            contract_type = self._contract_type_from_instType(inst_type)
            contract_direction = j.get("ctType") or ""           # linear / inverse
            # Bracket-related fields (OKX has no bks, set all to None)
            notional_bracket = None
            floor_notional   = None
            cap_notional     = None
            maintenance_ratio= None
            auxiliary_number = None
            min_leverage     = None
            lever = j.get("lever")
            max_leverage     = float(lever) if lever not in (None, "") else None

            # Precision / contract value / currency
            settle_currency        = j.get("settleCcy") or None
            tick_size              = float(j["tickSz"]) if j.get("tickSz") not in (None, "") else None
            lot_size               = float(j["lotSz"])  if j.get("lotSz")  not in (None, "") else None
            min_size               = float(j["minSz"])  if j.get("minSz")  not in (None, "") else None
            contract_value         = float(j["ctVal"])  if j.get("ctVal")  not in (None, "") else None
            contract_multiplier    = float(j["ctMult"]) if j.get("ctMult") not in (None, "") else None
            contract_value_currency= j.get("ctValCcy") or None

            # Build row
            row = [
                event_time, collectionTime, symbol_source, symbol, contract_type,contract_direction,
                delivery_dt, onboard_dt, contract_status,
                notional_bracket, floor_notional, cap_notional,
                maintenance_ratio, auxiliary_number, min_leverage, max_leverage,
                settle_currency, tick_size, lot_size, min_size,
                contract_value, contract_multiplier, contract_value_currency
            ]
            
            try:
                self.realtime_q.put(row, block=False)
            except:
                with self.file_lock, open(self.BUFFER_FILE, "a", encoding="utf-8") as f:
                    f.write(json.dumps(row) + "\n")

if __name__ == "__main__":
    # Create config based on type

    config = contractInfoConfig()
    
    # Start data pipeline
    io_thread, okx_thread = config.start(inst_ids=[])
    
    # Keep main thread alive
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        print(f"\n[{time.strftime('%H:%M:%S')}] Stopping...")
        if okx_thread.ws:
            okx_thread.loop.call_soon_threadsafe(
                asyncio.create_task, okx_thread.ws.stop()
            )
        time.sleep(1.0)
        print(f"[{time.strftime('%H:%M:%S')}] Program exited")
        os._exit(0)
