Commit 99d89009 authored by lukas leufen's avatar lukas leufen

modified setup_samples for all data handlers, new run script for mixed sampling

parent 1fb90fd3
Pipeline #51137 passed with stages
in 8 minutes and 7 seconds
......@@ -38,8 +38,10 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
self.load_data()
self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit)
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
self.set_inputs_and_targets()
self.apply_kz_filter()
# this is just a code snippet to check the results of the kz filter
......
......@@ -2,67 +2,60 @@ __author__ = 'Lukas Leufen'
__date__ = '2020-11-05'
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler import DefaultDataHandler
from mlair.configuration import path_config
from mlair import helpers
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_SAMPLING
import logging
import os
import inspect
import pandas as pd
import xarray as xr
class DataHandlerMixedSampling(DataHandlerSingleStation):
class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
def __init__(self, *args, sampling_inputs, **kwargs):
sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING))
kwargs.update({"sampling": sampling})
super().__init__(*args, **kwargs)
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
self.load_data()
self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit)
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def load_data(self):
try:
self.read_data_from_disk()
except FileNotFoundError:
self.download_data()
self.load_data()
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind],
self.station_type, self.network, self.store_data_locally)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
return data
def read_data_from_disk(self, source_name=""):
"""
Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
def set_inputs_and_targets(self):
inputs = self._data[0].sel({self.target_dim: helpers.to_list(self.variables)})
targets = self._data[1].sel({self.target_dim: self.target_var})
self.input_data.data = inputs
self.target_data.data = targets
Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
set, it is assumed, that data should be saved locally.
"""
source_name = source_name if len(source_name) == 0 else f" from {source_name}"
path_config.check_path_and_create(self.path)
file_name = self._set_file_name()
meta_file = self._set_meta_file_name()
if self.overwrite_local_data is True:
logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}")
if os.path.exists(file_name):
os.remove(file_name)
if os.path.exists(meta_file):
os.remove(meta_file)
data, self.meta = self.download_data(file_name, meta_file)
logging.debug(f"loaded new data{source_name}")
else:
try:
logging.debug(f"try to load local data from: {file_name}")
data = xr.open_dataarray(file_name)
self.meta = pd.read_csv(meta_file, index_col=0)
self.check_station_meta()
logging.debug("loading finished")
except FileNotFoundError as e:
logging.debug(e)
logging.debug(f"load new data{source_name}")
data, self.meta = self.download_data(file_name, meta_file)
logging.debug("loading finished")
# create slices and check for negative concentration.
data = self._slice_prep(data)
self._data = self.check_for_negative_concentrations(data)
def setup_data_path(self, data_path, sampling):
"""Sets two paths instead of single path. Expects sampling arg to be a list with two entries"""
assert len(sampling) == 2
return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling))
class DataHandlerMixedSampling(DefaultDataHandler):
"""Data handler using mixed sampling for input and target."""
data_handler = DataHandlerMixedSamplingSingleStation
data_handler_transformation = DataHandlerMixedSamplingSingleStation
_requirements = data_handler.requirements()
......@@ -141,9 +141,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, self.station_type, self.network,
self.store_data_locally)
self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit)
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
self.set_inputs_and_targets()
if self.do_transformation is True:
self.call_transform()
......@@ -179,27 +180,28 @@ class DataHandlerSingleStation(AbstractDataHandler):
os.remove(file_name)
if os.path.exists(meta_file):
os.remove(meta_file)
data, self.meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
station_type=station_type, network=network,
store_data_locally=store_data_locally)
data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
station_type=station_type, network=network,
store_data_locally=store_data_locally)
logging.debug(f"loaded new data")
else:
try:
logging.debug(f"try to load local data from: {file_name}")
data = xr.open_dataarray(file_name)
self.meta = pd.read_csv(meta_file, index_col=0)
self.check_station_meta(station, station_type, network)
meta = pd.read_csv(meta_file, index_col=0)
self.check_station_meta(meta, station, station_type, network)
logging.debug("loading finished")
except FileNotFoundError as e:
logging.debug(e)
logging.debug(f"load new data")
data, self.meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
station_type=station_type, network=network,
store_data_locally=store_data_locally)
data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
station_type=station_type, network=network,
store_data_locally=store_data_locally)
logging.debug("loading finished")
# create slices and check for negative concentration.
data = self._slice_prep(data)
self._data = self.check_for_negative_concentrations(data)
data = self.check_for_negative_concentrations(data)
return data, meta
@staticmethod
def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling,
......@@ -233,7 +235,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
data, meta = self.download_data_from_join(*args, **kwargs)
return data, meta
def check_station_meta(self, station, station_type, network):
@staticmethod
def check_station_meta(meta, station, station_type, network):
"""
Search for the entries in meta data and compare the value with the requested values.
......@@ -244,9 +247,9 @@ class DataHandlerSingleStation(AbstractDataHandler):
for (k, v) in check_dict.items():
if v is None:
continue
if self.meta.at[k, station[0]] != v:
if meta.at[k, station[0]] != v:
logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
f"{self.meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new "
f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new "
f"grapping from web.")
raise FileNotFoundError
......@@ -270,8 +273,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
return data
@staticmethod
def setup_data_path(data_path, sampling):
def setup_data_path(self, data_path: str, sampling: str):
return os.path.join(os.path.abspath(data_path), sampling)
def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray:
......@@ -326,7 +328,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
all_vars = sorted(statistics_per_var.keys())
return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv")
def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
@staticmethod
def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
**kwargs):
"""
Interpolate values according to different methods.
......@@ -364,8 +367,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
:return: xarray.DataArray
"""
self._data = self._data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate,
**kwargs)
return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)
def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
"""
......
......@@ -184,7 +184,7 @@ class ExperimentSetup(RunEnvironment):
training) set for a second time to the sample. If multiple valus are given, a sample is added for each
exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given
`extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default,
upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are
upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are
actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using
``extremes_on_right_tail_only``. This can be useful for positive skew variables.
:param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only``
......@@ -255,7 +255,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
# set experiment name
sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING)
sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling
experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling)
experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path)
self._set_param("experiment_name", experiment_name)
......
__author__ = "Lukas Leufen"
__date__ = '2019-11-14'
import argparse
from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling
def main(parser_args):
args = dict(sampling="daily",
sampling_inputs="hourly",
window_history_size=72,
**parser_args.__dict__,
data_handler=DataHandlerMixedSampling,
start="2006-01-01",
train_start="2006-01-01",
end="2011-12-31",
test_end="2011-12-31",
stations=["DEBW107", "DEBW013"],
epochs=100,
)
workflow = DefaultWorkflow(**args)
workflow.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment