Source code for pybroker.portfolio

"""Contains portfolio related functionality, such as portfolio metrics and
placing orders.
"""

"""Copyright (C) 2023 Edward West. All rights reserved.

This code is licensed under Apache 2.0 with Commons Clause license
(see LICENSE for details).
"""

import itertools
import numpy as np
import pandas as pd
from pybroker.common import (
    BarData,
    DataCol,
    FeeMode,
    PriceType,
    StopType,
    to_decimal,
)
from pybroker.scope import PriceScope, StaticScope
from collections import deque
from dataclasses import dataclass, field
from decimal import Decimal
from typing import (
    Callable,
    Iterable,
    Literal,
    NamedTuple,
    Optional,
    Union,
)


[docs] class Stop(NamedTuple): """Contains information about a stop set on :class:`.Entry`. Attributes: id: Unique identifier. symbol: Symbol of the stop. pos_type: Type of :class:`.Position`, either ``long`` or ``short``. percent: Percent from entry price. points: Cash amount from entry price. bars: Number of bars after which to trigger the stop. fill_price: Price that the stop will be filled at. limit_price: Limit price to use for the stop. exit_price: Exit :class:`pybroker.common.PriceType` to use for the stop exit. If set, the stop is checked against the ``exit_price`` and exits at the ``exit_price`` when triggered. """ id: int symbol: str stop_type: StopType pos_type: Literal["long", "short"] percent: Optional[Decimal] points: Optional[Decimal] bars: Optional[int] fill_price: Optional[ Union[ int, float, np.floating, Decimal, PriceType, Callable[[str, BarData], Union[int, float, Decimal]], ] ] limit_price: Optional[Decimal] exit_price: Optional[PriceType]
[docs] @dataclass class Entry: """Contains information about an entry into a :class:`.Position`. Attributes: id: Unique identifier. date: Date of the entry. symbol: Symbol of the entry. shares: Number of shares. price: Share price of the entry. type: Type of :class:`.Position`, either ``long`` or ``short``. bars: Current number of bars since entry. stops: Stops set on the entry. """ id: int date: np.datetime64 symbol: str shares: Decimal price: Decimal type: Literal["long", "short"] bars: int = field(default=0) stops: set[Stop] = field(default_factory=set)
@dataclass class _StopData: value: Decimal stop: Stop entry: Entry
[docs] @dataclass class Position: r"""Contains information about an open position in ``symbol``. Attributes: symbol: Ticker symbol of the position. shares: Number of shares. type: Type of position, either ``long`` or ``short``. close: Last close price of ``symbol``. equity: Equity in the position. market_value: Market value of position. margin: Amount of margin in position. pnl: Unrealized profit and loss (PnL). entries: ``deque`` of position :class:`.Entry`\ s sorted in ascending chronological order. bars: Current number of bars since entry. """ symbol: str shares: Decimal type: Literal["long", "short"] close: Decimal = field(default_factory=Decimal) equity: Decimal = field(default_factory=Decimal) market_value: Decimal = field(default_factory=Decimal) margin: Decimal = field(default_factory=Decimal) pnl: Decimal = field(default_factory=Decimal) entries: deque[Entry] = field(default_factory=deque) bars: int = field(default=0)
[docs] class Trade(NamedTuple): """Holds information about a completed trade (entry and exit). Attributes: id: Unique identifier. type: Type of trade, either ``long`` or ``short``. symbol: Ticker symbol of the trade. entry_date: Entry date. exit_date: Exit date. entry: Entry price. exit: Exit price. shares: Number of shares. pnl: Profit and loss (PnL). return_pct: Return measured in percentage. agg_pnl: Aggregate profit and loss (PnL) of the strategy after the trade. bars: Number of bars the trade was held. pnl_per_bar: Profit and loss (PnL) per bar held. stop: Type of stop that was triggered, if any. """ id: int type: Literal["long", "short"] symbol: str entry_date: np.datetime64 exit_date: np.datetime64 entry: Decimal exit: Decimal shares: Decimal pnl: Decimal return_pct: Decimal agg_pnl: Decimal bars: int pnl_per_bar: Decimal stop: Optional[Literal["bar", "loss", "profit", "trailing"]]
[docs] class Order(NamedTuple): """Holds information about a filled order. Attributes: id: Unique identifier. type: Type of order, either ``buy`` or ``sell``. symbol: Ticker symbol of the order. date: Date the order was filled. shares: Number of shares bought or sold. limit_price: Limit price that was used for the order. fill_price: Price that the order was filled at. fees: Brokerage fees for order. """ id: int type: Literal["buy", "sell"] symbol: str date: np.datetime64 shares: Decimal limit_price: Optional[Decimal] fill_price: Decimal fees: Decimal
[docs] class PortfolioBar(NamedTuple): """Snapshot of :class:`.Portfolio` state, captured per bar. Attributes: date: Date of bar. cash: Amount of cash in :class:`.Portfolio`. equity: Amount of equity in :class:`.Portfolio`. margin: Amount of margin in :class:`.Portfolio`. market_value: Market value of :class:`.Portfolio`. pnl: Realized profit and loss (PnL) of :class:`.Portfolio`. unrealized_pnl: Unrealized profit and loss (PnL) of :class:`.Portfolio`. fees: Brokerage fees. """ date: np.datetime64 cash: Decimal equity: Decimal margin: Decimal market_value: Decimal pnl: Decimal unrealized_pnl: Decimal fees: Decimal
[docs] class PositionBar(NamedTuple): r"""Snapshot of an open :class:`.Position`\ 's state, captured per bar. Attributes: symbol: Ticker symbol of :class:`.Position`. date: Date of bar. long_shares: Number of shares long in :class:`.Position`. short_shares: Number of shares short in :class:`.Position`. close: Last close price of ``symbol``. equity: Amount of equity in :class:`.Position`. market_value: Market value of :class:`.Position`. margin: Amount of margin in :class:`.Position`. unrealized_pnl: Unrealized profit and loss (PnL) of :class:`.Position`. """ symbol: str date: np.datetime64 long_shares: Decimal short_shares: Decimal close: Decimal equity: Decimal market_value: Decimal margin: Decimal unrealized_pnl: Decimal
class _OrderResult(NamedTuple): filled_shares: Decimal rem_shares: Decimal def _calculate_pnl( price: Decimal, entries: Iterable[Entry], entry_type: Literal["short", "long"], ) -> Decimal: if entry_type == "long": return Decimal( sum((price - entry.price) * entry.shares for entry in entries) ) elif entry_type == "short": return Decimal( sum((entry.price - price) * entry.shares for entry in entries) ) else: raise ValueError(f"Unknown entry_type: {entry_type}")
[docs] class Portfolio: r"""Class representing a portfolio of holdings. The portfolio contains information about open positions and balances, and is also used to place buy and sell orders. Args: cash: Starting cash balance. fee_mode: Brokerage fee mode. fee_amount: Brokerage fee amount. max_long_positions: Maximum number of long :class:`.Position`\ s that can be held at a time. If ``None``, then unlimited. max_short_positions: Maximum number of short :class:`.Position`\ s that can be held at a time. If ``None``, then unlimited. Attributes: cash: Current cash balance. equity: Current amount of equity. market_value: Current market value. The market value is defined as the amount of equity held in cash and long positions added together with the unrealized PnL of all open short positions. fees: Current brokerage fees. enable_fractional_shares: Whether to enable trading fractional shares. orders: ``deque`` of all filled orders, sorted in ascending chronological order. margin: Current amount of margin held in open positions. pnl: Realized profit and loss (PnL). long_positions: ``dict`` mapping ticker symbols to open long :class:`.Position`\ s. short_positions: ``dict`` mapping ticker symbols to open short :class:`.Position`\ s. symbols: Ticker symbols of all currently open positions. bars: ``deque`` of snapshots of :class:`.Portfolio` state on every bar, sorted in ascending chronological order. position_bars: ``deque`` of snapshots of :class:`.Position` states on every bar, sorted in ascending chronological order. win_rate: Running win rate of trades. loss_rate: Running loss rate of trades. """ def __init__( self, cash: float, fee_mode: Optional[FeeMode] = None, fee_amount: Optional[float] = None, enable_fractional_shares: bool = False, max_long_positions: Optional[int] = None, max_short_positions: Optional[int] = None, ): self.cash: Decimal = to_decimal(cash) self._initial_market_value = self.cash self._fee_mode = fee_mode self._fee_amount: Optional[Decimal] = ( None if fee_amount is None else to_decimal(fee_amount) ) self._enable_fractional_shares = enable_fractional_shares self.equity: Decimal = self.cash self.market_value: Decimal = self.cash self.fees = Decimal() self._max_long_positions = max_long_positions self._max_short_positions = max_short_positions self.orders: deque[Order] = deque() self.trades: deque[Trade] = deque() self.margin: Decimal = Decimal() self.pnl: Decimal = Decimal() self.long_positions: dict[str, Position] = {} self.short_positions: dict[str, Position] = {} self.symbols: set[str] = set() self.bars: deque[PortfolioBar] = deque() self.position_bars: deque[PositionBar] = deque() self.win_rate: Decimal = Decimal() self.loss_rate: Decimal = Decimal() self._wins: Decimal = Decimal() self._logger = StaticScope.instance().logger self._stop_data: dict[int, _StopData] = {} self._order_id: int = 0 self._entry_id: int = 0 self._trade_id: int = 0 def _calculate_fees(self, fill_price: Decimal, shares: Decimal) -> Decimal: fees = Decimal() if self._fee_mode is None or self._fee_amount is None: return fees if self._fee_mode == FeeMode.ORDER_PERCENT: fees = self._fee_amount / Decimal(100) * fill_price * shares elif self._fee_mode == FeeMode.PER_ORDER: fees = self._fee_amount elif self._fee_mode == FeeMode.PER_SHARE: fees = self._fee_amount * shares else: raise ValueError(f"Unknown FeeMode: {self._fee_mode!r}") return fees def _verify_input( self, shares: Union[int, float, Decimal], fill_price: Decimal, limit_price: Optional[Decimal], ): if shares < 0: raise ValueError(f"Shares cannot be negative: {shares}") if fill_price <= 0: raise ValueError(f"Fill price must be > 0: {fill_price}") if limit_price is not None and limit_price <= 0: raise ValueError(f"Limit price must be > 0: {limit_price}") def _add_entry( self, date: np.datetime64, symbol: str, shares: Decimal, price: Decimal, type: Literal["long", "short"], pos: Position, ) -> Entry: self._entry_id += 1 entry = Entry( id=self._entry_id, symbol=symbol, shares=shares, price=price, date=date, type=type, ) pos.entries.append(entry) return entry def _add_order( self, date: np.datetime64, symbol: str, type: Literal["buy", "sell"], limit_price: Optional[Decimal], fill_price: Decimal, shares: Decimal, ) -> Order: self._order_id += 1 fees = self._calculate_fees(fill_price, shares) order = Order( id=self._order_id, date=date, symbol=symbol, type=type, limit_price=limit_price, fill_price=fill_price, shares=shares, fees=fees, ) self.orders.append(order) self.fees += fees return order def _add_trade( self, type: Literal["long", "short"], symbol: str, entry_date: np.datetime64, exit_date: np.datetime64, entry_price: Decimal, exit_price: Decimal, shares: Decimal, pnl: Decimal, return_pct: Decimal, agg_pnl: Decimal, bars: int, pnl_per_bar: Decimal, stop_type: Optional[StopType], ): self._trade_id += 1 trade = Trade( id=self._trade_id, type=type, symbol=symbol, entry_date=entry_date, exit_date=exit_date, entry=entry_price, exit=exit_price, shares=shares, pnl=pnl, return_pct=return_pct, agg_pnl=agg_pnl, bars=bars, pnl_per_bar=pnl_per_bar, stop=None if stop_type is None else stop_type.value, ) self.trades.append(trade) if pnl > 0: self._wins += 1 self.win_rate = self._wins / len(self.trades) self.loss_rate = 1 - self.win_rate def _get_stop_amount(self, stop: Stop, price: Decimal) -> Decimal: if stop.percent is not None: return price * stop.percent / 100 elif stop.points is not None: return stop.points else: raise ValueError("Stop amount not set.") def _add_stops(self, entry: Entry, stops: Iterable[Stop]): for stop in stops: if stop.id in self._stop_data: raise ValueError(f"Duplicate stop ID: {stop.id}") entry.stops.add(stop) if stop.stop_type == StopType.BAR: continue amount = self._get_stop_amount(stop, entry.price) if ( stop.pos_type == "long" and stop.stop_type == StopType.PROFIT ) or ( stop.pos_type == "short" and ( stop.stop_type == StopType.LOSS or stop.stop_type == StopType.TRAILING ) ): stop_value = entry.price + amount else: stop_value = entry.price - amount self._stop_data[stop.id] = _StopData( value=stop_value, stop=stop, entry=entry ) def _remove_stop_data(self, entry: Entry): for stop in entry.stops: if stop.id in self._stop_data: del self._stop_data[stop.id] def _clamp_shares(self, fill_price: Decimal, shares: Decimal) -> Decimal: max_shares = ( Decimal(self.cash / fill_price) if self._enable_fractional_shares else Decimal(self.cash // fill_price) ) return min(shares, max_shares)
[docs] def buy( self, date: np.datetime64, symbol: str, shares: Decimal, fill_price: Decimal, limit_price: Optional[Decimal] = None, stops: Optional[Iterable[Stop]] = None, ) -> Optional[Order]: r"""Places a buy order. Args: date: Date when the :class:`.Order` is placed. symbol: Ticker symbol to buy. shares: Number of shares to buy. fill_price: If filled, the price used to fill the :class:`.Order`. limit_price: Limit price of the :class:`.Order`. stops: :class:`.Stop`\ s to set on the :class:`.Entry` created from the :class:`.Order`, if filled. Returns: :class:`.Order` if the order was filled, otherwise ``None``. """ self._verify_input(shares, fill_price, limit_price) self._logger.debug_place_buy_order( date=date, symbol=symbol, shares=shares, fill_price=fill_price, limit_price=limit_price, ) if limit_price is not None and limit_price < fill_price: return None if shares == 0: return None covered = self._cover(date, symbol, shares, fill_price) bought_shares = self._buy( date, symbol, covered.rem_shares, fill_price, limit_price, stops ) if not covered.filled_shares and not bought_shares: return None order = self._add_order( date=date, symbol=symbol, type="buy", limit_price=limit_price, fill_price=fill_price, shares=covered.filled_shares + bought_shares, ) return order
def _cover( self, date: np.datetime64, symbol: str, shares: Decimal, fill_price: Decimal, ) -> _OrderResult: if symbol not in self.short_positions: return _OrderResult(Decimal(), shares) rem_shares = shares if rem_shares <= 0: return _OrderResult(Decimal(), shares) pos = self.short_positions[symbol] while pos.entries: entry = pos.entries[0] if rem_shares >= entry.shares: rem_shares -= entry.shares self._exit_short( date, pos, entry, entry.shares, fill_price, stop_type=None ) self._remove_stop_data(entry) pos.entries.popleft() else: self._exit_short( date, pos, entry, rem_shares, fill_price, stop_type=None ) rem_shares = Decimal() break self._update_position(pos) return _OrderResult(shares - rem_shares, rem_shares) def _exit_short( self, date: np.datetime64, pos: Position, entry: Entry, shares: Decimal, fill_price: Decimal, stop_type: Optional[StopType], ): order_amount = shares * fill_price entry_amount = shares * entry.price entry_pnl = entry_amount - order_amount self.pnl += entry_pnl self.cash += entry_pnl pos.shares -= shares entry.shares -= shares pnl_per_bar = entry_pnl if not entry.bars else entry_pnl / entry.bars return_pct = ((entry.price / fill_price) - 1) * 100 self._add_trade( type=entry.type, symbol=entry.symbol, entry_date=entry.date, exit_date=date, entry_price=entry.price, exit_price=fill_price, shares=shares, pnl=entry_pnl, return_pct=return_pct, agg_pnl=self.pnl, bars=entry.bars, pnl_per_bar=pnl_per_bar, stop_type=stop_type, ) def _buy( self, date: np.datetime64, symbol: str, shares: Decimal, fill_price: Decimal, limit_price: Optional[Decimal], stops: Optional[Iterable[Stop]], ) -> Decimal: clamped_shares = self._clamp_shares(fill_price, shares) if clamped_shares < shares: self._logger.debug_buy_shares_exceed_cash( date=date, symbol=symbol, shares=shares, fill_price=fill_price, limit_price=limit_price, cash=self.cash, clamped_shares=clamped_shares, ) shares = clamped_shares if shares <= 0: return Decimal() if ( self._max_long_positions is not None and symbol not in self.long_positions and len(self.long_positions) == self._max_long_positions ): return Decimal() order_amount = shares * fill_price self.cash -= order_amount if symbol not in self.long_positions: self.symbols.add(symbol) pos = Position(symbol=symbol, shares=shares, type="long") self.long_positions[symbol] = pos else: pos = self.long_positions[symbol] pos.shares += shares entry = self._add_entry( date=date, symbol=symbol, shares=shares, price=fill_price, type="long", pos=pos, ) if stops is not None: self._add_stops(entry, stops) return shares
[docs] def sell( self, date: np.datetime64, symbol: str, shares: Decimal, fill_price: Decimal, limit_price: Optional[Decimal] = None, stops: Optional[Iterable[Stop]] = None, ) -> Optional[Order]: r"""Places a sell order. Args: date: Date when the :class:`.Order` is placed. symbol: Ticker symbol to sell. shares: Number of shares to sell. fill_price: If filled, the price used to fill the :class:`.Order`. limit_price: Limit price of the :class:`.Order`. stops: :class:`.Stop`\ s to set on the :class:`.Entry` created from the :class:`.Order`, if filled. Returns: :class:`.Order` if the order was filled, otherwise ``None``. """ self._verify_input(shares, fill_price, limit_price) self._logger.debug_place_sell_order( date=date, symbol=symbol, shares=shares, fill_price=fill_price, limit_price=limit_price, ) if limit_price is not None and limit_price > fill_price: return None if shares == 0: return None sold = self._sell_existing(date, symbol, shares, fill_price) short_shares = self._short( date, symbol, sold.rem_shares, fill_price, stops ) if not sold.filled_shares and not short_shares: return None order = self._add_order( date=date, symbol=symbol, type="sell", limit_price=limit_price, fill_price=fill_price, shares=sold.filled_shares + short_shares, ) return order
def _sell_existing( self, date: np.datetime64, symbol: str, shares: Decimal, fill_price: Decimal, ) -> _OrderResult: if symbol not in self.long_positions: return _OrderResult(Decimal(), shares) rem_shares = shares pos = self.long_positions[symbol] while pos.entries: entry = pos.entries[0] if rem_shares >= entry.shares: rem_shares -= entry.shares self._exit_long( date, pos, entry, entry.shares, fill_price, stop_type=None ) self._remove_stop_data(entry) pos.entries.popleft() else: self._exit_long( date, pos, entry, rem_shares, fill_price, stop_type=None ) rem_shares = Decimal() break self._update_position(pos) return _OrderResult(shares - rem_shares, rem_shares) def _exit_long( self, date: np.datetime64, pos: Position, entry: Entry, shares: Decimal, fill_price: Decimal, stop_type: Optional[StopType], ): order_amount = shares * fill_price entry_amount = shares * entry.price entry_pnl = order_amount - entry_amount self.pnl += entry_pnl self.cash += order_amount pos.shares -= shares entry.shares -= shares pnl_per_bar = entry_pnl if not entry.bars else entry_pnl / entry.bars return_pct = ((fill_price / entry.price) - 1) * 100 self._add_trade( type=entry.type, symbol=entry.symbol, entry_date=entry.date, exit_date=date, entry_price=entry.price, exit_price=fill_price, shares=shares, pnl=entry_pnl, return_pct=return_pct, agg_pnl=self.pnl, bars=entry.bars, pnl_per_bar=pnl_per_bar, stop_type=stop_type, ) def _update_position(self, pos: Position): if pos.entries: return if pos.type == "long": if pos.symbol in self.long_positions: del self.long_positions[pos.symbol] else: if pos.symbol in self.short_positions: del self.short_positions[pos.symbol] if ( pos.symbol in self.symbols and pos.symbol not in self.long_positions and pos.symbol not in self.short_positions ): self.symbols.remove(pos.symbol) def _short( self, date: np.datetime64, symbol: str, shares: Decimal, fill_price: Decimal, stops: Optional[Iterable[Stop]], ) -> Decimal: if shares <= 0: return Decimal() if ( self._max_short_positions is not None and symbol not in self.short_positions and len(self.short_positions) == self._max_short_positions ): return Decimal() if symbol not in self.short_positions: self.symbols.add(symbol) pos = Position(symbol=symbol, shares=shares, type="short") self.short_positions[symbol] = pos else: pos = self.short_positions[symbol] pos.shares += shares entry = self._add_entry( date=date, symbol=symbol, shares=shares, price=fill_price, type="short", pos=pos, ) if stops is not None: self._add_stops(entry, stops) return shares
[docs] def exit_position( self, date: np.datetime64, symbol: str, buy_fill_price: Decimal, sell_fill_price: Decimal, ): """Exits any long and short positions for ``symbol`` at ``buy_fill_price`` and ``sell_fill_price``. """ if symbol in self.long_positions: self.sell( date=date, symbol=symbol, shares=self.long_positions[symbol].shares, fill_price=sell_fill_price, ) if symbol in self.short_positions: self.buy( date=date, symbol=symbol, shares=self.short_positions[symbol].shares, fill_price=buy_fill_price, )
[docs] def capture_bar(self, date: np.datetime64, df: pd.DataFrame): """Captures portfolio state of the current bar. Args: date: Date of current bar. df: :class:`pandas.DataFrame` containing close prices. """ total_equity = self.cash total_market_value = total_equity total_margin = Decimal() for sym in self.symbols: index = (sym, date) close = None if index in df.index: close = to_decimal(df.loc[index][DataCol.CLOSE.value]) pos_long_shares = Decimal() pos_short_shares = Decimal() pos_equity = Decimal() pos_market_value = Decimal() pos_margin = Decimal() pos_pnl = Decimal() if sym in self.long_positions: pos = self.long_positions[sym] if close is not None: pos.equity = pos.shares * close pos.market_value = pos.equity pos.close = close pos.pnl = _calculate_pnl(close, pos.entries, "long") pos_long_shares += pos.shares pos_equity += pos.equity pos_market_value += pos.market_value pos_pnl += pos.pnl total_equity += pos.equity total_market_value += pos.equity if sym in self.short_positions: pos = self.short_positions[sym] if close is not None: pos.close = close pos.pnl = _calculate_pnl(close, pos.entries, "short") pos.margin = close * pos.shares pos.market_value = pos.margin + pos.pnl pos_margin += pos.margin pos_short_shares += pos.shares pos_market_value += pos.market_value pos_pnl += pos.pnl total_margin += pos.margin total_market_value += pos.pnl if close is not None: self.position_bars.append( PositionBar( symbol=sym, date=date, long_shares=pos_long_shares, short_shares=pos_short_shares, close=close, equity=pos_equity, market_value=pos_market_value, margin=pos_margin, unrealized_pnl=pos_pnl, ) ) self.equity = total_equity self.market_value = total_market_value self.margin = total_margin self.bars.append( PortfolioBar( date=date, cash=self.cash, equity=self.equity, market_value=self.market_value, margin=self.margin, pnl=self.equity - self._initial_market_value, unrealized_pnl=self.market_value - self.equity, fees=self.fees, ) )
[docs] def incr_bars(self): """Increments the number of bars held by every trade entry.""" for pos in itertools.chain( self.long_positions.values(), self.short_positions.values() ): pos.bars += 1 for entry in pos.entries: entry.bars += 1
[docs] def remove_stop(self, stop_id: int) -> bool: """Removes a :class:`.Stop` with ``stop_id``.""" if stop_id in self._stop_data: stop_data = self._stop_data[stop_id] del self._stop_data[stop_id] if stop_data.stop in stop_data.entry.stops: stop_data.entry.stops.remove(stop_data.stop) return True return False
[docs] def remove_stops( self, val: Union[str, Position, Entry], stop_type: Optional[StopType] = None, ): r"""Removes :class:`.Stop`\ s. Args: val: Ticker symbol, :class:`.Position`, or :class:`.Entry` for which to cancel stops. stop_type: :class:`pybroker.common.StopType`. """ if isinstance(val, str): if val in self.long_positions: self._remove_position_stops( self.long_positions[val], stop_type ) if val in self.short_positions: self._remove_position_stops( self.short_positions[val], stop_type ) elif isinstance(val, Position): self._remove_position_stops(val, stop_type) elif isinstance(val, Entry): self._remove_entry_stops(val, stop_type)
def _remove_position_stops( self, pos: Position, stop_type: Optional[StopType] ): for entry in pos.entries: self._remove_entry_stops(entry, stop_type) def _remove_entry_stops(self, entry: Entry, stop_type: Optional[StopType]): if stop_type is None: self._remove_stop_data(entry) entry.stops.clear() else: stop_id = None for stop in entry.stops: if stop.stop_type == stop_type: stop_id = stop.id break if stop_id is not None: self.remove_stop(stop_id)
[docs] def check_stops(self, date: np.datetime64, price_scope: PriceScope): """Checks whether stops are triggered.""" executed: deque[tuple[Position, Entry]] = deque() for pos in itertools.chain( self.long_positions.values(), self.short_positions.values() ): for entry in pos.entries: for stop in entry.stops: if self._trigger_stop(date, price_scope, pos, entry, stop): executed.append((pos, entry)) break for pos, entry in executed: pos.entries.remove(entry) self._remove_stop_data(entry) self._update_position(pos)
def _trigger_stop( self, date: np.datetime64, price_scope: PriceScope, pos: Position, entry: Entry, stop: Stop, ) -> bool: if stop.stop_type == StopType.BAR: fill_price = self._trigger_bar_stop(stop, price_scope, entry) elif ( stop.stop_type == StopType.LOSS or stop.stop_type == StopType.PROFIT ): fill_price = self._trigger_profit_or_loss_stop(stop, price_scope) elif stop.stop_type == StopType.TRAILING: fill_price = self._trigger_trailing_stop(stop, price_scope) else: raise ValueError(f"Unknown stop type: {stop.stop_type}") if fill_price is None: return False order_type: Literal["buy", "sell"] stop_shares = entry.shares if stop.pos_type == "long": if stop.limit_price is not None and fill_price < stop.limit_price: return False self._exit_long( date, pos, entry, entry.shares, fill_price, stop.stop_type ) order_type = "sell" elif stop.pos_type == "short": if stop.limit_price is not None and fill_price > stop.limit_price: return False self._exit_short( date, pos, entry, entry.shares, fill_price, stop.stop_type ) order_type = "buy" else: raise ValueError(f"Unknown pos_type: {stop.pos_type}") self._add_order( date=date, symbol=pos.symbol, type=order_type, limit_price=stop.limit_price, fill_price=fill_price, shares=stop_shares, ) return True def _trigger_bar_stop( self, stop: Stop, price_scope: PriceScope, entry: Entry ) -> Optional[Decimal]: if stop.bars is None: raise ValueError("Bars not set on bar stop.") if entry.bars >= stop.bars: return price_scope.fetch( stop.symbol, PriceType.MIDDLE if stop.fill_price is None else stop.fill_price, ) return None def _trigger_profit_or_loss_stop( self, stop: Stop, price_scope: PriceScope ) -> Optional[Decimal]: if ( stop.pos_type == "long" and ( stop.stop_type == StopType.LOSS or stop.stop_type == StopType.TRAILING ) ) or (stop.pos_type == "short" and stop.stop_type == StopType.PROFIT): if stop.exit_price is not None: exit_price = price_scope.fetch(stop.symbol, stop.exit_price) if exit_price <= self._stop_data[stop.id].value: return exit_price else: low = price_scope.fetch(stop.symbol, PriceType.LOW) high = price_scope.fetch(stop.symbol, PriceType.HIGH) if low <= self._stop_data[stop.id].value: return min(self._stop_data[stop.id].value, high) elif ( stop.pos_type == "long" and stop.stop_type == StopType.PROFIT ) or ( stop.pos_type == "short" and ( stop.stop_type == StopType.LOSS or stop.stop_type == StopType.TRAILING ) ): if stop.exit_price is not None: exit_price = price_scope.fetch(stop.symbol, stop.exit_price) if exit_price >= self._stop_data[stop.id].value: return exit_price else: low = price_scope.fetch(stop.symbol, PriceType.LOW) high = price_scope.fetch(stop.symbol, PriceType.HIGH) if high >= self._stop_data[stop.id].value: return max(self._stop_data[stop.id].value, low) return None def _trigger_trailing_stop( self, stop: Stop, price_scope: PriceScope ) -> Optional[Decimal]: fill_price = self._trigger_profit_or_loss_stop(stop, price_scope) if fill_price is not None: return fill_price if stop.pos_type == "long": high = price_scope.fetch(stop.symbol, PriceType.HIGH) amount = self._get_stop_amount(stop, high) self._stop_data[stop.id].value = max( high - amount, self._stop_data[stop.id].value ) else: low = price_scope.fetch(stop.symbol, PriceType.LOW) amount = self._get_stop_amount(stop, low) self._stop_data[stop.id].value = min( low + amount, self._stop_data[stop.id].value ) return None