Commit 3dc3b8c0 authored by lukas leufen's avatar lukas leufen

enabled tests for preprocessing and training

parent 37219af8
Pipeline #41906 passed with stages
in 5 minutes and 40 seconds
......@@ -60,7 +60,7 @@ Thumbs.db
htmlcov/
.pytest_cache
/test/data/
/test/test_modules/data/
/test/test_run_modules/data/
report.html
/TestExperiment/
/testrun_network*/
......
......@@ -38,9 +38,9 @@ pydot==1.4.1
pyparsing==2.4.6
pyproj==2.5.0
pyshp==2.1.0
pytest==5.3.5
pytest-cov==2.8.1
pytest-html==2.0.1
pytest==6.0.0
pytest-cov==2.10.0
pytest-html==2.1.1
pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0
pytest-sugar
......
......@@ -11,5 +11,5 @@ __date__ = '2020-04-17'
from .bootstraps import BootStraps
from .iterator import KerasIterator, DataCollection
from .advanced_data_handling import DefaultDataPreparation, AbstractDataPreparation
from .advanced_data_handler import DefaultDataPreparation, AbstractDataPreparation
from .data_preparation_neighbors import DataPreparationNeighbors
......@@ -12,6 +12,7 @@ import pandas as pd
import datetime as dt
import shutil
import inspect
import copy
from typing import Union, List, Tuple, Dict
import logging
......@@ -109,9 +110,9 @@ class DefaultDataPreparation(AbstractDataPreparation):
@classmethod
def build(cls, station, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
sp = StationPrep(station, **sp_keys)
dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs}
dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
return cls(sp, **dp_args)
def _create_collection(self):
......@@ -274,7 +275,7 @@ class DefaultDataPreparation(AbstractDataPreparation):
@classmethod
def transformation(cls, set_stations, **kwargs):
sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs}
sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
transformation_dict = sp_keys.pop("transformation")
if transformation_dict is None:
return
......
......@@ -12,7 +12,6 @@ __author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2020-02-07'
import logging
import os
from collections import Iterator, Iterable
from itertools import chain
......@@ -20,7 +19,7 @@ from itertools import chain
import numpy as np
import xarray as xr
from src.data_handler.advanced_data_handling import AbstractDataPreparation
from src.data_handler.advanced_data_handler import AbstractDataPreparation
class BootstrapIterator(Iterator):
......
......@@ -5,7 +5,7 @@ __date__ = '2020-07-17'
from src.helpers import to_list
from src.data_handler.station_preparation import StationPrep
from src.data_handler.advanced_data_handling import DefaultDataPreparation
from src.data_handler.advanced_data_handler import DefaultDataPreparation
import os
from typing import Union, List
......
......@@ -39,4 +39,6 @@ def run(stations=None,
if __name__ == "__main__":
run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, create_new_model=True)
from src.model_modules.model_class import MyBranchedModel
run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True,
create_new_model=True, model=MyBranchedModel)
......@@ -18,7 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D
DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
from src.data_handler.advanced_data_handling import DefaultDataPreparation
from src.data_handler.advanced_data_handler import DefaultDataPreparation
from src.run_modules.run_environment import RunEnvironment
from src.model_modules.model_class import MyLittleModel as VanillaModel
......
......@@ -157,7 +157,7 @@ class PreProcessing(RunEnvironment):
raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
f"order was: {subset_names}.")
for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
self.create_set_split_new(ind, scope)
self.create_set_split(ind, scope)
@staticmethod
def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]:
......@@ -181,7 +181,7 @@ class PreProcessing(RunEnvironment):
train_val_index = slice(0, pos_test_split)
return train_index, val_index, test_index, train_val_index
def create_set_split_new(self, index_list: slice, set_name: str) -> None:
def create_set_split(self, index_list: slice, set_name: str) -> None:
# get set stations
stations = self.data_store.get("stations", scope=set_name)
if self.data_store.get("use_all_stations_on_all_data_sets"):
......@@ -212,7 +212,7 @@ class PreProcessing(RunEnvironment):
:return: Corrected list containing only valid station IDs.
"""
t_outer = TimeTracking()
logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}")
logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}")
# calculate transformation using train data
if set_name == "train":
self.transformation(data_preparation, set_stations)
......
......@@ -2,8 +2,7 @@ import logging
import pytest
from src.data_handler import DataPrepJoin
from src.data_handler.data_generator import DataGenerator
from src.data_handler import DefaultDataPreparation, DataCollection, AbstractDataPreparation
from src.helpers.datastore import NameNotFoundInScope
from src.helpers import PyTestRegex
from src.run_modules.experiment_setup import ExperimentSetup
......@@ -29,7 +28,7 @@ class TestPreProcessing:
def obj_with_exp_setup(self):
ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background",
data_preparation=DataPrepJoin)
data_preparation=DefaultDataPreparation)
pre = object.__new__(PreProcessing)
super(PreProcessing, pre).__init__()
yield pre
......@@ -48,19 +47,20 @@ class TestPreProcessing:
RunEnvironment().__del__()
def test_run(self, obj_with_exp_setup):
assert obj_with_exp_setup.data_store.search_name("generator") == []
assert obj_with_exp_setup.data_store.search_name("data_collection") == []
assert obj_with_exp_setup._run() is None
assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val",
"general.train_val", "general.test"])
assert obj_with_exp_setup.data_store.search_name("data_collection") == sorted(["general.train", "general.val",
"general.train_val",
"general.test"])
def test_split_train_val_test(self, obj_with_exp_setup):
assert obj_with_exp_setup.data_store.search_name("generator") == []
assert obj_with_exp_setup.data_store.search_name("data_collection") == []
obj_with_exp_setup.split_train_val_test()
data_store = obj_with_exp_setup.data_store
expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length", "extreme_values",
"extremes_on_right_tail_only", "upsampling"]
expected_params = ["data_collection", "start", "end", "stations", "permute_data", "min_length",
"extreme_values", "extremes_on_right_tail_only", "upsampling"]
assert data_store.search_scope("general.train") == sorted(expected_params)
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
assert data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.test",
"general.train_val"])
def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
......@@ -69,9 +69,9 @@ class TestPreProcessing:
obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples
data_store = obj_with_exp_setup.data_store
assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
with pytest.raises(NameNotFoundInScope):
data_store.get("generator", "general")
data_store.get("data_collection", "general")
assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"]
def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup):
......@@ -80,22 +80,22 @@ class TestPreProcessing:
message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']"
assert ('root', 10, message) in caplog.record_tuples
data_store = obj_with_exp_setup.data_store
assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
with pytest.raises(NameNotFoundInScope):
data_store.get("generator", "general")
data_store.get("data_collection", "general")
assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
@pytest.mark.parametrize("name", (None, "tester"))
def test_check_valid_stations(self, caplog, obj_with_exp_setup, name):
def test_validate_station(self, caplog, obj_with_exp_setup, name):
pre = obj_with_exp_setup
caplog.set_level(logging.INFO)
args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST)
kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST)
stations = pre.data_store.get("stations", "general")
valid_stations = pre.check_valid_stations(args, kwargs, stations, name=name)
data_preparation = pre.data_store.get("data_preparation")
collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name)
assert isinstance(collection, DataCollection)
assert len(valid_stations) < len(stations)
assert valid_stations == stations[:-1]
expected = 'check valid stations started (tester)' if name else 'check valid stations started'
expected = "check valid stations started" + ' (%s)' % (name if name else 'all')
assert caplog.record_tuples[0] == ('root', 20, expected)
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 '
r'station\(s\). Found 5/6 valid stations.'))
......@@ -107,3 +107,11 @@ class TestPreProcessing:
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
assert dummy_list[train_val] == list(range(0, 13))
def test_transformation(self):
pre = object.__new__(PreProcessing)
data_preparation = AbstractDataPreparation
stations = ['DEBW107', 'DEBY081']
assert pre.transformation(data_preparation, stations) is None
class data_preparation_no_trans: pass
assert pre.transformation(data_preparation_no_trans, stations) is None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment