Commit 4ead7464 authored by lukas leufen's avatar lukas leufen

MLAir runs now again until the end, implemented get_coordinates method for data handlers

parent 3c241f2d
Pipeline #41618 failed with stages
in 2 minutes and 12 seconds
......@@ -82,13 +82,16 @@ class AbstractDataPreparation:
def get_data(self, upsampling=False, as_numpy=False):
return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
def get_coordinates(self):
return None, None
class DefaultDataPreparation(AbstractDataPreparation):
_requirements = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"])
def __init__(self, id_class, data_path, min_length=0,
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None):
super().__init__()
self.id_class = id_class
self.interpolate_dim = "datetime"
......@@ -97,7 +100,8 @@ class DefaultDataPreparation(AbstractDataPreparation):
self._Y = None
self._X_extreme = None
self._Y_extreme = None
self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle")
_name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle")
self._collection = self._create_collection()
self.harmonise_X()
self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
......@@ -292,6 +296,9 @@ class DefaultDataPreparation(AbstractDataPreparation):
std_estimated = std.mean("Stations")
return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
def get_coordinates(self):
return self.id_class.get_coordinates()
def run_data_prep():
......
This diff is collapsed.
......@@ -42,7 +42,8 @@ class StationPrep(AbstractStationPrep):
def __init__(self, station, data_path, statistics_per_var, station_type, network, sampling,
target_dim, target_var, interpolate_dim, window_history_size, window_lead_time,
overwrite_local_data: bool = False, transformation=None, **kwargs):
overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
min_length: int = 0, start=None, end=None, **kwargs):
super().__init__() # path, station, statistics_per_var, transformation, **kwargs)
self.station = helpers.to_list(station)
self.path = os.path.abspath(data_path)
......@@ -58,6 +59,10 @@ class StationPrep(AbstractStationPrep):
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.overwrite_local_data = overwrite_local_data
self.store_data_locally = store_data_locally
self.min_length = min_length
self.start = start
self.end = end
# internal
self.data = None
......@@ -120,6 +125,10 @@ class StationPrep(AbstractStationPrep):
def get_Y(self):
return self.get_transposed_label()
def get_coordinates(self):
coords = self.meta.loc[["station_lon", "station_lat"]].astype(float)
return coords.rename(index={"station_lon": "lon", "station_lat": "lat"}).to_dict()[str(self)]
def call_transform(self, inverse=False):
self.transform(dim=self.interpolate_dim, method=self.transformation["method"],
mean=self.transformation['mean'], std=self.transformation["std"],
......@@ -158,7 +167,7 @@ class StationPrep(AbstractStationPrep):
check_path_and_create(self.path)
file_name = self._set_file_name()
meta_file = self._set_meta_file_name()
if self.kwargs.get('overwrite_local_data', False):
if self.overwrite_local_data is True:
logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}")
if os.path.exists(file_name):
os.remove(file_name)
......@@ -201,7 +210,7 @@ class StationPrep(AbstractStationPrep):
# convert df_all to xarray
xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()}
xarr = xr.Dataset(xarr).to_array(dim='Stations')
if self.kwargs.get('store_data_locally', True):
if self.store_data_locally is True:
# save locally as nc/csv file
xarr.to_netcdf(path=file_name)
meta.to_csv(meta_file)
......@@ -398,8 +407,7 @@ class StationPrep(AbstractStationPrep):
intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
non_nan_observation.coords[dim].values))
min_length = self.kwargs.get("min_length", 0)
if len(intersect) < max(min_length, 1):
if len(intersect) < max(self.min_length, 1):
self.history = None
self.label = None
self.observation = None
......@@ -417,8 +425,8 @@ class StationPrep(AbstractStationPrep):
:return: sliced data
"""
start = self.kwargs.get('start', data.coords[coord][0].values)
end = self.kwargs.get('end', data.coords[coord][-1].values)
start = self.start if self.start is not None else data.coords[coord][0].values
end = self.end if self.end is not None else data.coords[coord][-1].values
return self._slice(data, start, end, coord)
@staticmethod
......
......@@ -37,6 +37,10 @@ class DataPreparationNeighbors(DefaultDataPreparation):
def _create_collection(self):
return [self.id_class] + self.neighbors
def get_coordinates(self, include_neighbors=False):
neighbors = list(map(lambda n: n.get_coordinates(), self.neighbors)) if include_neighbors is True else []
return [super(DataPreparationNeighbors, self).get_coordinates()].append(neighbors)
if __name__ == "__main__":
......
......@@ -237,12 +237,10 @@ class PlotStationMap(AbstractPlotClass):
import cartopy.crs as ccrs
if generators is not None:
for color, gen in generators.items():
for k, v in enumerate(gen):
station_coords = gen.get_data_generator(k).meta.loc[['station_lon', 'station_lat']]
# station_names = gen.get_data_generator(k).meta.loc[['station_id']]
IDx, IDy = float(station_coords.loc['station_lon'].values), float(
station_coords.loc['station_lat'].values)
for color, data_collection in generators.items():
for station in data_collection:
coords = station.get_coordinates()
IDx, IDy = coords["lon"], coords["lat"]
self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
def _plot(self, generators: Dict):
......@@ -772,8 +770,8 @@ class PlotTimeSeries:
def _plot(self, plot_folder):
pdf_pages = self._create_pdf_pages(plot_folder)
start, end = self._get_time_range(self._load_data(self._stations[0]))
for pos, station in enumerate(self._stations):
start, end = self._get_time_range(self._load_data(self._stations[0]))
data = self._load_data(station)
fig, axes, factor = self._create_subplots(start, end)
nan_list = []
......
......@@ -13,8 +13,7 @@ import numpy as np
import pandas as pd
import xarray as xr
from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrepJoin, KerasIterator
from src.data_handling.bootstraps import BootStrapsNew
from src.data_handling import BootStraps, KerasIterator
from src.helpers.datastore import NameNotFoundInDataStore
from src.helpers import TimeTracking, statistics
from src.model_modules.linear_model import OrdinaryLeastSquaredModel
......@@ -147,7 +146,7 @@ class PostProcessing(RunEnvironment):
for station in self.test_data:
logging.info(str(station))
X, Y = None, None
bootstraps = BootStrapsNew(station, bootstrap_path, number_of_bootstraps)
bootstraps = BootStraps(station, number_of_bootstraps)
for boot in bootstraps:
X, Y, (index, dimension) = boot
# make bootstrap predictions
......@@ -185,9 +184,8 @@ class PostProcessing(RunEnvironment):
forecast_path = self.data_store.get("forecast_path")
window_lead_time = self.data_store.get("window_lead_time")
number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
# bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
forecast_file = f"forecasts_norm_%s_test.nc"
bootstraps = BootStrapsNew(self.test_data[0], bootstrap_path, number_of_bootstraps).bootstraps()
bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
skill_scores = statistics.SkillScores(None)
score = {}
for station in self.test_data:
......
......@@ -266,7 +266,7 @@ class PreProcessing(RunEnvironment):
kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
for station in set_stations:
try:
dp = data_preparation.build(station, **kwargs)
dp = data_preparation.build(station, name_affix=set_name, **kwargs)
collection.add(dp)
valid_stations.append(station)
except (AttributeError, EmptyQueryResult):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment