import torch
import torch.nn as nn
import torch.optim as optim
import dolphindb as ddb
from dolphindb_tools.dataloader import DDBDataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import Subset
import pandas as pd
import warnings
import time

# Define the LSTM model
class LSTMModel(nn.Module):
    def __init__(self, inputSize, units):
        super(LSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size=inputSize, hidden_size=units, batch_first=True, dropout=0.4)
        self.lstm2 = nn.LSTM(input_size=units, hidden_size=128, batch_first=True, dropout=0.3)
        self.lstm3 = nn.LSTM(input_size=128, hidden_size=32, batch_first=True, dropout=0.1)
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        out, _ = self.lstm1(x)
        out, _ = self.lstm2(out)
        out, _ = self.lstm3(out)
        out = out[:, -1, :]  
        out = self.fc(out)
        return out

# Define the main function
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data preprocessing and split train and test sets
    conn = ddb.session("192.168.100.201", 8848, "admin", "123456")
    conn.run("""startDay = 2021.01.01
                endDay = 2021.12.31
                splitDay = (startDay..endDay)[((endDay-startDay)*0.8).floor()]
                Data = select FactorValues from loadTable("dfs://tenMinutesFactorDB", "tenMinutesFactorTB") where date(DateTime) >= objByName('startDay') and date(DateTime) <= objByName('endDay') and SecurityID=`600030 pivot by DateTime, SecurityID, FactorNames 
                Data = Data[each(isValid, Data.values()).rowAnd()]
                """)
    
    # Dataloader parameters
    targetColumns = ["LogReturn0_realizedVolatility"]
    excludedColumns = ["SecurityID", "DateTime", "LogReturn0_realizedVolatility"]
    batchSize = 256
    windowSize = [120, 1]
    windowStride=[1, 1]
    offset = 120
    trainSql = """select * from objByName('Data') where date(DateTime) >= objByName('startDay') and date(DateTime) <= objByName('splitDay')"""
    testSql = """select * from objByName('Data') where date(DateTime) > objByName('splitDay') and date(DateTime) <= objByName('endDay')"""  

    # Instantiate DDBDataLoader
    trainLoader = DDBDataLoader(ddbSession=conn, sql=trainSql, targetCol=targetColumns, excludeCol=excludedColumns, batchSize=batchSize, device=device, windowSize=windowSize, windowStride=windowStride, offset=offset)
    testLoader = DDBDataLoader(ddbSession=conn, sql=testSql, targetCol=targetColumns, excludeCol=excludedColumns, batchSize=batchSize, device=device, windowSize=windowSize, windowStride=windowStride, offset=offset)
    print("Using DDBDataLoader, data is ready")

    # Initialize model
    model = LSTMModel(inputSize=675, units=256)
    model.to(device)

    # Set loss function, optimizer and learning rate scheduler
    criterion = nn.SmoothL1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    # Model training
    startTime = time.time()
    bestLoss = float('inf')
    for epoch in range(100):
        epochStartTime = time.time()

        # Training
        model.train()
        trainLoss = 0.0
        trainLen = 0
        for inputs, targets in trainLoader:
            inputs = inputs.float().to(device)
            targets = targets.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            trainLoss += loss.item()
            loss.backward()
            optimizer.step()
            trainLen += 1
        trainLoss /= trainLen

        # Testing
        model.eval()
        testLoss = 0.0
        testLen = 0
        with torch.no_grad():
            for inputs, targets in testLoader:
                inputs = inputs.float().to(device)
                targets = targets.float().to(device)
                outputsTest = model(inputs)
                testLoss += criterion(outputsTest, targets).item()
                testLen += 1
        testLoss /= testLen
        print(f'Epoch {epoch+1}, Train Loss: {trainLoss}, Test Loss: {testLoss}')

        # Save the model if the test loss is the best so far
        if testLoss < bestLoss:
            bestLoss = testLoss
            torchScriptModel = torch.jit.script(model)
            torchScriptModel.save("/home/lnfu/ytxie/LSTMmodel.pt")
            print(f'New best model saved with test loss: {bestLoss}')

        # Adjust learning rate
        scheduler.step(testLoss)

        epochEndTime = time.time()
        print(f"Epoch {epoch+1} training time: {epochEndTime - epochStartTime} seconds")

    endTime = time.time()
    print(f"Total training time: {endTime - startTime} seconds")

if __name__ == "__main__":
    main()