pybroker.model module
Contains model related functionality.
- class CachedModel(model: Any, input_cols: tuple[str] | None)[source]
Bases:
NamedTuple
Stores cached model data.
- model
Trained model instance.
- Type:
Any
- input_cols
Names of the columns to be used as input for the model when making predictions.
- class ModelLoader(name: str, load_fn: Callable[[...], Any | tuple[Any, Iterable[str]]], indicator_names: Iterable[str], input_data_fn: Callable[[DataFrame], DataFrame] | None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None, kwargs: dict[str, Any])[source]
Bases:
ModelSource
Loads a pre-trained model.
- Parameters:
name – Name of model.
load_fn –
Callable[[symbol: str, train_start_date: datetime, train_end_date: datetime, ...], DataFrame]
used to load and return a pre-trained model. This is expected to return either a trained model instance, or a tuple containing a trained model instance and aIterable
of column names to to be used as input for the model when making predictions.indicator_names –
Iterable
of names ofpybroker.indicator.Indicator
s used as features of the model.input_data_fn –
Callable[[DataFrame], DataFrame]
for preprocessing input data passed to the model when making predictions. If set,input_data_fn
will be called with apandas.DataFrame
containing all test data.predict_fn –
Callable[[Model, DataFrame], ndarray]
that overrides calling the model’s defaultpredict
function. If set,predict_fn
will be called with the trained model and apandas.DataFrame
containing all test data.kwargs –
dict
of kwargs to pass toload_fn
.
- __call__(symbol: str, train_start_date: datetime, train_end_date: datetime) Any | tuple[Any, Iterable[str]] [source]
Loads pre-trained model.
- Parameters:
symbol – Ticker symbol for loading the pre-trained model.
train_start_date – Start date of training window.
train_end_date – End date of training window.
- Returns:
Pre-trained model.
- class ModelSource(name: str, indicator_names: Iterable[str], input_data_fn: Callable[[DataFrame], DataFrame] | None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None, kwargs: dict[str, Any])[source]
Bases:
object
Base class of a model source. A model source provides a model instance either by training one or by loading a pre-trained model.
- Parameters:
name – Name of model.
indicator_names –
Iterable
of names ofpybroker.indicator.Indicator
s used as features of the model.input_data_fn –
Callable[[DataFrame], DataFrame]
for preprocessing input data passed to the model when making predictions. If set,input_data_fn
will be called with apandas.DataFrame
containing all test data.predict_fn –
Callable[[Model, DataFrame], ndarray]
that overrides calling the model’s defaultpredict
function. If set,predict_fn
will be called with the trained model and apandas.DataFrame
containing all test data.kwargs –
dict
of additional kwargs.
- prepare_input_data(df: DataFrame) DataFrame [source]
Prepares a
pandas.DataFrame
of input data for passing to a model when making predictions. If set, theinput_data_fn
is used to preprocess the input data. IfFalse
, then indicator columns indf
are used as input features.
- class ModelTrainer(name: str, train_fn: Callable[[...], Any | tuple[Any, Iterable[str]]], indicator_names: Iterable[str], input_data_fn: Callable[[DataFrame], DataFrame] | None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None, kwargs: dict[str, Any])[source]
Bases:
ModelSource
Trains a model.
- Parameters:
name – Name of model.
train_fn –
Callable[[symbol: str, train_data: DataFrame, test_data: DataFrame, ...], DataFrame]
used to train and return a model. This is expected to return either a trained model instance, or a tuple containing a trained model instance and aIterable
of column names to to be used as input for the model when making predictions.indicator_names –
Iterable
of names ofpybroker.indicator.Indicator
s used as features of the model.input_data_fn –
Callable[[DataFrame], DataFrame]
for preprocessing input data passed to the model when making predictions. If set,input_data_fn
will be called with apandas.DataFrame
containing all test data.predict_fn –
Callable[[Model, DataFrame], ndarray]
that overrides calling the model’s defaultpredict
function. If set,predict_fn
will be called with the trained model and apandas.DataFrame
containing all test data.kwargs –
dict
of kwargs to pass totrain_fn
.
- class ModelsMixin[source]
Bases:
object
Mixin implementing model related functionality.
- train_models(model_syms: Iterable[ModelSymbol], train_data: DataFrame, test_data: DataFrame, indicator_data: Mapping[IndicatorSymbol, Series], cache_date_fields: CacheDateFields) dict[ModelSymbol, TrainedModel] [source]
Trains models for the provided
pybroker.common.ModelSymbol
pairs.- Parameters:
model_syms –
Iterable
ofpybroker.common.ModelSymbol
pairs of models to train.train_data –
pandas.DataFrame
of training data.test_data –
pandas.DataFrame
of test data.indicator_data –
Mapping
ofpybroker.common.IndicatorSymbol
pairs topandas.Series
ofpybroker.indicator.Indicator
values.cache_date_fields – Date fields used to key cache data.
- Returns:
dict
mapping eachpybroker.common.ModelSymbol
pair to apybroker.common.TrainedModel
.
- model(name: str, fn: Callable[[...], Any | tuple[Any, Iterable[str]]], indicators: Iterable[Indicator] | None = None, input_data_fn: Callable[[DataFrame], DataFrame] | None = None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None = None, pretrained: bool = False, **kwargs) ModelSource [source]
Creates a
ModelSource
instance and registers it globally withname
.- Parameters:
name – Name for referencing the model globally.
fn –
Callable
used to either train or load a model instance. If for training, thenfn
has signatureCallable[[symbol: str, train_data: DataFrame, test_data: DataFrame, ...], DataFrame]
. If for loading, thenfn
has signatureCallable[[symbol: str, train_start_date: datetime, train_end_date: datetime, ...], DataFrame]
. This is expected to return either a trained model instance, or a tuple containing a trained model instance and aIterable
of column names to to be used as input for the model when making predictions.indicators –
Iterable
ofpybroker.indicator.Indicator
s used as features of the model.input_data_fn –
Callable[[DataFrame], DataFrame]
for preprocessing input data passed to the model when making predictions. If set,input_data_fn
will be called with apandas.DataFrame
containing all test data.predict_fn –
Callable[[Model, DataFrame], ndarray]
that overrides calling the model’s defaultpredict
function. If set,predict_fn
will be called with the trained model and apandas.DataFrame
containing all test data.pretrained – If
True
, thenfn
is used to load and return a pre-trained model. IfFalse
,fn
is used to train and return a new model. Defaults toFalse
.**kwargs – Additional arguments to pass to
fn
.
- Returns:
ModelSource
instance.