macd_calculator.py
· 5.6 KiB · Python
原始檔案
# import shutil
# import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import rcParams
import error_printer
from stock_data_downloader import StockDataDownloader
error_printer.configure_icecream()
rcParams["font.family"] = "sans-serif"
# rcParams["font.sans-serif"] = ["SimHei"] # 或其他支持中文的字型
rcParams["font.sans-serif"] = ["STHeiti", "PingFang"]
class MACDCalculator:
def __init__(self, stock_data: pd.DataFrame):
"""
初始化 MACDCalculator 類的實例。
參數:
stock_data (pd.DataFrame): 包含股票資料的 DataFrame。
"""
self.stock_data = stock_data
def calculate_macd(self) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
"""
計算股票資料的 MACD 指標。
返回:
tuple[pd.Series, pd.Series, pd.Series, pd.Series]: 包含快線 EMA、慢線 EMA、MACD 線和訊號線的元組。
"""
short_ema = self.stock_data["Close"].ewm(span=12, adjust=False).mean() # 計算股票收盤價的快速指數移動平均線(EMA)
long_ema = self.stock_data["Close"].ewm(span=26, adjust=False).mean() # 計算股票收盤價的慢速指數移動平均線(EMA)
macd = short_ema - long_ema # 計算 MACD 線
signal_line = macd.ewm(span=9, adjust=False).mean() # 計算訊號線
return short_ema, long_ema, macd, signal_line # 返回短期EMA、長期EMA、MACD、訊號線
def find_golden_crosses(self, macd: pd.Series, signal_line: pd.Series) -> np.ndarray:
"""
找到 MACD 線和訊號線的黃金交叉點。
參數:
macd (pd.Series): MACD 線。
signal_line (pd.Series): 訊號線。
返回:
np.ndarray: 包含黃金交叉點索引的 NumPy 陣列。
"""
crosses = (macd.shift(1) < signal_line.shift(1)) & (macd > signal_line)
cross_indices = np.where(crosses)[0]
return cross_indices
def find_death_crosses(self, macd: pd.Series, signal_line: pd.Series) -> np.ndarray:
"""
找到 MACD 線和訊號線的死亡交叉點。
參數:
macd (pd.Series): MACD 線。
signal_line (pd.Series): 訊號線。
返回:
np.ndarray: 包含死亡交叉點索引的 NumPy 陣列。
"""
crosses = (macd.shift(1) > signal_line.shift(1)) & (macd < signal_line)
cross_indices = np.where(crosses)[0]
return cross_indices
def get_golden_cross_dates(self, golden_crosses: np.ndarray) -> list:
"""
獲取黃金交叉點的日期。
參數:
golden_crosses (np.ndarray): 包含黃金交叉點索引的 NumPy 陣列。
返回:
list: 包含黃金交叉點日期的列表。
"""
cross_dates = []
for cross_index in golden_crosses:
cross_date = str(self.stock_data.index[cross_index])[:10]
cross_dates.append(cross_date)
return cross_dates
def get_death_cross_dates(self, death_crosses: np.ndarray) -> list:
"""
獲取死亡交叉點的日期。
參數:
death_crosses (np.ndarray): 包含死亡交叉點索引的 NumPy 陣列。
返回:
list: 包含死亡交叉點日期的列表。
"""
cross_dates = []
for cross_index in death_crosses:
cross_date = str(self.stock_data.index[cross_index])[:10]
cross_dates.append(cross_date)
return cross_dates
def plot_macd(self):
# 從實例中獲取資料
short_ema, long_ema, macd, signal_line = self.calculate_macd()
# 定義顏色變數
macd_color = (243 / 255, 158 / 255, 55 / 255) # 橙色
signal_line_color = (38 / 255, 128 / 255, 218 / 255) # 藍色
positive_bar_color = (247 / 255, 65 / 255, 84 / 255) # 紅色
negative_bar_color = (51 / 255, 184 / 255, 90 / 255) # 綠色
plt.figure(figsize=(10, 5))
plt.plot(macd.index, macd, label="MACD線", color=macd_color)
plt.plot(signal_line.index, signal_line, label="訊號線", color=signal_line_color)
# 繪製 MACD 柱狀圖,使用顏色變數
bar_colors = [positive_bar_color if v >= 0 else negative_bar_color for v in macd - signal_line]
plt.bar(macd.index, macd - signal_line, color=bar_colors)
plt.legend()
plt.show()
# 檢查目前程式是否被作為主程式執行
if __name__ == "__main__":
# 建立 StockDataDownloader 實例
downloader = StockDataDownloader(months=1)
# 下載股票資料
stock_data = downloader.download_stock_data("00929.TW")
# 建立 MACDCalculator 實例
macd_calculator = MACDCalculator(stock_data)
# 計算 MACD 指標
short_ema, long_ema, macd, signal_line = macd_calculator.calculate_macd()
# 找到黃金交叉點和死亡交叉點
golden_crosses = macd_calculator.find_golden_crosses(macd, signal_line)
death_crosses = macd_calculator.find_death_crosses(macd, signal_line)
# 獲取黃金交叉點和死亡交叉點的日期
golden_cross_dates = macd_calculator.get_golden_cross_dates(golden_crosses)
death_cross_dates = macd_calculator.get_death_cross_dates(death_crosses)
# 列印結果
# print("短期EMA: ", short_ema)
# print("長期EMA: ", long_ema)
# print("MACD: ", macd)
# print("訊號線: ", signal_line)
# print("黃金交叉點: ", golden_crosses)
# print("死亡交叉點: ", death_crosses)
print("黃金交叉日期: ", golden_cross_dates)
print("死亡交叉日期: ", death_cross_dates)
macd_calculator.plot_macd() # 繪製 MACD 圖表
| 1 | # import shutil |
| 2 | |
| 3 | # import matplotlib |
| 4 | |
| 5 | import matplotlib.pyplot as plt |
| 6 | import numpy as np |
| 7 | import pandas as pd |
| 8 | from matplotlib import rcParams |
| 9 | |
| 10 | import error_printer |
| 11 | from stock_data_downloader import StockDataDownloader |
| 12 | |
| 13 | error_printer.configure_icecream() |
| 14 | |
| 15 | |
| 16 | rcParams["font.family"] = "sans-serif" |
| 17 | # rcParams["font.sans-serif"] = ["SimHei"] # 或其他支持中文的字型 |
| 18 | rcParams["font.sans-serif"] = ["STHeiti", "PingFang"] |
| 19 | |
| 20 | |
| 21 | class MACDCalculator: |
| 22 | def __init__(self, stock_data: pd.DataFrame): |
| 23 | """ |
| 24 | 初始化 MACDCalculator 類的實例。 |
| 25 | |
| 26 | 參數: |
| 27 | stock_data (pd.DataFrame): 包含股票資料的 DataFrame。 |
| 28 | """ |
| 29 | self.stock_data = stock_data |
| 30 | |
| 31 | def calculate_macd(self) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]: |
| 32 | """ |
| 33 | 計算股票資料的 MACD 指標。 |
| 34 | |
| 35 | 返回: |
| 36 | tuple[pd.Series, pd.Series, pd.Series, pd.Series]: 包含快線 EMA、慢線 EMA、MACD 線和訊號線的元組。 |
| 37 | """ |
| 38 | |
| 39 | short_ema = self.stock_data["Close"].ewm(span=12, adjust=False).mean() # 計算股票收盤價的快速指數移動平均線(EMA) |
| 40 | long_ema = self.stock_data["Close"].ewm(span=26, adjust=False).mean() # 計算股票收盤價的慢速指數移動平均線(EMA) |
| 41 | macd = short_ema - long_ema # 計算 MACD 線 |
| 42 | signal_line = macd.ewm(span=9, adjust=False).mean() # 計算訊號線 |
| 43 | return short_ema, long_ema, macd, signal_line # 返回短期EMA、長期EMA、MACD、訊號線 |
| 44 | |
| 45 | def find_golden_crosses(self, macd: pd.Series, signal_line: pd.Series) -> np.ndarray: |
| 46 | """ |
| 47 | 找到 MACD 線和訊號線的黃金交叉點。 |
| 48 | |
| 49 | 參數: |
| 50 | macd (pd.Series): MACD 線。 |
| 51 | signal_line (pd.Series): 訊號線。 |
| 52 | |
| 53 | 返回: |
| 54 | np.ndarray: 包含黃金交叉點索引的 NumPy 陣列。 |
| 55 | """ |
| 56 | crosses = (macd.shift(1) < signal_line.shift(1)) & (macd > signal_line) |
| 57 | cross_indices = np.where(crosses)[0] |
| 58 | return cross_indices |
| 59 | |
| 60 | def find_death_crosses(self, macd: pd.Series, signal_line: pd.Series) -> np.ndarray: |
| 61 | """ |
| 62 | 找到 MACD 線和訊號線的死亡交叉點。 |
| 63 | |
| 64 | 參數: |
| 65 | macd (pd.Series): MACD 線。 |
| 66 | signal_line (pd.Series): 訊號線。 |
| 67 | |
| 68 | 返回: |
| 69 | np.ndarray: 包含死亡交叉點索引的 NumPy 陣列。 |
| 70 | """ |
| 71 | crosses = (macd.shift(1) > signal_line.shift(1)) & (macd < signal_line) |
| 72 | cross_indices = np.where(crosses)[0] |
| 73 | return cross_indices |
| 74 | |
| 75 | def get_golden_cross_dates(self, golden_crosses: np.ndarray) -> list: |
| 76 | """ |
| 77 | 獲取黃金交叉點的日期。 |
| 78 | |
| 79 | 參數: |
| 80 | golden_crosses (np.ndarray): 包含黃金交叉點索引的 NumPy 陣列。 |
| 81 | |
| 82 | 返回: |
| 83 | list: 包含黃金交叉點日期的列表。 |
| 84 | """ |
| 85 | cross_dates = [] |
| 86 | for cross_index in golden_crosses: |
| 87 | cross_date = str(self.stock_data.index[cross_index])[:10] |
| 88 | cross_dates.append(cross_date) |
| 89 | return cross_dates |
| 90 | |
| 91 | def get_death_cross_dates(self, death_crosses: np.ndarray) -> list: |
| 92 | """ |
| 93 | 獲取死亡交叉點的日期。 |
| 94 | |
| 95 | 參數: |
| 96 | death_crosses (np.ndarray): 包含死亡交叉點索引的 NumPy 陣列。 |
| 97 | |
| 98 | 返回: |
| 99 | list: 包含死亡交叉點日期的列表。 |
| 100 | """ |
| 101 | cross_dates = [] |
| 102 | for cross_index in death_crosses: |
| 103 | cross_date = str(self.stock_data.index[cross_index])[:10] |
| 104 | cross_dates.append(cross_date) |
| 105 | return cross_dates |
| 106 | |
| 107 | def plot_macd(self): |
| 108 | # 從實例中獲取資料 |
| 109 | short_ema, long_ema, macd, signal_line = self.calculate_macd() |
| 110 | |
| 111 | # 定義顏色變數 |
| 112 | macd_color = (243 / 255, 158 / 255, 55 / 255) # 橙色 |
| 113 | signal_line_color = (38 / 255, 128 / 255, 218 / 255) # 藍色 |
| 114 | positive_bar_color = (247 / 255, 65 / 255, 84 / 255) # 紅色 |
| 115 | negative_bar_color = (51 / 255, 184 / 255, 90 / 255) # 綠色 |
| 116 | |
| 117 | plt.figure(figsize=(10, 5)) |
| 118 | plt.plot(macd.index, macd, label="MACD線", color=macd_color) |
| 119 | plt.plot(signal_line.index, signal_line, label="訊號線", color=signal_line_color) |
| 120 | |
| 121 | # 繪製 MACD 柱狀圖,使用顏色變數 |
| 122 | bar_colors = [positive_bar_color if v >= 0 else negative_bar_color for v in macd - signal_line] |
| 123 | plt.bar(macd.index, macd - signal_line, color=bar_colors) |
| 124 | |
| 125 | plt.legend() |
| 126 | plt.show() |
| 127 | |
| 128 | |
| 129 | # 檢查目前程式是否被作為主程式執行 |
| 130 | if __name__ == "__main__": |
| 131 | # 建立 StockDataDownloader 實例 |
| 132 | downloader = StockDataDownloader(months=1) |
| 133 | |
| 134 | # 下載股票資料 |
| 135 | stock_data = downloader.download_stock_data("00929.TW") |
| 136 | |
| 137 | # 建立 MACDCalculator 實例 |
| 138 | macd_calculator = MACDCalculator(stock_data) |
| 139 | |
| 140 | # 計算 MACD 指標 |
| 141 | short_ema, long_ema, macd, signal_line = macd_calculator.calculate_macd() |
| 142 | |
| 143 | # 找到黃金交叉點和死亡交叉點 |
| 144 | golden_crosses = macd_calculator.find_golden_crosses(macd, signal_line) |
| 145 | death_crosses = macd_calculator.find_death_crosses(macd, signal_line) |
| 146 | |
| 147 | # 獲取黃金交叉點和死亡交叉點的日期 |
| 148 | golden_cross_dates = macd_calculator.get_golden_cross_dates(golden_crosses) |
| 149 | death_cross_dates = macd_calculator.get_death_cross_dates(death_crosses) |
| 150 | |
| 151 | # 列印結果 |
| 152 | # print("短期EMA: ", short_ema) |
| 153 | # print("長期EMA: ", long_ema) |
| 154 | # print("MACD: ", macd) |
| 155 | # print("訊號線: ", signal_line) |
| 156 | # print("黃金交叉點: ", golden_crosses) |
| 157 | # print("死亡交叉點: ", death_crosses) |
| 158 | print("黃金交叉日期: ", golden_cross_dates) |
| 159 | print("死亡交叉日期: ", death_cross_dates) |
| 160 | |
| 161 | macd_calculator.plot_macd() # 繪製 MACD 圖表 |
| 162 |