Commit 45a48051 authored by lukas leufen's avatar lukas leufen

Merge branch 'lukas_issue196_feat_separation-of-scales' into 'develop'

Resolve "Separation of Scales"

See merge request !186
parents 48ddffd4 77cda5f8
Pipeline #52203 passed with stages
in 16 minutes and 1 second
......@@ -48,7 +48,7 @@ DEFAULT_CREATE_NEW_BOOTSTRAPS = False
DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
"PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
"PlotAvailability", "PlotSeparationOfScales"]
......@@ -25,11 +25,9 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs):
kz_filter_length = to_list(kz_filter_length)
kz_filter_iter = to_list(kz_filter_iter)
# self.original_data = None # ToDo: implement here something to store unfiltered data
self.kz_filter_length = kz_filter_length
self.kz_filter_iter = kz_filter_iter
self.kz_filter_length = to_list(kz_filter_length)
self.kz_filter_iter = to_list(kz_filter_iter)
self.cutoff_period = None
self.cutoff_period_days = None
super().__init__(*args, **kwargs)
......@@ -4,15 +4,15 @@ __date__ = '2020-11-05'
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation
from mlair.data_handler import DefaultDataHandler
from mlair.configuration import path_config
from mlair import helpers
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_SAMPLING
import logging
import os
import inspect
from typing import Callable
import datetime as dt
import numpy as np
import pandas as pd
import xarray as xr
......@@ -37,7 +37,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind],
self.station_type,, self.store_data_locally)
self.station_type,, self.store_data_locally, self.start, self.end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
return data
......@@ -88,6 +88,33 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
def estimate_filter_width(self):
f = 0.5 / (len * sqrt(itr)) -> T = 1 / f
return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2)
def _add_time_delta(date, delta):
new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta)
return new_date.strftime("%Y-%m-%d")
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
if ind == 0: # for inputs
estimated_filter_width = self.estimate_filter_width()
start = self._add_time_delta(self.start, -estimated_filter_width)
end = self._add_time_delta(self.end, estimated_filter_width)
else: # target
start, end = self.start, self.end
data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind],
self.station_type,, self.store_data_locally, start, end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
return data
class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
"""Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
......@@ -95,3 +122,80 @@ class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
data_handler = DataHandlerMixedSamplingWithFilterSingleStation
data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation
_requirements = data_handler.requirements()
class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation):
Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the
separation frequency of a filtered time series the time step delta for input data is adjusted (see image below).
.. image:: ../../../../../_source/_plots/separation_of_scales.png
:width: 400
_requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
def __init__(self, *args, time_delta=np.sqrt, **kwargs):
assert isinstance(time_delta, Callable)
self.time_delta = time_delta
super().__init__(*args, **kwargs)
def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
Create a xr.DataArray containing history data.
Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
data. This is used to represent history in the data. Results are stored in history attribute.
:param dim_name_of_inputs: Name of dimension which contains the input variables
:param window: number of time steps to look back in history
Note: window will be treated as negative value. This should be in agreement with looking back on
a time line. Nonetheless positive values are allowed but they are converted to its negative
:param dim_name_of_shift: Dimension along shift will be applied
window = -abs(window)
data =
self.history = self.stride(data, dim_name_of_shift, window)
def stride(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray:
# this is just a code snippet to check the results of the kz filter
# import matplotlib
# matplotlib.use("TkAgg")
# import matplotlib.pyplot as plt
# xr.concat(res, dim="filter").sel({"variables":"temp", "Stations":"DEBW107", "datetime":"2010-01-01T00:00:00"}).plot.line(hue="filter")
time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int)
start, end = window, 1
res = []
window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim)
for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]):
res_filter = []
data_filter = data.sel({"filter": filter_name})
for w in range(start, end):
res_filter.append(data_filter.shift({dim: -w * delta}))
res_filter = xr.concat(res_filter, dim=window_array).chunk()
res = xr.concat(res, dim="filter")
return res
def estimate_filter_width(self):
Attention: this method returns the maximum value of
* either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or
* time delta method applied on the estimated filter width mupliplied by window_history_size
to provide a sufficiently wide filter width.
est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2
return int(max([self.time_delta(est) * self.window_history_size, est]))
class DataHandlerMixedSamplingSeparationOfScales(DefaultDataHandler):
"""Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step
sizes are applied in relation to frequencies."""
data_handler = DataHandlerMixedSamplingSeparationOfScalesSingleStation
data_handler_transformation = DataHandlerMixedSamplingSeparationOfScalesSingleStation
_requirements = data_handler.requirements()
......@@ -142,7 +142,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
Setup samples. This method prepares and creates samples X, and labels Y.
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type,, self.store_data_locally)
self.station_type,, self.store_data_locally, self.start, self.end)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
......@@ -163,7 +163,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None,
store_data_locally=False, start=None, end=None):
Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
......@@ -199,7 +199,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
logging.debug("loading finished")
# create slices and check for negative concentration.
data = self._slice_prep(data)
data = self._slice_prep(data, start=start, end=end)
data = self.check_for_negative_concentrations(data)
return data, meta
......@@ -442,7 +442,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
self.label = self.label.sel({dim: intersect})
self.observation = self.observation.sel({dim: intersect})
def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray:
def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray:
Set start and end date for slicing and execute self._slice().
......@@ -451,9 +451,9 @@ class DataHandlerSingleStation(AbstractDataHandler):
:return: sliced data
start = self.start if self.start is not None else data.coords[coord][0].values
end = self.end if self.end is not None else data.coords[coord][-1].values
return self._slice(data, start, end, coord)
start = start if start is not None else data.coords[self.time_dim][0].values
end = end if end is not None else data.coords[self.time_dim][-1].values
return self._slice(data, start, end, self.time_dim)
def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
......@@ -25,6 +25,11 @@ from mlair.helpers import TimeTrackingWrapper
# import matplotlib
# matplotlib.use("TkAgg")
# import matplotlib.pyplot as plt
class AbstractPlotClass:
Abstract class for all plotting routines to unify plot workflow.
......@@ -72,6 +77,9 @@ class AbstractPlotClass:
def __init__(self, plot_folder, plot_name, resolution=500):
"""Set up plot folder and name, and plot resolution (default 500dpi)."""
plot_folder = os.path.abspath(plot_folder)
if not os.path.exists(plot_folder):
self.plot_folder = plot_folder
self.plot_name = plot_name
self.resolution = resolution
......@@ -82,7 +90,7 @@ class AbstractPlotClass:
def _save(self, **kwargs):
"""Store plot locally. Name of and path to plot need to be set on initialisation."""
plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf")
plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution, **kwargs)
......@@ -995,10 +1003,31 @@ class PlotAvailability(AbstractPlotClass):
return lgd
class PlotSeparationOfScales(AbstractPlotClass):
def __init__(self, collection: DataCollection, plot_folder: str = "."):
# create standard Gantt plot for all stations (currently in single pdf file with single page)
plot_folder = os.path.join(plot_folder, "separation_of_scales")
super().__init__(plot_folder, "separation_of_scales")
def _plot(self, collection: DataCollection):
orig_plot_name = self.plot_name
for dh in collection:
data = dh.get_X(as_numpy=False)[0]
station = dh.id_class.station[0]
data = data.sel(Stations=station)
# plt.subplots()
data.plot(x="datetime", y="window", col="filter", row="variables", robust=True)
self.plot_name = f"{orig_plot_name}_{station}"
if __name__ == "__main__":
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
path = "../../testrun_network/forecasts"
plt_path = "../../"
con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path)
......@@ -19,7 +19,8 @@ from mlair.helpers import TimeTracking, statistics, extract_value
from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
from mlair.model_modules.model_class import AbstractModelClass
from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles, \
from mlair.run_modules.run_environment import RunEnvironment
......@@ -262,7 +263,10 @@ class PostProcessing(RunEnvironment):
plot_list = self.data_store.get("plot_list", "postprocessing")
time_dimension = self.data_store.get("time_dim")
if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list:
if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ("PlotSeparationOfScales" in plot_list):
PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path)
if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list):
PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN")
if "PlotConditionalQuantiles" in plot_list:
......@@ -207,6 +207,7 @@ class PreProcessing(RunEnvironment):"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}")
# calculate transformation using train data
if set_name == "train":"setup transformation using train data exclusively")
self.transformation(data_handler, set_stations)
# start station check
collection = DataCollection()
......@@ -4,17 +4,18 @@ __date__ = '2019-11-14'
import argparse
from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \
def main(parser_args):
args = dict(sampling="daily",
data_handler=DataHandlerMixedSampling, # WithFilter,
kz_filter_length=[365 * 24, 20 * 24],
kz_filter_iter=[3, 5],
kz_filter_length=[100 * 24, 15 * 24],
kz_filter_iter=[4, 5],
......@@ -70,4 +70,4 @@ class TestAllDefaults:
assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore",
"PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore",
"PlotConditionalQuantiles", "PlotAvailability"]
"PlotConditionalQuantiles", "PlotAvailability", "PlotSeparationOfScales"]
import pytest
import inspect
from mlair.data_handler.abstract_data_handler import AbstractDataHandler
class TestDefaultDataHandler:
def test_required_attributes(self):
dh = AbstractDataHandler
assert hasattr(dh, "_requirements")
assert hasattr(dh, "__init__")
assert hasattr(dh, "build")
assert hasattr(dh, "requirements")
assert hasattr(dh, "own_args")
assert hasattr(dh, "transformation")
assert hasattr(dh, "get_X")
assert hasattr(dh, "get_Y")
assert hasattr(dh, "get_data")
assert hasattr(dh, "get_coordinates")
def test_init(self):
assert isinstance(AbstractDataHandler(), AbstractDataHandler)
def test_build(self):
assert isinstance(, AbstractDataHandler)
def test_requirements(self):
dh = AbstractDataHandler()
assert isinstance(dh._requirements, list)
assert len(dh._requirements) == 0
assert isinstance(dh.requirements(), list)
assert len(dh.requirements()) == 0
def test_own_args(self):
dh = AbstractDataHandler()
assert isinstance(dh.own_args(), list)
assert len(dh.own_args()) == 0
assert "self" not in dh.own_args()
def test_transformation(self):
assert AbstractDataHandler.transformation() is None
def test_get_X(self):
dh = AbstractDataHandler()
with pytest.raises(NotImplementedError):
assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_X).args)
assert (False, False) == inspect.getfullargspec(dh.get_X).defaults
def test_get_Y(self):
dh = AbstractDataHandler()
with pytest.raises(NotImplementedError):
assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_Y).args)
assert (False, False) == inspect.getfullargspec(dh.get_Y).defaults
def test_get_data(self):
dh = AbstractDataHandler()
with pytest.raises(NotImplementedError):
assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_data).args)
assert (False, False) == inspect.getfullargspec(dh.get_data).defaults
def test_get_coordinates(self):
dh = AbstractDataHandler()
assert dh.get_coordinates() is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment