Commit 22ad918c authored by lukas leufen's avatar lukas leufen

first versino of separation of scales plot finished

parent 104c321f
Pipeline #52008 passed with stages
in 7 minutes and 36 seconds
......@@ -176,9 +176,9 @@ class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSa
data_filter = data.sel({"filter": filter_name})
for w in range(start, end):
res_filter.append(data_filter.shift({dim: -w * delta}))
res_filter = xr.concat(res_filter, dim=window_array)
res_filter = xr.concat(res_filter, dim=window_array).chunk()
res.append(res_filter)
res = xr.concat(res, dim="filter").chunk()
res = xr.concat(res, dim="filter")
return res
def estimate_filter_width(self):
......
......@@ -72,6 +72,9 @@ class AbstractPlotClass:
def __init__(self, plot_folder, plot_name, resolution=500):
"""Set up plot folder and name, and plot resolution (default 500dpi)."""
plot_folder = os.path.abspath(plot_folder)
if not os.path.exists(plot_folder):
os.makedirs(plot_folder)
self.plot_folder = plot_folder
self.plot_name = plot_name
self.resolution = resolution
......@@ -82,7 +85,7 @@ class AbstractPlotClass:
def _save(self, **kwargs):
"""Store plot locally. Name of and path to plot need to be set on initialisation."""
plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf")
plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution, **kwargs)
plt.close('all')
......@@ -995,10 +998,31 @@ class PlotAvailability(AbstractPlotClass):
return lgd
@TimeTrackingWrapper
class PlotSeparationOfScales(AbstractPlotClass):
def __init__(self, collection: DataCollection, plot_folder: str = "."):
"""Initialise."""
# create standard Gantt plot for all stations (currently in single pdf file with single page)
plot_folder = os.path.join(plot_folder, "separation_of_scales")
super().__init__(plot_folder, "separation_of_scales")
self._plot(collection)
def _plot(self, collection: DataCollection):
orig_plot_name = self.plot_name
for dh in collection:
data = dh.get_X(as_numpy=False)[0]
station = dh.id_class.station[0]
data = data.sel(Stations=station)
# plt.subplots()
data.plot(x="datetime", y="window", col="filter", row="variables")
self.plot_name = f"{orig_plot_name}_{station}"
self._save()
if __name__ == "__main__":
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
path = "../../testrun_network/forecasts"
plt_path = "../../"
con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path)
......@@ -19,7 +19,8 @@ from mlair.helpers import TimeTracking, statistics, extract_value
from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
from mlair.model_modules.model_class import AbstractModelClass
from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles, \
PlotSeparationOfScales
from mlair.run_modules.run_environment import RunEnvironment
......@@ -262,6 +263,8 @@ class PostProcessing(RunEnvironment):
plot_list = self.data_store.get("plot_list", "postprocessing")
time_dimension = self.data_store.get("time_dim")
PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path)
if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list:
PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment