Commit a19c4fe6 authored by lukas leufen's avatar lukas leufen

Merge branch 'develop' into release_v0.9.0

parents 53546303 7553788c
Pipeline #34435 passed with stages
in 6 minutes and 18 seconds
...@@ -45,9 +45,10 @@ Thumbs.db ...@@ -45,9 +45,10 @@ Thumbs.db
/data/ /data/
/plots/ /plots/
# tmp folder # # tmp and logging folder #
############## ##########################
/tmp/ /tmp/
/logging/
# test related data # # test related data #
##################### #####################
......
...@@ -23,13 +23,71 @@ version: ...@@ -23,13 +23,71 @@ version:
paths: paths:
- badges/ - badges/
### Tests (from scratch) ###
tests (from scratch):
tags:
- base
- zam347
stage: test
only:
- master
- /^release.*$/
- develop
variables:
FAILURE_THRESHOLD: 100
TEST_TYPE: "scratch"
before_script:
- chmod +x ./CI/update_badge.sh
- ./CI/update_badge.sh > /dev/null
script:
- zypper --non-interactive install binutils libproj-devel gdal-devel
- zypper --non-interactive install proj geos-devel
- pip install -r requirements.txt
- chmod +x ./CI/run_pytest.sh
- ./CI/run_pytest.sh
after_script:
- ./CI/update_badge.sh > /dev/null
artifacts:
name: pages
when: always
paths:
- badges/
- test_results/
### Tests (on GPU) ###
tests (on GPU):
tags:
- gpu
- zam347
stage: test
only:
- master
- /^release.*$/
- develop
variables:
FAILURE_THRESHOLD: 100
TEST_TYPE: "gpu"
before_script:
- chmod +x ./CI/update_badge.sh
- ./CI/update_badge.sh > /dev/null
script:
- pip install -r requirements.txt
- chmod +x ./CI/run_pytest.sh
- ./CI/run_pytest.sh
after_script:
- ./CI/update_badge.sh > /dev/null
artifacts:
name: pages
when: always
paths:
- badges/
- test_results/
### Tests ### ### Tests ###
tests: tests:
tags: tags:
- leap - machinelearningtools
- zam347 - zam347
- base
- django
stage: test stage: test
variables: variables:
FAILURE_THRESHOLD: 100 FAILURE_THRESHOLD: 100
...@@ -51,10 +109,8 @@ tests: ...@@ -51,10 +109,8 @@ tests:
coverage: coverage:
tags: tags:
- leap - machinelearningtools
- zam347 - zam347
- base
- django
stage: test stage: test
variables: variables:
FAILURE_THRESHOLD: 50 FAILURE_THRESHOLD: 50
...@@ -79,7 +135,6 @@ coverage: ...@@ -79,7 +135,6 @@ coverage:
pages: pages:
stage: pages stage: pages
tags: tags:
- leap
- zam347 - zam347
- base - base
script: script:
......
#!/bin/bash #!/bin/bash
# run pytest for all run_modules # run pytest for all run_modules
python3 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out python3.6 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out
IS_FAILED=$? IS_FAILED=$?
if [ -z ${TEST_TYPE+x} ]; then
TEST_TYPE=""; else
TEST_TYPE="_"${TEST_TYPE};
fi
# move html test report # move html test report
mkdir test_results/ TEST_RESULTS_DIR="test_results${TEST_TYPE}/"
mkdir ${TEST_RESULTS_DIR}
BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}") BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}")
mkdir test_results/${BRANCH_NAME} mkdir "${TEST_RESULTS_DIR}/${BRANCH_NAME}"
mkdir test_results/recent mkdir "${TEST_RESULTS_DIR}/recent"
cp report.html test_results/${BRANCH_NAME}/. cp report.html "${TEST_RESULTS_DIR}/${BRANCH_NAME}/."
cp report.html test_results/recent/. cp report.html "${TEST_RESULTS_DIR}/recent/."
if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then
cp -r report.html test_results/. cp -r report.html ${TEST_RESULTS_DIR}/.
fi fi
# exit 0 if no tests implemented # exit 0 if no tests implemented
RUN_NO_TESTS="$(grep -c 'no tests ran' test_results.out)" RUN_NO_TESTS="$(grep -c 'no tests ran' test_results.out)"
if [[ ${RUN_NO_TESTS} > 0 ]]; then if [[ ${RUN_NO_TESTS} -gt 0 ]]; then
echo "no test available" echo "no test available"
echo "incomplete" > status.txt echo "incomplete" > status.txt
echo "no tests avail" > incomplete.txt echo "no tests avail" > incomplete.txt
...@@ -27,20 +33,19 @@ fi ...@@ -27,20 +33,19 @@ fi
# extract if tests passed or not # extract if tests passed or not
TEST_FAILED="$(grep -oP '(\d+\s{1}failed)' test_results.out)" TEST_FAILED="$(grep -oP '(\d+\s{1}failed)' test_results.out)"
TEST_FAILED="$(echo ${TEST_FAILED} | (grep -oP '\d*'))" TEST_FAILED="$(echo "${TEST_FAILED}" | (grep -oP '\d*'))"
TEST_PASSED="$(grep -oP '\d+\s{1}passed' test_results.out)" TEST_PASSED="$(grep -oP '\d+\s{1}passed' test_results.out)"
TEST_PASSED="$(echo ${TEST_PASSED} | (grep -oP '\d*'))" TEST_PASSED="$(echo "${TEST_PASSED}" | (grep -oP '\d*'))"
if [[ -z "$TEST_FAILED" ]]; then if [[ -z "$TEST_FAILED" ]]; then
TEST_FAILED=0 TEST_FAILED=0
fi fi
let "TEST_PASSED=${TEST_PASSED}-${TEST_FAILED}" (( TEST_PASSED=TEST_PASSED-TEST_FAILED ))
# calculate metrics # calculate metrics
let "SUM=${TEST_FAILED}+${TEST_PASSED}" (( SUM=TEST_FAILED+TEST_PASSED ))
let "TEST_PASSED_RATIO=${TEST_PASSED}*100/${SUM}" (( TEST_PASSED_RATIO=TEST_PASSED*100/SUM ))
# report # report
if [[ ${IS_FAILED} == 0 ]]; then if [[ ${IS_FAILED} -eq 0 ]]; then
if [[ ${TEST_PASSED_RATIO} -lt 100 ]]; then if [[ ${TEST_PASSED_RATIO} -lt 100 ]]; then
echo "only ${TEST_PASSED_RATIO}% passed" echo "only ${TEST_PASSED_RATIO}% passed"
echo "incomplete" > status.txt echo "incomplete" > status.txt
......
#!/usr/bin/env bash #!/usr/bin/env bash
# run coverage twice, 1) for html deploy 2) for success evaluation # run coverage twice, 1) for html deploy 2) for success evaluation
python3 -m pytest --cov=src --cov-report html test/ python3.6 -m pytest --cov=src --cov-report term --cov-report html test/ | tee coverage_results.out
python3 -m pytest --cov=src --cov-report term test/ | tee coverage_results.out
IS_FAILED=$? IS_FAILED=$?
# move html coverage report # move html coverage report
mkdir coverage/ mkdir coverage/
BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}") BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}")
mkdir coverage/${BRANCH_NAME} mkdir "coverage/${BRANCH_NAME}"
mkdir coverage/recent mkdir "coverage/recent"
cp -r htmlcov/* coverage/${BRANCH_NAME}/. cp -r htmlcov/* "coverage/${BRANCH_NAME}/."
cp -r htmlcov/* coverage/recent/. cp -r htmlcov/* coverage/recent/.
if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then
cp -r htmlcov/* coverage/. cp -r htmlcov/* coverage/.
...@@ -19,10 +18,10 @@ fi ...@@ -19,10 +18,10 @@ fi
# extract coverage information # extract coverage information
COVERAGE_RATIO="$(grep -oP '\d+\%' coverage_results.out | tail -1)" COVERAGE_RATIO="$(grep -oP '\d+\%' coverage_results.out | tail -1)"
COVERAGE_RATIO="$(echo ${COVERAGE_RATIO} | (grep -oP '\d*'))" COVERAGE_RATIO="$(echo "${COVERAGE_RATIO}" | (grep -oP '\d*'))"
# report # report
if [[ ${IS_FAILED} == 0 ]]; then if [[ ${IS_FAILED} -eq 0 ]]; then
if [[ ${COVERAGE_RATIO} -lt ${COVERAGE_PASS_THRESHOLD} ]]; then if [[ ${COVERAGE_RATIO} -lt ${COVERAGE_PASS_THRESHOLD} ]]; then
echo "only ${COVERAGE_RATIO}% covered" echo "only ${COVERAGE_RATIO}% covered"
echo "incomplete" > status.txt echo "incomplete" > status.txt
......
["DENW094", "DEBW029", "DENI052", "DENI063", "DEBY109", "DEUB022", "DESN001", "DEUB013", "DETH016", "DEBY002", "DEBY005", "DEBY099", "DEUB038", "DEBE051", "DEBE056", "DEBE062", "DEBE032", "DEBE034", "DEBE010", "DEHE046", "DEST031", "DEBY122", "DERP022", "DEBY079", "DEBW102", "DEBW076", "DEBW045", "DESH016", "DESN004", "DEHE032", "DEBB050", "DEBW042", "DEBW046", "DENW067", "DESL019", "DEST014", "DENW062", "DEHE033", "DENW081", "DESH008", "DEBB055", "DENI011", "DEHB001", "DEHB004", "DEHB002", "DEHB003", "DEHB005", "DEST039", "DEUB003", "DEBW072", "DEST002", "DEBB001", "DEHE039", "DEBW035", "DESN005", "DEBW047", "DENW004", "DESN011", "DESN076", "DEBB064", "DEBB006", "DEHE001", "DESN012", "DEST030", "DESL003", "DEST104", "DENW050", "DENW008", "DETH026", "DESN085", "DESN014", "DESN092", "DENW071", "DEBW004", "DENI028", "DETH013", "DENI059", "DEBB007", "DEBW049", "DENI043", "DETH020", "DEBY017", "DEBY113", "DENW247", "DENW028", "DEBW025", "DEUB039", "DEBB009", "DEHE027", "DEBB042", "DEHE008", "DESN017", "DEBW084", "DEBW037", "DEHE058", "DEHE028", "DEBW112", "DEBY081", "DEBY082", "DEST032", "DETH009", "DEHE010", "DESN019", "DEHE023", "DETH036", "DETH040", "DEMV017", "DEBW028", "DENI042", "DEMV004", "DEMV019", "DEST044", "DEST050", "DEST072", "DEST022", "DEHH049", "DEHH047", "DEHH033", "DEHH050", "DEHH008", "DEHH021", "DENI054", "DEST070", "DEBB053", "DENW029", "DEBW050", "DEUB034", "DENW018", "DEST052", "DEBY020", "DENW063", "DESN050", "DETH061", "DERP014", "DETH024", "DEBW094", "DENI031", "DETH041", "DERP019", "DEBW081", "DEHE013", "DEBW021", "DEHE060", "DEBY031", "DESH021", "DESH033", "DEHE052", "DEBY004", "DESN024", "DEBW052", "DENW042", "DEBY032", "DENW053", "DENW059", "DEBB082", "DEBB031", "DEHE025", "DEBW053", "DEHE048", "DENW051", "DEBY034", "DEUB035", "DEUB032", "DESN028", "DESN059", "DEMV024", "DENW079", "DEHE044", "DEHE042", "DEBB043", "DEBB036", "DEBW024", "DERP001", "DEMV012", "DESH005", "DESH023", "DEUB031", "DENI062", "DENW006", "DEBB065", "DEST077", "DEST005", "DERP007", "DEBW006", "DEBW007", "DEHE030", "DENW015", "DEBY013", "DETH025", "DEUB033", "DEST025", "DEHE045", "DESN057", "DENW036", "DEBW044", "DEUB036", "DENW096", "DETH095", "DENW038", "DEBY089", "DEBY039", "DENW095", "DEBY047", "DEBB067", "DEBB040", "DEST078", "DENW065", "DENW066", "DEBY052", "DEUB030", "DETH027", "DEBB048", "DENW047", "DEBY049", "DERP021", "DEHE034", "DESN079", "DESL008", "DETH018", "DEBW103", "DEHE017", "DEBW111", "DENI016", "DENI038", "DENI058", "DENI029", "DEBY118", "DEBW032", "DEBW110", "DERP017", "DESN036", "DEBW026", "DETH042", "DEBB075", "DEBB052", "DEBB021", "DEBB038", "DESN051", "DEUB041", "DEBW020", "DEBW113", "DENW078", "DEHE018", "DEBW065", "DEBY062", "DEBW027", "DEBW041", "DEHE043", "DEMV007", "DEMV021", "DEBW054", "DETH005", "DESL012", "DESL011", "DEST069", "DEST071", "DEUB004", "DESH006", "DEUB029", "DEUB040", "DESN074", "DEBW031", "DENW013", "DENW179", "DEBW056", "DEBW087", "DEST061", "DEMV001", "DEBB024", "DEBW057", "DENW064", "DENW068", "DENW080", "DENI019", "DENI077", "DEHE026", "DEBB066", "DEBB083", "DEST063", "DEBW013", "DETH086", "DESL018", "DETH096", "DEBW059", "DEBY072", "DEBY088", "DEBW060", "DEBW107", "DEBW036", "DEUB026", "DEBW019", "DENW010", "DEST098", "DEHE019", "DEBW039", "DESL017", "DEBW034", "DEUB005", "DEBB051", "DEHE051", "DEBW023", "DEBY092", "DEBW008", "DEBW030", "DENI060", "DEST011", "DENW030", "DENI041", "DERP015", "DEUB001", "DERP016", "DERP028", "DERP013", "DEHE022", "DEUB021", "DEBW010", "DEST066", "DEBB063", "DEBB028", "DEHE024", "DENI020", "DENI051", "DERP025", "DEBY077", "DEMV018", "DEST089", "DEST028", "DETH060", "DEHE050", "DEUB028", "DESN045", "DEUB042"]
...@@ -43,17 +43,20 @@ pytest-cov==2.8.1 ...@@ -43,17 +43,20 @@ pytest-cov==2.8.1
pytest-html==2.0.1 pytest-html==2.0.1
pytest-lazy-fixture==0.6.3 pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0 pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1 python-dateutil==2.8.1
pytz==2019.3 pytz==2019.3
PyYAML==5.3 PyYAML==5.3
requests==2.23.0 requests==2.23.0
scipy==1.4.1 scipy==1.4.1
seaborn==0.10.0 seaborn==0.10.0
Shapely==1.7.0 --no-binary shapely Shapely==1.7.0
six==1.11.0 six==1.11.0
statsmodels==0.11.1 statsmodels==0.11.1
tensorboard==1.12.2 tabulate
tensorflow==1.12.0 tensorboard==1.13.1
tensorflow-estimator==1.13.0
tensorflow==1.13.1
termcolor==1.1.0 termcolor==1.1.0
toolz==0.10.0 toolz==0.10.0
urllib3==1.25.8 urllib3==1.25.8
......
absl-py==0.9.0
astor==0.8.1
atomicwrites==1.3.0
attrs==19.3.0
Cartopy==0.17.0
certifi==2019.11.28
chardet==3.0.4
cloudpickle==1.3.0
coverage==5.0.3
cycler==0.10.0
Cython==0.29.15
dask==2.11.0
fsspec==0.6.2
gast==0.3.3
grpcio==1.27.2
h5py==2.10.0
idna==2.8
importlib-metadata==1.5.0
Keras==2.2.4
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
kiwisolver==1.1.0
locket==0.2.0
Markdown==3.2.1
matplotlib==3.2.0
mock==4.0.1
more-itertools==8.2.0
numpy==1.18.1
packaging==20.3
pandas==1.0.1
partd==1.1.0
patsy==0.5.1
Pillow==7.0.0
pluggy==0.13.1
protobuf==3.11.3
py==1.8.1
pydot==1.4.1
pyparsing==2.4.6
pyproj==2.5.0
pyshp==2.1.0
pytest==5.3.5
pytest-cov==2.8.1
pytest-html==2.0.1
pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1
pytz==2019.3
PyYAML==5.3
requests==2.23.0
scipy==1.4.1
seaborn==0.10.0
--no-binary shapely Shapely==1.7.0
six==1.11.0
statsmodels==0.11.1
tabulate
tensorboard==1.13.1
tensorflow-estimator==1.13.0
tensorflow-gpu==1.13.1
termcolor==1.1.0
toolz==0.10.0
urllib3==1.25.8
wcwidth==0.1.8
Werkzeug==1.0.0
xarray==0.15.0
zipp==3.1.0
...@@ -3,7 +3,6 @@ __date__ = '2019-11-14' ...@@ -3,7 +3,6 @@ __date__ = '2019-11-14'
import argparse import argparse
import logging
from src.run_modules.experiment_setup import ExperimentSetup from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.model_setup import ModelSetup from src.run_modules.model_setup import ModelSetup
...@@ -17,7 +16,8 @@ def main(parser_args): ...@@ -17,7 +16,8 @@ def main(parser_args):
with RunEnvironment(): with RunEnvironment():
ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
station_type='background', trainable=False, create_new_model=False) station_type='background', trainable=False, create_new_model=False, window_history_size=6,
create_new_bootstraps=True)
PreProcessing() PreProcessing()
ModelSetup() ModelSetup()
...@@ -29,10 +29,6 @@ def main(parser_args): ...@@ -29,10 +29,6 @@ def main(parser_args):
if __name__ == "__main__": if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
# logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string") help="set experiment date as string")
......
...@@ -29,10 +29,6 @@ def main(parser_args): ...@@ -29,10 +29,6 @@ def main(parser_args):
if __name__ == "__main__": if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO)
# logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string") help="set experiment date as string")
......
__author__ = "Lukas Leufen"
__date__ = '2019-11-14'
import argparse
import json
import logging
from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.model_setup import ModelSetup
from src.run_modules.post_processing import PostProcessing
from src.run_modules.pre_processing import PreProcessing
from src.run_modules.run_environment import RunEnvironment
from src.run_modules.training import Training
def load_stations():
try:
filename = 'German_background_stations.json'
with open(filename, 'r') as jfile:
stations = json.load(jfile)
logging.info(f"load station IDs from file: {filename} ({len(stations)} stations)")
# stations.remove('DEUB042')
except FileNotFoundError:
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']
# stations = ['DEBB050', 'DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']
logging.info(f"Use stations from list: {stations}")
return stations
def main(parser_args):
with RunEnvironment():
ExperimentSetup(parser_args, stations=load_stations(), station_type='background', trainable=False,
create_new_model=True)
PreProcessing()
ModelSetup()
Training()
PostProcessing()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
This diff is collapsed.
...@@ -8,15 +8,18 @@ import math ...@@ -8,15 +8,18 @@ import math
import keras import keras
import numpy as np import numpy as np
from src.data_handling.data_generator import DataGenerator
class Distributor(keras.utils.Sequence): class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256,
permute_data: bool = False): permute_data: bool = False, upsampling: bool = False):
self.generator = generator self.generator = generator
self.model = model self.model = model
self.batch_size = batch_size self.batch_size = batch_size
self.do_data_permutation = permute_data self.do_data_permutation = permute_data
self.upsampling = upsampling
def _get_model_rank(self): def _get_model_rank(self):
mod_out = self.model.output_shape mod_out = self.model.output_shape
...@@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence): ...@@ -31,7 +34,7 @@ class Distributor(keras.utils.Sequence):
return mod_rank return mod_rank
def _get_number_of_mini_batches(self, values): def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size) return math.ceil(values.shape[0] / self.batch_size)
def _permute_data(self, x, y): def _permute_data(self, x, y):
""" """
...@@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence): ...@@ -48,10 +51,18 @@ class Distributor(keras.utils.Sequence):
for k, v in enumerate(self.generator): for k, v in enumerate(self.generator):
# get rank of output # get rank of output
mod_rank = self._get_model_rank() mod_rank = self._get_model_rank()
# get number of mini batches # get data
num_mini_batches = self._get_number_of_mini_batches(v)
x_total = np.copy(v[0]) x_total = np.copy(v[0])
y_total = np.copy(v[1]) y_total = np.copy(v[1])
if self.upsampling:
try:
s = self.generator.get_data_generator(k)
x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0)
y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0)
except AttributeError: # no extremes history / labels available, copy will fail
pass
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(x_total)
# permute order for mini-batches # permute order for mini-batches
x_total, y_total = self._permute_data(x_total, y_total) x_total, y_total = self._permute_data(x_total, y_total)
for prev, curr in enumerate(range(1, num_mini_batches+1)): for prev, curr in enumerate(range(1, num_mini_batches+1)):
......
...@@ -14,6 +14,9 @@ from src import helpers ...@@ -14,6 +14,9 @@ from src import helpers
from src.data_handling.data_preparation import DataPrep from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult from src.join import EmptyQueryResult
number = Union[float, int]
num_or_list = Union[number, List[number]]
class DataGenerator(keras.utils.Sequence): class DataGenerator(keras.utils.Sequence):
""" """
...@@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence): ...@@ -27,7 +30,7 @@ class DataGenerator(keras.utils.Sequence):
def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transformation: Dict = None, **kwargs): window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs):
self.data_path = os.path.abspath(data_path) self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp): if not os.path.exists(self.data_path_tmp):
...@@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence): ...@@ -43,6 +46,7 @@ class DataGenerator(keras.utils.Sequence):
self.limit_nan_fill = limit_nan_fill self.limit_nan_fill = limit_nan_fill
self.window_history_size = window_history_size self.window_history_size = window_history_size
self.window_lead_time = window_lead_time self.window_lead_time = window_lead_time
self.extreme_values = extreme_values
self.kwargs = kwargs self.kwargs = kwargs
self.transformation = self.setup_transformation(transformation) self.transformation = self.setup_transformation(transformation)
...@@ -178,7 +182,7 @@ class DataGenerator(keras.utils.Sequence): ...@@ -178,7 +182,7 @@ class DataGenerator(keras.utils.Sequence):
raise FileNotFoundError raise FileNotFoundError
data = self._load_pickle_data(station, self.variables) data = self._load_pickle_data(station, self.variables)
except FileNotFoundError: except FileNotFoundError:
logging.info(f"load not pickle data for {station}") logging.debug(f"load not pickle data for {station}")
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs) **self.kwargs)
if self.transformation is not None: if self.transformation is not None:
...@@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence): ...@@ -188,6 +192,9 @@ class DataGenerator(keras.utils.Sequence):
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)