import copy
import datetime
from abc import abstractmethod, ABCMeta
import re
from typing import List, Optional, Union
import onetick.py as otp
import pandas as pd
from onetick.ml.utils.paramsaver import BaseParameterSaver
from onetick.ml.utils.schema import DataSchema
class BasePipelineOperator(BaseParameterSaver, metaclass=ABCMeta):
refit_on_predict = True
deprocessable = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.schema = DataSchema()
self.columns = kwargs.get("columns", "__all__")
self.suffix = kwargs.get('suffix', "_" + self.__class__.__name__.capitalize())
self._columns_map = {}
def __str__(self):
return self.__class__.__name__ + f'({self.columns})'
def __repr__(self):
return self.__str__()
def _set_column_map(self, src):
self._columns_map = {new_column: column for new_column, column in self.iter_column_pairs(src)}
def get_new_column_name(self, column):
return f"{column}{self.suffix}"
def iter_column_pairs(self, src):
for column in self.column_names(src):
yield self.get_new_column_name(column), column
def _remove_utility_columns(self, columns):
return [column for column in columns if not column.startswith("__")
and column not in self.schema.utility_columns
and column != "Time"]
def _get_columns_list(self, src) -> List[str]:
if isinstance(src, otp.Source):
return self._remove_utility_columns(list(src.schema))
if isinstance(src, pd.DataFrame):
return self._remove_utility_columns(list(src.columns))
raise TypeError(f"Unsupported type {type(src)} passed to _get_columns_list() method of {self}!")
def column_names(self, src) -> List[str]:
data_columns = self._get_columns_list(src)
if self.columns == '__all__':
return data_columns
elif self.columns == "__features__":
return self.schema.features_columns
elif self.columns == "__targets__":
return self.schema.target_columns
elif isinstance(self.columns, str):
self.columns = [self.columns]
# match columns with regex/string
all_matched_columns = []
for column in self.columns:
matched_columns = []
for data_column in data_columns:
if re.fullmatch(column, data_column):
matched_columns.append(data_column)
if not matched_columns:
raise ValueError(f"No column found for '{column}' in {data_columns}! Requested by {self}.")
all_matched_columns += matched_columns
return all_matched_columns
def transform(self, src: Union[otp.Source, pd.DataFrame]) -> Union[otp.Source, pd.DataFrame]:
"""Transforms DataFrame or otp.Source.
Parameters
----------
src : DataFrame or otp.Source.
Data feed to be processed.
Returns
-------
DataFrame or otp.Source
Transformed data feed.
"""
if isinstance(src, otp.Source):
return self.transform_ot(src=src)
elif isinstance(src, pd.DataFrame):
return self.transform_pandas(df=src)
raise TypeError(f"Unsupported type {type(src)} passed to transform() method of {self}!")
def transform_ot(self, src: otp.Source):
raise NotImplementedError(f"transform_ot() is not implemented for {self.__class__.__name__}!")
def transform_pandas(self, df: pd.DataFrame):
raise NotImplementedError(f"transform_pandas() is not implemented for {self.__class__.__name__}!")
def fit(self, src: Union[pd.DataFrame, otp.Source]):
"""Fit on a given DataFrame or otp.Source.
Parameters
----------
src : DataFrame or otp.Source.
Data feed to be fitted on.
Returns
-------
DataFrame or otp.Source
Changed data feed (added columns, etc.)
"""
if isinstance(src, otp.Source):
return self.fit_ot(src)
elif isinstance(src, pd.DataFrame):
return self.fit_pandas(src)
raise TypeError(f"Unsupported type {type(src)}")
def fit_ot(self, src: otp.Source):
return src
def fit_pandas(self, df: pd.DataFrame):
return df
def fit_transform(self, src: Union[pd.DataFrame, otp.Source]):
src = self.fit(src)
src = self.transform(src)
return src
def inverse_transform(self, prediction_df: pd.DataFrame = None) -> pd.DataFrame:
"""Reverse process prediction dataframe.
Parameters
----------
prediction_df : pd.DataFrame, optional
Prediction dataframe to be deprocessed, by default None
Returns
-------
pd.DataFrame
Reverse processed prediction dataframe (if deprocessable)
"""
for column in prediction_df.columns:
if column in self._columns_map:
original_column = self._columns_map[column]
prediction_df[original_column] = prediction_df[column]
return prediction_df
def update_target_columns(self):
for new_column, column in self._columns_map.items():
if column in self.schema.target_columns:
# TODO preserve order
self.schema.target_columns.remove(column)
self.schema.target_columns.append(new_column)
[docs]class BaseDatafeed(BasePipelineOperator, metaclass=ABCMeta):
"""Base class for all datafeeds.
"""
@abstractmethod
def load(self, *args, **kwargs):
pass
@classmethod
def restore_instance(cls, params):
if 'start' in params:
params['start'] = otp.dt(*params['start'])
if 'end' in params:
params['end'] = otp.dt(*params['end'])
return cls(**params)
def save_init_params(self, params, no_class=False):
params = copy.deepcopy(params)
for arg in ['start', 'end']:
if arg in params and isinstance(params[arg], otp.types.datetime):
a = params[arg]
res = [a.year, a.month, a.day, a.hour, a.minute, a.second, a.microsecond, a.nanosecond]
while res[-1] == 0:
res.pop(-1)
params[arg] = res
return super().save_init_params(params, no_class=no_class)
def merge_symbols(self, src):
return src
def run(self, src):
return src
[docs]class BasePreprocess(BasePipelineOperator, metaclass=ABCMeta):
"""Base class for all preprocessors.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.group_by_column = kwargs.pop('group_by_column', None)
self._proc_df = None
def __new__(cls, *args, **kwargs):
# We override this method in order to automatically create
# `GroupByColumn` classes instead when `group_by_column` is set.
# Based on https://github.com/encode/django-rest-framework/blob/master/rest_framework/serializers.py#L120
from onetick.ml.impl import GroupByColumn
group_by_column = kwargs.pop('group_by_column', False)
if group_by_column and cls != GroupByColumn:
return GroupByColumn(
preprocessor=cls(*args, **kwargs),
columns=kwargs.get('columns', '__all__'))
instance = super().__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
# Allow type checkers to make serializers generic.
def __class_getitem__(cls, *args, **kwargs):
return cls
[docs] def fit(self, src: Union[pd.DataFrame, otp.Source]):
"""Fit on a given DataFrame or otp.Source.
Parameters
----------
src : DataFrame or otp.Source.
Data feed to be fitted on.
Returns
-------
DataFrame or otp.Source
Changed data feed (added columns, etc.)
"""
self._set_column_map(src)
return super().fit(src)
[docs]class BaseFeatures(BasePipelineOperator, metaclass=ABCMeta):
"""Base class for all feature extractors.
"""
def __init__(self, **kwargs):
self.added_features = []
self.schema = None
super().__init__(**kwargs)
[docs]class BaseSplitting(BasePipelineOperator, metaclass=ABCMeta):
"""Base class for all splitting methods.
"""
class BaseFilter(BasePipelineOperator, metaclass=ABCMeta):
"""Base class for filters.
"""