Source code for project_x_py.indicators.base

"""
ProjectX Indicators - Base Classes

Author: @TexasCoding
Date: 2025-08-02

Overview:
    Provides abstract base classes and shared validation/utilities for all ProjectX
    indicator modules. Encapsulates error handling, cache logic, and call semantics
    for consistent and efficient indicator development. All custom indicators should
    inherit from these classes for uniformity and extensibility.

Key Features:
    - `BaseIndicator` with parameter validation, data checks, and result caching
    - Specialized subclasses: OverlapIndicator, MomentumIndicator, VolatilityIndicator,
      VolumeIndicator
    - Utility functions for safe division, rolling sums, and EMA alpha calculation
    - Standardized exception (`IndicatorError`) for all indicator errors

Example Usage:
    ```python
    from project_x_py.indicators.base import BaseIndicator


    class MyCustomIndicator(BaseIndicator):
        def calculate(self, data, period=10):
            self.validate_data(data, ["close"])
            self.validate_period(period)
            # ... custom calculation ...
    ```

See Also:
    - `project_x_py.indicators.momentum.MomentumIndicator`
    - `project_x_py.indicators.overlap.OverlapIndicator`
    - `project_x_py.indicators.volatility.VolatilityIndicator`
    - `project_x_py.indicators.volume.VolumeIndicator`
"""

import hashlib
from abc import ABC, abstractmethod
from typing import Any

import polars as pl


[docs] class IndicatorError(Exception): """Custom exception for indicator calculation errors."""
[docs] class BaseIndicator(ABC): """ Base class for all technical indicators. Provides common validation, error handling, caching, and utility methods that all indicators can inherit from. This abstract base class ensures consistent behavior across all indicators while providing performance optimizations through intelligent caching. Key Features: - Automatic parameter validation and data checking - Built-in caching system to avoid redundant calculations - Standardized error handling with IndicatorError exceptions - Support for both class-based and function-based usage - Memory-efficient operations with Polars DataFrames All custom indicators should inherit from this class or one of its specialized subclasses (OverlapIndicator, MomentumIndicator, etc.) for consistent behavior and optimal performance. """
[docs] def __init__(self, name: str, description: str = "") -> None: """ Initialize base indicator. Args: name: Indicator name description: Optional description """ self.name: str = name self.description: str = description # Cache for computed results to avoid recomputation self._cache: dict[str, pl.DataFrame] = {} self._cache_max_size: int = 100
[docs] def validate_data(self, data: pl.DataFrame, required_columns: list[str]) -> None: """ Validate input DataFrame and required columns. Args: data: Input DataFrame required_columns: List of required column names Raises: IndicatorError: If validation fails """ if data is None: raise IndicatorError("Data cannot be None") if data.is_empty(): raise IndicatorError("Data cannot be empty") for col in required_columns: if col not in data.columns: raise IndicatorError(f"Required column '{col}' not found in data")
[docs] def validate_period(self, period: int, min_period: int = 1) -> None: """ Validate period parameter. Args: period: Period value to validate min_period: Minimum allowed period Raises: IndicatorError: If period is invalid """ if not isinstance(period, int) or period < min_period: raise IndicatorError(f"Period must be an integer >= {min_period}")
[docs] def validate_data_length(self, data: pl.DataFrame, min_length: int) -> None: """ Validate that data has sufficient length for calculation. Args: data: Input DataFrame min_length: Minimum required data length Raises: IndicatorError: If data is too short """ if len(data) < min_length: raise IndicatorError( f"Insufficient data: need at least {min_length} rows, got {len(data)}" )
[docs] @abstractmethod def calculate(self, data: pl.DataFrame, **kwargs: Any) -> pl.DataFrame: """ Calculate the indicator values. This method must be implemented by all indicator subclasses. It should perform the core calculation logic for the specific indicator, including parameter validation, data processing, and result generation. The method should: 1. Validate input data and parameters using inherited validation methods 2. Perform the indicator-specific calculations 3. Return a DataFrame with the original data plus new indicator columns 4. Handle edge cases (insufficient data, invalid parameters, etc.) Args: data: Input DataFrame with OHLCV data (must contain required columns) **kwargs: Additional parameters specific to each indicator (period, thresholds, column names, etc.) Returns: pl.DataFrame: DataFrame with original data plus new indicator columns. The indicator values should be added as new columns with descriptive names (e.g., "rsi", "macd", "bb_upper"). Raises: IndicatorError: If data validation fails or calculation cannot proceed """
def _generate_cache_key(self, data: pl.DataFrame, **kwargs: Any) -> str: """ Generate a cache key for the given data and parameters. Args: data: Input DataFrame **kwargs: Additional parameters Returns: Cache key string """ # Create hash from DataFrame shape, column names, and last few rows data_bytes = data.tail(5).to_numpy().tobytes() data_str = f"{data.shape}{list(data.columns)}" data_hash = hashlib.md5(data_str.encode() + data_bytes).hexdigest() # Include parameters in the key params_str = "_".join(f"{k}={v}" for k, v in sorted(kwargs.items())) return f"{self.name}_{data_hash}_{params_str}" def _get_from_cache(self, cache_key: str) -> pl.DataFrame | None: """Get result from cache if available.""" return self._cache.get(cache_key) def _store_in_cache(self, cache_key: str, result: pl.DataFrame) -> None: """Store result in cache with size management.""" # Simple LRU cache management if len(self._cache) >= self._cache_max_size: # Remove oldest entry oldest_key = next(iter(self._cache)) del self._cache[oldest_key] self._cache[cache_key] = result
[docs] def __call__(self, data: pl.DataFrame, **kwargs: Any) -> pl.DataFrame: """ Allow indicator to be called directly with caching. Args: data: Input DataFrame **kwargs: Additional parameters Returns: DataFrame with indicator values """ # Check cache first cache_key = self._generate_cache_key(data, **kwargs) cached_result = self._get_from_cache(cache_key) if cached_result is not None: return cached_result # Calculate and cache result result = self.calculate(data, **kwargs) self._store_in_cache(cache_key, result) return result
[docs] class OverlapIndicator(BaseIndicator): """Base class for overlap study indicators (trend-following)."""
[docs] def __init__(self, name: str, description: str = "") -> None: super().__init__(name, description) self.category = "overlap"
[docs] class MomentumIndicator(BaseIndicator): """Base class for momentum indicators."""
[docs] def __init__(self, name: str, description: str = "") -> None: super().__init__(name, description) self.category = "momentum"
[docs] class VolatilityIndicator(BaseIndicator): """Base class for volatility indicators."""
[docs] def __init__(self, name: str, description: str = "") -> None: super().__init__(name, description) self.category = "volatility"
[docs] class VolumeIndicator(BaseIndicator): """Base class for volume indicators."""
[docs] def __init__(self, name: str, description: str = "") -> None: super().__init__(name, description) self.category = "volume"
# Utility functions for common calculations
[docs] def safe_division( numerator: pl.Expr, denominator: pl.Expr, default: float = 0.0 ) -> pl.Expr: """ Safe division that handles division by zero. This utility function creates a Polars expression that performs division while safely handling cases where the denominator is zero. It's commonly used in technical indicator calculations where division operations might encounter zero values. Args: numerator: Numerator expression (pl.Expr) denominator: Denominator expression (pl.Expr) default: Default value to return when denominator is zero (default: 0.0) Returns: pl.Expr: Polars expression that performs safe division, returning the default value when denominator is zero Example: >>> # Safe division in RSI calculation >>> gain = pl.col("close").diff().filter(pl.col("close").diff() > 0) >>> loss = -pl.col("close").diff().filter(pl.col("close").diff() < 0) >>> rs = safe_division(gain.rolling_mean(14), loss.rolling_mean(14)) """ return pl.when(denominator != 0).then(numerator / denominator).otherwise(default)
def rolling_sum_positive(expr: pl.Expr, window: int) -> pl.Expr: """ Calculate rolling sum of positive values only. Args: expr: Polars expression window: Rolling window size Returns: Polars expression for rolling sum of positive values """ return pl.when(expr > 0).then(expr).otherwise(0).rolling_sum(window_size=window) def rolling_sum_negative(expr: pl.Expr, window: int) -> pl.Expr: """ Calculate rolling sum of absolute negative values. Args: expr: Polars expression window: Rolling window size Returns: Polars expression for rolling sum of absolute negative values """ return pl.when(expr < 0).then(-expr).otherwise(0).rolling_sum(window_size=window)
[docs] def ema_alpha(period: int) -> float: """ Calculate EMA alpha (smoothing factor) from period. This utility function calculates the smoothing factor (alpha) used in Exponential Moving Average calculations. The alpha determines how much weight is given to recent prices versus older prices. Formula: alpha = 2 / (period + 1) Args: period: EMA period (number of periods for the moving average) Returns: float: Alpha value (smoothing factor) between 0 and 1 Example: >>> alpha = ema_alpha(14) # Returns 0.1333... >>> # Higher alpha = more weight to recent prices >>> alpha_short = ema_alpha(5) # 0.3333... >>> alpha_long = ema_alpha(50) # 0.0392... """ return 2.0 / (period + 1)