Commit c39978b0 authored by lukas leufen's avatar lukas leufen

Merge branch 'lukas_issue168_refac_simplify-mylittlemodel' into 'lukas_issue153_feat_advanced-docu'

refac simplify mylittlemodel, /close #168

See merge request !141
parents 5f47e92b 91feab0c
Pipeline #45968 passed with stages
in 6 minutes and 50 seconds
......@@ -351,9 +351,8 @@ class AbstractModelClass(ABC):
class MyLittleModel(AbstractModelClass):
"""
A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the
output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first
Dense layer.
A customised model 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending
on the window_lead_time parameter.
"""
def __init__(self, shape_inputs: list, shape_outputs: list):
......@@ -382,13 +381,8 @@ class MyLittleModel(AbstractModelClass):
"""
Build the model.
"""
# add 1 to window_size to include current time step t0
x_input = keras.layers.Input(shape=self.shape_inputs)
x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
x_in = keras.layers.Flatten(name='{}'.format("major"))(x_input)
x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
x_in = self.activation()(x_in)
x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
......
......@@ -786,8 +786,8 @@ class PlotTimeSeries:
def _plot(self, plot_folder):
pdf_pages = self._create_pdf_pages(plot_folder)
for pos, station in enumerate(self._stations):
start, end = self._get_time_range(self._load_data(self._stations[0]))
data = self._load_data(station)
start, end = self._get_time_range(data)
fig, axes, factor = self._create_subplots(start, end)
nan_list = []
for i_year in range(end - start + 1):
......
......@@ -81,16 +81,12 @@ class PostProcessing(RunEnvironment):
def _run(self):
# ols model
with TimeTracking():
self.train_ols_model()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip train_ols_model() whenever it is possible to save time.")
self.train_ols_model()
# forecasts
with TimeTracking():
self.make_prediction()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
self.make_prediction()
# skill scores on test data
self.calculate_test_score()
# bootstraps
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment