Commit 3d5be124 authored by liadomide's avatar liadomide

TVB-2598 Fix BrainViewerTest after recent changes of mine

parent 77bae857
......@@ -116,10 +116,11 @@ class BrainViewer(ABCSurfaceDisplayer):
used_shape = (overall_shape[0] / (self.PAGE_SIZE * 2.0), overall_shape[1], overall_shape[2], overall_shape[3])
return numpy.prod(used_shape) * 8.0
def generate_preview(self, time_series, shell_surface=None, figure_size=None):
def generate_preview(self, view_model, figure_size=None):
"""
Generate the preview for the burst page
"""
time_series = self.load_entity_by_gid(view_model.time_series.hex)
self.populate_surface_fields(time_series)
url_vertices, url_normals, url_lines, url_triangles, url_region_map = \
......
......@@ -35,14 +35,13 @@
import os
from uuid import UUID
import tvb_data.surfaceData
import tvb_data.regionMapping as demo_data
import tvb_data.regionMapping
from tvb.core.neocom import h5
from tvb.tests.framework.core.base_testcase import TransactionalTestCase
from tvb.core.entities.file.files_helper import FilesHelper
from tvb.datatypes.surfaces import CORTICAL
from tvb.adapters.visualizers.brain import BrainViewer, DualBrainViewer, ConnectivityIndex
from tvb.tests.framework.core.factory import TestFactory
from tvb.tests.framework.conftest import time_series_region_index_factory, operation_factory
class TestBrainViewer(TransactionalTestCase):
......@@ -57,7 +56,7 @@ class TestBrainViewer(TransactionalTestCase):
'extended_view', 'legendLabels', 'labelsStateVar', 'labelsModes', 'title']
face = os.path.join(os.path.dirname(tvb_data.surfaceData.__file__), 'cortex_16384.zip')
region_mapping_path = os.path.join(os.path.dirname(demo_data.__file__), 'regionMapping_16k_76.txt')
region_mapping_path = os.path.join(os.path.dirname(tvb_data.regionMapping.__file__), 'regionMapping_16k_76.txt')
def transactional_setup_method(self):
"""
......@@ -65,7 +64,6 @@ class TestBrainViewer(TransactionalTestCase):
creates a test user, a test project, a connectivity, a cortical surface and a face surface;
imports a CFF data-set
"""
self.test_user = TestFactory.create_user('Brain_Viewer_User')
self.test_project = TestFactory.create_project(self.test_user, 'Brain_Viewer_Project')
......@@ -79,11 +77,8 @@ class TestBrainViewer(TransactionalTestCase):
region_mapping = TestFactory.import_region_mapping(self.test_user, self.test_project,
self.region_mapping_path, self.face_surface.gid,
connectivity_idx.gid)
conn = h5.load_from_index(connectivity_idx)
rm = h5.load_from_index(region_mapping)
self.time_series_index = time_series_region_index_factory(
operation_factory(None, None)(test_user=self.test_user, test_project=self.test_project))(conn, rm)
self.connectivity = h5.load_from_index(connectivity_idx)
self.region_mapping = h5.load_from_index(region_mapping)
def transactional_teardown_method(self):
"""
......@@ -91,14 +86,16 @@ class TestBrainViewer(TransactionalTestCase):
"""
FilesHelper().remove_project_structure(self.test_project.name)
def test_launch(self):
def test_launch(self, time_series_region_index_factory):
"""
Check that all required keys are present in output from BrainViewer launch.
"""
time_series_index = time_series_region_index_factory(self.connectivity, self.region_mapping,
self.test_user, self.test_project)
viewer = BrainViewer()
viewer.current_project_id = self.test_project.id
view_model = viewer.get_view_model_class()()
view_model.time_series = UUID(self.time_series_index.gid)
view_model.time_series = UUID(time_series_index.gid)
view_model.shell_surface = UUID(self.face_surface.gid)
result = viewer.launch(view_model)
......@@ -106,36 +103,42 @@ class TestBrainViewer(TransactionalTestCase):
assert key in result and result[key] is not None
assert not result['extended_view']
def test_get_required_memory(self):
def test_get_required_memory(self, time_series_region_index_factory):
"""
Brainviewer should know required memory so expect positive number and not -1.
"""
time_series_index = time_series_region_index_factory(self.connectivity, self.region_mapping,
self.test_user, self.test_project)
viewer = BrainViewer()
viewer.current_project_id = self.test_project.id
view_model = viewer.get_view_model_class()()
view_model.time_series = UUID(self.time_series_index.gid)
view_model.time_series = UUID(time_series_index.gid)
assert viewer.get_required_memory_size(view_model) > 0
def test_generate_preview(self):
def test_generate_preview(self, time_series_region_index_factory):
"""
Check that all required keys are present in preview generate by BrainViewer.
"""
time_series_index = time_series_region_index_factory(self.connectivity, self.region_mapping,
self.test_user, self.test_project)
viewer = BrainViewer()
viewer.current_project_id = self.test_project.id
view_model = viewer.get_view_model_class()()
view_model.time_series = UUID(self.time_series_index.gid)
view_model.time_series = UUID(time_series_index.gid)
result = viewer.generate_preview(view_model, figure_size=(500, 200))
for key in TestBrainViewer.EXPECTED_KEYS:
assert key in result and result[key] is not None, key
def test_launch_eeg(self):
def test_launch_eeg(self, time_series_region_index_factory):
"""
Tests successful launch of a BrainEEG and that all required keys are present in returned template dictionary
"""
time_series_index = time_series_region_index_factory(self.connectivity, self.region_mapping,
self.test_user, self.test_project)
viewer = DualBrainViewer()
viewer.current_project_id = self.test_project.id
view_model = viewer.get_view_model_class()()
view_model.time_series = UUID(self.time_series_index.gid)
view_model.time_series = UUID(time_series_index.gid)
view_model.shell_surface = UUID(self.face_surface.gid)
result = viewer.launch(view_model)
for key in TestBrainViewer.EXPECTED_KEYS + TestBrainViewer.EXPECTED_EXTRA_KEYS:
......
......@@ -274,6 +274,7 @@ def region_simulation_factory(connectivity_factory):
return build
@pytest.fixture()
def time_series_factory():
def build(data=None):
......@@ -283,12 +284,14 @@ def time_series_factory():
data = numpy.zeros((time.size, 1, 3, 1))
data[:, 0, 0, 0] = numpy.sin(2 * numpy.pi * time / 1000.0 * 40)
data[:, 0, 1, 0] = numpy.sin(2 * numpy.pi * time / 1000.0 * 200)
data[:, 0, 2, 0] = numpy.sin(2 * numpy.pi * time / 1000.0 * 100) + numpy.sin(2 * numpy.pi * time / 1000.0 * 300)
data[:, 0, 2, 0] = numpy.sin(2 * numpy.pi * time / 1000.0 * 100) + numpy.sin(
2 * numpy.pi * time / 1000.0 * 300)
return TimeSeries(time=time, data=data, sample_period=1.0 / 4000)
return build
@pytest.fixture()
def time_series_index_factory(time_series_factory, operation_factory):
def build(data=None, op=None):
......@@ -315,7 +318,7 @@ def time_series_index_factory(time_series_factory, operation_factory):
@pytest.fixture()
def time_series_region_index_factory(operation_factory):
def build(connectivity, region_mapping):
def build(connectivity, region_mapping, test_user=None, test_project=None):
time = numpy.linspace(0, 1000, 4000)
data = numpy.zeros((time.size, 1, 3, 1))
data[:, 0, 0, 0] = numpy.sin(2 * numpy.pi * time / 1000.0 * 40)
......@@ -323,9 +326,10 @@ def time_series_region_index_factory(operation_factory):
data[:, 0, 2, 0] = numpy.sin(2 * numpy.pi * time / 1000.0 * 100) + \
numpy.sin(2 * numpy.pi * time / 1000.0 * 300)
ts = TimeSeriesRegion(time=time, data=data, sample_period=1.0 / 4000, connectivity=connectivity, region_mapping=region_mapping)
ts = TimeSeriesRegion(time=time, data=data, sample_period=1.0 / 4000, connectivity=connectivity,
region_mapping=region_mapping)
op = operation_factory()
op = operation_factory(test_user=test_user, test_project=test_project)
ts_db = TimeSeriesRegionIndex()
ts_db.fk_from_operation = op.id
......@@ -339,6 +343,7 @@ def time_series_region_index_factory(operation_factory):
ts_db = dao.store_entity(ts_db)
return ts_db
return build
......@@ -346,6 +351,7 @@ def time_series_region_index_factory(operation_factory):
def dummy_datatype_factory():
def build():
return DummyDataType()
return build
......@@ -353,6 +359,7 @@ def dummy_datatype_factory():
def dummy_datatype2_index_factory():
def build(subject=None, state=None):
return DummyDataType2Index(subject=subject, state=state)
return build
......@@ -383,7 +390,6 @@ def dummy_datatype_index_factory(dummy_datatype_factory, operation_factory):
@pytest.fixture()
def datatype_measure_factory(operation_factory):
def build(analyzed_entity):
measure = DatatypeMeasureIndex()
measure.metrics = '{"v": 3}'
measure.source = analyzed_entity
......@@ -395,7 +401,8 @@ def datatype_measure_factory(operation_factory):
@pytest.fixture()
def datatype_group_factory(time_series_index_factory, datatype_measure_factory, project_factory, user_factory, operation_factory):
def datatype_group_factory(time_series_index_factory, datatype_measure_factory, project_factory, user_factory,
operation_factory):
def build(subject="Datatype Factory User", state="RAW_DATA", project=None):
range_1 = ["row1", [1, 2, 3]]
......@@ -496,4 +503,5 @@ def test_adapter_factory():
stored_adapter.id = inst_from_db.id
dao.store_entity(stored_adapter, inst_from_db is not None)
return build
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment