Commit 64467572 authored by lukas leufen's avatar lukas leufen

AbstractDataClass and CustomDataClass are now replaced by the...

AbstractDataClass and CustomDataClass are now replaced by the AbstractDataPreparation and DefaultDataPreparation and DataPreparationNeighbors
parent 14e96e9d
Pipeline #41215 failed with stages
in 6 minutes and 39 seconds
...@@ -14,5 +14,5 @@ from .data_preparation_join import DataPrepJoin ...@@ -14,5 +14,5 @@ from .data_preparation_join import DataPrepJoin
from .data_generator import DataGenerator from .data_generator import DataGenerator
from .data_distributor import Distributor from .data_distributor import Distributor
from .iterator import KerasIterator, DataCollection from .iterator import KerasIterator, DataCollection
from .advanced_data_handling import DataPreparation from .advanced_data_handling import DefaultDataPreparation
from .data_preparation import StationPrep from .data_preparation import StationPrep
\ No newline at end of file
...@@ -4,7 +4,6 @@ __date__ = '2020-07-08' ...@@ -4,7 +4,6 @@ __date__ = '2020-07-08'
from src.helpers import to_list, remove_items from src.helpers import to_list, remove_items
from src.data_handling.data_preparation import StationPrep
import numpy as np import numpy as np
import xarray as xr import xarray as xr
import pickle import pickle
...@@ -12,10 +11,13 @@ import os ...@@ -12,10 +11,13 @@ import os
import pandas as pd import pandas as pd
import datetime as dt import datetime as dt
import shutil import shutil
import inspect
from typing import Union, List, Tuple from typing import Union, List, Tuple
import logging import logging
from functools import reduce from functools import reduce
from src.data_handling.data_preparation import StationPrep
number = Union[float, int] number = Union[float, int]
num_or_list = Union[number, List[number]] num_or_list = Union[number, List[number]]
...@@ -45,25 +47,68 @@ class DummyDataSingleStation: # pragma: no cover ...@@ -45,25 +47,68 @@ class DummyDataSingleStation: # pragma: no cover
return self.name return self.name
class DataPreparation: class AbstractDataPreparation:
_requirements = []
def __init__(self, *args, **kwargs):
pass
def __init__(self, id_class, interpolate_dim: str, data_path, neighbors=None, min_length=0, @classmethod
extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False): def build(cls, *args, **kwargs):
"""Return initialised class."""
return cls(*args, **kwargs)
@classmethod
def requirements(cls):
"""Return requirements and own arguments without duplicates."""
return list(set(cls._requirements + cls.own_args()))
@classmethod
def own_args(cls, *args):
return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
def get_X(self, upsampling=False, as_numpy=False):
raise NotImplementedError
def get_Y(self, upsampling=False, as_numpy=False):
raise NotImplementedError
class DefaultDataPreparation(AbstractDataPreparation):
_requirements = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
def __init__(self, id_class, data_path, min_length=0,
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
super().__init__()
self.id_class = id_class self.id_class = id_class
self.neighbors = to_list(neighbors) if neighbors is not None else [] self.interpolate_dim = "datetime"
self.interpolate_dim = interpolate_dim
self.min_length = min_length self.min_length = min_length
self._X = None self._X = None
self._Y = None self._Y = None
self._X_extreme = None self._X_extreme = None
self._Y_extreme = None self._Y_extreme = None
self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle") self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle")
self._collection = [] self._collection = self._create_collection()
self._create_collection()
self.harmonise_X() self.harmonise_X()
self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim) self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
self._store(fresh_store=True) self._store(fresh_store=True)
@classmethod
def build(cls, station, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
sp = StationPrep(station, **sp_keys)
dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
return cls(sp, **dp_args)
def _create_collection(self):
return [self.id_class]
@classmethod
def requirements(cls):
return remove_items(super().requirements(), "id_class")
def _reset_data(self): def _reset_data(self):
self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
...@@ -99,10 +144,6 @@ class DataPreparation: ...@@ -99,10 +144,6 @@ class DataPreparation:
self._reset_data() self._reset_data()
return X, Y return X, Y
def _create_collection(self):
for data_class in [self.id_class] + self.neighbors:
self._collection.append(data_class)
def __repr__(self): def __repr__(self):
return ";".join(list(map(lambda x: str(x), self._collection))) return ";".join(list(map(lambda x: str(x), self._collection)))
...@@ -221,7 +262,7 @@ def run_data_prep(): ...@@ -221,7 +262,7 @@ def run_data_prep():
data.get_Y() data.get_Y()
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
data_prep = DataPreparation(DummyDataSingleStation("main_class"), "datetime", path, data_prep = DataPreparation(DummyDataSingleStation("main_class"), path,
neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")], neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
extreme_values=[1., 1.2]) extreme_values=[1., 1.2])
data_prep.get_data(upsampling=False) data_prep.get_data(upsampling=False)
...@@ -238,54 +279,20 @@ def create_data_prep(): ...@@ -238,54 +279,20 @@ def create_data_prep():
interpolate_dim = 'datetime' interpolate_dim = 'datetime'
window_history_size = 7 window_history_size = 7
window_lead_time = 3 window_lead_time = 3
central_station = StationPrep(path, "DEBW011", {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim, central_station = StationPrep("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim,
target_var, interpolate_dim, window_history_size, window_lead_time) target_var, interpolate_dim, window_history_size, window_lead_time)
neighbor1 = StationPrep(path, "DEBW013", {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim, neighbor1 = StationPrep("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim,
target_var, interpolate_dim, window_history_size, window_lead_time) target_var, interpolate_dim, window_history_size, window_lead_time)
neighbor2 = StationPrep(path, "DEBW034", {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim, neighbor2 = StationPrep("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
target_var, interpolate_dim, window_history_size, window_lead_time) target_var, interpolate_dim, window_history_size, window_lead_time)
data_prep = [] data_prep = []
data_prep.append(DataPreparation(central_station, interpolate_dim, path, neighbors=[neighbor1, neighbor2])) data_prep.append(DataPreparation(central_station, path, neighbors=[neighbor1, neighbor2]))
data_prep.append(DataPreparation(neighbor1, interpolate_dim, path, neighbors=[central_station, neighbor2])) data_prep.append(DataPreparation(neighbor1, path, neighbors=[central_station, neighbor2]))
data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station])) data_prep.append(DataPreparation(neighbor2, path, neighbors=[neighbor1, central_station]))
return data_prep return data_prep
class AbstractDataClass:
def __init__(self):
self._requires = []
def __call__(self, *args, **kwargs):
raise NotImplementedError
@property
def requirements(self):
return self._requires
@requirements.setter
def requirements(self, value):
self._requires = value
class CustomDataClass(AbstractDataClass):
def __init__(self):
import inspect
super().__init__()
self.sp_keys = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
self.dp_keys = remove_items(inspect.getfullargspec(DataPreparation).args, ["self", "id_class"])
self.requirements = self.sp_keys + self.dp_keys
def __call__(self, station, **kwargs):
sp_keys = {k: kwargs[k] for k in self.sp_keys if k in kwargs}
sp_keys["station"] = station
sp = StationPrep(**sp_keys)
dp_args = {k: kwargs[k] for k in self.dp_keys if k in kwargs}
return DataPreparation(sp, **dp_args)
if __name__ == "__main__": if __name__ == "__main__":
from src.data_handling.data_preparation import StationPrep from src.data_handling.data_preparation import StationPrep
from src.data_handling.iterator import KerasIterator, DataCollection from src.data_handling.iterator import KerasIterator, DataCollection
......
...@@ -68,7 +68,7 @@ class AbstractStationPrep(): ...@@ -68,7 +68,7 @@ class AbstractStationPrep():
class StationPrep(AbstractStationPrep): class StationPrep(AbstractStationPrep):
def __init__(self, data_path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var, def __init__(self, station, data_path, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs): interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs):
super().__init__() # path, station, statistics_per_var, transformation, **kwargs) super().__init__() # path, station, statistics_per_var, transformation, **kwargs)
self.station_type = station_type self.station_type = station_type
...@@ -93,7 +93,7 @@ class StationPrep(AbstractStationPrep): ...@@ -93,7 +93,7 @@ class StationPrep(AbstractStationPrep):
self.label = None self.label = None
self.observation = None self.observation = None
self.transformation = self.setup_transformation(transformation) self.transformation = None # self.setup_transformation(transformation)
self.kwargs = kwargs self.kwargs = kwargs
self.kwargs["overwrite_local_data"] = overwrite_local_data self.kwargs["overwrite_local_data"] = overwrite_local_data
......
__author__ = 'Lukas Leufen'
__date__ = '2020-07-17'
from src.helpers import to_list, remove_items
from src.data_handling.data_preparation import StationPrep
from src.data_handling.advanced_data_handling import AbstractDataPreparation, DefaultDataPreparation
import numpy as np
import xarray as xr
import pickle
import os
import shutil
import inspect
from typing import Union, List, Tuple
import logging
from functools import reduce
number = Union[float, int]
num_or_list = Union[number, List[number]]
class DataPreparationNeighbors(DefaultDataPreparation):
def __init__(self, id_class, data_path, neighbors=None, min_length=0,
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
self.neighbors = to_list(neighbors) if neighbors is not None else []
super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values,
extremes_on_right_tail_only=extremes_on_right_tail_only)
@classmethod
def build(cls, station, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
sp = StationPrep(station, **sp_keys)
n_list = []
for neighbor in kwargs.get("neighbors", []):
n_list.append(StationPrep(neighbor, **sp_keys))
else:
kwargs["neighbors"] = n_list if len(n_list) > 0 else None
dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
return cls(sp, **dp_args)
def _create_collection(self):
return [self.id_class] + self.neighbors
if __name__ == "__main__":
a = DataPreparationNeighbors
requirements = a.requirements()
kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
"station_type": None,
"network": 'UBA',
"sampling": 'daily',
"target_dim": 'variables',
"target_var": 'o3',
"interpolate_dim": 'datetime',
"window_history_size": 7,
"window_lead_time": 3,
"neighbors": ["DEBW034"],
"data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'},
"transformation": None,}
a_inst = a.build("DEBW011", **kwargs)
print(a_inst)
...@@ -29,7 +29,7 @@ def run(stations=None, ...@@ -29,7 +29,7 @@ def run(stations=None,
model=None, model=None,
batch_size=None, batch_size=None,
epochs=None, epochs=None,
data_preparation=None): data_preparation=None,):
params = inspect.getfullargspec(DefaultWorkflow).args params = inspect.getfullargspec(DefaultWorkflow).args
kwargs = {k: v for k, v in locals().items() if k in params and v is not None} kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
...@@ -39,9 +39,4 @@ def run(stations=None, ...@@ -39,9 +39,4 @@ def run(stations=None,
if __name__ == "__main__": if __name__ == "__main__":
from src.data_handling.advanced_data_handling import CustomDataClass run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, create_new_model=True)
run(data_preparation=CustomDataClass, statistics_per_var={'o3': 'dma8eu'}, transformation={"scope": "data",
"method": "standardise",
"mean": 50,
"std": 50},
trainable=False, create_new_model=False)
...@@ -18,8 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D ...@@ -18,8 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D
DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
from src.data_handling import DataPrepJoin from src.data_handling.advanced_data_handling import DefaultDataPreparation
from src.data_handling.advanced_data_handling import CustomDataClass
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
from src.model_modules.model_class import MyLittleModel as VanillaModel from src.model_modules.model_class import MyLittleModel as VanillaModel
...@@ -301,9 +300,7 @@ class ExperimentSetup(RunEnvironment): ...@@ -301,9 +300,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("sampling", sampling) self._set_param("sampling", sampling)
self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
self._set_param("transformation", None, scope="preprocessing") self._set_param("transformation", None, scope="preprocessing")
self._set_param("data_preparation", data_preparation() if data_preparation is not None else None, self._set_param("data_preparation", data_preparation, default=DefaultDataPreparation)
default=CustomDataClass())
assert isinstance(getattr(self.data_store.get("data_preparation"), "requirements"), property) is False
# target # target
self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
...@@ -350,6 +347,7 @@ class ExperimentSetup(RunEnvironment): ...@@ -350,6 +347,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS, self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS,
scope="general.postprocessing") scope="general.postprocessing")
self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing
# check variables, statistics and target variable # check variables, statistics and target variable
self._check_target_var() self._check_target_var()
......
...@@ -11,8 +11,7 @@ import numpy as np ...@@ -11,8 +11,7 @@ import numpy as np
import pandas as pd import pandas as pd
from src.data_handling import DataGenerator from src.data_handling import DataGenerator
from src.data_handling import DataCollection, DataPreparation, StationPrep from src.data_handling import DataCollection
from src.data_handling.advanced_data_handling import CustomDataClass
from src.helpers import TimeTracking from src.helpers import TimeTracking
from src.configuration import path_config from src.configuration import path_config
from src.helpers.join import EmptyQueryResult from src.helpers.join import EmptyQueryResult
...@@ -260,10 +259,10 @@ class PreProcessing(RunEnvironment): ...@@ -260,10 +259,10 @@ class PreProcessing(RunEnvironment):
logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}") logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}")
collection = DataCollection() collection = DataCollection()
valid_stations = [] valid_stations = []
kwargs = self.data_store.create_args_dict(data_preparation.requirements, scope=set_name) kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
for station in set_stations: for station in set_stations:
try: try:
dp = data_preparation(station, **kwargs) dp = data_preparation.build(station, **kwargs)
collection.add(dp) collection.add(dp)
valid_stations.append(station) valid_stations.append(station)
except (AttributeError, EmptyQueryResult): except (AttributeError, EmptyQueryResult):
......
...@@ -26,4 +26,4 @@ class Workflow: ...@@ -26,4 +26,4 @@ class Workflow:
"""Run workflow embedded in a run environment and according to the stage's ordering.""" """Run workflow embedded in a run environment and according to the stage's ordering."""
with RunEnvironment(): with RunEnvironment():
for stage, kwargs in self._registry.items(): for stage, kwargs in self._registry.items():
stage(**kwargs) stage(**kwargs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment