Commit 64467572 authored by lukas leufen's avatar lukas leufen

AbstractDataClass and CustomDataClass are now replaced by the...

AbstractDataClass and CustomDataClass are now replaced by the AbstractDataPreparation and DefaultDataPreparation and DataPreparationNeighbors
parent 14e96e9d
Pipeline #41215 failed with stages
in 6 minutes and 39 seconds
......@@ -14,5 +14,5 @@ from .data_preparation_join import DataPrepJoin
from .data_generator import DataGenerator
from .data_distributor import Distributor
from .iterator import KerasIterator, DataCollection
from .advanced_data_handling import DataPreparation
from .advanced_data_handling import DefaultDataPreparation
from .data_preparation import StationPrep
\ No newline at end of file
......@@ -4,7 +4,6 @@ __date__ = '2020-07-08'
from src.helpers import to_list, remove_items
from src.data_handling.data_preparation import StationPrep
import numpy as np
import xarray as xr
import pickle
......@@ -12,10 +11,13 @@ import os
import pandas as pd
import datetime as dt
import shutil
import inspect
from typing import Union, List, Tuple
import logging
from functools import reduce
from src.data_handling.data_preparation import StationPrep
number = Union[float, int]
num_or_list = Union[number, List[number]]
......@@ -45,25 +47,68 @@ class DummyDataSingleStation: # pragma: no cover
return self.name
class DataPreparation:
class AbstractDataPreparation:
_requirements = []
def __init__(self, *args, **kwargs):
pass
def __init__(self, id_class, interpolate_dim: str, data_path, neighbors=None, min_length=0,
extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False):
@classmethod
def build(cls, *args, **kwargs):
"""Return initialised class."""
return cls(*args, **kwargs)
@classmethod
def requirements(cls):
"""Return requirements and own arguments without duplicates."""
return list(set(cls._requirements + cls.own_args()))
@classmethod
def own_args(cls, *args):
return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args))
def get_X(self, upsampling=False, as_numpy=False):
raise NotImplementedError
def get_Y(self, upsampling=False, as_numpy=False):
raise NotImplementedError
class DefaultDataPreparation(AbstractDataPreparation):
_requirements = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
def __init__(self, id_class, data_path, min_length=0,
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
super().__init__()
self.id_class = id_class
self.neighbors = to_list(neighbors) if neighbors is not None else []
self.interpolate_dim = interpolate_dim
self.interpolate_dim = "datetime"
self.min_length = min_length
self._X = None
self._Y = None
self._X_extreme = None
self._Y_extreme = None
self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle")
self._collection = []
self._create_collection()
self._collection = self._create_collection()
self.harmonise_X()
self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
self._store(fresh_store=True)
@classmethod
def build(cls, station, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
sp = StationPrep(station, **sp_keys)
dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
return cls(sp, **dp_args)
def _create_collection(self):
return [self.id_class]
@classmethod
def requirements(cls):
return remove_items(super().requirements(), "id_class")
def _reset_data(self):
self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
......@@ -99,10 +144,6 @@ class DataPreparation:
self._reset_data()
return X, Y
def _create_collection(self):
for data_class in [self.id_class] + self.neighbors:
self._collection.append(data_class)
def __repr__(self):
return ";".join(list(map(lambda x: str(x), self._collection)))
......@@ -221,7 +262,7 @@ def run_data_prep():
data.get_Y()
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
data_prep = DataPreparation(DummyDataSingleStation("main_class"), "datetime", path,
data_prep = DataPreparation(DummyDataSingleStation("main_class"), path,
neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
extreme_values=[1., 1.2])
data_prep.get_data(upsampling=False)
......@@ -238,54 +279,20 @@ def create_data_prep():
interpolate_dim = 'datetime'
window_history_size = 7
window_lead_time = 3
central_station = StationPrep(path, "DEBW011", {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim,
central_station = StationPrep("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim,
target_var, interpolate_dim, window_history_size, window_lead_time)
neighbor1 = StationPrep(path, "DEBW013", {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim,
neighbor1 = StationPrep("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim,
target_var, interpolate_dim, window_history_size, window_lead_time)
neighbor2 = StationPrep(path, "DEBW034", {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
neighbor2 = StationPrep("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim,
target_var, interpolate_dim, window_history_size, window_lead_time)
data_prep = []
data_prep.append(DataPreparation(central_station, interpolate_dim, path, neighbors=[neighbor1, neighbor2]))
data_prep.append(DataPreparation(neighbor1, interpolate_dim, path, neighbors=[central_station, neighbor2]))
data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station]))
data_prep.append(DataPreparation(central_station, path, neighbors=[neighbor1, neighbor2]))
data_prep.append(DataPreparation(neighbor1, path, neighbors=[central_station, neighbor2]))
data_prep.append(DataPreparation(neighbor2, path, neighbors=[neighbor1, central_station]))
return data_prep
class AbstractDataClass:
def __init__(self):
self._requires = []
def __call__(self, *args, **kwargs):
raise NotImplementedError
@property
def requirements(self):
return self._requires
@requirements.setter
def requirements(self, value):
self._requires = value
class CustomDataClass(AbstractDataClass):
def __init__(self):
import inspect
super().__init__()
self.sp_keys = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
self.dp_keys = remove_items(inspect.getfullargspec(DataPreparation).args, ["self", "id_class"])
self.requirements = self.sp_keys + self.dp_keys
def __call__(self, station, **kwargs):
sp_keys = {k: kwargs[k] for k in self.sp_keys if k in kwargs}
sp_keys["station"] = station
sp = StationPrep(**sp_keys)
dp_args = {k: kwargs[k] for k in self.dp_keys if k in kwargs}
return DataPreparation(sp, **dp_args)
if __name__ == "__main__":
from src.data_handling.data_preparation import StationPrep
from src.data_handling.iterator import KerasIterator, DataCollection
......
......@@ -68,7 +68,7 @@ class AbstractStationPrep():
class StationPrep(AbstractStationPrep):
def __init__(self, data_path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
def __init__(self, station, data_path, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var,
interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs):
super().__init__() # path, station, statistics_per_var, transformation, **kwargs)
self.station_type = station_type
......@@ -93,7 +93,7 @@ class StationPrep(AbstractStationPrep):
self.label = None
self.observation = None
self.transformation = self.setup_transformation(transformation)
self.transformation = None # self.setup_transformation(transformation)
self.kwargs = kwargs
self.kwargs["overwrite_local_data"] = overwrite_local_data
......
__author__ = 'Lukas Leufen'
__date__ = '2020-07-17'
from src.helpers import to_list, remove_items
from src.data_handling.data_preparation import StationPrep
from src.data_handling.advanced_data_handling import AbstractDataPreparation, DefaultDataPreparation
import numpy as np
import xarray as xr
import pickle
import os
import shutil
import inspect
from typing import Union, List, Tuple
import logging
from functools import reduce
number = Union[float, int]
num_or_list = Union[number, List[number]]
class DataPreparationNeighbors(DefaultDataPreparation):
def __init__(self, id_class, data_path, neighbors=None, min_length=0,
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
self.neighbors = to_list(neighbors) if neighbors is not None else []
super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values,
extremes_on_right_tail_only=extremes_on_right_tail_only)
@classmethod
def build(cls, station, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
sp = StationPrep(station, **sp_keys)
n_list = []
for neighbor in kwargs.get("neighbors", []):
n_list.append(StationPrep(neighbor, **sp_keys))
else:
kwargs["neighbors"] = n_list if len(n_list) > 0 else None
dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
return cls(sp, **dp_args)
def _create_collection(self):
return [self.id_class] + self.neighbors
if __name__ == "__main__":
a = DataPreparationNeighbors
requirements = a.requirements()
kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
"station_type": None,
"network": 'UBA',
"sampling": 'daily',
"target_dim": 'variables',
"target_var": 'o3',
"interpolate_dim": 'datetime',
"window_history_size": 7,
"window_lead_time": 3,
"neighbors": ["DEBW034"],
"data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'},
"transformation": None,}
a_inst = a.build("DEBW011", **kwargs)
print(a_inst)
......@@ -29,7 +29,7 @@ def run(stations=None,
model=None,
batch_size=None,
epochs=None,
data_preparation=None):
data_preparation=None,):
params = inspect.getfullargspec(DefaultWorkflow).args
kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
......@@ -39,9 +39,4 @@ def run(stations=None,
if __name__ == "__main__":
from src.data_handling.advanced_data_handling import CustomDataClass
run(data_preparation=CustomDataClass, statistics_per_var={'o3': 'dma8eu'}, transformation={"scope": "data",
"method": "standardise",
"mean": 50,
"std": 50},
trainable=False, create_new_model=False)
run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, create_new_model=True)
......@@ -18,8 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D
DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
from src.data_handling import DataPrepJoin
from src.data_handling.advanced_data_handling import CustomDataClass
from src.data_handling.advanced_data_handling import DefaultDataPreparation
from src.run_modules.run_environment import RunEnvironment
from src.model_modules.model_class import MyLittleModel as VanillaModel
......@@ -301,9 +300,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("sampling", sampling)
self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
self._set_param("transformation", None, scope="preprocessing")
self._set_param("data_preparation", data_preparation() if data_preparation is not None else None,
default=CustomDataClass())
assert isinstance(getattr(self.data_store.get("data_preparation"), "requirements"), property) is False
self._set_param("data_preparation", data_preparation, default=DefaultDataPreparation)
# target
self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
......@@ -350,6 +347,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS,
scope="general.postprocessing")
self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing
# check variables, statistics and target variable
self._check_target_var()
......
......@@ -11,8 +11,7 @@ import numpy as np
import pandas as pd
from src.data_handling import DataGenerator
from src.data_handling import DataCollection, DataPreparation, StationPrep
from src.data_handling.advanced_data_handling import CustomDataClass
from src.data_handling import DataCollection
from src.helpers import TimeTracking
from src.configuration import path_config
from src.helpers.join import EmptyQueryResult
......@@ -260,10 +259,10 @@ class PreProcessing(RunEnvironment):
logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}")
collection = DataCollection()
valid_stations = []
kwargs = self.data_store.create_args_dict(data_preparation.requirements, scope=set_name)
kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
for station in set_stations:
try:
dp = data_preparation(station, **kwargs)
dp = data_preparation.build(station, **kwargs)
collection.add(dp)
valid_stations.append(station)
except (AttributeError, EmptyQueryResult):
......
......@@ -26,4 +26,4 @@ class Workflow:
"""Run workflow embedded in a run environment and according to the stage's ordering."""
with RunEnvironment():
for stage, kwargs in self._registry.items():
stage(**kwargs)
\ No newline at end of file
stage(**kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment