from datetime import datetime
from typing import Literal, Optional, Union
import cloudpickle
import onetick.py as otp
import pandas as pd
from onetick.ml.interfaces import BaseDatafeed
from onetick.ml.interfaces.data_pipelines import (BaseFeatures, BasePipelineOperator, BasePreprocess)
[docs]class BaseOnetickLoader(BaseDatafeed):
def __init__(self, **kwargs):
# make all kwargs are set as attributes of self
for key, value in kwargs.items():
if not hasattr(self, key):
setattr(self, key, value)
# set defaults
self.timezone = kwargs.get('timezone', 'EST5EDT')
self.symbols = kwargs.get('symbols', ['AAPL'])
if isinstance(self.symbols, str):
self.symbols = [self.symbols]
super().__init__(**kwargs)
[docs] def get_source(self) -> otp.DataSource:
"""Generate otp.Source for further processing and loading."""
raise NotImplementedError
def run(self, src):
src = self.merge_symbols(src)
# if datafeed is not splitted, then we assume that whole data is train
if self.schema.set_name not in src.schema:
src[self.schema.set_name] = "TRAIN"
# src = src[self.schema.get_all_columns()]
run_kwargs = {}
if len(self.symbols) == 1:
run_kwargs = dict(symbols=self.symbols)
df = otp.run(src,
# apply_times_daily=self.apply_times_daily,
symbol_date=self.end,
timezone=self.timezone,
# the minute bar for 9:30-9:31 has the timestamp of 9:31
start=self.start,
end=self.end,
**run_kwargs)
return df
def merge_symbols(self, src):
if len(self.symbols) > 1:
src = otp.merge([src], symbols=self.symbols, identify_input_ts=True)
src.drop(columns=['TICK_TYPE'], inplace=True)
elif len(self.symbols) == 1:
src['SYMBOL_NAME'] = self.symbols[0]
return src
[docs] def load(self):
"""
Main method used to load data.
Returns
----------
result: pd.DataFrame
Loaded data
"""
self.schema.symbols = self.symbols
self.schema.db = self.db
self.schema.tick_type = self.tick_type
src = self.get_source()
# set schema data
return src
[docs]class OneTickBarsDatafeedOT(BaseOnetickLoader):
"""
OneTick datafeed with bars (Open, High, Low, Close, Volume, Trade Count).
Parameters
----------
db : str
Name for database to use.
Default: 'NYSE_TAQ_BARS'.
tick_type: str
Tick type to load.
Default: 'TRD_1M'.
symbols: List[str]
List of symbols to load.
Default: ['AAPL'].
start: otp.datetime
Start datetime.
Default: `datetime(2022, 3, 1, 9, 30)`
end: otp.datetime
End datetime.
Default: `datetime(2022, 3, 10, 16, 0)`
bucket: int
Bucket size used to aggregate data (timeframe).
Default: 600.
bucket_time: str
Bucket time to use: `start` or `end`.
Default: `start`.
timezone: str
Timezone to use.
Default: 'EST5EDT'.
columns: list
List of columns to load.
apply_times_daily: bool
Apply times daily to the data, skipping data outside of the specified times for all days.
Default: True.
"""
def __init__(self, **kwargs):
defaults = dict(db='NYSE_TAQ_BARS',
tick_type='TRD_1M',
symbols=['AAPL'],
start=otp.dt(2022, 3, 1, 9, 30),
end=otp.dt(2022, 3, 10, 16, 0),
bucket=600,
bucket_time="start",
timezone='EST5EDT',
apply_times_daily=True,
columns=['Time', 'SYMBOL_NAME',
'OPEN', 'HIGH', 'LOW', 'CLOSE', 'TRADE_COUNT',
'VOLUME'])
defaults.update(kwargs)
super().__init__(**defaults)
[docs] def get_source(self):
data = otp.DataSource(db=self.db,
tick_type=self.tick_type,)
data["VOLUME"] = data["VOLUME"].apply(float)
data, _ = data[data['TRADE_TICK_COUNT'] > 0]
# aggregate data by bucket_interval
data = data.agg({'OPEN': otp.agg.first(data['FIRST']),
'HIGH': otp.agg.max(data['HIGH']),
'LOW': otp.agg.min(data['LOW']),
'CLOSE': otp.agg.last(data['LAST']),
'VOLUME': otp.agg.sum(data['VOLUME']),
'TRADE_COUNT': otp.agg.sum(data['TRADE_TICK_COUNT'])},
bucket_interval=self.bucket,
bucket_time=self.bucket_time,
)
# apply values adjustments (splits, dividends, etc.)
data = otp.functions.corp_actions(data,
adjustment_date=int(self.end.strftime('%Y%m%d')),
adjustment_date_tz="GMT",
adjust_rule='SIZE',
fields='VOLUME')
data = otp.functions.corp_actions(data,
adjustment_date=int(self.end.strftime('%Y%m%d')),
adjustment_date_tz="GMT",
adjust_rule='PRICE',
fields='OPEN,HIGH,LOW,CLOSE')
# filter out data outside of the specified times
if self.apply_times_daily:
data = data.time_filter(start_time=self.start.strftime('%H%M%S%f')[:-3],
end_time=self.end.strftime('%H%M%S%f')[:-3],
timezone=self.timezone)
# mark VOLUME as nan for empty bars (needed for filtering after lags)
# TODO Use holiday calendar: pip install exchange_calendars
empty, data = data[(data["VOLUME"] == 0) & (data["HIGH"] == otp.nan)]
empty["VOLUME"] = otp.nan
data = otp.merge([empty, data])
data, _ = data[data["VOLUME"] != otp.nan]
return data
[docs]class WindowFunction(BaseFeatures):
refit_on_predict = False
deprocessable = False
def __init__(self,
columns: Optional[list] = None,
suffix: str = '_WINDOW_',
window_function: str = Literal['mean', 'std', 'min', 'max'],
window_size: int = 10):
self.columns = columns
self.suffix = suffix
self.window_function = window_function
self.window_size = window_size
super().__init__(columns=columns,
suffix=suffix,
window_function=window_function,
window_size=window_size)
[docs]class OIDSymbolOT(BasePreprocess):
"""
Adds OID column based on symbol name.
"""
def transform_ot(self, src: otp.Source):
# how other way we could avoid of db name in symbols in otp.run()?
symbology = otp.SymbologyMapping(dest_symbology="OID",
tick_type=self.schema.db + "::ANY")
src = otp.join(src, symbology, on="all")
src.rename(columns={"MAPPED_SYMBOL_NAME": "OID"}, inplace=True)
src["OID"] = src["OID"].apply(int)
return src
class ExpressionOperator(BasePipelineOperator):
def __init__(self,
expression,
new_column_name: str,
inverse_expression=None,
apply_kwargs: bool = None):
super().__init__(expression=expression,
inverse_expression=inverse_expression,
apply_kwargs=apply_kwargs,
new_column_name=new_column_name)
self.expression = expression
self.inverse_expression = inverse_expression
self.new_column_name = new_column_name
self.apply_kwargs = apply_kwargs
if self.apply_kwargs is None:
self.apply_kwargs = {}
self.columns = [] # to avoid of adding columns to schema by parent class
def transform_ot(self, src: otp.Source):
src[self.new_column_name] = src.apply(self.expression, **self.apply_kwargs)
return src
def transform_pandas(self, df: pd.DataFrame):
kwargs = {'axis': 1}
kwargs.update(self.apply_kwargs)
df[self.new_column_name] = df.apply(self.expression, **kwargs)
return df
def transform(self, src: Union[pd.DataFrame, otp.Source]):
return super().transform(src)
def save_init_params(self, params, no_class=False):
params['expression'] = cloudpickle.dumps(params['expression'])
return super().save_init_params(params, no_class=no_class)
def inverse_transform(self, prediction_df: pd.DataFrame):
if not self.inverse_expression:
return prediction_df
kwargs = {'axis': 1}
kwargs.update(self.apply_kwargs)
prediction_df[self.new_column_name] = prediction_df.apply(self.inverse_expression, **kwargs)
return prediction_df
@classmethod
def restore_instance(cls, params):
params['expression'] = cloudpickle.loads(params['expression'])
return super().restore_instance(params)
class ToPandas(BasePipelineOperator):
pass