Commit e945b365 authored by lukas leufen's avatar lukas leufen

release v0.9.0

Resolve "release branch / CI on gpu"

Closes #58, #60, #71, #73, #76, #77, #78, #80, #81, #82, #85, #86, #87, #88, #90, #94, #95, #103, and #111

See merge request toar/machinelearningtools!89
parents 53546303 a19c4fe6
Pipeline #39778 failed with stages
in 4 minutes and 2 seconds
...@@ -45,9 +45,10 @@ Thumbs.db ...@@ -45,9 +45,10 @@ Thumbs.db
/data/ /data/
/plots/ /plots/
# tmp folder # # tmp and logging folder #
############## ##########################
/tmp/ /tmp/
/logging/
# test related data # # test related data #
##################### #####################
......
...@@ -23,13 +23,71 @@ version: ...@@ -23,13 +23,71 @@ version:
paths: paths:
- badges/ - badges/
### Tests (from scratch) ###
tests (from scratch):
tags:
- base
- zam347
stage: test
only:
- master
- /^release.*$/
- develop
variables:
FAILURE_THRESHOLD: 100
TEST_TYPE: "scratch"
before_script:
- chmod +x ./CI/update_badge.sh
- ./CI/update_badge.sh > /dev/null
script:
- zypper --non-interactive install binutils libproj-devel gdal-devel
- zypper --non-interactive install proj geos-devel
- pip install -r requirements.txt
- chmod +x ./CI/run_pytest.sh
- ./CI/run_pytest.sh
after_script:
- ./CI/update_badge.sh > /dev/null
artifacts:
name: pages
when: always
paths:
- badges/
- test_results/
### Tests (on GPU) ###
tests (on GPU):
tags:
- gpu
- zam347
stage: test
only:
- master
- /^release.*$/
- develop
variables:
FAILURE_THRESHOLD: 100
TEST_TYPE: "gpu"
before_script:
- chmod +x ./CI/update_badge.sh
- ./CI/update_badge.sh > /dev/null
script:
- pip install -r requirements.txt
- chmod +x ./CI/run_pytest.sh
- ./CI/run_pytest.sh
after_script:
- ./CI/update_badge.sh > /dev/null
artifacts:
name: pages
when: always
paths:
- badges/
- test_results/
### Tests ### ### Tests ###
tests: tests:
tags: tags:
- leap - machinelearningtools
- zam347 - zam347
- base
- django
stage: test stage: test
variables: variables:
FAILURE_THRESHOLD: 100 FAILURE_THRESHOLD: 100
...@@ -51,10 +109,8 @@ tests: ...@@ -51,10 +109,8 @@ tests:
coverage: coverage:
tags: tags:
- leap - machinelearningtools
- zam347 - zam347
- base
- django
stage: test stage: test
variables: variables:
FAILURE_THRESHOLD: 50 FAILURE_THRESHOLD: 50
...@@ -79,7 +135,6 @@ coverage: ...@@ -79,7 +135,6 @@ coverage:
pages: pages:
stage: pages stage: pages
tags: tags:
- leap
- zam347 - zam347
- base - base
script: script:
......
#!/bin/bash #!/bin/bash
# run pytest for all run_modules # run pytest for all run_modules
python3 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out python3.6 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out
IS_FAILED=$? IS_FAILED=$?
if [ -z ${TEST_TYPE+x} ]; then
TEST_TYPE=""; else
TEST_TYPE="_"${TEST_TYPE};
fi
# move html test report # move html test report
mkdir test_results/ TEST_RESULTS_DIR="test_results${TEST_TYPE}/"
mkdir ${TEST_RESULTS_DIR}
BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}") BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}")
mkdir test_results/${BRANCH_NAME} mkdir "${TEST_RESULTS_DIR}/${BRANCH_NAME}"
mkdir test_results/recent mkdir "${TEST_RESULTS_DIR}/recent"
cp report.html test_results/${BRANCH_NAME}/. cp report.html "${TEST_RESULTS_DIR}/${BRANCH_NAME}/."
cp report.html test_results/recent/. cp report.html "${TEST_RESULTS_DIR}/recent/."
if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then
cp -r report.html test_results/. cp -r report.html ${TEST_RESULTS_DIR}/.
fi fi
# exit 0 if no tests implemented # exit 0 if no tests implemented
RUN_NO_TESTS="$(grep -c 'no tests ran' test_results.out)" RUN_NO_TESTS="$(grep -c 'no tests ran' test_results.out)"
if [[ ${RUN_NO_TESTS} > 0 ]]; then if [[ ${RUN_NO_TESTS} -gt 0 ]]; then
echo "no test available" echo "no test available"
echo "incomplete" > status.txt echo "incomplete" > status.txt
echo "no tests avail" > incomplete.txt echo "no tests avail" > incomplete.txt
...@@ -27,20 +33,19 @@ fi ...@@ -27,20 +33,19 @@ fi
# extract if tests passed or not # extract if tests passed or not
TEST_FAILED="$(grep -oP '(\d+\s{1}failed)' test_results.out)" TEST_FAILED="$(grep -oP '(\d+\s{1}failed)' test_results.out)"
TEST_FAILED="$(echo ${TEST_FAILED} | (grep -oP '\d*'))" TEST_FAILED="$(echo "${TEST_FAILED}" | (grep -oP '\d*'))"
TEST_PASSED="$(grep -oP '\d+\s{1}passed' test_results.out)" TEST_PASSED="$(grep -oP '\d+\s{1}passed' test_results.out)"
TEST_PASSED="$(echo ${TEST_PASSED} | (grep -oP '\d*'))" TEST_PASSED="$(echo "${TEST_PASSED}" | (grep -oP '\d*'))"
if [[ -z "$TEST_FAILED" ]]; then if [[ -z "$TEST_FAILED" ]]; then
TEST_FAILED=0 TEST_FAILED=0
fi fi
let "TEST_PASSED=${TEST_PASSED}-${TEST_FAILED}" (( TEST_PASSED=TEST_PASSED-TEST_FAILED ))
# calculate metrics # calculate metrics
let "SUM=${TEST_FAILED}+${TEST_PASSED}" (( SUM=TEST_FAILED+TEST_PASSED ))
let "TEST_PASSED_RATIO=${TEST_PASSED}*100/${SUM}" (( TEST_PASSED_RATIO=TEST_PASSED*100/SUM ))
# report # report
if [[ ${IS_FAILED} == 0 ]]; then if [[ ${IS_FAILED} -eq 0 ]]; then
if [[ ${TEST_PASSED_RATIO} -lt 100 ]]; then if [[ ${TEST_PASSED_RATIO} -lt 100 ]]; then
echo "only ${TEST_PASSED_RATIO}% passed" echo "only ${TEST_PASSED_RATIO}% passed"
echo "incomplete" > status.txt echo "incomplete" > status.txt
......
#!/usr/bin/env bash #!/usr/bin/env bash
# run coverage twice, 1) for html deploy 2) for success evaluation # run coverage twice, 1) for html deploy 2) for success evaluation
python3 -m pytest --cov=src --cov-report html test/ python3.6 -m pytest --cov=src --cov-report term --cov-report html test/ | tee coverage_results.out
python3 -m pytest --cov=src --cov-report term test/ | tee coverage_results.out
IS_FAILED=$? IS_FAILED=$?
# move html coverage report # move html coverage report
mkdir coverage/ mkdir coverage/
BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}") BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}")
mkdir coverage/${BRANCH_NAME} mkdir "coverage/${BRANCH_NAME}"
mkdir coverage/recent mkdir "coverage/recent"
cp -r htmlcov/* coverage/${BRANCH_NAME}/. cp -r htmlcov/* "coverage/${BRANCH_NAME}/."
cp -r htmlcov/* coverage/recent/. cp -r htmlcov/* coverage/recent/.
if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then
cp -r htmlcov/* coverage/. cp -r htmlcov/* coverage/.
...@@ -19,10 +18,10 @@ fi ...@@ -19,10 +18,10 @@ fi
# extract coverage information # extract coverage information
COVERAGE_RATIO="$(grep -oP '\d+\%' coverage_results.out | tail -1)" COVERAGE_RATIO="$(grep -oP '\d+\%' coverage_results.out | tail -1)"
COVERAGE_RATIO="$(echo ${COVERAGE_RATIO} | (grep -oP '\d*'))" COVERAGE_RATIO="$(echo "${COVERAGE_RATIO}" | (grep -oP '\d*'))"
# report # report
if [[ ${IS_FAILED} == 0 ]]; then if [[ ${IS_FAILED} -eq 0 ]]; then
if [[ ${COVERAGE_RATIO} -lt ${COVERAGE_PASS_THRESHOLD} ]]; then if [[ ${COVERAGE_RATIO} -lt ${COVERAGE_PASS_THRESHOLD} ]]; then
echo "only ${COVERAGE_RATIO}% covered" echo "only ${COVERAGE_RATIO}% covered"
echo "incomplete" > status.txt echo "incomplete" > status.txt
......
["DENW094", "DEBW029", "DENI052", "DENI063", "DEBY109", "DEUB022", "DESN001", "DEUB013", "DETH016", "DEBY002", "DEBY005", "DEBY099", "DEUB038", "DEBE051", "DEBE056", "DEBE062", "DEBE032", "DEBE034", "DEBE010", "DEHE046", "DEST031", "DEBY122", "DERP022", "DEBY079", "DEBW102", "DEBW076", "DEBW045", "DESH016", "DESN004", "DEHE032", "DEBB050", "DEBW042", "DEBW046", "DENW067", "DESL019", "DEST014", "DENW062", "DEHE033", "DENW081", "DESH008", "DEBB055", "DENI011", "DEHB001", "DEHB004", "DEHB002", "DEHB003", "DEHB005", "DEST039", "DEUB003", "DEBW072", "DEST002", "DEBB001", "DEHE039", "DEBW035", "DESN005", "DEBW047", "DENW004", "DESN011", "DESN076", "DEBB064", "DEBB006", "DEHE001", "DESN012", "DEST030", "DESL003", "DEST104", "DENW050", "DENW008", "DETH026", "DESN085", "DESN014", "DESN092", "DENW071", "DEBW004", "DENI028", "DETH013", "DENI059", "DEBB007", "DEBW049", "DENI043", "DETH020", "DEBY017", "DEBY113", "DENW247", "DENW028", "DEBW025", "DEUB039", "DEBB009", "DEHE027", "DEBB042", "DEHE008", "DESN017", "DEBW084", "DEBW037", "DEHE058", "DEHE028", "DEBW112", "DEBY081", "DEBY082", "DEST032", "DETH009", "DEHE010", "DESN019", "DEHE023", "DETH036", "DETH040", "DEMV017", "DEBW028", "DENI042", "DEMV004", "DEMV019", "DEST044", "DEST050", "DEST072", "DEST022", "DEHH049", "DEHH047", "DEHH033", "DEHH050", "DEHH008", "DEHH021", "DENI054", "DEST070", "DEBB053", "DENW029", "DEBW050", "DEUB034", "DENW018", "DEST052", "DEBY020", "DENW063", "DESN050", "DETH061", "DERP014", "DETH024", "DEBW094", "DENI031", "DETH041", "DERP019", "DEBW081", "DEHE013", "DEBW021", "DEHE060", "DEBY031", "DESH021", "DESH033", "DEHE052", "DEBY004", "DESN024", "DEBW052", "DENW042", "DEBY032", "DENW053", "DENW059", "DEBB082", "DEBB031", "DEHE025", "DEBW053", "DEHE048", "DENW051", "DEBY034", "DEUB035", "DEUB032", "DESN028", "DESN059", "DEMV024", "DENW079", "DEHE044", "DEHE042", "DEBB043", "DEBB036", "DEBW024", "DERP001", "DEMV012", "DESH005", "DESH023", "DEUB031", "DENI062", "DENW006", "DEBB065", "DEST077", "DEST005", "DERP007", "DEBW006", "DEBW007", "DEHE030", "DENW015", "DEBY013", "DETH025", "DEUB033", "DEST025", "DEHE045", "DESN057", "DENW036", "DEBW044", "DEUB036", "DENW096", "DETH095", "DENW038", "DEBY089", "DEBY039", "DENW095", "DEBY047", "DEBB067", "DEBB040", "DEST078", "DENW065", "DENW066", "DEBY052", "DEUB030", "DETH027", "DEBB048", "DENW047", "DEBY049", "DERP021", "DEHE034", "DESN079", "DESL008", "DETH018", "DEBW103", "DEHE017", "DEBW111", "DENI016", "DENI038", "DENI058", "DENI029", "DEBY118", "DEBW032", "DEBW110", "DERP017", "DESN036", "DEBW026", "DETH042", "DEBB075", "DEBB052", "DEBB021", "DEBB038", "DESN051", "DEUB041", "DEBW020", "DEBW113", "DENW078", "DEHE018", "DEBW065", "DEBY062", "DEBW027", "DEBW041", "DEHE043", "DEMV007", "DEMV021", "DEBW054", "DETH005", "DESL012", "DESL011", "DEST069", "DEST071", "DEUB004", "DESH006", "DEUB029", "DEUB040", "DESN074", "DEBW031", "DENW013", "DENW179", "DEBW056", "DEBW087", "DEST061", "DEMV001", "DEBB024", "DEBW057", "DENW064", "DENW068", "DENW080", "DENI019", "DENI077", "DEHE026", "DEBB066", "DEBB083", "DEST063", "DEBW013", "DETH086", "DESL018", "DETH096", "DEBW059", "DEBY072", "DEBY088", "DEBW060", "DEBW107", "DEBW036", "DEUB026", "DEBW019", "DENW010", "DEST098", "DEHE019", "DEBW039", "DESL017", "DEBW034", "DEUB005", "DEBB051", "DEHE051", "DEBW023", "DEBY092", "DEBW008", "DEBW030", "DENI060", "DEST011", "DENW030", "DENI041", "DERP015", "DEUB001", "DERP016", "DERP028", "DERP013", "DEHE022", "DEUB021", "DEBW010", "DEST066", "DEBB063", "DEBB028", "DEHE024", "DENI020", "DENI051", "DERP025", "DEBY077", "DEMV018", "DEST089", "DEST028", "DETH060", "DEHE050", "DEUB028", "DESN045", "DEUB042"]
...@@ -43,17 +43,20 @@ pytest-cov==2.8.1 ...@@ -43,17 +43,20 @@ pytest-cov==2.8.1
pytest-html==2.0.1 pytest-html==2.0.1
pytest-lazy-fixture==0.6.3 pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0 pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1 python-dateutil==2.8.1
pytz==2019.3 pytz==2019.3
PyYAML==5.3 PyYAML==5.3
requests==2.23.0 requests==2.23.0
scipy==1.4.1 scipy==1.4.1
seaborn==0.10.0 seaborn==0.10.0
Shapely==1.7.0 --no-binary shapely Shapely==1.7.0
six==1.11.0 six==1.11.0
statsmodels==0.11.1 statsmodels==0.11.1
tensorboard==1.12.2 tabulate
tensorflow==1.12.0 tensorboard==1.13.1
tensorflow-estimator==1.13.0
tensorflow==1.13.1
termcolor==1.1.0 termcolor==1.1.0
toolz==0.10.0 toolz==0.10.0
urllib3==1.25.8 urllib3==1.25.8
......
absl-py==0.9.0
astor==0.8.1
atomicwrites==1.3.0
attrs==19.3.0
Cartopy==0.17.0
certifi==2019.11.28
chardet==3.0.4
cloudpickle==1.3.0
coverage==5.0.3
cycler==0.10.0
Cython==0.29.15
dask==2.11.0
fsspec==0.6.2
gast==0.3.3
grpcio==1.27.2
h5py==2.10.0
idna==2.8
importlib-metadata==1.5.0
Keras==2.2.4
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
kiwisolver==1.1.0
locket==0.2.0
Markdown==3.2.1
matplotlib==3.2.0
mock==4.0.1
more-itertools==8.2.0
numpy==1.18.1
packaging==20.3
pandas==1.0.1
partd==1.1.0
patsy==0.5.1
Pillow==7.0.0
pluggy==0.13.1
protobuf==3.11.3
py==1.8.1
pydot==1.4.1
pyparsing==2.4.6
pyproj==2.5.0
pyshp==2.1.0
pytest==5.3.5
pytest-cov==2.8.1
pytest-html==2.0.1
pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1
pytz==2019.3
PyYAML==5.3
requests==2.23.0
scipy==1.4.1
seaborn==0.10.0
--no-binary shapely Shapely==1.7.0
six==1.11.0
statsmodels==0.11.1
tabulate
tensorboard==1.13.1
tensorflow-estimator==1.13.0
tensorflow-gpu==1.13.1
termcolor==1.1.0
toolz==0.10.0
urllib3==1.25.8
wcwidth==0.1.8
Werkzeug==1.0.0
xarray==0.15.0
zipp==3.1.0
...@@ -3,7 +3,6 @@ __date__ = '2019-11-14' ...@@ -3,7 +3,6 @@ __date__ = '2019-11-14'
import argparse import argparse
import logging
from src.run_modules.experiment_setup import ExperimentSetup from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.model_setup import ModelSetup from src.run_modules.model_setup import ModelSetup
...@@ -17,7 +16,8 @@ def main(parser_args): ...@@ -17,7 +16,8 @@ def main(parser_args):
with RunEnvironment(): with RunEnvironment():
ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
station_type='background', trainable=False, create_new_model=False) station_type='background', trainable=False, create_new_model=False, window_history_size=6,
create_new_bootstraps=True)
PreProcessing() PreProcessing()
ModelSetup() ModelSetup()
...@@ -29,10 +29,6 @@ def main(parser_args): ...@@ -29,10 +29,6 @@ def main(parser_args):
if __name__ == "__main__": if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
# logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string") help="set experiment date as string")
......
...@@ -29,10 +29,6 @@ def main(parser_args): ...@@ -29,10 +29,6 @@ def main(parser_args):
if __name__ == "__main__": if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
# logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string") help="set experiment date as string")
......
__author__ = "Lukas Leufen"
__date__ = '2019-11-14'
import argparse
import json
import logging
from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.model_setup import ModelSetup
from src.run_modules.post_processing import PostProcessing
from src.run_modules.pre_processing import PreProcessing
from src.run_modules.run_environment import RunEnvironment
from src.run_modules.training import Training
def load_stations():
try:
filename = 'German_background_stations.json'
with open(filename, 'r') as jfile:
stations = json.load(jfile)
logging.info(f"load station IDs from file: {filename} ({len(stations)} stations)")
# stations.remove('DEUB042')
except FileNotFoundError:
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']
# stations = ['DEBB050', 'DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']
logging.info(f"Use stations from list: {stations}")
return stations
def main(parser_args):
with RunEnvironment():
ExperimentSetup(parser_args, stations=load_stations(), station_type='background', trainable=False,
create_new_model=True)
PreProcessing()
ModelSetup()
Training()
PostProcessing()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
...@@ -2,152 +2,108 @@ __author__ = 'Felix Kleinert, Lukas Leufen' ...@@ -2,152 +2,108 @@ __author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2020-02-07' __date__ = '2020-02-07'
from src.run_modules.run_environment import RunEnvironment
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
import numpy as np import numpy as np
import logging import logging
import keras
import dask.array as da import dask.array as da
import xarray as xr import xarray as xr
import os import os
import re import re
from src import helpers from src import helpers
from typing import List, Union, Pattern, Tuple
class BootStrapGenerator: class BootStrapGenerator(keras.utils.Sequence):
"""
generator for bootstraps as keras sequence inheritance. Initialise with number of boots, the original history, the
shuffled data, all used variables and the current shuffled variable. While iterating over this generator, it returns
the bootstrapped history for given boot index (this is the iterator index) in the same format like the original
history ready to use. Note, that in some cases some samples can contain nan values (in these cases the entire data
row is null, not only single entries).
"""
def __init__(self, number_of_boots: int, history: xr.DataArray, shuffled: xr.DataArray, variables: List[str],
shuffled_variable: str):
self.number_of_boots = number_of_boots
self.variables = variables
self.history_orig = history
self.history = history.sel(variables=helpers.list_pop(self.variables, shuffled_variable))
self.shuffled = shuffled.sel(variables=shuffled_variable)
def __init__(self, orig_generator, boots, chunksize, bootstrap_path): def __len__(self) -> int:
self.orig_generator: DataGenerator = orig_generator return self.number_of_boots
self.stations = self.orig_generator.stations
self.variables = self.orig_generator.variables
self.boots = boots
self.chunksize = chunksize
self.bootstrap_path = bootstrap_path
self._iterator = 0
self.bootstrap_meta = []
def __len__(self):
"""
display the number of stations
"""
return len(self.orig_generator)*self.boots*len(self.variables)
def get_labels(self, key):
_, label = self.orig_generator[key]
for _ in range(self.boots):
yield label
def get_generator(self):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
while True:
for i, data in enumerate(self.orig_generator):
station = self.orig_generator.get_station_key(i)
logging.info(f"station: {station}")
hist, label = data
len_of_label = len(label)
shuffled_data = self.load_boot_data(station)
for var in self.variables:
logging.info(f" var: {var}")
for boot in range(self.boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables")
self.bootstrap_meta.extend([[var, station]]*len_of_label)
yield boot_hist, label
return
def get_orig_prediction(self, path, file_name, prediction_name="CNN"):
file = os.path.join(path, file_name)
data = xr.open_dataarray(file)
for _ in range(self.boots):
yield data.sel(type=prediction_name).squeeze()
def load_boot_data(self, station):
files = os.listdir(self.bootstrap_path)
regex = re.compile(rf"{station}_\w*\.nc")
file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
shuffled_data = xr.open_dataarray(file_name, chunks=100)
return shuffled_data
def __getitem__(self, index: int) -> xr.DataArray:
"""
return bootstrapped history for given bootstrap index in same index structure like the original history object
:param index: boot index e [0, nboots-1]
:return: bootstrapped history ready to use
"""
logging.debug(f"boot: {index}")
boot_hist = self.history.copy()
boot_hist = boot_hist.combine_first(self.__get_shuffled(index))
return boot_hist.reindex_like(self.history_orig)
class BootStraps(RunEnvironment): def __get_shuffled(self, index: int) -> xr.DataArray:
"""
returns shuffled data for given boot index from shuffled attribute
:param index: boot index e [0, nboots-1]
:return: shuffled data
"""
shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots")
return shuffled_var.transpose("datetime", "window", "Stations", "variables")