Commit 5c58d438 authored by lukas leufen's avatar lukas leufen

Merge branch 'lukas_issue197_feat_mixed-sampling' into 'develop'

Resolve "Mixed sampling types"

See merge request !182
parents c63cff8a fe40c9e1
Pipeline #51983 passed with stages
in 7 minutes and 51 seconds
......@@ -49,6 +49,7 @@ DEFAULT_NUMBER_OF_BOOTSTRAPS = 20
DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
"PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
"PlotAvailability"]
DEFAULT_SAMPLING = "daily"
def get_defaults():
......
......@@ -20,7 +20,7 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str:
:param create_new: Create new path if enabled
:param data_path: Parse your custom path (and therefore ignore preset paths fitting to known hosts)
:param sampling: sampling rate to separate data physically by temporal resolution
:param sampling: sampling rate to separate data physically by temporal resolution (deprecated)
:return: full path to data
"""
......@@ -32,17 +32,14 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str:
data_path = f"/home/{user}/Data/toar_{sampling}/"
elif hostname == "zam347":
data_path = f"/home/{user}/Data/toar_{sampling}/"
elif hostname == "linux-aa9b":
data_path = f"/home/{user}/mlair/data/toar_{sampling}/"
elif (len(hostname) > 2) and (hostname[:2] == "jr"):
data_path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/"
elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']):
data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/toar_{sampling}/"
data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/MLAIR/"
elif runner_regex.match(hostname) is not None:
data_path = f"/home/{user}/mlair/data/toar_{sampling}/"
data_path = f"/home/{user}/mlair/data/"
else:
data_path = os.path.join(os.getcwd(), "data", sampling)
# raise OSError(f"unknown host '{hostname}'")
data_path = os.path.join(os.getcwd(), "data")
if not os.path.exists(data_path):
try:
......@@ -97,7 +94,7 @@ def set_experiment_name(name: str = None, sampling: str = None) -> str:
return experiment_name
def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> str:
def set_bootstrap_path(bootstrap_path: str, data_path: str) -> str:
"""
Set path for bootstrap input data.
......@@ -105,12 +102,11 @@ def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> st
:param bootstrap_path: custom path to store bootstrap data
:param data_path: path of data for default bootstrap path
:param sampling: sampling rate to add, if path is set to default
:return: full bootstrap path
"""
if bootstrap_path is None:
bootstrap_path = os.path.join(data_path, "..", f"bootstrap_{sampling}")
bootstrap_path = os.path.join(data_path, "bootstrap")
check_path_and_create(bootstrap_path)
return os.path.abspath(bootstrap_path)
......
......@@ -24,7 +24,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs):
assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution
self._check_sampling(**kwargs)
kz_filter_length = to_list(kz_filter_length)
kz_filter_iter = to_list(kz_filter_iter)
# self.original_data = None # ToDo: implement here something to store unfiltered data
......@@ -34,12 +34,17 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
self.cutoff_period_days = None
super().__init__(*args, **kwargs)
def _check_sampling(self, **kwargs):
assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
self.load_data()
self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit)
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
self.set_inputs_and_targets()
self.apply_kz_filter()
# this is just a code snippet to check the results of the kz filter
......
__author__ = 'Lukas Leufen'
__date__ = '2020-11-05'
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation
from mlair.data_handler import DefaultDataHandler
from mlair.configuration import path_config
from mlair import helpers
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_SAMPLING
import logging
import os
import inspect
import pandas as pd
import xarray as xr
class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
def __init__(self, *args, sampling_inputs, **kwargs):
sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING))
kwargs.update({"sampling": sampling})
super().__init__(*args, **kwargs)
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind],
self.station_type, self.network, self.store_data_locally)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
return data
def set_inputs_and_targets(self):
inputs = self._data[0].sel({self.target_dim: helpers.to_list(self.variables)})
targets = self._data[1].sel({self.target_dim: self.target_var})
self.input_data.data = inputs
self.target_data.data = targets
def setup_data_path(self, data_path, sampling):
"""Sets two paths instead of single path. Expects sampling arg to be a list with two entries"""
assert len(sampling) == 2
return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling))
class DataHandlerMixedSampling(DefaultDataHandler):
"""Data handler using mixed sampling for input and target."""
data_handler = DataHandlerMixedSamplingSingleStation
data_handler_transformation = DataHandlerMixedSamplingSingleStation
_requirements = data_handler.requirements()
class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation,
DataHandlerKzFilterSingleStation):
_requirements1 = DataHandlerKzFilterSingleStation.requirements()
_requirements2 = DataHandlerMixedSamplingSingleStation.requirements()
_requirements = list(set(_requirements1 + _requirements2))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _check_sampling(self, **kwargs):
assert kwargs.get("sampling") == ("hourly", "daily")
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
with daily resolution.
"""
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
self.apply_kz_filter()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
"""Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
data_handler = DataHandlerMixedSamplingWithFilterSingleStation
data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation
_requirements = data_handler.requirements()
......@@ -52,7 +52,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
min_length: int = 0, start=None, end=None, variables=None, **kwargs):
super().__init__() # path, station, statistics_per_var, transformation, **kwargs)
self.station = helpers.to_list(station)
self.path = os.path.abspath(data_path)
self.path = self.setup_data_path(data_path, sampling)
self.statistics_per_var = statistics_per_var
self.do_transformation = transformation is not None
self.input_data, self.target_data = self.setup_transformation(transformation)
......@@ -141,8 +141,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
self.load_data()
self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit)
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
self.set_inputs_and_targets()
if self.do_transformation is True:
self.call_transform()
......@@ -160,7 +162,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
self.make_observation(self.target_dim, self.target_var, self.time_dim)
self.remove_nan(self.time_dim)
def read_data_from_disk(self, source_name=""):
def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None,
store_data_locally=False):
"""
Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
......@@ -168,35 +171,42 @@ class DataHandlerSingleStation(AbstractDataHandler):
cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
set, it is assumed, that data should be saved locally.
"""
source_name = source_name if len(source_name) == 0 else f" from {source_name}"
check_path_and_create(self.path)
file_name = self._set_file_name()
meta_file = self._set_meta_file_name()
check_path_and_create(path)
file_name = self._set_file_name(path, station, statistics_per_var)
meta_file = self._set_meta_file_name(path, station, statistics_per_var)
if self.overwrite_local_data is True:
logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}")
logging.debug(f"overwrite_local_data is true, therefore reload {file_name}")
if os.path.exists(file_name):
os.remove(file_name)
if os.path.exists(meta_file):
os.remove(meta_file)
data, self.meta = self.download_data(file_name, meta_file)
logging.debug(f"loaded new data{source_name}")
data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
station_type=station_type, network=network,
store_data_locally=store_data_locally)
logging.debug(f"loaded new data")
else:
try:
logging.debug(f"try to load local data from: {file_name}")
data = xr.open_dataarray(file_name)
self.meta = pd.read_csv(meta_file, index_col=0)
self.check_station_meta()
meta = pd.read_csv(meta_file, index_col=0)
self.check_station_meta(meta, station, station_type, network)
logging.debug("loading finished")
except FileNotFoundError as e:
logging.debug(e)
logging.debug(f"load new data{source_name}")
data, self.meta = self.download_data(file_name, meta_file)
logging.debug(f"load new data")
data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
station_type=station_type, network=network,
store_data_locally=store_data_locally)
logging.debug("loading finished")
# create slices and check for negative concentration.
data = self._slice_prep(data)
self._data = self.check_for_negative_concentrations(data)
data = self.check_for_negative_concentrations(data)
return data, meta
def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
@staticmethod
def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling,
station_type=None, network=None, store_data_locally=True) -> [xr.DataArray,
pd.DataFrame]:
"""
Download data from TOAR database using the JOIN interface.
......@@ -209,36 +219,37 @@ class DataHandlerSingleStation(AbstractDataHandler):
:return: downloaded data and its meta data
"""
df_all = {}
df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var,
station_type=self.station_type, network_name=self.network, sampling=self.sampling)
df_all[self.station[0]] = df
df, meta = join.download_join(station_name=station, stat_var=statistics_per_var, station_type=station_type,
network_name=network, sampling=sampling)
df_all[station[0]] = df
# convert df_all to xarray
xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
xarr = xr.Dataset(xarr).to_array(dim='Stations')
if self.store_data_locally is True:
if store_data_locally is True:
# save locally as nc/csv file
xarr.to_netcdf(path=file_name)
meta.to_csv(meta_file)
return xarr, meta
def download_data(self, file_name, meta_file):
data, meta = self.download_data_from_join(file_name, meta_file)
def download_data(self, *args, **kwargs):
data, meta = self.download_data_from_join(*args, **kwargs)
return data, meta
def check_station_meta(self):
@staticmethod
def check_station_meta(meta, station, station_type, network):
"""
Search for the entries in meta data and compare the value with the requested values.
Will raise a FileNotFoundError if the values mismatch.
"""
if self.station_type is not None:
check_dict = {"station_type": self.station_type, "network_name": self.network}
if station_type is not None:
check_dict = {"station_type": station_type, "network_name": network}
for (k, v) in check_dict.items():
if v is None:
continue
if self.meta.at[k, self.station[0]] != v:
if meta.at[k, station[0]] != v:
logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new "
f"grapping from web.")
raise FileNotFoundError
......@@ -257,10 +268,14 @@ class DataHandlerSingleStation(AbstractDataHandler):
"""
chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
"propane", "so2", "toluene"]
# used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys()))
used_chem_vars = list(set(chem_vars) & set(self.variables))
data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
return data
def setup_data_path(self, data_path: str, sampling: str):
return os.path.join(os.path.abspath(data_path), sampling)
def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray:
"""
Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
......@@ -303,15 +318,18 @@ class DataHandlerSingleStation(AbstractDataHandler):
res.name = index_name
return res
def _set_file_name(self):
all_vars = sorted(self.statistics_per_var.keys())
return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc")
@staticmethod
def _set_file_name(path, station, statistics_per_var):
all_vars = sorted(statistics_per_var.keys())
return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc")
def _set_meta_file_name(self):
all_vars = sorted(self.statistics_per_var.keys())
return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv")
@staticmethod
def _set_meta_file_name(path, station, statistics_per_var):
all_vars = sorted(statistics_per_var.keys())
return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv")
def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
@staticmethod
def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
**kwargs):
"""
Interpolate values according to different methods.
......@@ -349,8 +367,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
:return: xarray.DataArray
"""
self._data = self._data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
**kwargs)
return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)
def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
"""
......@@ -452,25 +469,6 @@ class DataHandlerSingleStation(AbstractDataHandler):
"""
return data.loc[{coord: slice(str(start), str(end))}]
def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
"""
Set all negative concentrations to zero.
Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
#2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
"o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
:param data: data array containing variables to check
:param minimum: minimum value, by default this should be 0
:return: corrected data
"""
chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5",
"propane", "so2", "toluene"]
used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys()))
data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
return data
@staticmethod
def setup_transformation(transformation: statistics.TransformationClass):
"""
......@@ -490,13 +488,6 @@ class DataHandlerSingleStation(AbstractDataHandler):
else:
raise NotImplementedError("Cannot handle this.")
def load_data(self):
try:
self.read_data_from_disk()
except FileNotFoundError:
self.download_data()
self.load_data()
def transform(self, data_class, dim: Union[str, int] = 0, transform_method: str = 'standardise',
inverse: bool = False, mean=None,
std=None, min=None, max=None) -> None:
......
......@@ -30,7 +30,7 @@ class DefaultDataHandler(AbstractDataHandler):
_requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"])
def __init__(self, id_class: data_handler, data_path: str, min_length: int = 0,
def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0,
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None,
store_processed_data=True):
super().__init__()
......@@ -42,7 +42,7 @@ class DefaultDataHandler(AbstractDataHandler):
self._X_extreme = None
self._Y_extreme = None
_name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle")
self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle")
self._collection = self._create_collection()
self.harmonise_X()
self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim)
......
......@@ -17,7 +17,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \
DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING
from mlair.data_handler import DefaultDataHandler
from mlair.run_modules.run_environment import RunEnvironment
from mlair.model_modules.model_class import MyLittleModel as VanillaModel
......@@ -184,7 +184,7 @@ class ExperimentSetup(RunEnvironment):
training) set for a second time to the sample. If multiple valus are given, a sample is added for each
exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given
`extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default,
upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are
upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are
actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using
``extremes_on_right_tail_only``. This can be useful for positive skew variables.
:param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only``
......@@ -214,20 +214,25 @@ class ExperimentSetup(RunEnvironment):
dimensions=None,
time_dim=None,
interpolation_method=None,
interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None, fraction_of_train: float = None,
experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data = None, sampling: str = "daily",
create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None,
interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None,
test_start=None,
test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None,
fraction_of_train: float = None,
experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data=None,
sampling: str = None,
create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, **kwargs):
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None,
sampling_outputs=None, **kwargs):
# create run framework
super().__init__()
# experiment setup, hyperparameters
self._set_param("data_path", path_config.prepare_host(data_path=data_path, sampling=sampling))
self._set_param("data_path", path_config.prepare_host(data_path=data_path))
self._set_param("hostname", path_config.get_host())
self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST)
self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST)
......@@ -235,7 +240,7 @@ class ExperimentSetup(RunEnvironment):
if self.data_store.get("create_new_model"):
train_model = True
data_path = self.data_store.get("data_path")
bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path, sampling)
bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path)
self._set_param("bootstrap_path", bootstrap_path)
self._set_param("train_model", train_model, default=DEFAULT_TRAIN_MODEL)
self._set_param("fraction_of_training", fraction_of_train, default=DEFAULT_FRACTION_OF_TRAINING)
......@@ -250,6 +255,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
# set experiment name
sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling
experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling)
experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path)
self._set_param("experiment_name", experiment_name)
......@@ -287,7 +293,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
scope="preprocessing")
self._set_param("sampling", sampling)
self._set_param("sampling_inputs", sampling_inputs, default=sampling)
self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
self._set_param("transformation", None, scope="preprocessing")
self._set_param("data_handler", data_handler, default=DefaultDataHandler)
......@@ -356,7 +362,7 @@ class ExperimentSetup(RunEnvironment):
f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}")
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general",
apply: Callable = None) -> None:
apply: Callable = None) -> Any:
"""Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value
to a list use apply=helpers.to_list)."""
if value is None and default is not None:
......@@ -365,6 +371,7 @@ class ExperimentSetup(RunEnvironment):
value = apply(value)
self.data_store.set(param, value, scope)
logging.debug(f"set experiment attribute: {param}({scope})={value}")
return value
def _compare_variables_and_statistics(self):
"""
......
__author__ = "Lukas Leufen"
__date__ = '2019-11-14'
import argparse
from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilter
def main(parser_args):
args = dict(sampling="hourly",
window_history_size=24, **parser_args.__dict__,
data_handler=DataHandlerKzFilter,
kz_filter_length=[365 * 24, 20 * 24], # 13,5# , 4 * 24, 12, 6],
kz_filter_iter=[3, 5], # 3,4# , 3, 4, 4],
start="2006-01-01",
train_start="2006-01-01",
end="2011-12-31",
test_end="2011-12-31",
stations=["DEBW107", "DEBW013"]
)
workflow = DefaultWorkflow(**args)
workflow.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
__author__ = "Lukas Leufen"
__date__ = '2019-11-14'
import argparse
from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter
def main(parser_args):
args = dict(sampling="daily",
sampling_inputs="hourly",
window_history_size=72,
**parser_args.__dict__,
data_handler=DataHandlerMixedSampling, # WithFilter,
kz_filter_length=[365 * 24, 20 * 24],
kz_filter_iter=[3, 5],
start="2006-01-01",
train_start="2006-01-01",
end="2011-12-31",
test_end="2011-12-31",
stations=["DEBW107", "DEBW013"],
epochs=100,
)
workflow = DefaultWorkflow(**args)
workflow.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
......@@ -11,22 +11,21 @@ from mlair.helpers import PyTestRegex
class TestPrepareHost: