Commit b1fd6faa authored by liadomide's avatar liadomide

TVB-2598 Fix unit-tests to accept the new form with view_model. Wavelet...

TVB-2598 Fix unit-tests to accept the new form with view_model. Wavelet Analyzer is incomplete, mark it with a TODO
parent 926b4624
......@@ -39,21 +39,28 @@ ContinuousWaveletTransform Analyzer.
import uuid
import numpy
from tvb.adapters.datatypes.h5.time_series_h5 import TimeSeriesH5
from tvb.core.neotraits.view_model import ViewModel, DataTypeGidAttr
from tvb.analyzers.wavelet import ContinuousWaveletTransform
from tvb.basic.neotraits.api import Range
from tvb.datatypes.time_series import TimeSeries
from tvb.core.adapters.abcadapter import ABCAsynchronous, ABCAdapterForm
from tvb.core.entities.filters.chain import FilterChain
from tvb.basic.logger.builder import get_logger
from tvb.adapters.datatypes.h5.spectral_h5 import WaveletCoefficientsH5
from tvb.adapters.datatypes.db.spectral import WaveletCoefficientsIndex
from tvb.adapters.datatypes.db.time_series import TimeSeriesIndex
from tvb.core.neotraits.forms import ScalarField, FormField, Form, SimpleFloatField, \
TraitDataTypeSelectField
from tvb.core.neotraits.forms import ScalarField, FormField, Form, SimpleFloatField, TraitDataTypeSelectField
from tvb.core.neotraits.db import from_ndarray
from tvb.core.neocom import h5
LOG = get_logger(__name__)
class WaveletAdapterModel(ViewModel, ContinuousWaveletTransform):
time_series = DataTypeGidAttr(
linked_datatype=TimeSeries,
label="Time Series",
required=True,
doc="""The timeseries to which the wavelet is to be applied."""
)
class RangeForm(Form):
......@@ -67,12 +74,11 @@ class RangeForm(Form):
# default=ContinuousWaveletTransform.frequencies.hi)
# TODO: add all fields
class ContinuousWaveletTransformAdapterForm(ABCAdapterForm):
def __init__(self, prefix='', project_id=None):
super(ContinuousWaveletTransformAdapterForm, self).__init__(prefix, project_id)
self.time_series = TraitDataTypeSelectField(ContinuousWaveletTransform.time_series, self,
self.time_series = TraitDataTypeSelectField(WaveletAdapterModel.time_series, self,
name=self.get_input_name(), conditions=self.get_filters(),
has_all_option=True)
self.mother = ScalarField(ContinuousWaveletTransform.mother, self)
......@@ -83,6 +89,10 @@ class ContinuousWaveletTransformAdapterForm(ABCAdapterForm):
label=ContinuousWaveletTransform.frequencies.label,
doc=ContinuousWaveletTransform.frequencies.doc)
@staticmethod
def get_view_model():
return WaveletAdapterModel
@staticmethod
def get_required_datatype():
return TimeSeriesIndex
......@@ -114,12 +124,11 @@ class ContinuousWaveletTransformAdapter(ABCAsynchronous):
def get_output(self):
return [WaveletCoefficientsIndex]
def configure(self, time_series, mother=None, sample_period=None, normalisation=None, q_ratio=None,
frequencies='Range', frequencies_parameters=None):
def configure(self, view_model):
"""
Store the input shape to be later used to estimate memory usage. Also create the algorithm instance.
"""
self.input_time_series_index = time_series
self.input_time_series_index = self.load_entity_by_gid(view_model.time_series.hex)
input_shape = []
for length in [self.input_time_series_index.data_length_1d,
......@@ -130,28 +139,30 @@ class ContinuousWaveletTransformAdapter(ABCAsynchronous):
input_shape.append(length)
self.input_shape = tuple(input_shape)
LOG.debug("Time series shape is %s" % str(self.input_shape))
self.log.debug("Time series shape is %s" % str(self.input_shape))
# -------------------- Fill Algorithm for Analysis -------------------##
algorithm = ContinuousWaveletTransform()
if mother is not None:
algorithm.mother = mother
if view_model.mother is not None:
algorithm.mother = view_model.mother
if sample_period is not None:
algorithm.sample_period = sample_period
if view_model.sample_period is not None:
algorithm.sample_period = view_model.sample_period
if (frequencies_parameters is not None and 'lo' in frequencies_parameters
and 'hi' in frequencies_parameters and frequencies_parameters['hi'] != frequencies_parameters['lo']):
algorithm.frequencies = Range(**frequencies_parameters)
# TODO range form is not correctly populated, some work is still needed there
# if (view_model.frequencies is not None):
# and 'lo' in frequencies_parameters
# and 'hi' in frequencies_parameters and frequencies_parameters['hi'] != frequencies_parameters['lo']):
# algorithm.frequencies = Range(**frequencies_parameters)
if normalisation is not None:
algorithm.normalisation = normalisation
if view_model.normalisation is not None:
algorithm.normalisation = view_model.normalisation
if q_ratio is not None:
algorithm.q_ratio = q_ratio
if view_model.q_ratio is not None:
algorithm.q_ratio = view_model.q_ratio
self.algorithm = algorithm
def get_required_memory_size(self, **kwargs):
def get_required_memory_size(self, view_model):
"""
Return the required memory to run this algorithm.
"""
......@@ -163,7 +174,7 @@ class ContinuousWaveletTransformAdapter(ABCAsynchronous):
output_size = self.algorithm.result_size(used_shape, self.input_time_series_index.sample_period)
return input_size + output_size
def get_required_disk_size(self, **kwargs):
def get_required_disk_size(self, view_model):
"""
Returns the required disk size to be able to run the adapter.(in kB)
"""
......@@ -173,8 +184,7 @@ class ContinuousWaveletTransformAdapter(ABCAsynchronous):
self.input_shape[3])
return self.array_size2kb(self.algorithm.result_size(used_shape, self.input_time_series_index.sample_period))
def launch(self, time_series, mother=None, sample_period=None, normalisation=None, q_ratio=None,
frequencies='Range', frequencies_parameters=None):
def launch(self, view_model):
"""
Launch algorithm and build results.
"""
......@@ -183,14 +193,15 @@ class ContinuousWaveletTransformAdapter(ABCAsynchronous):
if self.algorithm.frequencies is not None:
frequencies_array = self.algorithm.frequencies.to_array()
time_series_h5 = h5.h5_file_for_index(time_series)
time_series_h5 = h5.h5_file_for_index(self.input_time_series_index)
assert isinstance(time_series_h5, TimeSeriesH5)
wavelet_index = WaveletCoefficientsIndex()
dest_path = h5.path_for(self.storage_path, WaveletCoefficientsH5, wavelet_index.gid)
wavelet_h5 = WaveletCoefficientsH5(path=dest_path)
wavelet_h5.gid.store(uuid.UUID(wavelet_index.gid))
wavelet_h5.source.store(time_series_h5.gid.load())
wavelet_h5.source.store(view_model.time)
wavelet_h5.mother.store(self.algorithm.mother)
wavelet_h5.q_ratio.store(self.algorithm.q_ratio)
wavelet_h5.sample_period.store(self.algorithm.sample_period)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment