"""Contains scopes that store data and object references used to execute a
:class:`pybroker.strategy.Strategy`.
"""
"""Copyright (C) 2023 Edward West. All rights reserved.
This code is licensed under Apache 2.0 with Commons Clause license
(see LICENSE for details).
"""
import numpy as np
import pandas as pd
from pybroker.common import (
BarData,
DataCol,
IndicatorSymbol,
ModelSymbol,
PriceType,
TrainedModel,
to_decimal,
)
from pybroker.log import Logger
from collections import defaultdict
from decimal import Decimal
from diskcache import Cache
from numpy.typing import NDArray
from typing import (
Any,
Callable,
Final,
Iterable,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Union,
)
_EMPTY_PARAM: Final = object()
[docs]
class StaticScope:
"""A static registry of data and object references.
Attributes:
logger: :class:`pybroker.log.Logger`
data_source_cache: :class:`diskcache.Cache` that stores data retrieved
from :class:`pybroker.data.DataSource`.
data_source_cache_ns: Namespace set for :attr:`.data_source_cache`.
indicator_cache: :class:`diskcache.Cache` that stores
:class:`pybroker.indicator.Indicator` data.
indicator_cache_ns: Namespace set for :attr:`.indicator_cache`.
model_cache: :class:`diskcache.Cache` that stores trained models.
model_cache_ns: Namespace set for :attr:`.model_cache`.
default_data_cols: Default data columns in :class:`pandas.DataFrame`
retrieved from a :class:`pybroker.data.DataSource`.
custom_data_cols: User-defined data columns in
:class:`pandas.DataFrame` retrieved from a
:class:`pybroker.data.DataSource`.
"""
__instance = None
def __init__(self):
self.logger = Logger(self)
self.data_source_cache: Optional[Cache] = None
self.data_source_cache_ns: str = ""
self.indicator_cache: Optional[Cache] = None
self.indicator_cache_ns: str = ""
self.model_cache: Optional[Cache] = None
self.model_cache_ns: str = ""
self._indicators = {}
self._model_sources = {}
self.default_data_cols = frozenset(
(
DataCol.DATE.value,
DataCol.OPEN.value,
DataCol.HIGH.value,
DataCol.LOW.value,
DataCol.CLOSE.value,
DataCol.VOLUME.value,
DataCol.VWAP.value,
)
)
self.custom_data_cols = set()
self._cols_frozen: bool = False
self._params: dict[str, Any] = {}
[docs]
def set_indicator(self, indicator):
"""Stores :class:`pybroker.indicator.Indicator` in static scope."""
self._indicators[indicator.name] = indicator
[docs]
def has_indicator(self, name: str) -> bool:
"""Whether :class:`pybroker.indicator.Indicator` is stored in static
scope.
"""
return name in self._indicators
[docs]
def get_indicator(self, name: str):
"""Retrieves a :class:`pybroker.indicator.Indicator` from static
scope."""
if not self.has_indicator(name):
raise ValueError(f"Indicator {name!r} does not exist.")
return self._indicators[name]
[docs]
def get_indicator_names(self, model_name: str) -> tuple[str]:
"""Returns a ``tuple[str]`` of all
:class:`pybroker.indicator.Indicator` names that are registered with
:class:`pybroker.model.ModelSource` having ``model_name``.
"""
return self._model_sources[model_name].indicators
[docs]
def set_model_source(self, source):
"""Stores :class:`pybroker.model.ModelSource` in static scope."""
self._model_sources[source.name] = source
[docs]
def has_model_source(self, name: str) -> bool:
"""Whether :class:`pybroker.model.ModelSource` is stored in static
scope.
"""
return name in self._model_sources
[docs]
def get_model_source(self, name: str):
"""Retrieves a :class:`pybroker.model.ModelSource` from static
scope.
"""
if not self.has_model_source(name):
raise ValueError(f"ModelSource {name!r} does not exist.")
return self._model_sources[name]
[docs]
def register_custom_cols(self, names: Union[str, Iterable[str]], *args):
"""Registers user-defined column names."""
self._verify_unfrozen_cols()
if isinstance(names, str):
names = (names, *args)
else:
names = (*names, *args)
names = filter(lambda col: col not in self.default_data_cols, names)
self.custom_data_cols.update(names)
[docs]
def unregister_custom_cols(self, names: Union[str, Iterable[str]], *args):
"""Unregisters user-defined column names."""
self._verify_unfrozen_cols()
if isinstance(names, str):
names = (names, *args)
else:
names = (*names, *args)
self.custom_data_cols.difference_update(names)
@property
def all_data_cols(self) -> frozenset[str]:
"""All registered data column names."""
return self.default_data_cols | self.custom_data_cols
def _verify_unfrozen_cols(self):
if self._cols_frozen:
raise ValueError("Cannot modify columns when strategy is running.")
[docs]
def freeze_data_cols(self):
"""Prevents additional data columns from being registered."""
self._cols_frozen = True
[docs]
def unfreeze_data_cols(self):
"""Allows additional data columns to be registered if
:func:`pybroker.scope.StaticScope.freeze_data_cols` was called.
"""
self._cols_frozen = False
[docs]
def param(
self, name: str, value: Optional[Any] = _EMPTY_PARAM
) -> Optional[Any]:
"""Get or set a global parameter."""
if value is _EMPTY_PARAM:
return self._params.get(name, None)
self._params[name] = value
return value
[docs]
@classmethod
def instance(cls) -> "StaticScope":
"""Returns singleton instance."""
if cls.__instance is None:
cls.__instance = StaticScope()
return cls.__instance
[docs]
def disable_logging():
"""Disables event logging."""
StaticScope.instance().logger.disable()
[docs]
def enable_logging():
"""Enables event logging."""
StaticScope.instance().logger.enable()
[docs]
def disable_progress_bar():
"""Disables logging a progress bar."""
StaticScope.instance().logger.disable_progress_bar()
[docs]
def enable_progress_bar():
"""Enables logging a progress bar."""
StaticScope.instance().logger.enable_progress_bar()
[docs]
def register_columns(names: Union[str, Iterable[str]], *args):
"""Registers ``names`` of user-defined data columns."""
StaticScope.instance().register_custom_cols(names, *args)
[docs]
def unregister_columns(names: Union[str, Iterable[str]], *args):
"""Unregisters ``names`` of user-defined data columns."""
StaticScope.instance().unregister_custom_cols(names, *args)
[docs]
def param(name: str, value: Optional[Any] = _EMPTY_PARAM) -> Optional[Any]:
"""Get or set a global parameter."""
return StaticScope.instance().param(name, value)
[docs]
class ColumnScope:
"""Caches and retrieves column data queried from :class:`pandas.DataFrame`.
Args:
df: :class:`pandas.DataFrame` containing the column data.
"""
def __init__(self, df: pd.DataFrame):
self._df = df.sort_index()
self._symbols = frozenset(df.index.get_level_values(0).unique())
self._sym_cols: dict[str, dict[str, Optional[NDArray]]] = defaultdict(
dict
)
[docs]
def fetch_dict(
self,
symbol: str,
names: Iterable[str],
end_index: Optional[int] = None,
) -> dict[str, Optional[NDArray]]:
r"""Fetches a ``dict`` of column data for ``symbol``.
Args:
symbol: Ticker symbol to query.
names: Names of columns to query.
end_index: Truncates column values (exclusive). If ``None``, then
column values are not truncated.
Returns:
``dict`` mapping column names to :class:`numpy.ndarray`\ s of
column values.
"""
result: dict[str, Optional[NDArray]] = {}
if not names:
return result
sym_dfs: dict[str, pd.DataFrame] = {}
for name in names:
if symbol in self._sym_cols and name in self._sym_cols[symbol]:
result[name] = self._sym_cols[symbol][name]
if result[name] is not None:
result[name] = result[name][:end_index] # type: ignore[index]
continue
if symbol in sym_dfs:
sym_df = sym_dfs[symbol]
else:
if symbol not in self._symbols:
raise ValueError(f"Symbol not found: {symbol}.")
sym_df = self._df.loc[pd.IndexSlice[symbol, :]].reset_index()
sym_dfs[symbol] = sym_df
if name not in sym_df.columns:
self._sym_cols[symbol][name] = None
result[name] = None
continue
array = sym_df[name].to_numpy()
self._sym_cols[symbol][name] = array
result[name] = array[:end_index]
return result
[docs]
def fetch(
self, symbol: str, name: str, end_index: Optional[int] = None
) -> Optional[NDArray]:
"""Fetches a :class:`numpy.ndarray` of column data for ``symbol``.
Args:
symbol: Ticker symbol to query.
name: Name of column to query.
end_index: Truncates column values (exclusive). If ``None``, then
column values are not truncated.
Returns:
:class:`numpy.ndarray` of column data for every bar until
``end_index`` (when specified).
"""
result = self.fetch_dict(symbol, (name,), end_index)
return result.get(name, None)
[docs]
def bar_data_from_data_columns(
self, symbol: str, end_index: int
) -> BarData:
"""Returns a new :class:`pybroker.common.BarData` instance containing
column data of default and custom data columns registered with
:class:`.StaticScope`.
Args:
symbol: Ticker symbol to query.
end_index: Truncates column values (exclusive). If ``None``, then
column values are not truncated.
"""
static_scope = StaticScope.instance()
default_col_data = self.fetch_dict(
symbol, static_scope.default_data_cols, end_index
)
custom_col_data = self.fetch_dict(
symbol, static_scope.custom_data_cols, end_index
)
return BarData(
**default_col_data, # type: ignore[arg-type]
**custom_col_data, # type: ignore[arg-type]
)
[docs]
class IndicatorScope:
"""Caches and retrieves :class:`pybroker.indicator.Indicator` data.
Args:
indicator_data: :class:`Mapping` of
:class:`pybroker.common.IndicatorSymbol` pairs to ``pandas.Series``
of :class:`pybroker.indicator.Indicator` values.
filter_dates: Filters :class:`pybroker.indicator.Indicator` data on
:class:`Sequence` of dates.
"""
def __init__(
self,
indicator_data: Mapping[IndicatorSymbol, pd.Series],
filter_dates: Sequence[np.datetime64],
):
self._indicator_data = indicator_data
self._filter_dates = filter_dates
self._sym_inds: dict[IndicatorSymbol, NDArray[np.float64]] = {}
[docs]
def fetch(
self, symbol: str, name: str, end_index: Optional[int] = None
) -> NDArray[np.float64]:
"""Fetches :class:`pybroker.indicator.Indicator` data.
Args:
symbol: Ticker symbol to query.
name: Name of :class:`pybroker.indicator.Indicator` to query.
end_index: Truncates the array of
:class:`pybroker.indicator.Indicator` data returned
(exclusive). If ``None``, then indicator data is not truncated.
Returns:
:class:`numpy.ndarray` of :class:`pybroker.indicator.Indicator`
data for every bar until ``end_index`` (when specified).
"""
ind_sym = IndicatorSymbol(name, symbol)
if ind_sym in self._sym_inds:
return self._sym_inds[ind_sym][:end_index]
if ind_sym not in self._indicator_data:
raise ValueError(f"Indicator {name!r} not found for {symbol}.")
ind_series = self._indicator_data[ind_sym]
ind_data = ind_series[ind_series.index.isin(self._filter_dates)].values
self._sym_inds[ind_sym] = ind_data
return ind_data[:end_index]
[docs]
class PredictionScope:
r"""Caches and retrieves model predictions.
Args:
models: :class:`Mapping` of
:class:`pybroker.common.ModelSymbol` pairs to
:class:`pybroker.common.TrainedModel`\ s.
input_scope: :class:`.ModelInputScope`.
"""
def __init__(
self,
models: Mapping[ModelSymbol, TrainedModel],
input_scope: ModelInputScope,
):
self._models = models
self._input_scope = input_scope
self._sym_preds: dict[ModelSymbol, NDArray] = {}
[docs]
def fetch(
self, symbol: str, name: str, end_index: Optional[int] = None
) -> NDArray:
"""Fetches model predictions.
Args:
symbol: Ticker symbol to query.
name: Name of :class:`pybroker.model.ModelSource` that made the
predictions.
end_index: Truncates the array of predictions returned (exclusive).
If ``None``, then predictions are not truncated.
Returns:
:class:`numpy.ndarray` of model predictions for every bar until
``end_index`` (when specified).
"""
model_sym = ModelSymbol(name, symbol)
if model_sym in self._sym_preds:
return self._sym_preds[model_sym][:end_index]
input_ = self._input_scope.fetch(symbol, name)
if input_.empty:
raise ValueError(
f"No input data found for model {name!r}. Consider "
"passing input_data_fn to pybroker#model() if custom columns "
"were registered."
)
if model_sym not in self._models:
raise ValueError(f"Model {name!r} not found for {symbol}.")
trained_model = self._models[model_sym]
if trained_model.predict_fn is not None:
pred = trained_model.predict_fn(trained_model.instance, input_)
else:
predict_fn = getattr(trained_model.instance, "predict", None)
if predict_fn is not None and callable(predict_fn):
pred = trained_model.instance.predict(input_)
else:
raise ValueError(
f"Model instance trained for {model_sym.model_name!r} "
"does not define a predict function. Please pass a "
"predict_fn to pybroker.model()."
)
if len(pred.shape) > 1:
pred = np.squeeze(pred)
self._sym_preds[model_sym] = pred
return pred[:end_index]
[docs]
class PriceScope:
"""Retrieves most recent prices."""
def __init__(
self,
col_scope: ColumnScope,
sym_end_index: Mapping[str, int],
round_fill_price: bool,
):
self._col_scope = col_scope
self._sym_end_index = sym_end_index
self._round_fill_price = round_fill_price
[docs]
def fetch(
self,
symbol: str,
price: Union[
int,
float,
np.floating,
Decimal,
PriceType,
Callable[[str, BarData], Union[int, float, Decimal]],
],
) -> Decimal:
end_index = self._sym_end_index[symbol]
price_type = type(price)
fill_price = None
if price_type == PriceType:
if price == PriceType.OPEN:
open_ = self._col_scope.fetch(
symbol, DataCol.OPEN.value, end_index
)
if open_ is None:
raise ValueError("Open price not found.")
fill_price = open_[-1]
elif price == PriceType.HIGH:
high = self._col_scope.fetch(
symbol, DataCol.HIGH.value, end_index
)
if high is None:
raise ValueError("High price not found.")
fill_price = high[-1]
elif price == PriceType.LOW:
low = self._col_scope.fetch(
symbol, DataCol.LOW.value, end_index
)
if low is None:
raise ValueError("Low price not found.")
fill_price = low[-1]
elif price == PriceType.CLOSE:
close = self._col_scope.fetch(
symbol, DataCol.CLOSE.value, end_index
)
if close is None:
raise ValueError("Close price not found.")
fill_price = close[-1]
elif price == PriceType.MIDDLE:
low = self._col_scope.fetch(
symbol, DataCol.LOW.value, end_index
)
if low is None:
raise ValueError("Low price not found.")
high = self._col_scope.fetch(
symbol, DataCol.HIGH.value, end_index
)
if high is None:
raise ValueError("High price not found.")
fill_price = low[-1] + (high[-1] - low[-1]) / 2.0
elif price == PriceType.AVERAGE:
open_ = self._col_scope.fetch(
symbol, DataCol.OPEN.value, end_index
)
if open_ is None:
raise ValueError("Open price not found.")
high = self._col_scope.fetch(
symbol, DataCol.HIGH.value, end_index
)
if high is None:
raise ValueError("High price not found.")
low = self._col_scope.fetch(
symbol, DataCol.LOW.value, end_index
)
if low is None:
raise ValueError("Low price not found.")
close = self._col_scope.fetch(
symbol, DataCol.CLOSE.value, end_index
)
if close is None:
raise ValueError("Close price not found.")
fill_price = (open_[-1] + low[-1] + high[-1] + close[-1]) / 4.0
else:
raise ValueError(f"Unknown price: {price_type}")
elif (
price_type is float
or price_type is int
or isinstance(price, np.floating)
or isinstance(price, Decimal)
):
fill_price = price
elif callable(price):
bar_data = self._col_scope.bar_data_from_data_columns(
symbol, self._sym_end_index[symbol]
)
fill_price = price(symbol, bar_data)
else:
raise ValueError(f"Unknown price: {price_type}")
if self._round_fill_price:
fill_price = round(fill_price, 2)
return to_decimal(fill_price)
[docs]
class PendingOrder(NamedTuple):
"""Holds data for a pending order.
Attributes:
id: Unique ID.
type: Type of order, either ``buy`` or ``sell``.
symbol: Ticker symbol of the order.
created: Date the order was created.
exec_date: Date the order will be executed.
shares: Number of shares to be bought or sold.
limit_price: Limit price to use for the order.
fill_price: Price that the order will be filled at.
"""
id: int
type: Literal["buy", "sell"]
symbol: str
created: np.datetime64
exec_date: np.datetime64
shares: Decimal
limit_price: Optional[Decimal]
fill_price: Union[
int,
float,
np.floating,
Decimal,
PriceType,
Callable[[str, BarData], Union[int, float, Decimal]],
]
[docs]
class PendingOrderScope:
r"""Stores :class:`.PendingOrder`\ s"""
_order_id: int = 0
def __init__(self):
self._orders: dict[int, PendingOrder] = {}
self._sym_orders: dict[str, set[PendingOrder]] = defaultdict(set)
[docs]
def contains(self, order_id: int) -> bool:
"""Returns whether a :class:`.PendingOrder` exists with
``order_id``.
"""
return order_id in self._orders
[docs]
def add(
self,
type: Literal["buy", "sell"],
symbol: str,
created: np.datetime64,
exec_date: np.datetime64,
shares: Decimal,
limit_price: Optional[Decimal],
fill_price: Union[
int,
float,
np.floating,
Decimal,
PriceType,
Callable[[str, BarData], Union[int, float, Decimal]],
],
) -> int:
"""Creates a :class:`.PendingOrder`.
Args:
type: Type of order, either ``buy`` or ``sell``.
symbol: Ticker symbol of the order.
created: Date the order was created.
exec_date: Date the order will be executed.
shares: Number of shares to be bought or sold.
limit_price: Limit price to use for the order.
fill_price: Price that the order will be filled at.
Returns:
ID of the :class:`.PendingOrder`.
"""
self._order_id += 1
order = PendingOrder(
id=self._order_id,
type=type,
symbol=symbol,
created=created,
exec_date=exec_date,
shares=shares,
limit_price=limit_price,
fill_price=fill_price,
)
self._orders[self._order_id] = order
self._sym_orders[symbol].add(order)
return order.id
[docs]
def remove(self, order_id: int) -> bool:
"""Removes a :class:`.PendingOrder` with ``order_id```."""
if order_id in self._orders:
order = self._orders[order_id]
del self._orders[order_id]
if (
order.symbol in self._sym_orders
and order in self._sym_orders[order.symbol]
):
self._sym_orders[order.symbol].remove(order)
return True
return False
[docs]
def remove_all(self, symbol: Optional[str] = None):
r"""Removes all :class:`.PendingOrder`\ s."""
if symbol is None:
cancel_ids = tuple(self._orders.keys())
for order_id in cancel_ids:
self.remove(order_id)
elif symbol in self._sym_orders:
cancel_ids = tuple(order.id for order in self._sym_orders[symbol])
for order_id in cancel_ids:
self.remove(order_id)
[docs]
def orders(self, symbol: Optional[str] = None) -> Iterable[PendingOrder]:
r"""Returns an :class:`Iterable` of :class:`.PendingOrder`\ s."""
if symbol is None:
return self._orders.values()
else:
if symbol not in self._sym_orders:
return []
return self._sym_orders[symbol]
[docs]
def get_signals(
symbols: Iterable[str],
col_scope: ColumnScope,
ind_scope: IndicatorScope,
pred_scope: PredictionScope,
) -> dict[str, pd.DataFrame]:
r"""Retrieves dictionary of :class:`pandas.DataFrame`\ s
containing bar data, indicator data, and model predictions for each symbol.
"""
static_scope = StaticScope.instance()
cols = static_scope.all_data_cols
inds = static_scope._indicators.keys()
models = static_scope._model_sources.keys()
dates = col_scope._df.index.get_level_values(1)
dfs: dict[str, pd.DataFrame] = {}
for sym in symbols:
data = {DataCol.DATE.value: dates}
for col in cols:
data[col] = col_scope.fetch(sym, col)
for ind in inds:
try:
data[ind] = ind_scope.fetch(sym, ind)
except ValueError:
continue
for model in models:
try:
data[f"{model}_pred"] = pred_scope.fetch(sym, model)
except ValueError:
continue
dfs[sym] = pd.DataFrame(data)
return dfs