pybroker.model module
Contains model related functionality.
- class CachedModel(model: Any, input_cols: tuple[str] | None)[源代码]
基类:
NamedTupleStores cached model data.
- model
Trained model instance.
- Type:
Any
- input_cols
Names of the columns to be used as input for the model when making predictions.
- class ModelLoader(name: str, load_fn: Callable[[...], Any | tuple[Any, Iterable[str]]], indicator_names: Iterable[str], input_data_fn: Callable[[DataFrame], DataFrame] | None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None, kwargs: dict[str, Any])[源代码]
基类:
ModelSourceLoads a pre-trained model.
- 参数:
name -- Name of model.
load_fn --
Callable[[symbol: str, train_start_date: datetime, train_end_date: datetime, ...], DataFrame]used to load and return a pre-trained model. This is expected to return either a trained model instance, or a tuple containing a trained model instance and aIterableof column names to to be used as input for the model when making predictions.indicator_names --
Iterableof names ofpybroker.indicator.Indicators used as features of the model.input_data_fn --
Callable[[DataFrame], DataFrame]for preprocessing input data passed to the model when making predictions. If set,input_data_fnwill be called with apandas.DataFramecontaining all test data.predict_fn --
Callable[[Model, DataFrame], ndarray]that overrides calling the model's defaultpredictfunction. If set,predict_fnwill be called with the trained model and apandas.DataFramecontaining all test data.kwargs --
dictof kwargs to pass toload_fn.
- __call__(symbol: str, train_start_date: datetime, train_end_date: datetime) Any | tuple[Any, Iterable[str]][源代码]
Loads pre-trained model.
- 参数:
symbol -- Ticker symbol for loading the pre-trained model.
train_start_date -- Start date of training window.
train_end_date -- End date of training window.
- 返回:
Pre-trained model.
- class ModelSource(name: str, indicator_names: Iterable[str], input_data_fn: Callable[[DataFrame], DataFrame] | None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None, kwargs: dict[str, Any])[源代码]
基类:
objectBase class of a model source. A model source provides a model instance either by training one or by loading a pre-trained model.
- 参数:
name -- Name of model.
indicator_names --
Iterableof names ofpybroker.indicator.Indicators used as features of the model.input_data_fn --
Callable[[DataFrame], DataFrame]for preprocessing input data passed to the model when making predictions. If set,input_data_fnwill be called with apandas.DataFramecontaining all test data.predict_fn --
Callable[[Model, DataFrame], ndarray]that overrides calling the model's defaultpredictfunction. If set,predict_fnwill be called with the trained model and apandas.DataFramecontaining all test data.kwargs --
dictof additional kwargs.
- prepare_input_data(df: DataFrame) DataFrame[源代码]
Prepares a
pandas.DataFrameof input data for passing to a model when making predictions. If set, theinput_data_fnis used to preprocess the input data. IfFalse, then indicator columns indfare used as input features.
- class ModelTrainer(name: str, train_fn: Callable[[...], Any | tuple[Any, Iterable[str]]], indicator_names: Iterable[str], input_data_fn: Callable[[DataFrame], DataFrame] | None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None, kwargs: dict[str, Any])[源代码]
基类:
ModelSourceTrains a model.
- 参数:
name -- Name of model.
train_fn --
Callable[[symbol: str, train_data: DataFrame, test_data: DataFrame, ...], DataFrame]used to train and return a model. This is expected to return either a trained model instance, or a tuple containing a trained model instance and aIterableof column names to to be used as input for the model when making predictions.indicator_names --
Iterableof names ofpybroker.indicator.Indicators used as features of the model.input_data_fn --
Callable[[DataFrame], DataFrame]for preprocessing input data passed to the model when making predictions. If set,input_data_fnwill be called with apandas.DataFramecontaining all test data.predict_fn --
Callable[[Model, DataFrame], ndarray]that overrides calling the model's defaultpredictfunction. If set,predict_fnwill be called with the trained model and apandas.DataFramecontaining all test data.kwargs --
dictof kwargs to pass totrain_fn.
- class ModelsMixin[源代码]
基类:
objectMixin implementing model related functionality.
- train_models(model_syms: Iterable[ModelSymbol], train_data: DataFrame, test_data: DataFrame, indicator_data: Mapping[IndicatorSymbol, Series], cache_date_fields: CacheDateFields) dict[ModelSymbol, TrainedModel][源代码]
Trains models for the provided
pybroker.common.ModelSymbolpairs.- 参数:
model_syms --
Iterableofpybroker.common.ModelSymbolpairs of models to train.train_data --
pandas.DataFrameof training data.test_data --
pandas.DataFrameof test data.indicator_data --
Mappingofpybroker.common.IndicatorSymbolpairs topandas.Seriesofpybroker.indicator.Indicatorvalues.cache_date_fields -- Date fields used to key cache data.
- 返回:
dictmapping eachpybroker.common.ModelSymbolpair to apybroker.common.TrainedModel.
- model(name: str, fn: Callable[[...], Any | tuple[Any, Iterable[str]]], indicators: Iterable[Indicator] | None = None, input_data_fn: Callable[[DataFrame], DataFrame] | None = None, predict_fn: Callable[[Any, DataFrame], ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None = None, pretrained: bool = False, **kwargs) ModelSource[源代码]
Creates a
ModelSourceinstance and registers it globally withname.- 参数:
name -- Name for referencing the model globally.
fn --
Callableused to either train or load a model instance. If for training, thenfnhas signatureCallable[[symbol: str, train_data: DataFrame, test_data: DataFrame, ...], DataFrame]. If for loading, thenfnhas signatureCallable[[symbol: str, train_start_date: datetime, train_end_date: datetime, ...], DataFrame]. This is expected to return either a trained model instance, or a tuple containing a trained model instance and aIterableof column names to to be used as input for the model when making predictions.indicators --
Iterableofpybroker.indicator.Indicators used as features of the model.input_data_fn --
Callable[[DataFrame], DataFrame]for preprocessing input data passed to the model when making predictions. If set,input_data_fnwill be called with apandas.DataFramecontaining all test data.predict_fn --
Callable[[Model, DataFrame], ndarray]that overrides calling the model's defaultpredictfunction. If set,predict_fnwill be called with the trained model and apandas.DataFramecontaining all test data.pretrained -- If
True, thenfnis used to load and return a pre-trained model. IfFalse,fnis used to train and return a new model. Defaults toFalse.**kwargs -- Additional arguments to pass to
fn.
- 返回:
ModelSourceinstance.