# import shutil

# import matplotlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import rcParams

import error_printer
from stock_data_downloader import StockDataDownloader

error_printer.configure_icecream()


rcParams["font.family"] = "sans-serif"
# rcParams["font.sans-serif"] = ["SimHei"]  # 或其他支持中文的字型
rcParams["font.sans-serif"] = ["STHeiti", "PingFang"]


class MACDCalculator:
    def __init__(self, stock_data: pd.DataFrame):
        """
        初始化 MACDCalculator 類的實例。

        參數:
        stock_data (pd.DataFrame): 包含股票資料的 DataFrame。
        """
        self.stock_data = stock_data

    def calculate_macd(self) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
        """
        計算股票資料的 MACD 指標。

        返回:
        tuple[pd.Series, pd.Series, pd.Series, pd.Series]: 包含快線 EMA、慢線 EMA、MACD 線和訊號線的元組。
        """

        short_ema = self.stock_data["Close"].ewm(span=12, adjust=False).mean()  # 計算股票收盤價的快速指數移動平均線（EMA）
        long_ema = self.stock_data["Close"].ewm(span=26, adjust=False).mean()  # 計算股票收盤價的慢速指數移動平均線（EMA）
        macd = short_ema - long_ema  # 計算 MACD 線
        signal_line = macd.ewm(span=9, adjust=False).mean()  # 計算訊號線
        return short_ema, long_ema, macd, signal_line  # 返回短期EMA、長期EMA、MACD、訊號線

    def find_golden_crosses(self, macd: pd.Series, signal_line: pd.Series) -> np.ndarray:
        """
        找到 MACD 線和訊號線的黃金交叉點。

        參數:
        macd (pd.Series): MACD 線。
        signal_line (pd.Series): 訊號線。

        返回:
        np.ndarray: 包含黃金交叉點索引的 NumPy 陣列。
        """
        crosses = (macd.shift(1) < signal_line.shift(1)) & (macd > signal_line)
        cross_indices = np.where(crosses)[0]
        return cross_indices

    def find_death_crosses(self, macd: pd.Series, signal_line: pd.Series) -> np.ndarray:
        """
        找到 MACD 線和訊號線的死亡交叉點。

        參數:
        macd (pd.Series): MACD 線。
        signal_line (pd.Series): 訊號線。

        返回:
        np.ndarray: 包含死亡交叉點索引的 NumPy 陣列。
        """
        crosses = (macd.shift(1) > signal_line.shift(1)) & (macd < signal_line)
        cross_indices = np.where(crosses)[0]
        return cross_indices

    def get_golden_cross_dates(self, golden_crosses: np.ndarray) -> list:
        """
        獲取黃金交叉點的日期。

        參數:
        golden_crosses (np.ndarray): 包含黃金交叉點索引的 NumPy 陣列。

        返回:
        list: 包含黃金交叉點日期的列表。
        """
        cross_dates = []
        for cross_index in golden_crosses:
            cross_date = str(self.stock_data.index[cross_index])[:10]
            cross_dates.append(cross_date)
        return cross_dates

    def get_death_cross_dates(self, death_crosses: np.ndarray) -> list:
        """
        獲取死亡交叉點的日期。

        參數:
        death_crosses (np.ndarray): 包含死亡交叉點索引的 NumPy 陣列。

        返回:
        list: 包含死亡交叉點日期的列表。
        """
        cross_dates = []
        for cross_index in death_crosses:
            cross_date = str(self.stock_data.index[cross_index])[:10]
            cross_dates.append(cross_date)
        return cross_dates

    def plot_macd(self):
        # 從實例中獲取資料
        short_ema, long_ema, macd, signal_line = self.calculate_macd()

        # 定義顏色變數
        macd_color = (243 / 255, 158 / 255, 55 / 255)  # 橙色
        signal_line_color = (38 / 255, 128 / 255, 218 / 255)  # 藍色
        positive_bar_color = (247 / 255, 65 / 255, 84 / 255)  # 紅色
        negative_bar_color = (51 / 255, 184 / 255, 90 / 255)  # 綠色

        plt.figure(figsize=(10, 5))
        plt.plot(macd.index, macd, label="MACD線", color=macd_color)
        plt.plot(signal_line.index, signal_line, label="訊號線", color=signal_line_color)

        # 繪製 MACD 柱狀圖，使用顏色變數
        bar_colors = [positive_bar_color if v >= 0 else negative_bar_color for v in macd - signal_line]
        plt.bar(macd.index, macd - signal_line, color=bar_colors)

        plt.legend()
        plt.show()


# 檢查目前程式是否被作為主程式執行
if __name__ == "__main__":
    # 建立 StockDataDownloader 實例
    downloader = StockDataDownloader(months=1)

    # 下載股票資料
    stock_data = downloader.download_stock_data("00929.TW")

    # 建立 MACDCalculator 實例
    macd_calculator = MACDCalculator(stock_data)

    # 計算 MACD 指標
    short_ema, long_ema, macd, signal_line = macd_calculator.calculate_macd()

    # 找到黃金交叉點和死亡交叉點
    golden_crosses = macd_calculator.find_golden_crosses(macd, signal_line)
    death_crosses = macd_calculator.find_death_crosses(macd, signal_line)

    # 獲取黃金交叉點和死亡交叉點的日期
    golden_cross_dates = macd_calculator.get_golden_cross_dates(golden_crosses)
    death_cross_dates = macd_calculator.get_death_cross_dates(death_crosses)

    # 列印結果
    # print("短期EMA: ", short_ema)
    # print("長期EMA: ", long_ema)
    # print("MACD: ", macd)
    # print("訊號線: ", signal_line)
    # print("黃金交叉點: ", golden_crosses)
    # print("死亡交叉點: ", death_crosses)
    print("黃金交叉日期: ", golden_cross_dates)
    print("死亡交叉日期: ", death_cross_dates)

    macd_calculator.plot_macd()  # 繪製 MACD 圖表
