Commit 7c403660 authored by lukas leufen's avatar lukas leufen

kz is now only applied to input data, but there is an issue with the scaling.

parent 04b28e2b
Pipeline #50786 passed with stages
in 7 minutes and 54 seconds
......@@ -7,16 +7,19 @@ import inspect
import numpy as np
import pandas as pd
import xarray as xr
from typing import List
from typing import List, Union
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler import DefaultDataHandler
from mlair.helpers import remove_items, to_list, TimeTrackingWrapper
from mlair.helpers.statistics import KolmogorovZurbenkoFilterMovingWindow as KZFilter
# define a more general date type for type hinting
str_or_list = Union[str, List[str]]
class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
"""Data handler for a single station to be used by a superior data handler. Data is kz filtered."""
"""Data handler for a single station to be used by a superior data handler. Inputs are kz filtered."""
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
......@@ -29,6 +32,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
self.kz_filter_iter = kz_filter_iter
self.cutoff_period = None
self.cutoff_period_days = None
self.data_target: xr.DataArray = None
super().__init__(*args, **kwargs)
def setup_samples(self):
......@@ -50,6 +54,8 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
@TimeTrackingWrapper
def apply_kz_filter(self):
"""Apply kolmogorov zurbenko filter only on inputs."""
self.data_target = self.data.sel({self.target_dim: [self.target_var]})
kz = KZFilter(self.data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime")
filtered_data: List[xr.DataArray] = kz.run()
self.cutoff_period = kz.period_null()
......@@ -69,6 +75,36 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
index = list(map(lambda x: str(x) + "d", index)) + ["res"]
return pd.Index(index, name="filter")
def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
window: int) -> None:
"""
Create a xr.DataArray containing labels.
Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
attribute.
:param dim_name_of_target: Name of dimension which contains the target variable
:param target_var: Name of target variable in 'dimension'
:param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
:param window: lead time of label
"""
window = abs(window)
data = self.data_target.sel({dim_name_of_target: target_var})
self.label = self.shift(data, dim_name_of_shift, window)
def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
"""
Create a xr.DataArray containing observations.
Observations are defined as value of the current time step t. Set observation attribute.
:param dim_name_of_target: Name of dimension which contains the observation variable
:param target_var: Name of observation variable(s) in 'dimension'
:param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
"""
data = self.data_target.sel({dim_name_of_target: target_var})
self.observation = self.shift(data, dim_name_of_shift, 0)
def get_transposed_history(self) -> xr.DataArray:
"""Return history.
......@@ -76,13 +112,6 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
"""
return self.history.transpose("datetime", "window", "Stations", "variables", "filter").copy()
def get_transposed_label(self) -> xr.DataArray:
"""Return label.
:return: label with dimensions datetime*, window*, Stations, variables.
"""
return self.label.squeeze("Stations").transpose("datetime", "window", "filter").copy()
class DataHandlerKzFilter(DefaultDataHandler):
"""Data handler using kz filtered data."""
......
......@@ -74,7 +74,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
self.end = end
# internal
self.data = None
self.data: xr.DataArray = None
self.meta = None
self.variables = list(statistics_per_var.keys()) if variables is None else variables
self.history = None
......@@ -284,10 +284,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
return data
def shift(self, dim: str, window: int) -> xr.DataArray:
def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray:
"""
Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
:param data: data set to shift
:param dim: dimension along shift is applied
:param window: number of steps to shift (corresponds to the window length)
......@@ -301,7 +302,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
end = window + 1
res = []
for w in range(start, end):
res.append(self.data.shift({dim: -w}))
res.append(data.shift({dim: -w}))
window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim)
res = xr.concat(res, dim=window_array)
return res
......@@ -389,7 +390,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
:param dim_name_of_shift: Dimension along shift will be applied
"""
window = -abs(window)
self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})
data = self.data.sel({dim_name_of_inputs: self.variables})
self.history = self.shift(data, dim_name_of_shift, window)
def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
window: int) -> None:
......@@ -405,7 +407,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
:param window: lead time of label
"""
window = abs(window)
self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
data = self.data.sel({dim_name_of_target: target_var})
self.label = self.shift(data, dim_name_of_shift, window)
def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
"""
......@@ -417,7 +420,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
:param target_var: Name of observation variable(s) in 'dimension'
:param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
"""
self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})
data = self.data.sel({dim_name_of_target: target_var})
self.observation = self.shift(data, dim_name_of_shift, 0)
def remove_nan(self, dim: str) -> None:
"""
......
......@@ -456,7 +456,7 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
It create the variables associate with the KolmogorovZurbenkoFilterMovingWindow class.
Args:
df(pd.DataFrame): time series of a variable
df(pd.DataFrame, xr.DataArray): time series of a variable
wl: window length
itr: number of iteration
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment