Commit 9f4d7a34 authored by liadomide's avatar liadomide

TVB-2598 Introduce ABCAdapter.load_traited_by_gid method.

Review signatures for base methods in adapters.
Rename viewer TimeSeries into TimeSeriesDisplay
parent e9e6dbc0
......@@ -35,11 +35,11 @@ Few supplementary steps are done here:
* from submitted Monitor/Model... names, build transient entities
* after UI parameters submit, compose transient Cortex entity to be passed to the Simulator.
.. moduleauthor:: Paula Popa <paula.popa@codemart.ro>
.. moduleauthor:: Lia Domide <lia.domide@codemart.ro>
.. moduleauthor:: Stuart A. Knock <Stuart@tvb.invalid>
"""
import numpy
from tvb.core.neotraits.view_model import ViewModel, DataTypeGidAttr
from tvb.datatypes.connectivity import Connectivity
from tvb.datatypes.cortex import Cortex
......@@ -187,8 +187,7 @@ class SimulatorAdapter(ABCAsynchronous):
simulator = Simulator()
simulator.gid = view_model.gid
conn_index = self.load_entity_by_gid(view_model.connectivity.hex)
conn = h5.load_from_index(conn_index)
conn = self.load_traited_by_gid(view_model.connectivity)
simulator.connectivity = conn
simulator.conduction_speed = view_model.conduction_speed
......@@ -206,9 +205,8 @@ class SimulatorAdapter(ABCAsynchronous):
simulator.surface.region_mapping_data = rm
if simulator.surface.local_connectivity:
lc_index = self.load_entity_by_gid(view_model.surface.local_connectivity.hex)
lc = h5.load_from_index(lc_index)
assert lc_index.surface_gid == rm_index.surface_gid
lc = self.load_traited_by_gid(view_model.surface.local_connectivity)
assert lc.surface.gid == rm_index.surface_gid
lc.surface = rm_surface
simulator.surface.local_connectivity = lc
......
......@@ -87,8 +87,7 @@ class BRCOImporter(ABCUploader):
@transactional
def launch(self, view_model):
try:
connectivity_index = self.load_entity_by_gid(view_model.connectivity.hex)
conn = h5.load_from_index(connectivity_index)
conn = self.load_traited_by_gid(view_model.connectivity)
parser = XMLParser(view_model.data_file, conn.region_labels)
annotations = parser.read_annotation_terms()
......
......@@ -98,8 +98,8 @@ class ImaginaryCoherenceDisplay(ABCDisplayer):
required_memory = numpy.prod(input_data_h5.read_data_shape()) * 8
return required_memory
def generate_preview(self, view_model):
# type: (ImaginaryCoherenceDisplayModel) -> dict
def generate_preview(self, view_model, figure_size=None):
# type: (ImaginaryCoherenceDisplayModel, (int,int)) -> dict
return self.launch(view_model)
def launch(self, view_model):
......
......@@ -46,8 +46,6 @@ from tvb.core.adapters.abcdisplayer import ABCDisplayer
from tvb.core.adapters.exceptions import LaunchException
from tvb.core.entities.filters.chain import FilterChain
from tvb.adapters.datatypes.db.connectivity import ConnectivityIndex
from tvb.adapters.datatypes.db.graph import ConnectivityMeasureIndex
from tvb.adapters.datatypes.db.surface import SurfaceIndex
from tvb.core.neotraits.forms import TraitDataTypeSelectField, FloatField
from tvb.core.neocom import h5
from tvb.core.neotraits.view_model import ViewModel, DataTypeGidAttr
......@@ -180,18 +178,15 @@ class ConnectivityViewer(ABCSpaceDisplayer):
return -1
def _load_input_data(self, view_model):
connectivity_index = self.load_entity_by_gid(view_model.connectivity.hex)
connectivity = h5.load_from_index(connectivity_index)
connectivity = self.load_traited_by_gid(view_model.connectivity)
assert isinstance(connectivity, Connectivity)
if view_model.colors:
colors_index = self.load_entity_by_gid(view_model.colors.hex)
colors_dt = h5.load_from_index(colors_index)
colors_dt = self.load_traited_by_gid(view_model.colors)
else:
colors_dt = None
if view_model.rays:
rays_index = self.load_entity_by_gid(view_model.rays.hex)
rays_dt = h5.load_from_index(rays_index)
rays_dt = self.load_traited_by_gid(view_model.rays)
else:
rays_dt = None
......@@ -207,8 +202,7 @@ class ConnectivityViewer(ABCSpaceDisplayer):
global_params, global_pages = self._compute_connectivity_global_params(connectivity)
if view_model.surface_data is not None:
surface_index = self.load_entity_by_gid(view_model.surface_data.hex)
surface_h5 = h5.h5_file_for_index(surface_index)
surface_h5 = self.load_traited_by_gid(view_model.surface_data)
url_vertices, url_normals, _, url_triangles, _ = SurfaceURLGenerator.get_urls_for_rendering(surface_h5)
else:
url_vertices, url_normals, url_triangles = [], [], []
......@@ -229,7 +223,7 @@ class ConnectivityViewer(ABCSpaceDisplayer):
return self.build_display_result("connectivity/main_connectivity", result_params, result_pages)
def generate_preview(self, view_model, figure_size=None):
# type: (ConnectivityViewerModel, int) -> dict
# type: (ConnectivityViewerModel, (int,int)) -> dict
"""
Generate the preview for the BURST cockpit.
......
......@@ -39,7 +39,6 @@ import json
from tvb.core.adapters.abcadapter import ABCAdapterForm
from tvb.core.adapters.abcdisplayer import ABCDisplayer
from tvb.adapters.datatypes.db.connectivity import ConnectivityIndex
from tvb.core.neocom import h5
from tvb.core.neotraits.forms import TraitDataTypeSelectField
from tvb.core.neotraits.view_model import ViewModel, DataTypeGidAttr
from tvb.datatypes.connectivity import Connectivity
......@@ -92,11 +91,11 @@ class ConnectivityEdgeBundle(ABCDisplayer):
def launch(self, view_model):
"""Construct data for visualization and launch it."""
connectivity = self.load_entity_by_gid(view_model.connectivity.hex)
connectivity_dt = h5.load_from_index(connectivity)
connectivity = self.load_traited_by_gid(view_model.connectivity)
pars = {"labels": json.dumps(connectivity_dt.region_labels.tolist()),
"url_base": ABCDisplayer.paths2url(connectivity.gid, attribute_name="weights", flatten="True")
pars = {"labels": json.dumps(connectivity.region_labels.tolist()),
"url_base": ABCDisplayer.paths2url(view_model.connectivity.hex,
attribute_name="weights", flatten="True")
}
return self.build_display_result("connectivity_edge_bundle/view", pars)
......@@ -99,10 +99,10 @@ class FourierSpectrumDisplay(ABCDisplayer):
Return the required memory to run this algorithm.
"""
fs_input_index = self.load_entity_by_gid(view_model.input_data.hex)
return numpy.prod(fs_input_index.read_data_shape()) * 8
return numpy.prod(fs_input_index.get_data_shape()) * 8
def generate_preview(self, view_model):
# type: (FourierSpectrumModel) -> dict
def generate_preview(self, view_model, figure_size=None):
# type: (FourierSpectrumModel, (int,int)) -> dict
return self.launch(view_model)
def launch(self, view_model):
......
......@@ -106,7 +106,7 @@ class HistogramViewer(ABCDisplayer):
input_data = self.load_entity_by_gid(view_model.input_data.hex)
return numpy.prod(input_data.shape) * 2
def generate_preview(self, view_model, figure_size):
def generate_preview(self, view_model, figure_size=None):
"""
The preview for the burst page.
"""
......
......@@ -102,19 +102,21 @@ class PearsonCorrelationCoefficientVisualizer(MappedArrayVisualizer):
with datatype_h5_class(datatype_h5_path) as datatype_h5:
matrix_shape = datatype_h5.array_data.shape[0:2]
ts_gid = datatype_h5.source.load()
ts_index = self.load_entity_by_gid(ts_gid.hex)
state_list = ts_index.get_labels_for_dimension(1)
mode_list = list(range(ts_index.data_length_4d))
ts_h5_class, ts_h5_path = self._load_h5_of_gid(ts_index.gid)
ts_h5_class, ts_h5_path = self._load_h5_of_gid(ts_gid.hex)
with ts_h5_class(ts_h5_path) as ts_h5:
labels = ts_h5.get_space_labels()
state_list = ts_h5.labels_dimensions.load().get(ts_h5.labels_ordering.load()[1], [])
mode_list = list(range(ts_index.data_length_4d))
if not labels:
labels = None
pars = dict(matrix_labels=json.dumps([labels, labels]),
matrix_shape=json.dumps(matrix_shape),
viewer_title='Cross Corelation Matrix plot',
url_base=URLGenerator.build_h5_url(view_model.datatype, 'get_correlation_data', parameter=''),
url_base=URLGenerator.build_h5_url(view_model.datatype.hex, 'get_correlation_data', parameter=''),
state_variable=state_list[0],
mode=mode_list[0],
state_list=state_list,
......
......@@ -65,12 +65,12 @@ class PearsonEdgeBundle(ABCDisplayer):
matrix_shape = datatype_h5.array_data.shape[0:2]
ts_gid = datatype_h5.source.load()
ts_index = self.load_entity_by_gid(ts_gid.hex)
state_list = ts_index.get_labels_for_dimension(1)
mode_list = list(range(ts_index.data_length_4d))
ts_h5_class, ts_h5_path = self._load_h5_of_gid(ts_index.gid)
with ts_h5_class(ts_h5_path) as ts_h5:
labels = ts_h5.get_space_labels()
state_list = ts_h5.labels_dimensions.load().get(ts_h5.labels_ordering.load()[1], [])
mode_list = list(range(ts_index.data_length_4d))
if not labels:
labels = None
pars = dict(matrix_labels=json.dumps(labels),
......
......@@ -448,7 +448,7 @@ class SurfaceViewer(ABCSurfaceDisplayer):
return self.build_display_result("surface/surface_view", params,
pages={"controlPage": "surface/surface_viewer_controls"})
def get_required_memory_size(self):
def get_required_memory_size(self, view_model):
return -1
......
......@@ -127,8 +127,8 @@ class ABCSpaceDisplayer(ABCDisplayer):
"""
if isinstance(ts_h5, TimeSeriesRegionH5):
connectivity_gid = ts_h5.connectivity.load()
conn_idx = self.load_entity_by_gid(connectivity_gid.hex)
conn = h5.load_from_index(conn_idx)
conn = self.load_traited_by_gid(connectivity_gid)
assert isinstance(conn, Connectivity)
return self._connectivity_grouped_space_labels(conn)
ts_h5.get_grouped_space_labels()
......@@ -156,7 +156,7 @@ class ABCSpaceDisplayer(ABCDisplayer):
return ts_h5.get_space_labels()
class TimeSeries(ABCSpaceDisplayer):
class TimeSeriesDisplay(ABCSpaceDisplayer):
_ui_name = "Time Series Visualizer (SVG/d3)"
_ui_subsection = "timeseries"
......@@ -176,7 +176,7 @@ class TimeSeries(ABCSpaceDisplayer):
assert isinstance(h5_file, TimeSeriesH5)
shape = list(h5_file.read_data_shape())
ts = h5_file.storage_manager.get_data('time')
state_variables = h5_file.labels_dimensions.load().get(time_series_index.labels_ordering[1], [])
state_variables = time_series_index.get_labels_for_dimension(1)
labels = self.get_space_labels(h5_file)
# Assume that the first dimension is the time since that is the case so far
......@@ -206,6 +206,6 @@ class TimeSeries(ABCSpaceDisplayer):
"""Construct data for visualization and launch it."""
return self._launch(view_model, None)
def generate_preview(self, view_model, figure_size):
# type: (TimeSeriesModel, list) -> dict
def generate_preview(self, view_model, figure_size=None):
# type: (TimeSeriesModel, (int, int)) -> dict
return self._launch(view_model, figsize=figure_size, preview=True)
......@@ -37,6 +37,7 @@ Root classes for adding custom functionality to the code.
import os
import json
import uuid
import psutil
import numpy
import importlib
......@@ -46,6 +47,7 @@ from abc import ABCMeta, abstractmethod
from six import add_metaclass
from tvb.basic.profile import TvbProfile
from tvb.basic.logger.builder import get_logger
from tvb.basic.neotraits.api import HasTraits
from tvb.core.adapters import constants
from tvb.core.entities.generic_attributes import GenericAttributes
from tvb.core.entities.load import load_entity_by_gid
......@@ -190,16 +192,9 @@ class ABCAdapter(object):
KEY_DISABLED = "disabled"
KEY_FILTERABLE = "filterable"
# TODO: move everything related to parameters PRE + POST into parameters_factory
KEYWORD_PARAMS = "_parameters_"
INTERFACE_ATTRIBUTES_ONLY = "attributes-only"
INTERFACE_ATTRIBUTES = "attributes"
# model.Algorithm instance that will be set for each adapter created by in build_adapter method
stored_adapter = None
def __init__(self):
# It will be populate with key from DataTypeMetaData
self.meta_data = {DataTypeMetaData.KEY_SUBJECT: DataTypeMetaData.DEFAULT_SUBJECT}
......@@ -254,12 +249,6 @@ class ABCAdapter(object):
"""
return True
def get_input_tree(self):
"""
Describes inputs and outputs of the launch method.
"""
return None
def submit_form(self, form):
self.submitted_form = form
......@@ -282,14 +271,12 @@ class ABCAdapter(object):
Describes inputs and outputs of the launch method.
"""
def configure(self, view_model):
"""
To be implemented in each Adapter that requires any specific configurations
before the actual launch.
"""
@abstractmethod
def get_required_memory_size(self, view_model):
"""
......@@ -297,7 +284,6 @@ class ABCAdapter(object):
for launching the adapter.
"""
@abstractmethod
def get_required_disk_size(self, view_model):
"""
......@@ -305,7 +291,6 @@ class ABCAdapter(object):
for launching the adapter in kilo-Bytes.
"""
def get_execution_time_approximation(self, view_model):
"""
Method should approximate based on input arguments, the time it will take for the operation
......@@ -313,7 +298,6 @@ class ABCAdapter(object):
"""
return -1
@abstractmethod
def launch(self, view_model):
"""
......@@ -324,7 +308,6 @@ class ABCAdapter(object):
:param view_model: the data model corresponding to the current adapter
"""
def add_operation_additional_info(self, message):
"""
Adds additional info on the operation to be displayed in the UI. Usually a warning message.
......@@ -395,7 +378,6 @@ class ABCAdapter(object):
self.__check_integrity(result)
return self._capture_operation_results(result)
def _capture_operation_results(self, result):
"""
After an operation was finished, make sure the results are stored
......@@ -413,7 +395,7 @@ class ABCAdapter(object):
burst_reference = self.meta_data[DataTypeMetaData.KEY_BURST]
count_stored = 0
group_type = None # In case of a group, the first not-none type is sufficient to memorize here
group_type = None # In case of a group, the first not-none type is sufficient to memorize here
for res in result:
if res is None:
continue
......@@ -443,7 +425,6 @@ class ABCAdapter(object):
return 'Operation ' + str(self.operation_id) + ' has finished.', count_stored
def __check_integrity(self, result):
"""
Check that the returned parameters for LAUNCH operation
......@@ -456,7 +437,6 @@ class ABCAdapter(object):
msg = "Unexpected output DataType %s"
raise InvalidParameterException(msg % type(result_entity))
def __is_data_in_supported_types(self, data):
if data is None:
......@@ -467,7 +447,6 @@ class ABCAdapter(object):
# Data can't be mapped on any supported type !!
return False
def _is_group_launch(self):
"""
Return true if this adapter is launched from a group of operations
......@@ -475,7 +454,6 @@ class ABCAdapter(object):
operation = dao.get_operation_by_id(self.operation_id)
return operation.fk_operation_group is not None
@staticmethod
def load_entity_by_gid(data_gid):
"""
......@@ -483,6 +461,14 @@ class ABCAdapter(object):
"""
return load_entity_by_gid(data_gid)
@staticmethod
def load_traited_by_gid(data_gid):
# type: (uuid.UUID) -> HasTraits
"""
Load a generic HasTraits instance, specified by GID.
"""
index = load_entity_by_gid(data_gid.hex)
return h5.load_from_index(index)
@staticmethod
def build_adapter_from_class(adapter_class):
......@@ -501,7 +487,6 @@ class ABCAdapter(object):
LOGGER.exception(excep)
raise IntrospectionException(str(excep))
@staticmethod
def build_adapter(stored_adapter):
"""
......@@ -519,9 +504,6 @@ class ABCAdapter(object):
LOGGER.exception(msg)
raise IntrospectionException(msg)
# METHODS for PROCESSING PARAMETERS start here #############################
def review_operation_inputs(self, parameters):
# TODO: implement this for neoforms
"""
......@@ -550,5 +532,3 @@ class ABCSynchronous(ABCAdapter):
"""
Abstract class, for marking adapters that are prone to be NOT executed on Cluster.
"""
......@@ -39,6 +39,7 @@ from six import add_metaclass
from tvb.core.adapters.abcadapter import ABCSynchronous
from tvb.core.adapters.exceptions import LaunchException
from tvb.core.neocom import h5
from tvb.core.neotraits.view_model import ViewModel
LOCK_CREATE_FIGURE = Lock()
......@@ -64,7 +65,6 @@ class URLGenerator(object):
return url
@staticmethod
def build_h5_url(entity_gid, method_name, flatten=False, datatype_kwargs=None, parameter=None):
json_kwargs = json.dumps(datatype_kwargs)
......@@ -77,7 +77,6 @@ class URLGenerator(object):
return url
@staticmethod
def paths2url(datatype_gid, attribute_name, flatten=False, parameter=None):
"""
......@@ -102,18 +101,16 @@ class ABCDisplayer(ABCSynchronous, metaclass=ABCMeta):
VISUALIZERS_ROOT = ''
VISUALIZERS_URL_PREFIX = ''
def get_output(self):
return []
def generate_preview(self, view_model):
def generate_preview(self, view_model, figure_size=None):
# type: (ViewModel, (int,int)) -> dict
"""
Should be implemented by all visualizers that can be used by portlets.
"""
raise LaunchException("%s used as Portlet but doesn't implement 'generate_preview'" % self.__class__)
def _prelaunch(self, operation, uid=None, available_disk_space=0, view_model=None, **kwargs):
"""
Shortcut in case of visualization calls.
......@@ -121,9 +118,7 @@ class ABCDisplayer(ABCSynchronous, metaclass=ABCMeta):
self.current_project_id = operation.project.id
self.user_id = operation.fk_launched_by
self.storage_path = self.file_handler.get_project_folder(operation.project, str(operation.id))
return self.launch(view_model=view_model, **kwargs), 0
return self.launch(view_model=view_model), 0
def get_required_disk_size(self, view_model):
"""
......@@ -131,7 +126,6 @@ class ABCDisplayer(ABCSynchronous, metaclass=ABCMeta):
"""
return 0
def build_display_result(self, template, parameters, pages=None):
"""
Helper method for building the result of the ABCDisplayer.
......@@ -154,7 +148,6 @@ class ABCDisplayer(ABCSynchronous, metaclass=ABCMeta):
return parameters
@staticmethod
def get_one_dimensional_list(list_of_elements, expected_size, error_msg):
"""
......@@ -184,12 +177,11 @@ class ABCDisplayer(ABCSynchronous, metaclass=ABCMeta):
return url
def build_h5_url(self, entity_gid, method_name, parameter=None):
@staticmethod
def build_h5_url(entity_gid, method_name, parameter=None):
url = '/flow/read_from_h5_file/' + entity_gid + '/' + method_name
if parameter is not None:
url += "?" + str(parameter)
return url
@staticmethod
......
......@@ -49,8 +49,7 @@ class SimulatorSerializer(object):
sensors_gid = monitor_h5.sensors.load()
region_mapping_gid = monitor_h5.region_mapping.load()
sensors_index = ABCAdapter.load_entity_by_gid(sensors_gid.hex)
sensors = h5.load_from_index(sensors_index)
sensors = ABCAdapter.load_traited_by_gid(sensors_gid)
if isinstance(simulator_in.monitors[0], EEG):
sensors = SensorsEEG.build_sensors_subclass(sensors)
......@@ -60,8 +59,7 @@ class SimulatorSerializer(object):
sensors = SensorsInternal.build_sensors_subclass(sensors)
simulator_in.monitors[0].sensors = sensors
region_mapping_index = ABCAdapter.load_entity_by_gid(region_mapping_gid.hex)
region_mapping = h5.load_from_index(region_mapping_index)
region_mapping = ABCAdapter.load_traited_by_gid(region_mapping_gid)
simulator_in.monitors[0].region_mapping = region_mapping
if simulator_in.surface:
......
......@@ -35,7 +35,7 @@
from tvb.core.neocom import h5
from tvb.tests.framework.core.base_testcase import TransactionalTestCase
from tvb.core.entities.file.files_helper import FilesHelper
from tvb.adapters.visualizers.time_series import TimeSeries
from tvb.adapters.visualizers.time_series import TimeSeriesDisplay
from tvb.tests.framework.core.factory import TestFactory
......@@ -65,7 +65,7 @@ class TestTimeSeries(TransactionalTestCase):
"""
time_series_index = time_series_index_factory()
time_series = h5.load_from_index(time_series_index)
viewer = TimeSeries()
viewer = TimeSeriesDisplay()
view_model = viewer.get_view_model_class()()
view_model.time_series = time_series.gid
result = viewer.launch(view_model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment