Commit a8cc85f1 authored by lukas leufen's avatar lukas leufen

MLAir is now independent from the window_lead_time_parameter (is extracted from Y shape)

parent 312f4f18
......@@ -13,7 +13,7 @@ import datetime as dt
import shutil
import inspect
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict
import logging
from functools import reduce
from src.data_handling.station_preparation import StationPrep
......@@ -71,7 +71,7 @@ class AbstractDataPreparation:
def transformation(cls, *args, **kwargs):
raise NotImplementedError
return None
def get_X(self, upsampling=False, as_numpy=False):
raise NotImplementedError
......@@ -82,8 +82,8 @@ class AbstractDataPreparation:
def get_data(self, upsampling=False, as_numpy=False):
return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
def get_coordinates(self):
return None, None
def get_coordinates(self) -> Union[None, Dict]:
return None
class DefaultDataPreparation(AbstractDataPreparation):
......@@ -17,7 +17,6 @@ import os
from collections import Iterator, Iterable
from itertools import chain
import dask.array as da
import numpy as np
import xarray as xr
......@@ -69,13 +68,12 @@ class BootstrapIterator(Iterator):
return d.values
def shuffle(data: da.array) -> da.core.Array:
def shuffle(data: np.ndarray) -> np.ndarray:
Shuffle randomly from given data (draw elements with replacement).
:param data: data to shuffle
:param chunks: chunk size for dask
:return: shuffled data as dask core array (not computed yet)
:return: shuffled data as numpy array
size = data.shape
return np.random.choice(data.reshape(-1, ), size=size)
......@@ -131,28 +129,3 @@ class BootStraps(Iterable):
prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
vals = np.tile(, (self.number_of_bootstraps, 1))
return vals[~np.isnan(vals).any(axis=1), :]
if __name__ == "__main__":
from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.run_environment import RunEnvironment
from src.run_modules.pre_processing import PreProcessing
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
with RunEnvironment() as run_env:
ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'],
station_type='background', trainable=True, window_history_size=9)
data = run_env.data_store.get("generator", "general.test")
number_bootstraps = 10
boots = BootStraps(data, number_bootstraps)
for b in boots.boot_strap_generator():
a, c = b"len is {len(boots.get_boot_strap_meta())}")
......@@ -3,4 +3,4 @@
from .testing import PyTestRegex, PyTestAllEqual
from .time_tracking import TimeTracking, TimeTrackingWrapper
from .logger import Logger
from .helpers import remove_items, float_round, dict_to_xarray, to_list
from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value
......@@ -92,3 +92,10 @@ def remove_items(obj: Union[List, Dict], items: Any):
return remove_from_dict(obj, items)
raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
def extract_value(encapsulated_value):
return extract_value(encapsulated_value[0])
except TypeError:
return encapsulated_value
......@@ -19,7 +19,6 @@ import xarray as xr
from matplotlib.backends.backend_pdf import PdfPages
from src import helpers
from src.data_handling import DataGenerator
from src.data_handling.iterator import DataCollection
from src.helpers import TimeTrackingWrapper
......@@ -881,7 +880,7 @@ class PlotAvailability(AbstractPlotClass):
def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily",
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
summary_name="data availability", time_dimension="datetime"):
# create standard Gantt plot for all stations (currently in single pdf file with single page)
......@@ -927,7 +926,7 @@ class PlotAvailability(AbstractPlotClass):
plt_dict[str(station)].update({subset: t2})
return plt_dict
def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str):
def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
plt_dict = {}
for subset, data_collection in generators.items():
all_data = None
......@@ -31,8 +31,6 @@ class ModelSetup(RunEnvironment):
* `trainable` [.]
* `create_new_model` [.]
* `generator` [train]
* `window_lead_time` [.]
* `window_history_size` [.]
* `model_class` [.]
Optional objects
......@@ -15,7 +15,7 @@ import xarray as xr
from src.data_handling import BootStraps, KerasIterator
from src.helpers.datastore import NameNotFoundInDataStore
from src.helpers import TimeTracking, statistics
from src.helpers import TimeTracking, statistics, extract_value
from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src.model_modules.model_class import AbstractModelClass
from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
......@@ -42,7 +42,7 @@ class PostProcessing(RunEnvironment):
* `model_path` [.]
* `target_var` [.]
* `sampling` [.]
* `window_lead_time` [.]
* `output_shape` [model]
* `evaluate_bootstraps` [postprocessing] and if enabled:
* `create_new_bootstraps` [postprocessing]
......@@ -74,6 +74,7 @@ class PostProcessing(RunEnvironment):
self.plot_path: str = self.data_store.get("plot_path")
self.target_var = self.data_store.get("target_var")
self._sampling = self.data_store.get("sampling")
self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
self.skill_scores = None
self.bootstrap_skill_scores = None
......@@ -182,7 +183,6 @@ class PostProcessing(RunEnvironment):
# extract all requirements from data store
bootstrap_path = self.data_store.get("bootstrap_path")
forecast_path = self.data_store.get("forecast_path")
window_lead_time = self.data_store.get("window_lead_time")
number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
forecast_file = f""
bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
......@@ -203,14 +203,14 @@ class PostProcessing(RunEnvironment):
orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
# calculate skill scores for each variable
skill = pd.DataFrame(columns=range(1, window_lead_time + 1))
skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1))
for boot_set in bootstraps:
boot_var = f"{boot_set[0]}_{boot_set[1]}"
file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc")
boot_data = xr.open_dataarray(file_name)
boot_data = boot_data.combine_first(labels).combine_first(orig)
boot_scores = []
for ahead in range(1, window_lead_time + 1):
for ahead in range(1, self.window_lead_time + 1):
data = boot_data.sel(ahead=ahead)
skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig"))
......@@ -429,8 +429,7 @@ class PostProcessing(RunEnvironment):
tmp_persi = data.copy()
if not normalised:
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time")
persistence_prediction.values = np.tile(tmp_persi, (window_lead_time, 1)).T
persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
return persistence_prediction
def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray,
......@@ -547,7 +546,6 @@ class PostProcessing(RunEnvironment):
:return: competitive and climatological skill scores
path = self.data_store.get("forecast_path")
window_lead_time = self.data_store.get("window_lead_time")
skill_score_competitive = {}
skill_score_climatological = {}
for station in self.test_data:
......@@ -555,7 +553,7 @@ class PostProcessing(RunEnvironment):
data = xr.open_dataarray(file)
skill_score = statistics.SkillScores(data)
external_data = self._get_external_data(station)
skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
return skill_score_competitive, skill_score_climatological
......@@ -10,7 +10,6 @@ from typing import Tuple, Dict, List
import numpy as np
import pandas as pd
from src.data_handling import DataGenerator
from src.data_handling import DataCollection
from src.helpers import TimeTracking
from src.configuration import path_config
......@@ -196,49 +195,6 @@ class PreProcessing(RunEnvironment):
self.data_store.set("stations", valid_stations, scope=set_name)
self.data_store.set("data_collection", collection, scope=set_name)
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True,
Check if all given stations in `all_stations` are valid.
Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
loading time are logged in debug mode.
:param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
`variables`, `interpolate_dim`, `target_dim`, `target_var`).
:param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
:param all_stations: All stations to check.
:param name: name to display in the logging info message
:return: Corrected list containing only valid station IDs.
t_outer = TimeTracking()
t_inner = TimeTracking(start=False)"check valid stations started{' (%s)' % name if name else ''}")
valid_stations = []
# all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
data_gen = DataGenerator(**args, **kwargs)
for pos, station in enumerate(all_stations):"check station {station} ({pos + 1} / {len(all_stations)})")
data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
if data.history is None:
raise AttributeError
f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
logging.debug(f"{station}: loading time = {t_inner}")
except (AttributeError, EmptyQueryResult):
continue"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
f"{len(all_stations)} valid stations.")
return valid_stations
def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False):
Check if all given stations in `all_stations` are valid.
