/**
File name: StreamComputing.dos
Application: Stream compute 10-minute frequency factors and predict stock volatility in real time.
Author: Yanyan Xu
Company: DolphinDB Inc.
DolphinDB server version: 3.00.2.3 2024.12.26 LINUX_ABI x86_64
Storage engine: TSDB
Last modification time: 2025.03.09
*/

use ops 

// Load the plugin
loadPlugin("LibTorch") 

// Set replay date
mdDate = 2021.01.29

// Stream computing environment cleanup
def cleanEnvironment(){
	try{ unsubscribeAll(tbName=`SnapshotStream) } catch(ex){ print(ex) }
	try{ unsubscribeAll(tbName=`aggrFeatures10min) } catch(ex){ print(ex) }
	try{ unsubscribeAll(tbName=`result10min) } catch(ex){ print(ex) }
	try{ dropStreamEngine("aggrFeatures10min") } catch(ex){ print(ex) }
	try{ dropStreamTable(`SnapshotStream) } catch(ex){ print(ex) }
	try{ dropStreamTable(`aggrFeatures10min) } catch(ex){ print(ex) }
	try{ dropStreamTable(`result10min) } catch(ex){ print(ex) }
}
cleanEnvironment()

// Create original snapshot stream table
name = loadTable("dfs://l2StockSHDB", "snapshot").schema().colDefs.name
type = loadTable("dfs://l2StockSHDB", "snapshot").schema().colDefs.typeString
share(streamTable(100000:0, name, type), `SnapshotStream)
go

/** 
	4.1 Real-time feature engineering
 */ 

// Define functions
def logReturn(s){
	return log(s)-log(prev(s))
}

def realizedVolatility(s){
	return sqrt(sum2(s))
}

def createAggMetaCode(aggDict){
	metaCode = []
	metaCodeColName = []
	for(colName in aggDict.keys()){
		for(funcName in aggDict[colName])
		{
			metaCode.append!(sqlCol(colName, funcByName(funcName), colName + `_ + funcName$STRING))
			metaCodeColName.append!(colName + `_ + funcName$STRING)
		}
	}
	return metaCode, metaCodeColName$STRING
}

// Feature Engineering
features = {
	"DateTime":[`count]
}
for( i in 0..9)
{
	features["Wap"+i] = [`sum, `mean, `std]
	features["LogReturn"+i] = [`sum, `realizedVolatility, `mean, `std]
	features["LogReturnOffer"+i] = [`sum, `realizedVolatility, `mean, `std]
	features["LogReturnBid"+i] = [`sum, `realizedVolatility, `mean, `std]
}
features["WapBalance"] = [`sum, `mean, `std]
features["PriceSpread"] = [`sum, `mean, `std]
features["BidSpread"] = [`sum, `mean, `std]
features["OfferSpread"] = [`sum, `mean, `std]
features["TotalVolume"] = [`sum, `mean, `std]
features["VolumeImbalance"] = [`sum, `mean, `std]
aggMetaCode, metaCodeColName = createAggMetaCode(features)

// Define aggregate function
defg featureEngineering(DateTime, BidPrice, BidOrderQty, OfferPrice, OfferOrderQty, aggMetaCode){
	wap = (BidPrice * OfferOrderQty + BidOrderQty * OfferPrice) \ (BidOrderQty + OfferOrderQty)
	wapBalance = abs(wap[0] - wap[1])
	priceSpread = (OfferPrice[0] - BidPrice[0]) \ ((OfferPrice[0] + BidPrice[0]) \ 2)
	BidSpread = BidPrice[0] - BidPrice[1]
	OfferSpread = OfferPrice[0] - OfferPrice[1]
	totalVolume = OfferOrderQty.rowSum() + BidOrderQty.rowSum()
	volumeImbalance = abs(OfferOrderQty.rowSum() - BidOrderQty.rowSum())
	logReturnWap = logReturn(wap)
	logReturnOffer = logReturn(OfferPrice)
	logReturnBid = logReturn(BidPrice)
	subTable = table(DateTime as `DateTime, BidPrice, BidOrderQty, OfferPrice, OfferOrderQty, wap, wapBalance, priceSpread, BidSpread, OfferSpread, totalVolume, volumeImbalance, logReturnWap, logReturnOffer, logReturnBid)
	colNum = 0..9$STRING
	colName = `DateTime <- (`BidPrice + colNum) <- (`BidOrderQty + colNum) <- (`OfferPrice + colNum) <- (`OfferOrderQty + colNum) <- (`Wap + colNum) <- `WapBalance`PriceSpread`BidSpread`OfferSpread`TotalVolume`VolumeImbalance <- (`logReturn + colNum) <- (`logReturnOffer + colNum) <- (`logReturnBid + colNum)
	subTable.rename!(colName)
	subTable['BarDateTime'] = bar(subTable['DateTime'], 10m)
	result = sql(select = aggMetaCode, from = subTable).eval().matrix()
	result150 = sql(select = aggMetaCode, from = subTable, where = <time(DateTime) >= (time(BarDateTime) + 150*1000) >).eval().matrix()
	result300 = sql(select = aggMetaCode, from = subTable, where = <time(DateTime) >= (time(BarDateTime) + 300*1000) >).eval().matrix()
	result450 = sql(select = aggMetaCode, from = subTable, where = <time(DateTime) >= (time(BarDateTime) + 450*1000) >).eval().matrix()
	return concatMatrix([result, result150, result300, result450])
}

defg getHandleTime(){
	return now(true)
}
// Create a stream table for real-time feature engineering
share(streamTable(100000:0 , `DateTime`SecurityID`ReceiveTime <- metaCodeColName <- (metaCodeColName+"_150") <- (metaCodeColName+"_300") <- (metaCodeColName+"_450") <- `HandleTime,`TIMESTAMP`SYMBOL`NANOTIMESTAMP <- take(`DOUBLE, 676) <- `NANOTIMESTAMP) , `aggrFeatures10min)
go
// Create time-series engine
metrics=sqlColAlias(<featureEngineering(DateTime,
		matrix(BidPrice[0],BidPrice[1],BidPrice[2],BidPrice[3],BidPrice[4],BidPrice[5],BidPrice[6],BidPrice[7],BidPrice[8],BidPrice[9]),
		matrix(BidOrderQty[0],BidOrderQty[1],BidOrderQty[2],BidOrderQty[3],BidOrderQty[4],BidOrderQty[5],BidOrderQty[6],BidOrderQty[7],BidOrderQty[8],BidOrderQty[9]),
		matrix(OfferPrice[0],OfferPrice[1],OfferPrice[2],OfferPrice[3],OfferPrice[4],OfferPrice[5],OfferPrice[6],OfferPrice[7],OfferPrice[8],OfferPrice[9]),
		matrix(OfferOrderQty[0],OfferOrderQty[1],OfferOrderQty[2],OfferOrderQty[3],OfferOrderQty[4],OfferOrderQty[5],OfferOrderQty[6],OfferOrderQty[7],OfferOrderQty[8],OfferOrderQty[9]), aggMetaCode)>, metaCodeColName <- (metaCodeColName+"_150") <- (metaCodeColName+"_300") <- (metaCodeColName+"_450"))	

createTimeSeriesEngine(name="aggrFeatures10min", windowSize=600000, step=600000, metrics=[<now(true) as ReceiveTime>, metrics, <getHandleTime() as HandleTime>], useSystemTime=false ,dummyTable=SnapshotStream, outputTable=aggrFeatures10min, timeColumn=`DateTime, useWindowStartTime=true, keyColumn=`SecurityID)

// Subscribe to the snapshot stream table
subscribeTable(tableName="SnapshotStream", actionName="aggrFeatures10min", offset=-1, handler=getStreamEngine("aggrFeatures10min"), msgAsTable=true, batchSize=2000, throttle=0.01, hash=0, reconnect=true)
/** 
	4.2 Real-time prediction
 */

 // Load model
modelPath = "/home/lnfu/ytxie/LSTMmodel.pt"
model = LibTorch::load(modelPath)
LibTorch::setDevice(model, "CUDA")

// Create real-time prediction result table
share(streamTable(100000:0 , `Predicted`SecurityID`DateTime`ReceiveTime`HandleTime`PredictedTime, `FLOAT`SYMBOL`TIMESTAMP`NANOTIMESTAMP`NANOTIMESTAMP`NANOTIMESTAMP), `result10min)
go
// Define real-time processing method: data preprocessing and model inference
def predictRV(model, window, msg){
	tmp = select * from msg 
	tmp.dropColumns!(`ReceiveTime`HandleTime)
	tmp.reorderColumns!(objByName(`historyData).columnNames())
	predictedtSet = []
	for(row in tmp){
		objByName(`historyData).tableInsert(row)
		data = tail(objByName(`historyData), window) 
		data.dropColumns!(`SecurityID`DateTime`LogReturn0_realizedVolatility)		
		data = data[each(isValid, data.values()).rowAnd()]
		input = tensor([matrix(data)$FLOAT])
		predict = LibTorch::predict(model, input)
		predictedtSet.append!(predict[0][0])
	}
	ret = select predictedtSet as predicted, SecurityID, DateTime, ReceiveTime, HandleTime, now(true) as PredictedTime from msg 
	objByName("result10min").append!(ret)
}
// Prepare historical data for sliding window
data = select FactorValues from loadTable("dfs://tenMinutesFactorDB", "tenMinutesFactorTB") where date(DateTime) between (mdDate-20):(mdDate-1) and SecurityID=`600030 pivot by DateTime, SecurityID, FactorNames 
data = data[each(isValid, data.values()).rowAnd()]
share(data, `historyData)
// Set window size for predicting the next volatility using 120 time points
window = 120
// Warm up the model to avoid high latency during the first inference call
warmupData = tail(objByName(`historyData), window)
warmupData.dropColumns!(`SecurityID`DateTime`LogReturn0_realizedVolatility)
input = tensor([matrix(warmupData)$FLOAT])
timer(10) out = LibTorch::predict(model, input)

// Subscribe to the snapshot stream table
subscribeTable(tableName="aggrFeatures10min", actionName="predictRV", offset=-1, handler=predictRV{model, window}, msgAsTable=true, batchSize=100, throttle=0.001, hash=1, reconnect=true)

// Replay historical data
testSnapshot = select * from loadTable("dfs://l2StockSHDB", "snapshot") where SecurityID=`600030 and date(DateTime)=mdDate
submitJob("replaySnapshot", "replay 1 day snapshot", replay{testSnapshot, SnapshotStream, `DateTime, `DateTime, 100, false, 1, , true})