Commit a9a2a681 authored by lukas leufen's avatar lukas leufen

removed station_type and network from default arguments because they are only...

removed station_type and network from default arguments because they are only related to the data class. They can be set using the **kwargs parameter.
parent 2612676d
Pipeline #43664 failed with stages
in 6 minutes and 4 seconds
......@@ -303,19 +303,23 @@ class DefaultDataPreparation(AbstractDataPreparation):
def run_data_prep():
from .data_preparation_neighbors import DataPreparationNeighbors
data = DummyDataSingleStation("main_class")
data.get_X()
data.get_Y()
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
data_prep = DataPreparation(DummyDataSingleStation("main_class"), path,
neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
extreme_values=[1., 1.2])
data_prep = DataPreparationNeighbors(DummyDataSingleStation("main_class"),
path,
neighbors=[DummyDataSingleStation("neighbor1"),
DummyDataSingleStation("neighbor2")],
extreme_values=[1., 1.2])
data_prep.get_data(upsampling=False)
def create_data_prep():
from .data_preparation_neighbors import DataPreparationNeighbors
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
station_type = None
network = 'UBA'
......@@ -333,15 +337,15 @@ def create_data_prep():
target_var, interpolate_dim, window_history_size, window_lead_time)
data_prep = []
data_prep.append(DataPreparation(central_station, path, neighbors=[neighbor1, neighbor2]))
data_prep.append(DataPreparation(neighbor1, path, neighbors=[central_station, neighbor2]))
data_prep.append(DataPreparation(neighbor2, path, neighbors=[neighbor1, central_station]))
data_prep.append(DataPreparationNeighbors(central_station, path, neighbors=[neighbor1, neighbor2]))
data_prep.append(DataPreparationNeighbors(neighbor1, path, neighbors=[central_station, neighbor2]))
data_prep.append(DataPreparationNeighbors(neighbor2, path, neighbors=[neighbor1, central_station]))
return data_prep
if __name__ == "__main__":
from src.data_handler.data_preparation import StationPrep
from src.data_handler.iterator import KerasIterator, DataCollection
from mlair.data_handler.station_preparation import StationPrep
from mlair.data_handler.iterator import KerasIterator, DataCollection
data_prep = create_data_prep()
data_collection = DataCollection(data_prep)
for data in data_collection:
......
......@@ -24,6 +24,20 @@ number = Union[float, int]
num_or_list = Union[number, List[number]]
data_or_none = Union[xr.DataArray, None]
# defaults
DEFAULT_STATION_TYPE = "background"
DEFAULT_NETWORK = "AIRBASE"
DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
'pblheight': 'maximum'}
DEFAULT_WINDOW_LEAD_TIME = 3
DEFAULT_WINDOW_HISTORY_SIZE = 13
DEFAULT_TIME_DIM = "datetime"
DEFAULT_TARGET_VAR = "o3"
DEFAULT_TARGET_DIM = "variables"
DEFAULT_SAMPLING = "daily"
DEFAULT_INTERPOLATION_METHOD = "linear"
class AbstractStationPrep(object):
def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs):
......@@ -38,9 +52,11 @@ class AbstractStationPrep(object):
class StationPrep(AbstractStationPrep):
def __init__(self, station, data_path, statistics_per_var, station_type, network, sampling,
target_dim, target_var, time_dim, window_history_size, window_lead_time,
interpolation_limit: int = 0, interpolation_method: str = 'linear',
def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE,
network=DEFAULT_NETWORK, sampling=DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM,
target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
interpolation_limit: int = 0, interpolation_method: str = DEFAULT_INTERPOLATION_METHOD,
overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
min_length: int = 0, start=None, end=None, **kwargs):
super().__init__() # path, station, statistics_per_var, transformation, **kwargs)
......
......@@ -50,8 +50,6 @@ class ExperimentSetup(RunEnvironment):
* `plot_path` [.]
* `forecast_path` [.]
* `stations` [.]
* `network` [.]
* `station_type` [.]
* `statistics_per_var` [.]
* `variables` [.]
* `start` [.]
......@@ -116,10 +114,6 @@ class ExperimentSetup(RunEnvironment):
investigations are stored outside this structure.
:param stations: list of stations or single station to use in experiment. If not provided, stations are set to
:py:const:`default stations <DEFAULT_STATIONS>`.
:param network: name of network to restrict to use only stations from this measurement network. Default is
`AIRBASE` .
:param station_type: restrict network type to one of TOAR's categories (background, traffic, industrial). Default is
`None` to use all categories.
:param variables: list of all variables to use. Valid names can be found in
`Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. If not provided, this
parameter is filled with keys from ``statistics_per_var``.
......@@ -209,8 +203,6 @@ class ExperimentSetup(RunEnvironment):
def __init__(self,
experiment_date=None,
stations: Union[str, List[str]] = None,
network: str = None,
station_type: str = None,
variables: Union[str, List[str]] = None,
statistics_per_var: Dict = None,
start: str = None,
......@@ -229,7 +221,7 @@ class ExperimentSetup(RunEnvironment):
train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None):
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None, **kwargs):
# create run framework
super().__init__()
......@@ -288,8 +280,8 @@ class ExperimentSetup(RunEnvironment):
# setup for data
self._set_param("stations", stations, default=DEFAULT_STATIONS)
self._set_param("network", network, default=DEFAULT_NETWORK)
self._set_param("station_type", station_type, default=DEFAULT_STATION_TYPE)
# self._set_param("network", network, default=DEFAULT_NETWORK)
# self._set_param("station_type", station_type, default=DEFAULT_STATION_TYPE)
self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
self._set_param("start", start, default=DEFAULT_START)
......@@ -311,7 +303,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("dimensions", dimensions, default=DEFAULT_DIMENSIONS)
self._set_param("time_dim", time_dim, default=DEFAULT_TIME_DIM)
self._set_param("interpolation_method", interpolation_method, default=DEFAULT_INTERPOLATION_METHOD)
self._set_param("limit_nan_fill", limit_nan_fill, default=DEFAULT_LIMIT_NAN_FILL)
self._set_param("interpolation_limit", limit_nan_fill, default=DEFAULT_LIMIT_NAN_FILL)
# train set parameters
self._set_param("start", train_start, default=DEFAULT_TRAIN_START, scope="train")
......@@ -356,6 +348,15 @@ class ExperimentSetup(RunEnvironment):
# set model architecture class
self._set_param("model_class", model, VanillaModel)
# set remaining kwargs
if len(kwargs) > 0:
for k, v in kwargs.items():
if len(self.data_store.search_name(k)) == 0:
self._set_param("k", v)
else:
raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a "
f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}")
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
"""Set given parameter and log in debug."""
if value is None and default is not None:
......@@ -395,6 +396,7 @@ class ExperimentSetup(RunEnvironment):
if not set(target_var).issubset(stat.keys()):
raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.DEBUG)
......
......@@ -6,11 +6,9 @@ import inspect
def run(stations=None,
station_type=None,
trainable=None, create_new_model=None,
window_history_size=None,
experiment_date="testrun",
network=None,
variables=None, statistics_per_var=None,
start=None, end=None,
target_var=None, target_dim=None,
......@@ -29,16 +27,17 @@ def run(stations=None,
model=None,
batch_size=None,
epochs=None,
data_preparation=None,):
data_preparation=None,
**kwargs):
params = inspect.getfullargspec(DefaultWorkflow).args
kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
workflow = DefaultWorkflow(**kwargs)
workflow = DefaultWorkflow(**kwargs_default, **kwargs)
workflow.run()
if __name__ == "__main__":
from mlair.model_modules.model_class import MyBranchedModel
run(statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True,
create_new_model=True, model=MyBranchedModel)
create_new_model=True, model=MyBranchedModel, station_type="background")
......@@ -14,11 +14,9 @@ class DefaultWorkflow(Workflow):
the mentioned ordering."""
def __init__(self, stations=None,
station_type=None,
trainable=None, create_new_model=None,
window_history_size=None,
experiment_date="testrun",
network=None,
variables=None, statistics_per_var=None,
start=None, end=None,
target_var=None, target_dim=None,
......@@ -37,13 +35,14 @@ class DefaultWorkflow(Workflow):
model=None,
batch_size=None,
epochs=None,
data_preparation=None):
data_preparation=None,
**kwargs):
super().__init__()
# extract all given kwargs arguments
params = remove_items(inspect.getfullargspec(self.__init__).args, "self")
kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
self._setup(**kwargs)
kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
self._setup(**kwargs_default, **kwargs)
def _setup(self, **kwargs):
"""Set up default workflow."""
......@@ -59,11 +58,9 @@ class DefaultWorkflowHPC(Workflow):
Training and PostProcessing in exact the mentioned ordering."""
def __init__(self, stations=None,
station_type=None,
trainable=None, create_new_model=None,
window_history_size=None,
experiment_date="testrun",
network=None,
variables=None, statistics_per_var=None,
start=None, end=None,
target_var=None, target_dim=None,
......@@ -82,13 +79,13 @@ class DefaultWorkflowHPC(Workflow):
model=None,
batch_size=None,
epochs=None,
data_preparation=None):
data_preparation=None, **kwargs):
super().__init__()
# extract all given kwargs arguments
params = remove_items(inspect.getfullargspec(self.__init__).args, "self")
kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
self._setup(**kwargs)
kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
self._setup(**kwargs_default, **kwargs)
def _setup(self, **kwargs):
"""Set up default workflow."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment