Commit 3c241f2d authored by lukas leufen's avatar lukas leufen

Can run all plots except StationMap, Cleanup required for unused code

parent a433991f
Pipeline #41584 failed with stages
in 6 minutes and 21 seconds
......@@ -79,6 +79,9 @@ class AbstractDataPreparation:
def get_Y(self, upsampling=False, as_numpy=False):
raise NotImplementedError
def get_data(self, upsampling=False, as_numpy=False):
return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
class DefaultDataPreparation(AbstractDataPreparation):
......
......@@ -93,7 +93,7 @@ class CreateShuffledData:
inside bootstrap_path.
"""
def __init__(self, data: DataGenerator, number_of_bootstraps: int, bootstrap_path: str):
def __init__(self, data, number_of_bootstraps: int, bootstrap_path: str):
"""
Shuffled data is automatically created in initialisation.
......@@ -115,15 +115,18 @@ class CreateShuffledData:
file will be created inside this function.
"""
logging.info("create / check shuffled bootstrap data")
variables_str = '_'.join(sorted(self.data.variables))
window = self.data.window_history_size
for station in self.data.stations:
valid, nboot = self.valid_bootstrap_file(station, variables_str, window)
variables = ["o3", "temp"]
# window = self.data.window_history_size
window = 3
for station in self.data:
variables = ["o3", "temp"]
window = 3
valid, nboot, variables, window = self.valid_bootstrap_file(str(station), variables, window)
if not valid:
logging.info(f'create bootstap data for {station}')
hist = self.data.get_data_generator(station).get_transposed_history()
file_path = self._set_file_path(station, variables_str, window, nboot)
hist = hist.expand_dims({'boots': range(nboot)}, axis=-1)
hist = station.get_X(as_numpy=False)
file_path = self._set_file_path(station, variables, window, nboot)
hist = list(map(lambda x: x.expand_dims({'boots': range(nboot)}, axis=-1), hist))
shuffled_variable = []
chunks = (100, *hist.shape[1:3], hist.shape[-1])
for i, var in enumerate(hist.coords['variables']):
......@@ -146,7 +149,7 @@ class CreateShuffledData:
:param nboots: number of boots
:return: full file path
"""
file_name = f"{station}_{variables}_hist{window}_nboots{nboots}_shuffled.nc"
file_name = f"{station}_{'_'.join(sorted(variables))}_hist{window}_nboots{nboots}_shuffled.nc"
return os.path.join(self.bootstrap_path, file_name)
def valid_bootstrap_file(self, station: str, variables: str, window: int) -> [bool, Union[None, int]]:
......@@ -167,19 +170,26 @@ class CreateShuffledData:
be used for the new boot creation (this is only relevant, if no valid file was found - otherwise the return
statement is anyway None).
"""
regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled")
regex = re.compile(rf"{station}_(.*)_hist(\d+)_nboots(\d+)_shuffled")
max_nboot = self.number_of_bootstraps
max_variables = set(variables)
max_window = window
for file in os.listdir(self.bootstrap_path):
match = regex.match(file)
if match:
window_file = int(match.group(1))
nboot_file = int(match.group(2))
variable_file = set(match.group(1).split("_"))
window_file = int(match.group(2))
nboot_file = int(match.group(3))
max_nboot = max([max_nboot, nboot_file])
if (window_file >= window) and (nboot_file >= self.number_of_bootstraps):
return True, None
max_variables = variable_file.union(variables)
max_window = max([max_window, window_file])
if (window_file >= window) \
and (nboot_file >= self.number_of_bootstraps) \
and variable_file >= set(variables):
return True, None, None, None
else:
os.remove(os.path.join(self.bootstrap_path, file))
return False, max_nboot
return False, max_nboot, max_variables, max_window
@staticmethod
def shuffle(data: da.array, chunks: Tuple) -> da.core.Array:
......@@ -220,7 +230,9 @@ class BootStraps:
self.data = data
self.number_of_bootstraps = number_of_bootstraps
self.bootstrap_path = bootstrap_path
CreateShuffledData(data, number_of_bootstraps, bootstrap_path)
CreateShuffledData(data, number_of_bootstraps, bootstrap_path) # Todo: think about how to create the bootstrapped
# data inside the datapreparation class and not on top. get_X(bootstrapped=True) or get_bootstrapped_X. If this
# method is not implemented, skip bootstrapping analysis
@property
def stations(self) -> List[str]:
......@@ -358,6 +370,125 @@ class BootStraps:
if (int(match.group(last - 1)) >= window) and (int(match.group(last)) >= nboot):
return f
from collections import Iterator, Iterable
from itertools import chain
class BootstrapIterator(Iterator):
_position: int = None
def __init__(self, data: "BootStrapsNew"):
assert isinstance(data, BootStrapsNew)
self._data = data
self._dimension = data.bootstrap_dimension
self._collection = self._data.bootstraps()
self._position = 0
def __next__(self):
"""Return next element or stop iteration."""
try:
index, dimension = self._collection[self._position]
nboot = self._data.number_of_bootstraps
_X, _Y = self._data.data.get_data(as_numpy=False)
_X = list(map(lambda x: x.expand_dims({'boots': range(nboot)}, axis=-1), _X))
_Y = _Y.expand_dims({"boots": range(nboot)}, axis=-1)
single_variable = _X[index].sel({self._dimension: [dimension]})
shuffled_variable = self.shuffle(single_variable.values)
shuffled_data = xr.DataArray(shuffled_variable, coords=single_variable.coords, dims=single_variable.dims)
_X[index] = shuffled_data.combine_first(_X[index]).reindex_like(_X[index])
self._position += 1
except IndexError:
raise StopIteration()
_X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
return self._reshape(_X), self._reshape(_Y), (index, dimension)
@staticmethod
def _reshape(d):
if isinstance(d, list):
return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d))
else:
shape = d.shape
return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1])
@staticmethod
def _to_numpy(d):
if isinstance(d, list):
return list(map(lambda x: x.values, d))
else:
return d.values
@staticmethod
def shuffle(data: da.array) -> da.core.Array:
"""
Shuffle randomly from given data (draw elements with replacement).
:param data: data to shuffle
:param chunks: chunk size for dask
:return: shuffled data as dask core array (not computed yet)
"""
size = data.shape
return np.random.choice(data.reshape(-1, ), size=size)
class BootStrapsNew(Iterable):
"""
Main class to perform bootstrap operations.
This class requires a DataGenerator object and a path, where to find and store all data related to the bootstrap
operation. In initialisation, this class will automatically call the class CreateShuffleData to set up the shuffled
data sets. How to use BootStraps:
* call .get_generator(<station>, <variable>) to get a generator for given station and variable combination that \
iterates over all bootstrap realisations (as keras sequence)
* call .get_labels(<station>) to get the measured observations in the same format as bootstrap predictions
* call .get_bootstrap_predictions(<station>, <variable>) to get the bootstrapped predictions
* call .get_orig_prediction(<station>) to get the non-bootstrapped predictions (referred as original predictions)
"""
from src.data_handling.advanced_data_handling import AbstractDataPreparation
def __init__(self, data: AbstractDataPreparation, bootstrap_path: str, number_of_bootstraps: int = 10,
bootstrap_dimension: str = "variables"):
"""
Automatically check and create (if needed) shuffled data on initialisation.
:param data: a data generator object to get data / history
:param bootstrap_path: path to find and store the bootstrap data
:param number_of_bootstraps: the number of bootstrap realisations
"""
self.data = data
self.number_of_bootstraps = number_of_bootstraps
self.bootstrap_path = bootstrap_path
self.bootstrap_dimension = bootstrap_dimension
def __iter__(self):
return BootstrapIterator(self)
def __len__(self):
return len(self.bootstraps())
def bootstraps(self):
l = []
for i, x in enumerate(self.data.get_X(as_numpy=False)):
l.append(list(map(lambda y: (i, y), x.indexes['variables'])))
return list(chain(*l))
def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray:
"""
Repeat predictions from given file(_name) in path by the number of boots.
:param path: path to file
:param file_name: file name
:param prediction_name: name of the prediction to select from loaded file (default CNN)
:return: repeated predictions
"""
file = os.path.join(path, file_name)
prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
vals = np.tile(prediction.data, (self.number_of_bootstraps, 1))
return vals[~np.isnan(vals).any(axis=1), :]
if __name__ == "__main__":
......
......@@ -38,6 +38,8 @@ class DataCollection(Iterable):
collection = []
assert isinstance(collection, list)
self._collection = collection
self._mapping = {}
self._set_mapping()
def __len__(self):
return len(self._collection)
......@@ -46,10 +48,22 @@ class DataCollection(Iterable):
return StandardIterator(self._collection)
def __getitem__(self, index):
return self._collection[index]
if isinstance(index, int):
return self._collection[index]
else:
return self._collection[self._mapping[str(index)]]
def add(self, element):
self._collection.append(element)
self._mapping[str(element)] = len(self._collection)
def _set_mapping(self):
for i, e in enumerate(self._collection):
self._mapping[str(e)] = i
def keys(self):
return list(self._mapping.keys())
class KerasIterator(keras.utils.Sequence):
......
......@@ -20,6 +20,7 @@ from matplotlib.backends.backend_pdf import PdfPages
from src import helpers
from src.data_handling import DataGenerator
from src.data_handling.iterator import DataCollection
from src.helpers import TimeTrackingWrapper
logging.getLogger('matplotlib').setLevel(logging.WARNING)
......@@ -883,10 +884,11 @@ class PlotAvailability(AbstractPlotClass):
"""
def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily",
summary_name="data availability"):
summary_name="data availability", time_dimension="datetime"):
"""Initialise."""
# create standard Gantt plot for all stations (currently in single pdf file with single page)
super().__init__(plot_folder, "data_availability")
self.dim = time_dimension
self.sampling = self._get_sampling(sampling)
plot_dict = self._prepare_data(generators)
lgd = self._plot(plot_dict)
......@@ -909,34 +911,30 @@ class PlotAvailability(AbstractPlotClass):
elif sampling == "hourly":
return "h"
def _prepare_data(self, generators: Dict[str, DataGenerator]):
def _prepare_data(self, generators: Dict[str, DataCollection]):
plt_dict = {}
for subset, generator in generators.items():
stations = generator.stations
for station in stations:
station_data = generator.get_data_generator(station)
labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean()
for subset, data_collection in generators.items():
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(window=1).notnull()
group = (labels_bool != labels_bool.shift(datetime=1)).cumsum()
group = (labels_bool != labels_bool.shift({self.dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
index=labels.datetime.values)
index=labels.coords[self.dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(station) is None:
plt_dict[station] = {subset: t2}
if plt_dict.get(str(station)) is None:
plt_dict[str(station)] = {subset: t2}
else:
plt_dict[station].update({subset: t2})
plt_dict[str(station)].update({subset: t2})
return plt_dict
def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str):
plt_dict = {}
for subset, generator in generators.items():
for subset, data_collection in generators.items():
all_data = None
stations = generator.stations
for station in stations:
station_data = generator.get_data_generator(station)
labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean()
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(window=1).notnull()
if all_data is None:
all_data = labels_bool
......@@ -945,8 +943,9 @@ class PlotAvailability(AbstractPlotClass):
all_data = np.logical_or(tmp, labels_bool).combine_first(
all_data) # apply logical on merge and fill missing with all_data
group = (all_data != all_data.shift(datetime=1)).cumsum()
plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, index=all_data.datetime.values)
group = (all_data != all_data.shift({self.dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
index=all_data.coords[self.dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(summary_name) is None:
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment