Commit a433991f authored by lukas leufen's avatar lukas leufen

MLAir runs now until calculate_test_score

parent 3f7cfae0
Pipeline #41394 failed with stages
in 6 minutes and 37 seconds
......@@ -191,6 +191,12 @@ class DefaultDataPreparation(AbstractDataPreparation):
Y = Y_original.sel({dim: intersect})
self._X, self._Y = X, Y
def get_observation(self):
return self.id_class.observation.copy().squeeze()
def get_transformation_Y(self):
return self.id_class.get_transformation_information()
def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
"""
......@@ -265,12 +271,10 @@ class DefaultDataPreparation(AbstractDataPreparation):
transformation_dict = sp_keys.pop("transformation")
if transformation_dict is None:
return
scope = transformation_dict.pop("scope")
method = transformation_dict.pop("method")
if transformation_dict.pop("mean", None) is not None:
return
mean, std = None, None
for station in set_stations:
try:
......@@ -286,8 +290,6 @@ class DefaultDataPreparation(AbstractDataPreparation):
return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
def run_data_prep():
data = DummyDataSingleStation("main_class")
......
......@@ -603,6 +603,29 @@ class StationPrep(AbstractStationPrep):
self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
self._transform_method = None
def get_transformation_information(self, variable: str = None) -> Tuple[data_or_none, data_or_none, str]:
"""
Extract transformation statistics and method.
Get mean and standard deviation for given variable and the transformation method if set. If a transformation
depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
returned with None as fill value.
:param variable: Variable for which the information on transformation is requested.
:return: mean, standard deviation and transformation method
"""
variable = self.target_var if variable is None else variable
try:
mean = self.mean.sel({'variables': variable}).values
except AttributeError:
mean = None
try:
std = self.std.sel({'variables': variable}).values
except AttributeError:
std = None
return mean, std, self._transform_method
class AbstractDataPrep(object):
"""
......
......@@ -55,7 +55,7 @@ class OrdinaryLeastSquaredModel:
def predict(self, data):
"""Apply OLS model on data."""
data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
data = sm.add_constant(np.concatenate(self.flatten(data), axis=1), has_constant="add")
return np.atleast_2d(self.model.predict(data))
@staticmethod
......
......@@ -312,12 +312,14 @@ class PostProcessing(RunEnvironment):
be found inside `forecast_path`.
"""
logging.debug("start make_prediction")
time_dimension = self.data_store.get("interpolate_dim")
for i, data in enumerate(self.test_data):
input_data = data.get_X()
target_data = data.get_Y()
target_data = data.get_Y(as_numpy=False)
observation_data = data.get_observation()
# get scaling parameters
# mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
mean, std, transformation_method = data.get_transformation_Y()
for normalised in [True, False]:
# create empty arrays
......@@ -329,7 +331,7 @@ class PostProcessing(RunEnvironment):
normalised)
# persistence
persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std,
persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, mean, std,
transformation_method, normalised)
# ols
......@@ -337,11 +339,12 @@ class PostProcessing(RunEnvironment):
normalised)
# observation
observation = self._create_observation(data, observation, mean, std, transformation_method, normalised)
observation = self._create_observation(target_data, observation, mean, std, transformation_method, normalised)
# merge all predictions
full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']),
full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']),
time_dimension,
CNN=nn_prediction,
persi=persistence_prediction,
obs=observation,
......@@ -350,7 +353,7 @@ class PostProcessing(RunEnvironment):
# save all forecasts locally
path = self.data_store.get("forecast_path")
prefix = "forecasts_norm" if normalised else "forecasts"
file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc")
file = os.path.join(path, f"{prefix}_{str(data)}_test.nc")
all_predictions.to_netcdf(file)
def _get_frequency(self) -> str:
......@@ -359,14 +362,14 @@ class PostProcessing(RunEnvironment):
return getter.get(self._sampling, None)
@staticmethod
def _create_observation(data: DataPrepJoin, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
def _create_observation(data, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
normalised: bool) -> xr.DataArray:
"""
Create observation as ground truth from given data.
Inverse transformation is applied to the ground truth to get the output in the original space.
:param data: transposed observation from DataPrep
:param data: observation
:param mean: mean of target value transformation
:param std: standard deviation of target value transformation
:param transformation_method: target values transformation method
......@@ -374,10 +377,9 @@ class PostProcessing(RunEnvironment):
:return: filled data array with observation
"""
obs = data.label.copy()
if not normalised:
obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
return obs
data = statistics.apply_inverse_transformation(data, mean, std, transformation_method)
return data
def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, mean: xr.DataArray,
std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
......@@ -398,12 +400,11 @@ class PostProcessing(RunEnvironment):
tmp_ols = self.ols_model.predict(input_data)
if not normalised:
tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method)
tmp_ols = np.expand_dims(tmp_ols, axis=1)
target_shape = ols_prediction.values.shape
ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
return ols_prediction
def _create_persistence_forecast(self, data: DataPrepJoin, persistence_prediction: xr.DataArray, mean: xr.DataArray,
def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, mean: xr.DataArray,
std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
"""
Create persistence forecast with given data.
......@@ -411,7 +412,7 @@ class PostProcessing(RunEnvironment):
Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window).
Inverse transformation is applied to the forecast to get the output in the original space.
:param data: DataPrep
:param data: observation
:param persistence_prediction: empty array in right shape to fill with data
:param mean: mean of target value transformation
:param std: standard deviation of target value transformation
......@@ -420,12 +421,11 @@ class PostProcessing(RunEnvironment):
:return: filled data array with persistence predictions
"""
tmp_persi = data.observation.copy().sel({'window': 0})
tmp_persi = data.copy()
if not normalised:
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time")
persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)),
axis=1)
persistence_prediction.values = np.tile(tmp_persi, (window_lead_time, 1)).T
return persistence_prediction
def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray,
......@@ -450,17 +450,19 @@ class PostProcessing(RunEnvironment):
if not normalised:
tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
if isinstance(tmp_nn, list):
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1], axis=1), 2, 0)
nn_prediction.values = tmp_nn[-1]
elif tmp_nn.ndim == 3:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
nn_prediction.values = tmp_nn[-1, ...]
elif tmp_nn.ndim == 2:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
nn_prediction.values = tmp_nn
else:
raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
return nn_prediction
@staticmethod
def _create_empty_prediction_arrays(target_data, count=1):
"""
Create array to collect all predictions. Expand target data by a station dimension. """
return [target_data.copy() for _ in range(count)]
@staticmethod
......@@ -489,7 +491,7 @@ class PostProcessing(RunEnvironment):
return index
@staticmethod
def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], **kwargs):
def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
"""
Combine different forecast types into single xarray.
......@@ -504,12 +506,8 @@ class PostProcessing(RunEnvironment):
res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
for k, v in kwargs.items():
try:
match_index = np.stack(set(res.index.values) & set(v.index.values))
res.loc[match_index, :, k] = v.loc[match_index]
except AttributeError: # v is xarray type and has no attribute .index
match_index = np.stack(set(res.index.values) & set(v.indexes['datetime'].values))
res.loc[match_index, :, k] = v.sel({'datetime': match_index}).squeeze('Stations').transpose()
match_index = np.stack(set(res.index.values) & set(v.indexes[time_dimension].values))
res.loc[match_index, :, k] = v.loc[match_index]
return res
def _get_external_data(self, station: str) -> Union[xr.DataArray, None]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment