Commit a19c4fe6 authored by lukas leufen's avatar lukas leufen

Merge branch 'develop' into release_v0.9.0

parents 53546303 7553788c
Pipeline #34435 passed with stages
in 6 minutes and 18 seconds
......@@ -45,9 +45,10 @@ Thumbs.db
/data/
/plots/
# tmp folder #
##############
# tmp and logging folder #
##########################
/tmp/
/logging/
# test related data #
#####################
......
......@@ -23,13 +23,71 @@ version:
paths:
- badges/
### Tests (from scratch) ###
tests (from scratch):
tags:
- base
- zam347
stage: test
only:
- master
- /^release.*$/
- develop
variables:
FAILURE_THRESHOLD: 100
TEST_TYPE: "scratch"
before_script:
- chmod +x ./CI/update_badge.sh
- ./CI/update_badge.sh > /dev/null
script:
- zypper --non-interactive install binutils libproj-devel gdal-devel
- zypper --non-interactive install proj geos-devel
- pip install -r requirements.txt
- chmod +x ./CI/run_pytest.sh
- ./CI/run_pytest.sh
after_script:
- ./CI/update_badge.sh > /dev/null
artifacts:
name: pages
when: always
paths:
- badges/
- test_results/
### Tests (on GPU) ###
tests (on GPU):
tags:
- gpu
- zam347
stage: test
only:
- master
- /^release.*$/
- develop
variables:
FAILURE_THRESHOLD: 100
TEST_TYPE: "gpu"
before_script:
- chmod +x ./CI/update_badge.sh
- ./CI/update_badge.sh > /dev/null
script:
- pip install -r requirements.txt
- chmod +x ./CI/run_pytest.sh
- ./CI/run_pytest.sh
after_script:
- ./CI/update_badge.sh > /dev/null
artifacts:
name: pages
when: always
paths:
- badges/
- test_results/
### Tests ###
tests:
tags:
- leap
- machinelearningtools
- zam347
- base
- django
stage: test
variables:
FAILURE_THRESHOLD: 100
......@@ -51,10 +109,8 @@ tests:
coverage:
tags:
- leap
- machinelearningtools
- zam347
- base
- django
stage: test
variables:
FAILURE_THRESHOLD: 50
......@@ -79,7 +135,6 @@ coverage:
pages:
stage: pages
tags:
- leap
- zam347
- base
script:
......
#!/bin/bash
# run pytest for all run_modules
python3 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out
python3.6 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out
IS_FAILED=$?
if [ -z ${TEST_TYPE+x} ]; then
TEST_TYPE=""; else
TEST_TYPE="_"${TEST_TYPE};
fi
# move html test report
mkdir test_results/
TEST_RESULTS_DIR="test_results${TEST_TYPE}/"
mkdir ${TEST_RESULTS_DIR}
BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}")
mkdir test_results/${BRANCH_NAME}
mkdir test_results/recent
cp report.html test_results/${BRANCH_NAME}/.
cp report.html test_results/recent/.
mkdir "${TEST_RESULTS_DIR}/${BRANCH_NAME}"
mkdir "${TEST_RESULTS_DIR}/recent"
cp report.html "${TEST_RESULTS_DIR}/${BRANCH_NAME}/."
cp report.html "${TEST_RESULTS_DIR}/recent/."
if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then
cp -r report.html test_results/.
cp -r report.html ${TEST_RESULTS_DIR}/.
fi
# exit 0 if no tests implemented
RUN_NO_TESTS="$(grep -c 'no tests ran' test_results.out)"
if [[ ${RUN_NO_TESTS} > 0 ]]; then
if [[ ${RUN_NO_TESTS} -gt 0 ]]; then
echo "no test available"
echo "incomplete" > status.txt
echo "no tests avail" > incomplete.txt
......@@ -27,20 +33,19 @@ fi
# extract if tests passed or not
TEST_FAILED="$(grep -oP '(\d+\s{1}failed)' test_results.out)"
TEST_FAILED="$(echo ${TEST_FAILED} | (grep -oP '\d*'))"
TEST_FAILED="$(echo "${TEST_FAILED}" | (grep -oP '\d*'))"
TEST_PASSED="$(grep -oP '\d+\s{1}passed' test_results.out)"
TEST_PASSED="$(echo ${TEST_PASSED} | (grep -oP '\d*'))"
TEST_PASSED="$(echo "${TEST_PASSED}" | (grep -oP '\d*'))"
if [[ -z "$TEST_FAILED" ]]; then
TEST_FAILED=0
fi
let "TEST_PASSED=${TEST_PASSED}-${TEST_FAILED}"
(( TEST_PASSED=TEST_PASSED-TEST_FAILED ))
# calculate metrics
let "SUM=${TEST_FAILED}+${TEST_PASSED}"
let "TEST_PASSED_RATIO=${TEST_PASSED}*100/${SUM}"
(( SUM=TEST_FAILED+TEST_PASSED ))
(( TEST_PASSED_RATIO=TEST_PASSED*100/SUM ))
# report
if [[ ${IS_FAILED} == 0 ]]; then
if [[ ${IS_FAILED} -eq 0 ]]; then
if [[ ${TEST_PASSED_RATIO} -lt 100 ]]; then
echo "only ${TEST_PASSED_RATIO}% passed"
echo "incomplete" > status.txt
......
#!/usr/bin/env bash
# run coverage twice, 1) for html deploy 2) for success evaluation
python3 -m pytest --cov=src --cov-report html test/
python3 -m pytest --cov=src --cov-report term test/ | tee coverage_results.out
python3.6 -m pytest --cov=src --cov-report term --cov-report html test/ | tee coverage_results.out
IS_FAILED=$?
# move html coverage report
mkdir coverage/
BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}")
mkdir coverage/${BRANCH_NAME}
mkdir coverage/recent
cp -r htmlcov/* coverage/${BRANCH_NAME}/.
mkdir "coverage/${BRANCH_NAME}"
mkdir "coverage/recent"
cp -r htmlcov/* "coverage/${BRANCH_NAME}/."
cp -r htmlcov/* coverage/recent/.
if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then
cp -r htmlcov/* coverage/.
......@@ -19,10 +18,10 @@ fi
# extract coverage information
COVERAGE_RATIO="$(grep -oP '\d+\%' coverage_results.out | tail -1)"
COVERAGE_RATIO="$(echo ${COVERAGE_RATIO} | (grep -oP '\d*'))"
COVERAGE_RATIO="$(echo "${COVERAGE_RATIO}" | (grep -oP '\d*'))"
# report
if [[ ${IS_FAILED} == 0 ]]; then
if [[ ${IS_FAILED} -eq 0 ]]; then
if [[ ${COVERAGE_RATIO} -lt ${COVERAGE_PASS_THRESHOLD} ]]; then
echo "only ${COVERAGE_RATIO}% covered"
echo "incomplete" > status.txt
......
["DENW094", "DEBW029", "DENI052", "DENI063", "DEBY109", "DEUB022", "DESN001", "DEUB013", "DETH016", "DEBY002", "DEBY005", "DEBY099", "DEUB038", "DEBE051", "DEBE056", "DEBE062", "DEBE032", "DEBE034", "DEBE010", "DEHE046", "DEST031", "DEBY122", "DERP022", "DEBY079", "DEBW102", "DEBW076", "DEBW045", "DESH016", "DESN004", "DEHE032", "DEBB050", "DEBW042", "DEBW046", "DENW067", "DESL019", "DEST014", "DENW062", "DEHE033", "DENW081", "DESH008", "DEBB055", "DENI011", "DEHB001", "DEHB004", "DEHB002", "DEHB003", "DEHB005", "DEST039", "DEUB003", "DEBW072", "DEST002", "DEBB001", "DEHE039", "DEBW035", "DESN005", "DEBW047", "DENW004", "DESN011", "DESN076", "DEBB064", "DEBB006", "DEHE001", "DESN012", "DEST030", "DESL003", "DEST104", "DENW050", "DENW008", "DETH026", "DESN085", "DESN014", "DESN092", "DENW071", "DEBW004", "DENI028", "DETH013", "DENI059", "DEBB007", "DEBW049", "DENI043", "DETH020", "DEBY017", "DEBY113", "DENW247", "DENW028", "DEBW025", "DEUB039", "DEBB009", "DEHE027", "DEBB042", "DEHE008", "DESN017", "DEBW084", "DEBW037", "DEHE058", "DEHE028", "DEBW112", "DEBY081", "DEBY082", "DEST032", "DETH009", "DEHE010", "DESN019", "DEHE023", "DETH036", "DETH040", "DEMV017", "DEBW028", "DENI042", "DEMV004", "DEMV019", "DEST044", "DEST050", "DEST072", "DEST022", "DEHH049", "DEHH047", "DEHH033", "DEHH050", "DEHH008", "DEHH021", "DENI054", "DEST070", "DEBB053", "DENW029", "DEBW050", "DEUB034", "DENW018", "DEST052", "DEBY020", "DENW063", "DESN050", "DETH061", "DERP014", "DETH024", "DEBW094", "DENI031", "DETH041", "DERP019", "DEBW081", "DEHE013", "DEBW021", "DEHE060", "DEBY031", "DESH021", "DESH033", "DEHE052", "DEBY004", "DESN024", "DEBW052", "DENW042", "DEBY032", "DENW053", "DENW059", "DEBB082", "DEBB031", "DEHE025", "DEBW053", "DEHE048", "DENW051", "DEBY034", "DEUB035", "DEUB032", "DESN028", "DESN059", "DEMV024", "DENW079", "DEHE044", "DEHE042", "DEBB043", "DEBB036", "DEBW024", "DERP001", "DEMV012", "DESH005", "DESH023", "DEUB031", "DENI062", "DENW006", "DEBB065", "DEST077", "DEST005", "DERP007", "DEBW006", "DEBW007", "DEHE030", "DENW015", "DEBY013", "DETH025", "DEUB033", "DEST025", "DEHE045", "DESN057", "DENW036", "DEBW044", "DEUB036", "DENW096", "DETH095", "DENW038", "DEBY089", "DEBY039", "DENW095", "DEBY047", "DEBB067", "DEBB040", "DEST078", "DENW065", "DENW066", "DEBY052", "DEUB030", "DETH027", "DEBB048", "DENW047", "DEBY049", "DERP021", "DEHE034", "DESN079", "DESL008", "DETH018", "DEBW103", "DEHE017", "DEBW111", "DENI016", "DENI038", "DENI058", "DENI029", "DEBY118", "DEBW032", "DEBW110", "DERP017", "DESN036", "DEBW026", "DETH042", "DEBB075", "DEBB052", "DEBB021", "DEBB038", "DESN051", "DEUB041", "DEBW020", "DEBW113", "DENW078", "DEHE018", "DEBW065", "DEBY062", "DEBW027", "DEBW041", "DEHE043", "DEMV007", "DEMV021", "DEBW054", "DETH005", "DESL012", "DESL011", "DEST069", "DEST071", "DEUB004", "DESH006", "DEUB029", "DEUB040", "DESN074", "DEBW031", "DENW013", "DENW179", "DEBW056", "DEBW087", "DEST061", "DEMV001", "DEBB024", "DEBW057", "DENW064", "DENW068", "DENW080", "DENI019", "DENI077", "DEHE026", "DEBB066", "DEBB083", "DEST063", "DEBW013", "DETH086", "DESL018", "DETH096", "DEBW059", "DEBY072", "DEBY088", "DEBW060", "DEBW107", "DEBW036", "DEUB026", "DEBW019", "DENW010", "DEST098", "DEHE019", "DEBW039", "DESL017", "DEBW034", "DEUB005", "DEBB051", "DEHE051", "DEBW023", "DEBY092", "DEBW008", "DEBW030", "DENI060", "DEST011", "DENW030", "DENI041", "DERP015", "DEUB001", "DERP016", "DERP028", "DERP013", "DEHE022", "DEUB021", "DEBW010", "DEST066", "DEBB063", "DEBB028", "DEHE024", "DENI020", "DENI051", "DERP025", "DEBY077", "DEMV018", "DEST089", "DEST028", "DETH060", "DEHE050", "DEUB028", "DESN045", "DEUB042"]
......@@ -43,17 +43,20 @@ pytest-cov==2.8.1
pytest-html==2.0.1
pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1
pytz==2019.3
PyYAML==5.3
requests==2.23.0
scipy==1.4.1
seaborn==0.10.0
Shapely==1.7.0
--no-binary shapely Shapely==1.7.0
six==1.11.0
statsmodels==0.11.1
tensorboard==1.12.2
tensorflow==1.12.0
tabulate
tensorboard==1.13.1
tensorflow-estimator==1.13.0
tensorflow==1.13.1
termcolor==1.1.0
toolz==0.10.0
urllib3==1.25.8
......
absl-py==0.9.0
astor==0.8.1
atomicwrites==1.3.0
attrs==19.3.0
Cartopy==0.17.0
certifi==2019.11.28
chardet==3.0.4
cloudpickle==1.3.0
coverage==5.0.3
cycler==0.10.0
Cython==0.29.15
dask==2.11.0
fsspec==0.6.2
gast==0.3.3
grpcio==1.27.2
h5py==2.10.0
idna==2.8
importlib-metadata==1.5.0
Keras==2.2.4
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
kiwisolver==1.1.0
locket==0.2.0
Markdown==3.2.1
matplotlib==3.2.0
mock==4.0.1
more-itertools==8.2.0
numpy==1.18.1
packaging==20.3
pandas==1.0.1
partd==1.1.0
patsy==0.5.1
Pillow==7.0.0
pluggy==0.13.1
protobuf==3.11.3
py==1.8.1
pydot==1.4.1
pyparsing==2.4.6
pyproj==2.5.0
pyshp==2.1.0
pytest==5.3.5
pytest-cov==2.8.1
pytest-html==2.0.1
pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1
pytz==2019.3
PyYAML==5.3
requests==2.23.0
scipy==1.4.1
seaborn==0.10.0
--no-binary shapely Shapely==1.7.0
six==1.11.0
statsmodels==0.11.1
tabulate
tensorboard==1.13.1
tensorflow-estimator==1.13.0
tensorflow-gpu==1.13.1
termcolor==1.1.0
toolz==0.10.0
urllib3==1.25.8
wcwidth==0.1.8
Werkzeug==1.0.0
xarray==0.15.0
zipp==3.1.0
......@@ -3,7 +3,6 @@ __date__ = '2019-11-14'
import argparse
import logging
from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.model_setup import ModelSetup
......@@ -17,7 +16,8 @@ def main(parser_args):
with RunEnvironment():
ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
station_type='background', trainable=False, create_new_model=False)
station_type='background', trainable=False, create_new_model=False, window_history_size=6,
create_new_bootstraps=True)
PreProcessing()
ModelSetup()
......@@ -29,10 +29,6 @@ def main(parser_args):
if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
# logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
......
......@@ -29,10 +29,6 @@ def main(parser_args):
if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
# logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
......
__author__ = "Lukas Leufen"
__date__ = '2019-11-14'
import argparse
import json
import logging
from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.model_setup import ModelSetup
from src.run_modules.post_processing import PostProcessing
from src.run_modules.pre_processing import PreProcessing
from src.run_modules.run_environment import RunEnvironment
from src.run_modules.training import Training
def load_stations():
try:
filename = 'German_background_stations.json'
with open(filename, 'r') as jfile:
stations = json.load(jfile)
logging.info(f"load station IDs from file: {filename} ({len(stations)} stations)")
# stations.remove('DEUB042')
except FileNotFoundError:
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']
# stations = ['DEBB050', 'DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']
logging.info(f"Use stations from list: {stations}")
return stations
def main(parser_args):
with RunEnvironment():
ExperimentSetup(parser_args, stations=load_stations(), station_type='background', trainable=False,
create_new_model=True)
PreProcessing()
ModelSetup()
Training()
PostProcessing()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
This diff is collapsed.
......@@ -8,15 +8,18 @@ import math
import keras
import numpy as np
from src.data_handling.data_generator import DataGenerator
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
permute_data: bool = False):
def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256,
permute_data: bool = False, upsampling: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.do_data_permutation = permute_data
self.upsampling = upsampling
def _get_model_rank(self):
mod_out = self.model.output_shape
......@@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence):
return mod_rank
def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size)
return math.ceil(values.shape[0] / self.batch_size)
def _permute_data(self, x, y):
"""
......@@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence):
for k, v in enumerate(self.generator):
# get rank of output
mod_rank = self._get_model_rank()
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(v)
# get data
x_total = np.copy(v[0])
y_total = np.copy(v[1])
if self.upsampling:
try:
s = self.generator.get_data_generator(k)
x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0)
y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0)
except AttributeError: # no extremes history / labels available, copy will fail
pass
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(x_total)
# permute order for mini-batches
x_total, y_total = self._permute_data(x_total, y_total)
for prev, curr in enumerate(range(1, num_mini_batches+1)):
......
......@@ -14,6 +14,9 @@ from src import helpers
from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult
number = Union[float, int]
num_or_list = Union[number, List[number]]
class DataGenerator(keras.utils.Sequence):
"""
......@@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence):
def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transformation: Dict = None, **kwargs):
window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs):
self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp):
......@@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence):
self.limit_nan_fill = limit_nan_fill
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.extreme_values = extreme_values
self.kwargs = kwargs
self.transformation = self.setup_transformation(transformation)
......@@ -178,7 +182,7 @@ class DataGenerator(keras.utils.Sequence):
raise FileNotFoundError
data = self._load_pickle_data(station, self.variables)
except FileNotFoundError:
logging.info(f"load not pickle data for {station}")
logging.debug(f"load not pickle data for {station}")
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
if self.transformation is not None:
......@@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence):
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
data.remove_nan(self.interpolate_dim)
if self.extreme_values:
kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)}
data.multiply_extremes(self.extreme_values, **kwargs)
if save_local_tmp_storage:
self._save_pickle_data(data)
return data
......
......@@ -5,7 +5,7 @@ import datetime as dt
from functools import reduce
import logging
import os
from typing import Union, List, Iterable
from typing import Union, List, Iterable, Tuple
import numpy as np
import pandas as pd
......@@ -17,6 +17,8 @@ from src import statistics
# define a more general date type for type hinting
date = Union[dt.date, dt.datetime]
str_or_list = Union[str, List[str]]
number = Union[float, int]
num_or_list = Union[number, List[number]]
class DataPrep(object):
......@@ -58,6 +60,8 @@ class DataPrep(object):
self.history = None
self.label = None
self.observation = None
self.extremes_history = None
self.extremes_label = None
self.kwargs = kwargs
self.data = None
self.meta = None
......@@ -353,7 +357,8 @@ class DataPrep(object):
non_nan_observation = self.observation.dropna(dim=dim)
intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values))
if len(intersect) == 0:
min_length = self.kwargs.get("min_length", 0)
if len(intersect) < max(min_length, 1):
self.history = None
self.label = None
self.observation = None
......@@ -413,12 +418,79 @@ class DataPrep(object):
data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
return data
def get_transposed_history(self):
def get_transposed_history(self) -> xr.DataArray:
return self.history.transpose("datetime", "window", "Stations", "variables").copy()
def get_transposed_label(self):
def get_transposed_label(self) -> xr.DataArray:
return self.label.squeeze("Stations").transpose("datetime", "window").copy()
def get_extremes_history(self) -> xr.DataArray:
return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy()
def get_extremes_label(self):
return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy()
def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
timedelta: Tuple[int, str] = (1, 'm')):
"""
This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
identify those "artificial" data points later easily. Extreme inputs and labels are stored in
self.extremes_history and self.extreme_labels, respectively.
:param extreme_values: user definition of extreme
:param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
if True only extract values larger than extreme_values
:param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
"""
# check if labels or history is None
if (self.label is None) or (self.history is None):
logging.debug(f"{self.station} has `None' labels, skip multiply extremes")
return
# check type if inputs
extreme_values = helpers.to_list(extreme_values)
for i in extreme_values:
if not isinstance(i, number.__args__):
raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
f"{i} is type {type(i)}")
for extr_val in sorted(extreme_values):
# check if some extreme values are already extracted
if (self.extremes_label is None) or (self.extremes_history is None):
# extract extremes based on occurance in labels
if extremes_on_right_tail_only:
extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,)
else:
extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1),
(self.label > extr_val).any(axis=0).values.reshape(-1, 1)),
axis=1).any(axis=1)
extremes_label = self.label[..., extreme_label_idx]
extremes_history = self.history[..., extreme_label_idx, :]
extremes_label.datetime.values += np.timedelta64(*timedelta)
extremes_history.datetime.values += np.timedelta64(*timedelta)
self.extremes_label = extremes_label#.squeeze('Stations').transpose('datetime', 'window')
self.extremes_history = extremes_history#.transpose('datetime', 'window', 'Stations', 'variables')
else: # one extr value iteration is done already: self.extremes_label is NOT None...
if extremes_on_right_tail_only:
extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, )
else:
extreme_label_idx = np.concatenate(((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1),
(self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1)
), axis=1).any(axis=1)
# check on existing extracted extremes to minimise computational costs for comparison
extremes_label = self.extremes_label[..., extreme_label_idx]
extremes_history = self.extremes_history[..., extreme_label_idx, :]
extremes_label.datetime.values += np.timedelta64(*timedelta)
extremes_history.datetime.values += np.timedelta64(*timedelta)
self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime')
self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime')
if __name__ == "__main__":
dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})