r"""Contains extension classes."""
"""Copyright (C) 2023 Edward West. All rights reserved.
This code is licensed under Apache 2.0 with Commons Clause license
(see LICENSE for details).
"""
from datetime import datetime
from typing import Optional
import akshare
import pandas as pd
from yahooquery import Ticker
from pybroker.common import DataCol, to_datetime
from pybroker.data import DataSource
[文档]
class AKShare(DataSource):
r"""Retrieves data from `AKShare <https://akshare.akfamily.xyz/>`_."""
_tf_to_period = {
"": "daily",
"1day": "daily",
"1week": "weekly",
}
def _fetch_data(
self,
symbols: frozenset[str],
start_date: datetime,
end_date: datetime,
timeframe: Optional[str],
adjust: Optional[str],
) -> pd.DataFrame:
""":meta private:"""
start_date_str = to_datetime(start_date).strftime("%Y%m%d")
end_date_str = to_datetime(end_date).strftime("%Y%m%d")
symbols_list = list(symbols)
symbols_simple = [item.split(".")[0] for item in symbols_list]
result = pd.DataFrame()
formatted_tf = self._format_timeframe(timeframe)
if formatted_tf in AKShare._tf_to_period:
period = AKShare._tf_to_period[formatted_tf]
for i in range(len(symbols_list)):
temp_df = akshare.stock_zh_a_hist(
symbol=symbols_simple[i],
start_date=start_date_str,
end_date=end_date_str,
period=period,
adjust=adjust if adjust is not None else "",
)
if not temp_df.columns.empty:
temp_df["symbol"] = symbols_list[i]
result = pd.concat([result, temp_df], ignore_index=True)
if result.columns.empty:
return pd.DataFrame(
columns=[
DataCol.SYMBOL.value,
DataCol.DATE.value,
DataCol.OPEN.value,
DataCol.HIGH.value,
DataCol.LOW.value,
DataCol.CLOSE.value,
DataCol.VOLUME.value,
]
)
if result.empty:
return result
result.rename(
columns={
"日期": DataCol.DATE.value,
"开盘": DataCol.OPEN.value,
"收盘": DataCol.CLOSE.value,
"最高": DataCol.HIGH.value,
"最低": DataCol.LOW.value,
"成交量": DataCol.VOLUME.value,
},
inplace=True,
)
result["date"] = pd.to_datetime(result["date"])
result = result[
[
DataCol.DATE.value,
DataCol.SYMBOL.value,
DataCol.OPEN.value,
DataCol.HIGH.value,
DataCol.LOW.value,
DataCol.CLOSE.value,
DataCol.VOLUME.value,
]
]
return result
[文档]
class YQuery(DataSource):
r"""Retrieves data from Yahoo Finance using
`Yahooquery <https://github.com/dpguthrie/yahooquery>`_\ ."""
_tf_to_period = {
"": "1d",
"1hour": "1h",
"1day": "1d",
"5day": "5d",
"1week": "1wk",
}
def __init__(self, proxies: Optional[dict] = None):
super().__init__()
self.proxies = proxies
def _fetch_data(
self,
symbols: frozenset[str],
start_date: datetime,
end_date: datetime,
timeframe: Optional[str],
adjust: Optional[bool],
) -> pd.DataFrame:
""":meta private:"""
show_yf_progress_bar = (
not self._logger._disabled
and not self._logger._progress_bar_disabled
)
ticker = Ticker(
symbols,
asynchronous=True,
progress=show_yf_progress_bar,
proxies=self.proxies,
)
timeframe = self._format_timeframe(timeframe)
if timeframe not in self._tf_to_period:
raise ValueError(
f"Unsupported timeframe: '{timeframe}'.\n"
f"Supported timeframes: {list(self._tf_to_period.keys())}."
)
df = ticker.history(
start=start_date,
end=end_date,
interval=self._tf_to_period[timeframe],
adj_ohlc=adjust,
)
if df.columns.empty:
return pd.DataFrame(
columns=[
DataCol.SYMBOL.value,
DataCol.DATE.value,
DataCol.OPEN.value,
DataCol.HIGH.value,
DataCol.LOW.value,
DataCol.CLOSE.value,
DataCol.VOLUME.value,
]
)
if df.empty:
return df
df = df.reset_index()
df[DataCol.DATE.value] = pd.to_datetime(df[DataCol.DATE.value])
df = df[
[
DataCol.SYMBOL.value,
DataCol.DATE.value,
DataCol.OPEN.value,
DataCol.HIGH.value,
DataCol.LOW.value,
DataCol.CLOSE.value,
DataCol.VOLUME.value,
]
]
return df