import time
import tqdm
import numpy as np
import pandas as pd
from numba import jit, prange
from typing import Dict
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['KaiTi']
plt.rcParams['axes.unicode_minus'] = False
np.random.seed(42)

@jit(nopython=True, parallel=True)
def generate_path(s: float, r: float, t: float, sigma: float, n: int, days_in_year: int,
                  up_limit: float = None, down_limit: float = None) -> np.ndarray:
    dt = 1 / days_in_year
    tdays = int(t * days_in_year)

    # Initialize price matrix
    price_mat = np.empty((tdays + 1, n), dtype=np.float32)
    price_mat[0, :] = s           # Initial underlying price

    # Generate the time series
    time_mat = np.zeros((tdays + 1, n), dtype=np.float32)
    for i in prange(tdays + 1):
        time_mat[i, :] = i * dt

    # Generate normal distribution paths
    norm_mat = np.random.standard_normal((tdays + 1, n)).astype(np.float32)

    # Cumulative Brownian motion
    cum_brownian = np.zeros_like(norm_mat)
    for i in range(1, tdays + 1):
        cum_brownian[i, :] = cum_brownian[i - 1, :] + norm_mat[i, :]

    # Construct price paths
    drift = (r - 0.5 * sigma ** 2) * time_mat
    diffusion = sigma * np.sqrt(dt) * cum_brownian
    price_mat = s * np.exp(drift + diffusion)

    # Apply daily price limit constraints (up/down limits)
    if up_limit is not None or down_limit is not None:
        for i in prange(1, tdays + 1):
            for j in range(n):
                prev = price_mat[i - 1, j]
                curr = price_mat[i, j]

                if up_limit is not None:
                    upper = prev * (1 + up_limit)
                    if curr > upper:
                        price_mat[i, j] = upper
                if down_limit is not None:
                    lower = prev * (1 - down_limit)
                    if curr < lower:
                        price_mat[i, j] = lower

    return price_mat

def snowball_pricing(s: float, r: float, t: float, days_in_year: int, sigma: float, n: int,
                     k_in: float, k_out: float, lock_period: int, coupon: float,
                     up_limit: float = None, down_limit: float = None) -> Dict:
    """
    Compute statistical metrics of snowball payoff for simulated paths
    """

    tdays = t * days_in_year
    month_day = int(days_in_year/12)
    price_path = generate_path(s, r, t, sigma, n, days_in_year, up_limit, down_limit)

    # Generate observation times for knock-in checks
    observation_out_idx = np.arange(month_day, tdays + 1, month_day)
    observation_out_idx = observation_out_idx[observation_out_idx > lock_period] - 1

    # Determine whether knock-out occurs each month
    knock_out_matrix = price_path[observation_out_idx, :] >= k_out
    observation_out_idx = np.expand_dims(observation_out_idx, 0).repeat(price_path.shape[1], axis=0).T
    knock_out_days = knock_out_matrix * observation_out_idx
    knock_out_day = np.nanmin(np.where(knock_out_days > 0, knock_out_days, np.nan), axis=0)

    # Determine whether knock-in occurs
    knock_in_day = np.sum(price_path < k_in, axis=0)

    # Count occurrences of the three scenarios
    knock_out_times = np.sum(knock_out_day > 0)
    existence_times = np.sum(np.logical_and(np.logical_or(knock_out_day <= 0, np.isnan(knock_out_day)), knock_in_day == 0))
    knock_in_times = np.sum(np.logical_and(np.logical_or(knock_out_day <= 0, np.isnan(knock_out_day)), knock_in_day > 0))

    # Count number of losses for snowball option buyer
    lose_times = np.sum(np.logical_and(
        np.logical_and(
            np.logical_or(knock_out_day <= 0, np.isnan(knock_out_day)),
            knock_in_day > 0),
        price_path[-1, :] < s)
    )

    # Given the coupon (assuming dividend coupon = knock-out coupon), discount to compute payoff
    # Knock-out occurs
    payoff1 = coupon * (knock_out_day / days_in_year) * np.exp(-r * (knock_out_day / days_in_year))
    # No knock-in and no knock-out
    payoff2 = coupon * np.exp(-r * t)
    # Knock-in occurs but no knock-out
    payoff3 = (price_path[-1,:] - s) * np.exp(-r * t)
    payoff = np.where(knock_out_day >0, payoff1, np.where(knock_in_day==0, payoff2, np.where(payoff3<0, payoff3, 0)))

    return {
        "knock_out_times": knock_out_times,
        "knock_in_times": knock_in_times,
        "existence_times": existence_times,
        "lose_times": lose_times,
        "payoff": np.mean(payoff)
    }

if __name__ == "__main__":
    _n = 300000             # Number of simulation paths
    _s = 1.0                # Initial underlying price
    _k_in = 0.85            # Knock-in barrier
    _k_out = 1.03           # Knock-out barrier
    _t = 1                  # Maturity (years)
    _days_in_year = 252     # Number of trading days in a year
    _sigma = 0.13           # Annualized volatility of underlying
    _r = 0.03               # Annualized risk-free rate
    _coupon = 0.2           # Discount rate (annualized)
    _lock_period = 0        # Lock-up period (days)
    _up_limit = 0.1         # Daily upper limit for simulated price paths
    _down_limit = 0.1       # Daily lower limit for simulated price paths
    snowball_pricing(_s, _r, _t, _days_in_year, _sigma, _n, _k_in, _k_out, _lock_period, _coupon,
                     _up_limit, _down_limit)
    t0 = time.time()
    res_dict =  snowball_pricing(_s, _r, _t, _days_in_year, _sigma, _n, _k_in, _k_out, _lock_period, _coupon,
                                 _up_limit, _down_limit)
    t1 = time.time()
    print(res_dict)
    print("Elapsed time:", t1 - t0)

    n_list = range(10000, 500001, 10000)
    payoff_record = []
    time_record = []
    for _n in tqdm.tqdm(n_list, desc="Iterating..."):
        if _n < 30000:              # Warm-up run (ensure JIT compilation is completed)
            snowball_pricing(_s, _r, _t, _days_in_year, _sigma, _n, _k_in, _k_out, _lock_period, _coupon)
        t0 = time.time()
        res_dict = snowball_pricing(_s, _r, _t, _days_in_year, _sigma, _n, _k_in, _k_out, _lock_period, _coupon)
        t1 = time.time()
        payoff_record.append(res_dict["payoff"])
        time_record.append(1000*(t1 - t0))

    result = pd.DataFrame({"N":n_list,
                           "payoff": payoff_record,
                           "time": time_record,
                          })
    result.to_csv(r".\Result\result(Python_JIT).csv", index=False)
    result = pd.read_csv(r".\Result\result(Python_JIT).csv", index_col=None)

    # Create the figure and the primary axis
    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Primary Y-axis: plot time (in milliseconds)
    color = 'tab:blue'
    ax1.set_xlabel('N (Simulated Paths)')
    ax1.set_ylabel('Time (ms)', color=color)
    ax1.plot(result["N"], result["time"], color=color, label='Time (ms)', linewidth=2)
    ax1.tick_params(axis='y', labelcolor=color)

    # Secondary Y-axis: plot the payoff
    ax2 = ax1.twinx()        # Share the same X-axis
    color = 'tab:red'
    ax2.set_ylabel('Payoff', color=color)
    ax2.plot(result["N"], result["payoff"], color=color, linestyle='-', label='Payoff', linewidth=2)
    ax2.tick_params(axis='y', labelcolor=color)

    # Title and legend
    fig.tight_layout()
    fig.legend(loc="upper left", bbox_to_anchor=(0.1, 0.9), frameon=False)

    # Display the figure
    plt.savefig(r".\Figure\result(Python_JIT).png")
    plt.show()
