Commit 3d46db4d authored by lukas leufen's avatar lukas leufen

MLAir can now use a single competitor, but this one is hard coded and always...

MLAir can now use a single competitor, but this one is hard coded and always expected to be available!!
parent 4e1252c2
Pipeline #53083 passed with stages
in 10 minutes and 22 seconds
......@@ -192,12 +192,13 @@ class SkillScores:
:return: skill score for each comparison and forecast step
"""
ahead_names = list(range(1, window_lead_time + 1))
skill_score = pd.DataFrame(index=['cnn-persi', 'ols-persi', 'cnn-ols'])
skill_score = pd.DataFrame(index=['cnn-persi', 'ols-persi', 'cnn-ols', 'cnn-competitor'])
for iahead in ahead_names:
data = self.internal_data.sel(ahead=iahead)
skill_score[iahead] = [self.general_skill_score(data, forecast_name="CNN", reference_name="persi"),
self.general_skill_score(data, forecast_name="OLS", reference_name="persi"),
self.general_skill_score(data, forecast_name="CNN", reference_name="OLS")]
self.general_skill_score(data, forecast_name="CNN", reference_name="OLS"),
self.general_skill_score(data, forecast_name="CNN", reference_name="competitor")]
return skill_score
def climatological_skill_scores(self, external_data: Data, window_lead_time: int) -> xr.DataArray:
......
......@@ -689,9 +689,10 @@ class PlotCompetitiveSkillScore(AbstractPlotClass):
def _plot(self):
"""Plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols."""
fig, ax = plt.subplots()
order = self._create_pseudo_order()
sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d",
showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
order=["cnn-persi", "ols-persi", "cnn-ols"])
order=order)
ax.axhline(y=0, color="grey", linewidth=.5)
ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations", ylim=self._ylim())
......@@ -699,6 +700,12 @@ class PlotCompetitiveSkillScore(AbstractPlotClass):
ax.legend(handles, self._labels)
plt.tight_layout()
def _create_pseudo_order(self):
"""Provide first predefined elements and append all remaining."""
first_elements = ["cnn-persi", "ols-persi", "cnn-ols"]
uniq, index = np.unique(first_elements + self._data.comparison.unique().tolist(), return_index=True)
return uniq[index.argsort()]
def _ylim(self) -> Tuple[float, float]:
"""
Calculate y-axis limits from data.
......
......@@ -68,7 +68,8 @@ class PostProcessing(RunEnvironment):
self.batch_size: int = self.data_store.get_default("batch_size", "model", 64)
self.test_data = self.data_store.get("data_collection", "test")
batch_path = self.data_store.get("batch_path", scope="test")
self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test", batch_path=batch_path)
self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test",
batch_path=batch_path)
self.train_data = self.data_store.get("data_collection", "train")
self.val_data = self.data_store.get("data_collection", "val")
self.train_val_data = self.data_store.get("data_collection", "train_val")
......@@ -78,6 +79,14 @@ class PostProcessing(RunEnvironment):
self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
self.skill_scores = None
self.bootstrap_skill_scores = None
# ToDo: adjust this hard coded by a new entry in the data store setup in experiment setup phase
self.competitor_path = os.path.join(self.data_store.get("data_path"),
"competitors",
self.target_var
# ToDo: make sure this is a string, multiple vars are joined by underscore
)
self.competitor_name = "test_model"
self.competitor_forecast_name = "CNN" # ToDo: another refac, rename the CNN field to something like forecast to be more general
self._run()
def _run(self):
......@@ -87,6 +96,9 @@ class PostProcessing(RunEnvironment):
# forecasts
self.make_prediction()
# competitors
self.load_competitors()
# skill scores on test data
self.calculate_test_score()
......@@ -103,6 +115,10 @@ class PostProcessing(RunEnvironment):
# plotting
self.plot()
def load_competitors(self):
for station in self.test_data:
competing_prediction = self._create_competitor_forecast(str(station))
def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None:
"""
Calculate skill scores of bootstrapped data.
......@@ -297,7 +313,6 @@ class PostProcessing(RunEnvironment):
avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, )# time_dimension=time_dimension)
def calculate_test_score(self):
"""Evaluate test score of model and save locally."""
test_score = self.model.evaluate_generator(generator=self.test_data_distributed,
......@@ -370,6 +385,14 @@ class PostProcessing(RunEnvironment):
getter = {"daily": "1D", "hourly": "1H"}
return getter.get(self._sampling, None)
def _create_competitor_forecast(self, station_name):
path = os.path.join(self.competitor_path, self.competitor_name)
file = os.path.join(path, f"forecasts_{station_name}_test.nc")
data = xr.open_dataarray(file)
forecast = data.sel(type=[self.competitor_forecast_name])
forecast.coords["type"] = ["competitor"]
return forecast
@staticmethod
def _create_observation(data, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
normalised: bool) -> xr.DataArray:
......@@ -556,7 +579,9 @@ class PostProcessing(RunEnvironment):
for station in self.test_data:
file = os.path.join(path, f"forecasts_{str(station)}_test.nc")
data = xr.open_dataarray(file)
skill_score = statistics.SkillScores(data)
competitor = self._create_competitor_forecast(str(station))
combined = xr.concat([data, competitor], dim="type")
skill_score = statistics.SkillScores(combined)
external_data = self._get_external_data(station)
skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment