timmy revisó este gist 10 months ago. Ir a la revisión
Sin cambios
timmy revisó este gist 10 months ago. Ir a la revisión
1 file changed, 104 insertions
stock_data_downloader.py(archivo creado)
| @@ -0,0 +1,104 @@ | |||
| 1 | + | import unittest | |
| 2 | + | from datetime import datetime, timedelta | |
| 3 | + | from unittest.mock import patch | |
| 4 | + | ||
| 5 | + | import pandas as pd | |
| 6 | + | import yfinance as yf | |
| 7 | + | from dateutil.relativedelta import relativedelta | |
| 8 | + | ||
| 9 | + | import error_printer | |
| 10 | + | ||
| 11 | + | error_printer.configure_pretty_errors() | |
| 12 | + | error_printer.configure_icecream() | |
| 13 | + | ||
| 14 | + | ||
| 15 | + | class StockDataDownloader: | |
| 16 | + | def __init__(self, months: int = 1): | |
| 17 | + | """ | |
| 18 | + | 初始化 StockDataDownloader 類的實例。 | |
| 19 | + | ||
| 20 | + | 參數: | |
| 21 | + | months (int): 設定 start 日期為 end 日期之前的月數。預設值為 1。 | |
| 22 | + | """ | |
| 23 | + | self.months = months | |
| 24 | + | self.start_date, self.end_date = self.set_date_range() | |
| 25 | + | ||
| 26 | + | def set_date_range(self) -> tuple[datetime, datetime]: | |
| 27 | + | """ | |
| 28 | + | 設定日期範圍,用於下載股票資料。 | |
| 29 | + | ||
| 30 | + | 返回: | |
| 31 | + | tuple[datetime, datetime]: 包含 start 和 end 日期的元組。 | |
| 32 | + | """ | |
| 33 | + | end = datetime.now() | |
| 34 | + | start = end - relativedelta(months=self.months) | |
| 35 | + | print("Start date:", start) | |
| 36 | + | print("End date:", end) | |
| 37 | + | return start, end | |
| 38 | + | ||
| 39 | + | def download_stock_data(self, stock_symbol: str) -> pd.DataFrame: | |
| 40 | + | """ | |
| 41 | + | 下載指定股票程式碼在指定日期範圍內的股票資料。 | |
| 42 | + | ||
| 43 | + | 參數: | |
| 44 | + | stock_symbol (str): 要下載的股票程式碼。 | |
| 45 | + | ||
| 46 | + | 返回: | |
| 47 | + | pd.DataFrame: 包含股票資料的 DataFrame。 | |
| 48 | + | """ | |
| 49 | + | return yf.download(stock_symbol, start=self.start_date, end=self.end_date) | |
| 50 | + | ||
| 51 | + | ||
| 52 | + | class TestStockDataDownloader(unittest.TestCase): | |
| 53 | + | def test_initialization(self): | |
| 54 | + | """測試類初始化是否正確設定月份和日期範圍。""" | |
| 55 | + | downloader = StockDataDownloader(3) | |
| 56 | + | self.assertEqual(downloader.months, 3) | |
| 57 | + | ||
| 58 | + | @patch(__name__ + ".datetime") | |
| 59 | + | def test_date_range(self, mock_datetime): | |
| 60 | + | """測試日期範圍是否正確計算。""" | |
| 61 | + | # 設定模擬的現在時間 | |
| 62 | + | mock_now = datetime(2024, 5, 6, 9, 27, 48, 439831) | |
| 63 | + | mock_datetime.now.return_value = mock_now | |
| 64 | + | ||
| 65 | + | downloader = StockDataDownloader(2) | |
| 66 | + | expected_start_date = mock_now - timedelta(days=60) # 假設每月30天 | |
| 67 | + | expected_end_date = mock_now | |
| 68 | + | ||
| 69 | + | actual_start_date, actual_end_date = downloader.set_date_range() | |
| 70 | + | self.assertEqual(actual_start_date, expected_start_date) | |
| 71 | + | self.assertEqual(actual_end_date, expected_end_date) | |
| 72 | + | ||
| 73 | + | @patch("yfinance.download") | |
| 74 | + | def test_download_stock_data(self, mock_download): | |
| 75 | + | """測試是否正確下載股票資料。""" | |
| 76 | + | # 設定模擬的現在時間 | |
| 77 | + | mock_now = datetime(2024, 5, 6, 9, 27, 48, 439831) | |
| 78 | + | mock_start_date = mock_now - timedelta(days=30) | |
| 79 | + | mock_end_date = mock_now | |
| 80 | + | ||
| 81 | + | # 設定模擬的股票資料 | |
| 82 | + | mock_data = pd.DataFrame({"Date": pd.date_range(start=mock_start_date, end=mock_end_date, freq="D"), "Open": [100.0] * 31, "High": [101.0] * 31, "Low": [99.0] * 31, "Close": [100.0] * 31, "Volume": [1000] * 31}) | |
| 83 | + | mock_data.set_index("Date", inplace=True) | |
| 84 | + | mock_download.return_value = mock_data | |
| 85 | + | ||
| 86 | + | downloader = StockDataDownloader(1) | |
| 87 | + | stock_data = downloader.download_stock_data("00929.TW") | |
| 88 | + | ||
| 89 | + | self.assertTrue(stock_data.equals(mock_data)) | |
| 90 | + | ||
| 91 | + | ||
| 92 | + | # 檢查目前程式是否被作為主程式執行 | |
| 93 | + | if __name__ == "__main__": | |
| 94 | + | unittest.main() | |
| 95 | + | ||
| 96 | + | # 建立 StockDataDownloader 實例 | |
| 97 | + | # downloader = StockDataDownloader(months=1) | |
| 98 | + | ||
| 99 | + | # 下載股票資料 | |
| 100 | + | # stock_data = downloader.download_stock_data("00929.TW") | |
| 101 | + | ||
| 102 | + | # 進一步處理股票資料 | |
| 103 | + | # ... | |
| 104 | + | # ic(stock_data) | |
Siguiente
Anterior