Exemplo n.º 1
0
 def test_full_import(self):
     """
     Test that importing a CFF generates at least one DataType in DB.
     """
     all_dt = self.get_all_datatypes()
     self.assertEqual(0, len(all_dt))
     TestFactory.import_cff(cff_path=self.VALID_CFF,
                            test_user=self.test_user,
                            test_project=self.test_project)
     flow_service = FlowService()
     ### Check that at one Connectivity was persisted
     gid_list = flow_service.get_available_datatypes(
         self.test_project.id, 'tvb.datatypes.connectivity.Connectivity')
     self.assertEquals(len(gid_list), 1)
     ### Check that at one RegionMapping was persisted
     gid_list = flow_service.get_available_datatypes(
         self.test_project.id, 'tvb.datatypes.surfaces.RegionMapping')
     self.assertEquals(len(gid_list), 1)
     ### Check that at one LocalConnectivity was persisted
     gids = flow_service.get_available_datatypes(
         self.test_project.id, 'tvb.datatypes.surfaces.LocalConnectivity')
     self.assertEquals(len(gids), 1)
     connectivity = dao.get_datatype_by_gid(gids[0][2])
     metadata = connectivity.get_metadata()
     self.assertEqual(metadata['Cutoff'], '40.0')
     self.assertEqual(metadata['Equation'], 'null')
     self.assertFalse(metadata['Invalid'])
     self.assertFalse(metadata['Is_nan'])
     self.assertEqual(metadata['Type'], 'LocalConnectivity')
     ### Check that at 2 Surfaces were persisted
     gid_list = flow_service.get_available_datatypes(
         self.test_project.id, 'tvb.datatypes.surfaces_data.SurfaceData')
     self.assertEquals(len(gid_list), 2)
Exemplo n.º 2
0
    def test_reduce_dimension_component(self):
        """
        Tests the generation of the component which allows the user
        to select one dimension from a multi dimension array
        """
        flow_service = FlowService()
        inserted_data = flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(inserted_data), 0, "Expected to find no data")
        adapter_instance = NDimensionArrayAdapter()
        PARAMS = {}
        OperationService().initiate_prelaunch(self.operation, adapter_instance,
                                              {}, **PARAMS)
        inserted_data = flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(inserted_data), 1, "Problems when inserting data")

        algogroup = dao.find_group('tvb_test.adapters.ndimensionarrayadapter',
                                   'NDimensionArrayAdapter')
        _, interface = flow_service.prepare_adapter(self.test_project.id,
                                                    algogroup)
        self.template_specification['inputList'] = interface
        resulted_html = _template2string(self.template_specification)
        self.soup = BeautifulSoup(resulted_html)

        found_divs = self.soup.findAll(
            'p', attrs=dict(id="dimensionsDiv_input_data"))
        self.assertEqual(len(found_divs), 1, "Data generated incorrect")

        gid = inserted_data[0][2]
        cherrypy.session = {'user': self.test_user}
        entity = dao.get_datatype_by_gid(gid)
        component_content = FlowController().gettemplatefordimensionselect(
            gid, "input_data")
        self.soup = BeautifulSoup(component_content)

        #check dimensions
        found_selects_0 = self.soup.findAll(
            'select', attrs=dict(id="dimId_input_data_dimensions_0"))
        found_selects_1 = self.soup.findAll(
            'select', attrs=dict(id="dimId_input_data_dimensions_1"))
        found_selects_2 = self.soup.findAll(
            'select', attrs=dict(id="dimId_input_data_dimensions_2"))
        self.assertEqual(len(found_selects_0), 1, "select not found")
        self.assertEqual(len(found_selects_1), 1, "select not found")
        self.assertEqual(len(found_selects_2), 1, "select not found")

        #check the aggregation functions selects
        agg_selects_0 = self.soup.findAll(
            'select', attrs=dict(id="funcId_input_data_dimensions_0"))
        agg_selects_1 = self.soup.findAll(
            'select', attrs=dict(id="funcId_input_data_dimensions_1"))
        agg_selects_2 = self.soup.findAll(
            'select', attrs=dict(id="funcId_input_data_dimensions_2"))
        self.assertEqual(len(agg_selects_0), 1, "incorrect first dim")
        self.assertEqual(len(agg_selects_1), 1, "incorrect second dim")
        self.assertEqual(len(agg_selects_2), 1, "incorrect third dim.")

        data_shape = entity.shape
        self.assertEqual(len(data_shape), 3, "Shape of the array is incorrect")
        for i in range(data_shape[0]):
            options = self.soup.findAll('option',
                                        attrs=dict(value=gid + "_0_" + str(i)))
            self.assertEqual(len(options), 1, "Generated option is incorrect")
            self.assertEqual(options[0].text, "Time " + str(i),
                             "The label of the option is not correct")
            self.assertEqual(options[0].parent.attrMap["name"],
                             "input_data_dimensions_0")
        for i in range(data_shape[1]):
            options = self.soup.findAll('option',
                                        attrs=dict(value=gid + "_1_" + str(i)))
            self.assertEqual(len(options), 1, "Generated option is incorrect")
            self.assertEqual(options[0].text, "Channel " + str(i),
                             "Option's label incorrect")
            self.assertEqual(options[0].parent.attrMap["name"],
                             "input_data_dimensions_1", "incorrect parent")
        for i in range(data_shape[2]):
            options = self.soup.findAll('option',
                                        attrs=dict(value=gid + "_2_" + str(i)))
            self.assertEqual(len(options), 1, "Generated option is incorrect")
            self.assertEqual(options[0].text, "Line " + str(i),
                             "The label of the option is not correct")
            self.assertEqual(options[0].parent.attrMap["name"],
                             "input_data_dimensions_2")

        #check the expected hidden fields
        expected_shape = self.soup.findAll(
            'input', attrs=dict(id="input_data_expected_shape"))
        self.assertEqual(len(expected_shape), 1,
                         "The generated option is not correct")
        self.assertEqual(expected_shape[0]["value"], "expected_shape_",
                         "The generated option is not correct")
        input_hidden_op = self.soup.findAll(
            'input', attrs=dict(id="input_data_operations"))
        self.assertEqual(len(input_hidden_op), 1,
                         "The generated option is not correct")
        self.assertEqual(input_hidden_op[0]["value"], "operations_",
                         "The generated option is not correct")
        input_hidden_dim = self.soup.findAll(
            'input', attrs=dict(id="input_data_expected_dim"))
        self.assertEqual(len(input_hidden_dim), 1,
                         "The generated option is not correct")
        self.assertEqual(input_hidden_dim[0]["value"], "requiredDim_1",
                         "The generated option is not correct")
        input_hidden_shape = self.soup.findAll(
            'input', attrs=dict(id="input_data_array_shape"))
        self.assertEqual(len(input_hidden_shape), 1,
                         "The generated option is not correct")
        self.assertEqual(input_hidden_shape[0]["value"], "[5, 1, 3]",
                         "The generated option is not correct")

        #check only the first option from the aggregations functions selects
        options = self.soup.findAll('option', attrs=dict(value="func_none"))
        self.assertEqual(len(options), 3,
                         "The generated option is not correct")
Exemplo n.º 3
0
class ProjectStructureTest(TransactionalTestCase):
    """
    Test ProjectService methods.
    """
    def setUp(self):
        """
        Prepare before each test.
        """
        self.project_service = ProjectService()
        self.flow_service = FlowService()
        self.structure_helper = FilesHelper()

        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(self.test_user,
                                                       "ProjectStructure")

        self.relevant_filter = StaticFiltersFactory.build_datatype_filters(
            single_filter=StaticFiltersFactory.RELEVANT_VIEW)
        self.full_filter = StaticFiltersFactory.build_datatype_filters(
            single_filter=StaticFiltersFactory.FULL_VIEW)

    def tearDown(self):
        self.delete_project_folders()


#    def tearDown(self):
#        """
#        Remove project folders and clean up database.
#        """
#        self.clean_database(True)

    def test_set_operation_visibility(self):
        """
        Check if the visibility for an operation is set correct.
        """
        self.__init_algorithmn()
        op1 = model.Operation(self.test_user.id, self.test_project.id,
                              self.algo_inst.id, "")
        op1 = dao.store_entity(op1)
        self.assertTrue(op1.visible, "The operation should be visible.")
        self.project_service.set_operation_and_group_visibility(op1.gid, False)
        updated_op = dao.get_operation_by_id(op1.id)
        self.assertFalse(updated_op.visible,
                         "The operation should not be visible.")

    def test_set_op_and_group_visibility(self):
        """
        When changing the visibility for an operation that belongs to an operation group, we
        should also change the visibility for the entire group of operations.
        """
        _, group_id = TestFactory.create_group(self.test_user,
                                               subject="test-subject-1")
        list_of_operations = dao.get_operations_in_group(group_id)
        for operation in list_of_operations:
            self.assertTrue(operation.visible,
                            "The operation should be visible.")
        self.project_service.set_operation_and_group_visibility(
            list_of_operations[0].gid, False)
        operations = dao.get_operations_in_group(group_id)
        for operation in operations:
            self.assertFalse(operation.visible,
                             "The operation should not be visible.")

    def test_set_op_group_visibility(self):
        """
        Tests if the visibility for an operation group is set correct.
        """
        _, group_id = TestFactory.create_group(self.test_user,
                                               subject="test-subject-1")
        list_of_operations = dao.get_operations_in_group(group_id)
        for operation in list_of_operations:
            self.assertTrue(operation.visible,
                            "The operation should be visible.")
        op_group = dao.get_operationgroup_by_id(group_id)
        self.project_service.set_operation_and_group_visibility(
            op_group.gid, False, True)
        operations = dao.get_operations_in_group(group_id)
        for operation in operations:
            self.assertFalse(operation.visible,
                             "The operation should not be visible.")

    def test_is_upload_operation(self):
        self.__init_algorithmn()
        upload_algo = self._create_algo_for_upload()
        op1 = model.Operation(self.test_user.id, self.test_project.id,
                              self.algo_inst.id, "")
        op2 = model.Operation(self.test_user.id, self.test_project.id,
                              upload_algo.id, "")
        operations = dao.store_entities([op1, op2])
        is_upload_operation = self.project_service.is_upload_operation(
            operations[0].gid)
        self.assertFalse(is_upload_operation,
                         "The operation is not an upload operation.")
        is_upload_operation = self.project_service.is_upload_operation(
            operations[1].gid)
        self.assertTrue(is_upload_operation,
                        "The operation is an upload operation.")

    def test_get_upload_operations(self):
        """
        Test get_all when filter is for Upload category.
        """
        self.__init_algorithmn()
        upload_algo = self._create_algo_for_upload()

        project = model.Project("test_proj_2", self.test_user.id, "desc")
        project = dao.store_entity(project)

        op1 = model.Operation(self.test_user.id, self.test_project.id,
                              self.algo_inst.id, "")
        op2 = model.Operation(self.test_user.id,
                              project.id,
                              upload_algo.id,
                              "",
                              status=model.STATUS_FINISHED)
        op3 = model.Operation(self.test_user.id, self.test_project.id,
                              upload_algo.id, "")
        op4 = model.Operation(self.test_user.id,
                              self.test_project.id,
                              upload_algo.id,
                              "",
                              status=model.STATUS_FINISHED)
        op5 = model.Operation(self.test_user.id,
                              self.test_project.id,
                              upload_algo.id,
                              "",
                              status=model.STATUS_FINISHED)
        operations = dao.store_entities([op1, op2, op3, op4, op5])

        upload_operations = self.project_service.get_all_operations_for_uploaders(
            self.test_project.id)
        self.assertEqual(2, len(upload_operations),
                         "Wrong number of upload operations.")
        upload_ids = [operation.id for operation in upload_operations]
        for i in [3, 4]:
            self.assertTrue(operations[i].id in upload_ids,
                            "The operation should be an upload operation.")
        for i in [0, 1, 2]:
            self.assertFalse(
                operations[i].id in upload_ids,
                "The operation should not be an upload operation.")

    def test_is_datatype_group(self):
        """
        Tests if a datatype is group.
        """
        _, dt_group_id, first_dt, _ = self._create_datatype_group()
        dt_group = dao.get_generic_entity(model.DataTypeGroup, dt_group_id)[0]
        is_dt_group = self.project_service.is_datatype_group(dt_group.gid)
        self.assertTrue(is_dt_group,
                        "The datatype should be a datatype group.")
        is_dt_group = self.project_service.is_datatype_group(first_dt.gid)
        self.assertFalse(is_dt_group,
                         "The datatype should not be a datatype group.")

    def test_count_datatypes_in_group(self):
        """ Test that counting dataTypes is correct. Happy flow."""
        _, dt_group_id, first_dt, _ = self._create_datatype_group()
        count = dao.count_datatypes_in_group(dt_group_id)
        self.assertEqual(count, 2)
        count = dao.count_datatypes_in_group(first_dt.id)
        self.assertEqual(count, 0, "There should be no dataType.")

    def test_set_datatype_visibility(self):
        """
        Check if the visibility for a datatype is set correct.
        """
        #it's a list of 3 elem.
        mapped_arrays = self._create_mapped_arrays(self.test_project.id)
        for mapped_array in mapped_arrays:
            is_visible = dao.get_datatype_by_id(mapped_array[0]).visible
            self.assertTrue(is_visible, "The data type should be visible.")

        self.project_service.set_datatype_visibility(mapped_arrays[0][2],
                                                     False)
        for i in xrange(len(mapped_arrays)):
            is_visible = dao.get_datatype_by_id(mapped_arrays[i][0]).visible
            if not i:
                self.assertFalse(is_visible,
                                 "The data type should not be visible.")
            else:
                self.assertTrue(is_visible, "The data type should be visible.")

    def test_set_visibility_for_dt_in_group(self):
        """
        Check if the visibility for a datatype from a datatype group is set correct.
        """
        _, dt_group_id, first_dt, second_dt = self._create_datatype_group()
        self.assertTrue(first_dt.visible, "The data type should be visible.")
        self.assertTrue(second_dt.visible, "The data type should be visible.")
        self.project_service.set_datatype_visibility(first_dt.gid, False)

        db_dt_group = self.project_service.get_datatype_by_id(dt_group_id)
        db_first_dt = self.project_service.get_datatype_by_id(first_dt.id)
        db_second_dt = self.project_service.get_datatype_by_id(second_dt.id)

        self.assertFalse(db_dt_group.visible,
                         "The data type should be visible.")
        self.assertFalse(db_first_dt.visible,
                         "The data type should not be visible.")
        self.assertFalse(db_second_dt.visible,
                         "The data type should be visible.")

    def test_set_visibility_for_group(self):
        """
        Check if the visibility for a datatype group is set correct.
        """
        _, dt_group_id, first_dt, second_dt = self._create_datatype_group()
        dt_group = dao.get_generic_entity(model.DataTypeGroup, dt_group_id)[0]

        self.assertTrue(dt_group.visible,
                        "The data type group should be visible.")
        self.assertTrue(first_dt.visible, "The data type should be visible.")
        self.assertTrue(second_dt.visible, "The data type should be visible.")
        self.project_service.set_datatype_visibility(dt_group.gid, False)

        updated_dt_group = self.project_service.get_datatype_by_id(dt_group_id)
        updated_first_dt = self.project_service.get_datatype_by_id(first_dt.id)
        updated_second_dt = self.project_service.get_datatype_by_id(
            second_dt.id)

        self.assertFalse(updated_dt_group.visible,
                         "The data type group should be visible.")
        self.assertFalse(updated_first_dt.visible,
                         "The data type should be visible.")
        self.assertFalse(updated_second_dt.visible,
                         "The data type should be visible.")

    def test_getdatatypes_from_dtgroup(self):
        """
        Validate that we can retrieve all DTs from a DT_Group
        """
        _, dt_group_id, first_dt, second_dt = self._create_datatype_group()
        datatypes = self.project_service.get_datatypes_from_datatype_group(
            dt_group_id)
        self.assertEqual(
            len(datatypes), 2,
            "There should be 2 datatypes into the datatype group.")
        expected_dict = {first_dt.id: first_dt, second_dt.id: second_dt}
        actual_dict = {
            datatypes[0].id: datatypes[0],
            datatypes[1].id: datatypes[1]
        }

        for key in expected_dict.keys():
            expected = expected_dict[key]
            actual = actual_dict[key]
            self.assertEqual(expected.id, actual.id, "Not the same id.")
            self.assertEqual(expected.gid, actual.gid, "Not the same gid.")
            self.assertEqual(expected.type, actual.type, "Not the same type.")
            self.assertEqual(expected.subject, actual.subject,
                             "Not the same subject.")
            self.assertEqual(expected.state, actual.state,
                             "Not the same state.")
            self.assertEqual(expected.visible, actual.visible,
                             "The datatype visibility is not correct.")
            self.assertEqual(expected.module, actual.module,
                             "Not the same module.")
            self.assertEqual(expected.user_tag_1, actual.user_tag_1,
                             "Not the same user_tag_1.")
            self.assertEqual(expected.invalid, actual.invalid,
                             "The invalid field value is not correct.")
            self.assertEqual(expected.is_nan, actual.is_nan,
                             "The is_nan field value is not correct.")

    def test_get_operations_for_dt(self):

        created_ops, datatype_gid = self._create_operations_with_inputs()
        operations = self.project_service.get_operations_for_datatype(
            datatype_gid, self.relevant_filter)
        self.assertEqual(len(operations), 2)
        self.assertTrue(
            created_ops[0].id in [operations[0].id, operations[1].id],
            "Retrieved wrong operations.")
        self.assertTrue(
            created_ops[2].id in [operations[0].id, operations[1].id],
            "Retrieved wrong operations.")

        operations = self.project_service.get_operations_for_datatype(
            datatype_gid, self.full_filter)
        self.assertEqual(len(operations), 4)
        ids = [
            operations[0].id, operations[1].id, operations[2].id,
            operations[3].id
        ]
        for i in range(4):
            self.assertTrue(created_ops[i].id in ids,
                            "Retrieved wrong operations.")

        operations = self.project_service.get_operations_for_datatype(
            datatype_gid, self.relevant_filter, True)
        self.assertEqual(len(operations), 1)
        self.assertEqual(created_ops[4].id, operations[0].id,
                         "Incorrect number of operations.")

        operations = self.project_service.get_operations_for_datatype(
            datatype_gid, self.full_filter, True)
        self.assertEqual(len(operations), 2)
        self.assertTrue(
            created_ops[4].id in [operations[0].id, operations[1].id],
            "Retrieved wrong operations.")
        self.assertTrue(
            created_ops[5].id in [operations[0].id, operations[1].id],
            "Retrieved wrong operations.")

    def test_get_operations_for_dt_group(self):

        created_ops, dt_group_id = self._create_operations_with_inputs(True)

        ops = self.project_service.get_operations_for_datatype_group(
            dt_group_id, self.relevant_filter)
        self.assertEqual(len(ops), 2)
        self.assertTrue(created_ops[0].id in [ops[0].id, ops[1].id],
                        "Retrieved wrong operations.")
        self.assertTrue(created_ops[2].id in [ops[0].id, ops[1].id],
                        "Retrieved wrong operations.")

        ops = self.project_service.get_operations_for_datatype_group(
            dt_group_id, self.full_filter)
        self.assertEqual(len(ops), 4, "Incorrect number of operations.")
        ids = [ops[0].id, ops[1].id, ops[2].id, ops[3].id]
        for i in range(4):
            self.assertTrue(created_ops[i].id in ids,
                            "Retrieved wrong operations.")

        ops = self.project_service.get_operations_for_datatype_group(
            dt_group_id, self.relevant_filter, True)
        self.assertEqual(len(ops), 1)
        self.assertEqual(created_ops[4].id, ops[0].id,
                         "Incorrect number of operations.")

        ops = self.project_service.get_operations_for_datatype_group(
            dt_group_id, self.full_filter, True)
        self.assertEqual(len(ops), 2)
        self.assertTrue(created_ops[4].id in [ops[0].id, ops[1].id],
                        "Retrieved wrong operations.")
        self.assertTrue(created_ops[5].id in [ops[0].id, ops[1].id],
                        "Retrieved wrong operations.")

    def test_get_inputs_for_operation(self):

        algo_group = dao.find_group('tvb_test.adapters.testadapter3',
                                    'TestAdapter3')
        algo = dao.get_algorithm_by_group(algo_group.id)

        array_wrappers = self._create_mapped_arrays(self.test_project.id)
        ids = []
        for datatype in array_wrappers:
            ids.append(datatype[0])

        datatype = dao.get_datatype_by_id(ids[0])
        datatype.visible = False
        dao.store_entity(datatype)

        parameters = json.dumps({
            "param_5": "1",
            "param_1": array_wrappers[0][2],
            "param_2": array_wrappers[1][2],
            "param_3": array_wrappers[2][2],
            "param_6": "0"
        })
        operation = model.Operation(self.test_user.id, self.test_project.id,
                                    algo.id, parameters)
        operation = dao.store_entity(operation)

        inputs = self.project_service.get_datatype_and_datatypegroup_inputs_for_operation(
            operation.gid, self.relevant_filter)
        self.assertEqual(len(inputs), 2)
        self.assertTrue(ids[1] in [inputs[0].id, inputs[1].id],
                        "Retrieved wrong dataType.")
        self.assertTrue(ids[2] in [inputs[0].id, inputs[1].id],
                        "Retrieved wrong dataType.")
        self.assertFalse(ids[0] in [inputs[0].id, inputs[1].id],
                         "Retrieved wrong dataType.")

        inputs = self.project_service.get_datatype_and_datatypegroup_inputs_for_operation(
            operation.gid, self.full_filter)
        self.assertEqual(len(inputs), 3, "Incorrect number of operations.")
        self.assertTrue(ids[0] in [inputs[0].id, inputs[1].id, inputs[2].id],
                        "Retrieved wrong dataType.")
        self.assertTrue(ids[1] in [inputs[0].id, inputs[1].id, inputs[2].id],
                        "Retrieved wrong dataType.")
        self.assertTrue(ids[2] in [inputs[0].id, inputs[1].id, inputs[2].id],
                        "Retrieved wrong dataType.")

        project, dt_group_id, first_dt, _ = self._create_datatype_group()
        first_dt.visible = False
        dao.store_entity(first_dt)
        parameters = json.dumps({"other_param": "_", "param_1": first_dt.gid})
        operation = model.Operation(self.test_user.id, project.id, algo.id,
                                    parameters)
        operation = dao.store_entity(operation)

        inputs = self.project_service.get_datatype_and_datatypegroup_inputs_for_operation(
            operation.gid, self.relevant_filter)
        self.assertEqual(len(inputs), 0, "Incorrect number of dataTypes.")
        inputs = self.project_service.get_datatype_and_datatypegroup_inputs_for_operation(
            operation.gid, self.full_filter)
        self.assertEqual(len(inputs), 1, "Incorrect number of dataTypes.")
        self.assertEqual(inputs[0].id, dt_group_id, "Wrong dataType.")
        self.assertTrue(inputs[0].id != first_dt.id, "Wrong dataType.")

    def test_get_inputs_for_op_group(self):
        """
        Tests method get_datatypes_inputs_for_operation_group.
        The DataType inputs will be from a DataType group.
        """
        project, dt_group_id, first_dt, second_dt = self._create_datatype_group(
        )
        first_dt.visible = False
        dao.store_entity(first_dt)
        second_dt.visible = False
        dao.store_entity(second_dt)

        op_group = model.OperationGroup(project.id, "group", "range1[1..2]")
        op_group = dao.store_entity(op_group)
        params_1 = json.dumps({
            "param_5": "1",
            "param_1": first_dt.gid,
            "param_6": "2"
        })
        params_2 = json.dumps({
            "param_5": "1",
            "param_4": second_dt.gid,
            "param_6": "5"
        })

        algo_group = dao.find_group('tvb_test.adapters.testadapter3',
                                    'TestAdapter3')
        algo = dao.get_algorithm_by_group(algo_group.id)

        op1 = model.Operation(self.test_user.id,
                              project.id,
                              algo.id,
                              params_1,
                              op_group_id=op_group.id)
        op2 = model.Operation(self.test_user.id,
                              project.id,
                              algo.id,
                              params_2,
                              op_group_id=op_group.id)
        dao.store_entities([op1, op2])

        inputs = self.project_service.get_datatypes_inputs_for_operation_group(
            op_group.id, self.relevant_filter)
        self.assertEqual(len(inputs), 0)

        inputs = self.project_service.get_datatypes_inputs_for_operation_group(
            op_group.id, self.full_filter)
        self.assertEqual(len(inputs), 1, "Incorrect number of dataTypes.")
        self.assertFalse(first_dt.id == inputs[0].id,
                         "Retrieved wrong dataType.")
        self.assertFalse(second_dt.id == inputs[0].id,
                         "Retrieved wrong dataType.")
        self.assertTrue(dt_group_id == inputs[0].id,
                        "Retrieved wrong dataType.")

        first_dt.visible = True
        dao.store_entity(first_dt)

        inputs = self.project_service.get_datatypes_inputs_for_operation_group(
            op_group.id, self.relevant_filter)
        self.assertEqual(len(inputs), 1, "Incorrect number of dataTypes.")
        self.assertFalse(first_dt.id == inputs[0].id,
                         "Retrieved wrong dataType.")
        self.assertFalse(second_dt.id == inputs[0].id,
                         "Retrieved wrong dataType.")
        self.assertTrue(dt_group_id == inputs[0].id,
                        "Retrieved wrong dataType.")

        inputs = self.project_service.get_datatypes_inputs_for_operation_group(
            op_group.id, self.full_filter)
        self.assertEqual(len(inputs), 1, "Incorrect number of dataTypes.")
        self.assertFalse(first_dt.id == inputs[0].id,
                         "Retrieved wrong dataType.")
        self.assertFalse(second_dt.id == inputs[0].id,
                         "Retrieved wrong dataType.")
        self.assertTrue(dt_group_id == inputs[0].id,
                        "Retrieved wrong dataType.")

    def test_get_inputs_for_op_group_simple_inputs(self):
        """
        Tests method get_datatypes_inputs_for_operation_group.
        The dataType inputs will not be part of a dataType group.
        """
        #it's a list of 3 elem.
        array_wrappers = self._create_mapped_arrays(self.test_project.id)
        array_wrapper_ids = []
        for datatype in array_wrappers:
            array_wrapper_ids.append(datatype[0])

        datatype = dao.get_datatype_by_id(array_wrapper_ids[0])
        datatype.visible = False
        dao.store_entity(datatype)

        op_group = model.OperationGroup(self.test_project.id, "group",
                                        "range1[1..2]")
        op_group = dao.store_entity(op_group)
        params_1 = json.dumps({
            "param_5": "2",
            "param_1": array_wrappers[0][2],
            "param_2": array_wrappers[1][2],
            "param_6": "7"
        })
        params_2 = json.dumps({
            "param_5": "5",
            "param_3": array_wrappers[2][2],
            "param_2": array_wrappers[1][2],
            "param_6": "6"
        })

        algo_group = dao.find_group('tvb_test.adapters.testadapter3',
                                    'TestAdapter3')
        algo = dao.get_algorithm_by_group(algo_group.id)

        op1 = model.Operation(self.test_user.id,
                              self.test_project.id,
                              algo.id,
                              params_1,
                              op_group_id=op_group.id)
        op2 = model.Operation(self.test_user.id,
                              self.test_project.id,
                              algo.id,
                              params_2,
                              op_group_id=op_group.id)
        dao.store_entities([op1, op2])

        inputs = self.project_service.get_datatypes_inputs_for_operation_group(
            op_group.id, self.relevant_filter)
        self.assertEqual(len(inputs), 2)
        self.assertFalse(array_wrapper_ids[0] in [inputs[0].id, inputs[1].id],
                         "Retrieved wrong dataType.")
        self.assertTrue(array_wrapper_ids[1] in [inputs[0].id, inputs[1].id],
                        "Retrieved wrong dataType.")
        self.assertTrue(array_wrapper_ids[2] in [inputs[0].id, inputs[1].id],
                        "Retrieved wrong dataType.")

        inputs = self.project_service.get_datatypes_inputs_for_operation_group(
            op_group.id, self.full_filter)
        self.assertEqual(len(inputs), 3, "Incorrect number of dataTypes.")
        self.assertTrue(
            array_wrapper_ids[0] in [inputs[0].id, inputs[1].id, inputs[2].id])
        self.assertTrue(
            array_wrapper_ids[1] in [inputs[0].id, inputs[1].id, inputs[2].id])
        self.assertTrue(
            array_wrapper_ids[2] in [inputs[0].id, inputs[1].id, inputs[2].id])

    def test_remove_datatype(self):
        """
        Tests the deletion of a datatype.
        """
        #it's a list of 3 elem.
        array_wrappers = self._create_mapped_arrays(self.test_project.id)
        dt_list = []
        for array_wrapper in array_wrappers:
            dt_list.append(dao.get_datatype_by_id(array_wrapper[0]))

        self.project_service.remove_datatype(self.test_project.id,
                                             dt_list[0].gid)
        self._check_if_datatype_was_removed(dt_list[0])

    def test_remove_datatype_from_group(self):
        """
        Tests the deletion of a datatype group.
        """
        project, dt_group_id, first_dt, second_dt = self._create_datatype_group(
        )
        datatype_group = dao.get_generic_entity(model.DataTypeGroup,
                                                dt_group_id)[0]

        self.project_service.remove_datatype(project.id, first_dt.gid)
        self._check_if_datatype_was_removed(first_dt)
        self._check_if_datatype_was_removed(second_dt)
        self._check_if_datatype_was_removed(datatype_group)
        self._check_datatype_group_removed(dt_group_id,
                                           datatype_group.fk_operation_group)

    def test_remove_datatype_group(self):
        """
        Tests the deletion of a datatype group.
        """
        project, dt_group_id, first_dt, second_dt = self._create_datatype_group(
        )
        datatype_group = dao.get_generic_entity(model.DataTypeGroup,
                                                dt_group_id)[0]

        self.project_service.remove_datatype(project.id, datatype_group.gid)
        self._check_if_datatype_was_removed(first_dt)
        self._check_if_datatype_was_removed(second_dt)
        self._check_if_datatype_was_removed(datatype_group)
        self._check_datatype_group_removed(dt_group_id,
                                           datatype_group.fk_operation_group)

    def _create_mapped_arrays(self, project_id):

        array_wrappers = self.flow_service.get_available_datatypes(
            project_id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(array_wrappers), 0)

        algo_group = dao.find_group('tvb_test.adapters.ndimensionarrayadapter',
                                    'NDimensionArrayAdapter')
        group, _ = self.flow_service.prepare_adapter(project_id, algo_group)

        adapter_instance = self.flow_service.build_adapter_instance(group)
        data = {'param_1': 'some value'}
        #create 3 data types
        self.flow_service.fire_operation(adapter_instance, self.test_user,
                                         project_id, **data)
        array_wrappers = self.flow_service.get_available_datatypes(
            project_id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(array_wrappers), 1)

        self.flow_service.fire_operation(adapter_instance, self.test_user,
                                         project_id, **data)
        array_wrappers = self.flow_service.get_available_datatypes(
            project_id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(array_wrappers), 2)

        self.flow_service.fire_operation(adapter_instance, self.test_user,
                                         project_id, **data)
        array_wrappers = self.flow_service.get_available_datatypes(
            project_id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(array_wrappers), 3)

        return array_wrappers

    def _create_operation(self, project_id, algorithm_id):
        algorithm = dao.get_algorithm_by_id(algorithm_id)
        meta = {
            DataTypeMetaData.KEY_SUBJECT: "John Doe",
            DataTypeMetaData.KEY_STATE: "RAW"
        }
        operation = model.Operation(self.test_user.id,
                                    project_id,
                                    algorithm.id,
                                    'test params',
                                    meta=json.dumps(meta),
                                    status="FINISHED",
                                    method_name=ABCAdapter.LAUNCH_METHOD)
        return dao.store_entity(operation)

    def _create_datatype_group(self):
        """
        Creates a project, one DataTypeGroup with 2 DataTypes into the new group.
        """
        test_project = TestFactory.create_project(self.test_user, "NewProject")

        all_operations = dao.get_filtered_operations(test_project.id, None)
        self.assertEqual(len(all_operations), 0,
                         "There should be no operation.")

        datatypes, op_group_id = TestFactory.create_group(
            self.test_user, test_project)
        dt_group = dao.get_datatypegroup_by_op_group_id(op_group_id)

        return test_project, dt_group.id, datatypes[0], datatypes[1]

    def _create_operations_with_inputs(self, is_group_parent=False):
        """
        Method used for creating a complex tree of operations.

        If 'if_group_parent' is True then a new group will be created and one of its entries it will be used as
        input for the returned operations.
        """
        group_dts, root_op_group_id = TestFactory.create_group(
            self.test_user, self.test_project)
        if is_group_parent:
            datatype_gid = group_dts[0].gid
        else:
            datatype_gid = ProjectServiceTest._create_value_wrapper(
                self.test_user, self.test_project)[1]

        parameters = json.dumps({"param_name": datatype_gid})

        ops = []
        for i in range(4):
            ops.append(
                TestFactory.create_operation(test_user=self.test_user,
                                             test_project=self.test_project))
            if i in [1, 3]:
                ops[i].visible = False
            ops[i].parameters = parameters
            ops[i] = dao.store_entity(ops[i])

        #groups
        _, ops_group = TestFactory.create_group(self.test_user,
                                                self.test_project)
        ops_group = dao.get_operations_in_group(ops_group)
        self.assertEqual(2, len(ops_group))
        ops_group[0].parameters = parameters
        ops_group[0] = dao.store_entity(ops_group[0])
        ops_group[1].visible = False
        ops_group[1].parameters = parameters
        ops_group[1] = dao.store_entity(ops_group[1])

        ops.extend(ops_group)
        if is_group_parent:
            dt_group = dao.get_datatypegroup_by_op_group_id(root_op_group_id)
            return ops, dt_group.id
        return ops, datatype_gid

    def _check_if_datatype_was_removed(self, datatype):
        """
        Check if a certain datatype was removed.
        """
        try:
            dao.get_datatype_by_id(datatype.id)
            self.fail("The datatype was not deleted.")
        except Exception:
            pass
        try:
            dao.get_operation_by_id(datatype.fk_from_operation)
            self.fail("The operation was not deleted.")
        except Exception:
            pass

    def _check_datatype_group_removed(self, datatype_group_id,
                                      operation_groupp_id):
        """
        Checks if the DataTypeGroup and OperationGroup was removed.
        """
        try:
            dao.get_generic_entity(model.DataTypeGroup, datatype_group_id)
            self.fail("The DataTypeGroup entity was not removed.")
        except Exception:
            pass

        try:
            dao.get_operationgroup_by_id(operation_groupp_id)
            self.fail("The OperationGroup entity was not removed.")
        except Exception:
            pass

    def __init_algorithmn(self):
        """
        Insert some starting data in the database.
        """
        categ1 = model.AlgorithmCategory('one', True)
        self.categ1 = dao.store_entity(categ1)
        algo = model.AlgorithmGroup("tvb_test.core.services.flowservice_test",
                                    "ValidTestAdapter", categ1.id)
        adapter = dao.store_entity(algo)
        algo = model.Algorithm(adapter.id,
                               'ident',
                               name='',
                               req_data='',
                               param_name='',
                               output='')
        self.algo_inst = dao.store_entity(algo)

    @staticmethod
    def _create_algo_for_upload():
        """ Creates a fake algorithm for an upload category. """
        category = dao.store_entity(
            model.AlgorithmCategory("upload_category", rawinput=True))
        algo_group = dao.store_entity(
            model.AlgorithmGroup("module", "classname", category.id))
        return dao.store_entity(model.Algorithm(algo_group.id, "algo"))
class ContextModelParametersTest(TransactionalTestCase):
    """
    Test class for the context_model_parameters module.
    """
    START = 100.55
    INCREMENT = 122.32

    def setUp(self):
        """
        Reset the database before each test;
        creates a test user, a test project, a connectivity;
        sets context model parameters and a Generic2dOscillator as a default model
        """
        self.flow_service = FlowService()

        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(self.test_user)
        TestFactory.import_cff(test_user=self.test_user,
                               test_project=self.test_project)
        self.default_model = models_module.Generic2dOscillator()

        all_connectivities = self.flow_service.get_available_datatypes(
            self.test_project.id, Connectivity)
        self.connectivity = ABCAdapter.load_entity_by_gid(
            all_connectivities[0][2])
        self.connectivity.number_of_regions = 74
        self.context_model_param = ContextModelParameters(
            self.connectivity, self.default_model)

    def tearDown(self):
        """
        Reset the database when test is done.
        """
        self.delete_project_folders()

    def test_load_model_for_connectivity_node(self):
        """
        Tests default parameters are loaded in BURST region model interface
        """
        self.context_model_param.load_model_for_connectivity_node(0)
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_default_values(model_0)
        self._check_model_params_for_default_values(
            self.context_model_param._phase_plane.model)

        self.context_model_param.load_model_for_connectivity_node(1)
        model_1 = self.context_model_param._get_model_for_region(1)
        self._check_model_params_for_default_values(model_1)
        self._check_model_params_for_default_values(
            self.context_model_param._phase_plane.model)

        self._update_all_model_params(1)
        model_1 = self.context_model_param._get_model_for_region(1)
        self._check_model_params_for_updated_values(model_1)
        self._check_model_params_for_updated_values(
            self.context_model_param._phase_plane.model)

        self.context_model_param.load_model_for_connectivity_node(0)
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_default_values(model_0)
        self._check_model_params_for_default_values(
            self.context_model_param._phase_plane.model)

    def test_update_model_parameter(self):
        """
        Tests parameters update correctly in BURST region model interface
        """
        self.context_model_param.load_model_for_connectivity_node(0)
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_default_values(model_0)
        self._check_model_params_for_default_values(
            self.context_model_param._phase_plane.model)

        self._update_all_model_params(0)
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_updated_values(model_0)
        self._check_model_params_for_updated_values(
            self.context_model_param._phase_plane.model)

    def test_reset_model_parameters_for_node(self):
        """
        Tests parameters are reset correctly in BURST region model interface
        """
        self.context_model_param.load_model_for_connectivity_node(0)
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_default_values(model_0)
        self._check_model_params_for_default_values(
            self.context_model_param._phase_plane.model)

        self._update_all_model_params(0)
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_updated_values(model_0)
        self._check_model_params_for_updated_values(
            self.context_model_param._phase_plane.model)

        self.context_model_param.reset_model_parameters_for_nodes([0])
        model_0 = self.context_model_param._get_model_for_region(0)
        self._check_model_params_for_default_values(model_0)
        #because we reset a list of nodes we do not update the phase plane
        self._check_model_params_for_updated_values(
            self.context_model_param._phase_plane.model)

    def test_get_values_for_parameter(self):
        """
        Tests method `ContextModelParameters.get_values_for_parameter(...)` works as expected
        """
        model_params = self.context_model_param.model_parameter_names
        for param in model_params:
            self.assertEqual(
                str(getattr(self.default_model, param).tolist()),
                self.context_model_param.get_values_for_parameter(param))

        self._update_all_model_params(0)
        for param in model_params:
            value = self.START + self.INCREMENT
            expected_list = [
                float(getattr(self.default_model, param)[0])
                for i in range(self.connectivity.number_of_regions)
            ]
            expected_list[0] = value
            self.assertEqual(
                str(expected_list),
                self.context_model_param.get_values_for_parameter(param))

    #############  Methods below are helper methods for testing #############

    def _update_all_model_params(self, connectivity_node_index):
        self.context_model_param.load_model_for_connectivity_node(
            connectivity_node_index)
        model_params = self.context_model_param.model_parameter_names
        for param in model_params:
            value = self.START + self.INCREMENT
            self.context_model_param.update_model_parameter(
                connectivity_node_index, param, value)

    def _check_model_params_for_default_values(self, model_to_check):
        model_params = self.context_model_param.model_parameter_names
        for param in model_params:
            self.assertEqual(getattr(self.default_model, param),
                             getattr(model_to_check, param),
                             "The parameters should be equal.")

    def _check_model_params_for_updated_values(self, model_to_check):
        model_params = self.context_model_param.model_parameter_names
        for param in model_params:
            value = self.START + self.INCREMENT
            self.assertEqual(numpy.array([value]),
                             getattr(model_to_check, param),
                             "The parameters should be equal.")
Exemplo n.º 5
0
class FlowServiceTest(TransactionalTestCase):
    """
    This class contains tests for the tvb.core.services.flowservice module.
    """


    def setUp(self):
        """
        Reset the database before each test.
        """
        #        self.reset_database()
        self.flow_service = FlowService()
        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(admin=self.test_user)
        ### Insert some starting data in the database.
        categ1 = model.AlgorithmCategory('one', True)
        self.categ1 = dao.store_entity(categ1)
        categ2 = model.AlgorithmCategory('two', rawinput=True)
        self.categ2 = dao.store_entity(categ2)

        algo = model.AlgorithmGroup("test_module1", "classname1", categ1.id)
        self.algo1 = dao.store_entity(algo)
        algo = model.AlgorithmGroup("test_module2", "classname2", categ2.id)
        dao.store_entity(algo)
        algo = model.AlgorithmGroup("tvb_test.core.services.flowservice_test", "ValidTestAdapter", categ2.id)
        adapter = dao.store_entity(algo)

        algo = model.Algorithm(adapter.id, 'ident', name='', req_data='', param_name='', output='')
        self.algo_inst = dao.store_entity(algo)
        algo = model.AlgorithmGroup("test_module3", "classname3", categ1.id)
        dao.store_entity(algo)
        algo = model.Algorithm(self.algo1.id, 'id', name='', req_data='', param_name='', output='')
        self.algo_inst = dao.store_entity(algo)


    def test_read_algorithm_categories(self):
        """
        Read algorithm categories when they exist in the database.
        """
        categories = self.flow_service.read_algorithm_categories()
        self.assertEqual(len(categories), 8)
        self.assertTrue(self.categ1 in categories, "Missing category")
        self.assertTrue(self.categ2 in categories, "Missing category")


    def test_groups_for_categories(self):
        """
        Test getting algorithms for specific categories.
        """
        category1 = self.flow_service.get_groups_for_categories([self.categ1])
        category2 = self.flow_service.get_groups_for_categories([self.categ2])
        dummy = model.AlgorithmCategory('dummy', rawinput=True)
        dummy.id = 999
        unexisting_cat = self.flow_service.get_groups_for_categories([dummy])
        self.assertEqual(len(category1), 2)
        for algorithm in category1:
            if algorithm.module not in ["test_module1", "test_module3"]:
                self.fail("Some invalid data retrieved")
        for algorithm in category2:
            if algorithm.module not in ["test_module2", "tvb_test.core.services.flowservice_test"]:
                self.fail("Some invalid data retrieved")
        self.assertEqual(len(category2), 2)
        self.assertEqual(len(unexisting_cat), 0)


    def test_get_broup_by_identifier(self):
        """
        Test for the get_algorithm_by_identifier.
        """
        algo_ret = self.flow_service.get_algo_group_by_identifier(self.algo1.id)
        self.assertEqual(algo_ret.id, self.algo1.id, "ID-s are different!")
        self.assertEqual(algo_ret.module, self.algo1.module, "Modules are different!")
        self.assertEqual(algo_ret.fk_category, self.algo1.fk_category, "Categories are different!")
        self.assertEqual(algo_ret.classname, self.algo1.classname, "Class names are different!")


    def test_build_adapter_instance(self):
        """
        Test standard flow for building an adapter instance.
        """
        module = "tvb_test.core.services.flowservice_test"
        class_name = "ValidTestAdapter"
        algo_group = dao.find_group(module, class_name)
        adapter = ABCAdapter.build_adapter(algo_group)
        self.assertTrue(isinstance(adapter, ABCSynchronous), "Something went wrong with valid data!")


    def test_build_adapter_invalid(self):
        """
        Test flow for trying to build an adapter that does not inherit from ABCAdapter.
        """
        module = "tvb_test.core.services.flowservice_test"
        class_name = "InvalidTestAdapter"
        group = dao.find_group(module, class_name)
        self.assertRaises(OperationException, self.flow_service.build_adapter_instance, group)


    def test_prepare_adapter(self):
        """
        Test preparation of an adapter.
        """
        module = "tvb_test.core.services.flowservice_test"
        class_name = "ValidTestAdapter"
        algo_group = dao.find_group(module, class_name)
        group, interface = self.flow_service.prepare_adapter(self.test_project.id, algo_group)
        self.assertTrue(isinstance(group, model.AlgorithmGroup), "Something went wrong with valid data!")
        self.assertTrue("name" in interface[0], "Bad interface created!")
        self.assertEquals(interface[0]["name"], "test", "Bad interface!")
        self.assertTrue("type" in interface[0], "Bad interface created!")
        self.assertEquals(interface[0]["type"], "int", "Bad interface!")
        self.assertTrue("default" in interface[0], "Bad interface created!")
        self.assertEquals(interface[0]["default"], "0", "Bad interface!")


    def test_fire_operation(self):
        """
        Test preparation of an adapter and launch mechanism.
        """
        module = "tvb_test.core.services.flowservice_test"
        class_name = "ValidTestAdapter"
        algo_group = dao.find_group(module, class_name)
        adapter = self.flow_service.build_adapter_instance(algo_group)
        data = {"test": 5}
        result = self.flow_service.fire_operation(adapter, self.test_user, self.test_project.id,
                                                  ABCAdapter.LAUNCH_METHOD, **data)
        self.assertTrue(result.endswith("has finished."), "Operation fail")


    def test_get_filtered_by_column(self):
        """
        Test the filter function when retrieving dataTypes with a filter
        after a column from a class specific table (e.g. DATA_arraywrapper).
        """
        operation_1 = TestFactory.create_operation(test_user=self.test_user, test_project=self.test_project)
        operation_2 = TestFactory.create_operation(test_user=self.test_user, test_project=self.test_project)

        one_dim_array = numpy.arange(5)
        two_dim_array = numpy.array([[1, 2], [2, 3], [1, 4]])
        self._store_float_array(one_dim_array, "John Doe 1", operation_1.id)
        self._store_float_array(one_dim_array, "John Doe 2", operation_1.id)
        self._store_float_array(two_dim_array, "John Doe 3", operation_2.id)

        inserted_data = self.flow_service.get_available_datatypes(self.test_project.id,
                                                                  "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(inserted_data), 3, "Problems with inserting data")
        first_filter = FilterChain(fields=[FilterChain.datatype + '._nr_dimensions'], operations=["=="], values=[1])
        filtered_data = self.flow_service.get_available_datatypes(self.test_project.id,
                                                                  "tvb.datatypes.arrays.MappedArray", first_filter)
        self.assertEqual(len(filtered_data), 2, "Data was not filtered")

        second_filter = FilterChain(fields=[FilterChain.datatype + '._nr_dimensions'], operations=["=="], values=[2])
        filtered_data = self.flow_service.get_available_datatypes(self.test_project.id,
                                                                  "tvb.datatypes.arrays.MappedArray", second_filter)
        self.assertEqual(len(filtered_data), 1, "Data was not filtered")
        self.assertEqual(filtered_data[0][3], "John Doe 3")

        third_filter = FilterChain(fields=[FilterChain.datatype + '._length_1d'], operations=["=="], values=[3])
        filtered_data = self.flow_service.get_available_datatypes(self.test_project.id,
                                                                  "tvb.datatypes.arrays.MappedArray", third_filter)
        self.assertEqual(len(filtered_data), 1, "Data was not filtered correct")
        self.assertEqual(filtered_data[0][3], "John Doe 3")
        try:
            if os.path.exists('One_dim.txt'):
                os.remove('One_dim.txt')
            if os.path.exists('Two_dim.txt'):
                os.remove('Two_dim.txt')
            if os.path.exists('One_dim-1.txt'):
                os.remove('One_dim-1.txt')
        except Exception:
            pass


    @staticmethod
    def _store_float_array(array_data, subject_name, operation_id):
        """Create Float Array and DB persist it"""
        datatype_inst = MappedArray(user_tag_1=subject_name)
        datatype_inst.set_operation_id(operation_id)
        datatype_inst.array_data = array_data
        datatype_inst.type = "MappedArray"
        datatype_inst.module = "tvb.datatypes.arrays"
        datatype_inst.subject = subject_name
        datatype_inst.state = "RAW"
        dao.store_entity(datatype_inst)


    def test_get_filtered_datatypes(self):
        """
        Test the filter function when retrieving dataTypes.
        """
        #Create some test operations
        start_dates = [datetime.now(),
                       datetime.strptime("08-06-2010", "%m-%d-%Y"),
                       datetime.strptime("07-21-2010", "%m-%d-%Y"),
                       datetime.strptime("05-06-2010", "%m-%d-%Y"),
                       datetime.strptime("07-21-2011", "%m-%d-%Y")]
        end_dates = [datetime.now(),
                     datetime.strptime("08-12-2010", "%m-%d-%Y"),
                     datetime.strptime("08-12-2010", "%m-%d-%Y"),
                     datetime.strptime("08-12-2011", "%m-%d-%Y"),
                     datetime.strptime("08-12-2011", "%m-%d-%Y")]
        for i in range(5):
            operation = model.Operation(self.test_user.id, self.test_project.id, self.algo_inst.id, 'test params',
                                        status="FINISHED", start_date=start_dates[i], completion_date=end_dates[i])
            operation = dao.store_entity(operation)
            storage_path = FilesHelper().get_project_folder(self.test_project, str(operation.id))
            if i < 4:
                datatype_inst = Datatype1()
                datatype_inst.type = "Datatype1"
                datatype_inst.subject = "John Doe" + str(i)
                datatype_inst.state = "RAW"
                datatype_inst.set_operation_id(operation.id)
                dao.store_entity(datatype_inst)
            else:
                for _ in range(2):
                    datatype_inst = Datatype2()
                    datatype_inst.storage_path = storage_path
                    datatype_inst.type = "Datatype2"
                    datatype_inst.subject = "John Doe" + str(i)
                    datatype_inst.state = "RAW"
                    datatype_inst.string_data = ["data"]
                    datatype_inst.set_operation_id(operation.id)
                    dao.store_entity(datatype_inst)

        returned_data = self.flow_service.get_available_datatypes(self.test_project.id,
                                                                  "tvb_test.datatypes.datatype1.Datatype1")
        for row in returned_data:
            if row[1] != 'Datatype1':
                self.fail("Some invalid data was returned!")
        self.assertEqual(4, len(returned_data), "Invalid length of result")

        filter_op = FilterChain(fields=[FilterChain.datatype + ".state", FilterChain.operation + ".start_date"],
                                values=["RAW", datetime.strptime("08-01-2010", "%m-%d-%Y")], operations=["==", ">"])
        returned_data = self.flow_service.get_available_datatypes(self.test_project.id,
                                                                  "tvb_test.datatypes.datatype1.Datatype1", filter_op)
        returned_subjects = [one_data[3] for one_data in returned_data]

        if "John Doe0" not in returned_subjects or "John Doe1" not in returned_subjects or len(returned_subjects) != 2:
            self.fail("DataTypes were not filtered properly!")
Exemplo n.º 6
0
class ImportServiceTest(TransactionalTestCase):
    """
    This class contains tests for the tvb.core.services.flowservice module.
    """
    def setUp(self):
        """
        Reset the database before each test.
        """
        self.import_service = ImportService()
        self.flow_service = FlowService()
        self.project_service = ProjectService()

        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(self.test_user,
                                                       name="GeneratedProject",
                                                       description="test_desc")
        self.operation = TestFactory.create_operation(
            test_user=self.test_user, test_project=self.test_project)
        self.adapter_instance = TestFactory.create_adapter(
            test_project=self.test_project)
        TestFactory.import_cff(test_user=self.test_user,
                               test_project=self.test_project)
        self.zip_path = None

    def tearDown(self):
        """
        Reset the database when test is done.
        """
        ### Delete TEMP folder
        if os.path.exists(cfg.TVB_TEMP_FOLDER):
            shutil.rmtree(cfg.TVB_TEMP_FOLDER)

        ### Delete folder where data was exported
        if os.path.exists(self.zip_path):
            shutil.rmtree(os.path.split(self.zip_path)[0])

        self.delete_project_folders()

    def test_import_export(self):
        """
        Test the import/export mechanism for a project structure.
        The project contains the following data types: Connectivity, Surface, MappedArray and ValueWrapper.
        """
        result = self.get_all_datatypes()
        expected_results = {}
        for one_data in result:
            expected_results[one_data.gid] = (one_data.module, one_data.type)

        #create an array mapped in DB
        data = {'param_1': 'some value'}
        OperationService().initiate_prelaunch(self.operation,
                                              self.adapter_instance, {},
                                              **data)
        inserted = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(inserted), 2, "Problems when inserting data")

        #create a value wrapper
        value_wrapper = self._create_value_wrapper()
        result = dao.get_filtered_operations(self.test_project.id, None)
        self.assertEqual(
            len(result), 2, "Should be two operations before export and not " +
            str(len(result)) + " !")
        self.zip_path = ExportManager().export_project(self.test_project)
        self.assertTrue(self.zip_path is not None, "Exported file is none")

        # Now remove the original project
        self.project_service.remove_project(self.test_project.id)
        result, lng_ = self.project_service.retrieve_projects_for_user(
            self.test_user.id)
        self.assertEqual(0, len(result), "Project Not removed!")
        self.assertEqual(0, lng_, "Project Not removed!")

        # Now try to import again project
        self.import_service.import_project_structure(self.zip_path,
                                                     self.test_user.id)
        result = self.project_service.retrieve_projects_for_user(
            self.test_user.id)[0]
        self.assertEqual(len(result), 1, "There should be only one project.")
        self.assertEqual(result[0].name, "GeneratedProject",
                         "The project name is not correct.")
        self.assertEqual(result[0].description, "test_desc",
                         "The project description is not correct.")
        self.test_project = result[0]

        result = dao.get_filtered_operations(self.test_project.id, None)

        #1 op. - import project; 1 op. - save the array wrapper
        self.assertEqual(
            len(result), 2, "Should be two operations after export and not " +
            str(len(result)) + " !")
        for gid in expected_results:
            datatype = dao.get_datatype_by_gid(gid)
            self.assertEqual(datatype.module, expected_results[gid][0],
                             'DataTypes not imported correctly')
            self.assertEqual(datatype.type, expected_results[gid][1],
                             'DataTypes not imported correctly')
        #check the value wrapper
        new_val = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.mapped_values.ValueWrapper")
        self.assertEqual(len(new_val), 1, "One !=" + str(len(new_val)))
        new_val = ABCAdapter.load_entity_by_gid(new_val[0][2])
        self.assertEqual(value_wrapper.data_value, new_val.data_value,
                         "Data value incorrect")
        self.assertEqual(value_wrapper.data_type, new_val.data_type,
                         "Data type incorrect")
        self.assertEqual(value_wrapper.data_name, new_val.data_name,
                         "Data name incorrect")

    def test_import_export_existing(self):
        """
        Test the import/export mechanism for a project structure.
        The project contains the following data types: Connectivity, Surface, MappedArray and ValueWrapper.
        """
        result = self.get_all_datatypes()
        expected_results = {}
        for one_data in result:
            expected_results[one_data.gid] = (one_data.module, one_data.type)

        #create an array mapped in DB
        data = {'param_1': 'some value'}
        OperationService().initiate_prelaunch(self.operation,
                                              self.adapter_instance, {},
                                              **data)
        inserted = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(inserted), 2, "Problems when inserting data")

        #create a value wrapper
        self._create_value_wrapper()
        result = dao.get_filtered_operations(self.test_project.id, None)
        self.assertEqual(
            len(result), 2, "Should be two operations before export and not " +
            str(len(result)) + " !")
        self.zip_path = ExportManager().export_project(self.test_project)
        self.assertTrue(self.zip_path is not None, "Exported file is none")

        try:
            self.import_service.import_project_structure(
                self.zip_path, self.test_user.id)
            self.fail("Invalid import as the project already exists!")
        except ProjectImportException:
            #OK, do nothing. The project already exists.
            pass

    def _create_timeseries(self):
        """Launch adapter to persist a TimeSeries entity"""
        activity_data = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                     [10, 11, 12]])
        time_data = numpy.array([1, 2, 3])
        storage_path = FilesHelper().get_project_folder(self.test_project)
        time_series = TimeSeries(time_files=None,
                                 activity_files=None,
                                 max_chunk=10,
                                 maxes=None,
                                 mins=None,
                                 data_shape=numpy.shape(activity_data),
                                 storage_path=storage_path,
                                 label_y="Time",
                                 time_data=time_data,
                                 data_name='TestSeries',
                                 activity_data=activity_data,
                                 sample_period=10.0)
        self._store_entity(time_series, "TimeSeries",
                           "tvb.datatypes.time_series")
        timeseries = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.time_series.TimeSeries")
        self.assertEqual(len(timeseries), 1, "Should be only one TimeSeries")

    def _create_value_wrapper(self):
        """Persist ValueWrapper"""
        value_ = ValueWrapper(data_value=5.0, data_name="my_value")
        self._store_entity(value_, "ValueWrapper",
                           "tvb.datatypes.mapped_values")
        valuew = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.mapped_values.ValueWrapper")
        self.assertEqual(len(valuew), 1, "Should be only one value wrapper")
        return ABCAdapter.load_entity_by_gid(valuew[0][2])

    def _store_entity(self, entity, type_, module):
        """Launch adapter to store a create a persistent DataType."""
        entity.type = type_
        entity.module = module
        entity.subject = "John Doe"
        entity.state = "RAW_STATE"
        entity.set_operation_id(self.operation.id)
        adapter_instance = StoreAdapter([entity])
        OperationService().initiate_prelaunch(self.operation, adapter_instance,
                                              {})
Exemplo n.º 7
0
class RemoveTest(TransactionalTestCase):
    """
    This class contains tests for the tvb.core.services.flowservice module.
    """
    def setUp(self):
        """
        Prepare the database before each test.
        """
        self.import_service = ImportService()
        self.flow_service = FlowService()
        self.project_service = ProjectService()

        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(self.test_user)
        self.operation = TestFactory.create_operation(
            test_user=self.test_user, test_project=self.test_project)
        self.adapter_instance = TestFactory.create_adapter(
            test_project=self.test_project)

        result = self.get_all_datatypes()
        self.assertEqual(len(result), 0, "There should be no data type in DB")
        TestFactory.import_cff(test_user=self.test_user,
                               test_project=self.test_project)

    def tearDown(self):
        """
        Reset the database when test is done.
        """
        self.delete_project_folders()

    def test_remove_used_connectivity(self):
        """
        Tests the remove of a connectivity which is used by other data types
        """
        connectivities = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.connectivity.Connectivity")
        self.assertEqual(len(connectivities), 1,
                         "Problems when inserting data")
        gid = connectivities[0][2]
        try:
            self.project_service.remove_datatype(self.test_project.id, gid)
            self.fail(
                "The connectivity is still used. It should not be possible to remove it."
            )
        except RemoveDataTypeException:
            #OK, do nothing
            pass
        res = dao.get_datatype_by_gid(gid)
        self.assertEqual(connectivities[0][0], res.id,
                         "Used connectivity removed")

    def test_remove_used_surface(self):
        """
        Tries to remove an used surface
        """
        mapping = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.surfaces.RegionMapping")
        self.assertEquals(len(mapping), 1, "There should be one Mapping.")
        mapping_gid = mapping[0][2]
        mapping = ABCAdapter.load_entity_by_gid(mapping_gid)
        #delete surface
        surfaces = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.surfaces.CorticalSurface")
        self.assertTrue(len(surfaces) > 0, "At least one Cortex expected")
        surface = dao.get_datatype_by_gid(mapping.surface.gid)
        self.assertEqual(surface.gid, mapping.surface.gid,
                         "The surfaces should have the same GID")
        try:
            self.project_service.remove_datatype(self.test_project.id,
                                                 surface.gid)
            self.fail(
                "The surface is still used by a RegionMapping. It should not be possible to remove it."
            )
        except RemoveDataTypeException:
            #OK, do nothing
            pass
        res = dao.get_datatype_by_gid(surface.gid)
        self.assertEqual(surface.id, res.id, "A used surface was deleted")

    def _remove_entity(self, data_name, before_number):
        """
        Try to remove entity. Fail otherwise.
        """
        gid_list = self.flow_service.get_available_datatypes(
            self.test_project.id, data_name)
        self.assertEquals(len(gid_list), before_number)
        for i in xrange(len(gid_list)):
            data_gid = gid_list[i][2]
            self.project_service.remove_datatype(self.test_project.id,
                                                 data_gid)
            res = dao.get_datatype_by_gid(data_gid)
            self.assertEqual(None, res, "The entity was not deleted")

    def test_happyflow_removedatatypes(self):
        """
        Tests the happy flow for the deletion multiple entities.
        They are tested together because they depend on each other and they
        have to be removed in a certain order.
        """
        self._remove_entity("tvb.datatypes.surfaces.LocalConnectivity", 1)
        self._remove_entity("tvb.datatypes.surfaces.RegionMapping", 1)
        ### Remove Surfaces
        # SqlAlchemy has no uniform way to retrieve Surface as base (wild-character for polymorphic_identity)
        self._remove_entity("tvb.datatypes.surfaces_data.SurfaceData", 2)
        ### Remove a Connectivity
        self._remove_entity("tvb.datatypes.connectivity.Connectivity", 1)

    def test_remove_time_series(self):
        """
        Tests the happy flow for the deletion of a time series.
        """
        datatypes = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.time_series.TimeSeries")
        self.assertEqual(len(datatypes), 0, "There should be no time series")
        self._create_timeseries()
        series = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.time_series.TimeSeries")
        self.assertEqual(len(series), 1,
                         "There should be only one time series")
        self.project_service.remove_datatype(self.test_project.id,
                                             series[0][2])
        res = dao.get_datatype_by_gid(series[0][2])
        self.assertEqual(None, res, "The time series was not deleted.")

    def test_remove_array_wrapper(self):
        """
        Tests the happy flow for the deletion of an array wrapper.
        """
        array_wrappers = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(array_wrappers), 1, "There should be no array")
        data = {'param_1': 'some value'}
        OperationService().initiate_prelaunch(self.operation,
                                              self.adapter_instance, {},
                                              **data)
        array_wrappers = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.arrays.MappedArray")
        self.assertEqual(len(array_wrappers), 2, "Should be only one array")
        arraygid = array_wrappers[0][2]
        self.project_service.remove_datatype(self.test_project.id, arraygid)
        res = dao.get_datatype_by_gid(arraygid)
        self.assertEqual(None, res, "The array wrapper was not deleted.")

    def test_remove_value_wrapper(self):
        """
        Test the deletion of a value wrapper dataType
        """
        wrappers = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.mapped_values.ValueWrapper")
        self.assertEqual(len(wrappers), 0, "There should be no value wrapper")
        value_wrapper = self._create_value_wrapper()
        self.project_service.remove_datatype(self.test_project.id,
                                             value_wrapper.gid)
        res = dao.get_datatype_by_gid(value_wrapper.gid)
        self.assertEqual(None, res, "The value wrapper was not deleted.")

    def _create_timeseries(self):
        """Launch adapter to persist a TimeSeries entity"""
        storage_path = FilesHelper().get_project_folder(
            self.test_project, str(self.operation.id))

        time_series = TimeSeries()
        time_series.sample_period = 10.0
        time_series.start_time = 0.0
        time_series.storage_path = storage_path
        time_series.write_data_slice(numpy.array([1.0, 2.0, 3.0]))
        time_series.close_file()
        time_series.sample_period_unit = 'ms'

        self._store_entity(time_series, "TimeSeries",
                           "tvb.datatypes.time_series")
        timeseries = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.time_series.TimeSeries")
        self.assertEqual(len(timeseries), 1, "Should be only one TimeSeries")

    def _create_value_wrapper(self):
        """Persist ValueWrapper"""
        value_ = ValueWrapper(data_value=5.0, data_name="my_value")
        self._store_entity(value_, "ValueWrapper",
                           "tvb.datatypes.mapped_values")
        valuew = self.flow_service.get_available_datatypes(
            self.test_project.id, "tvb.datatypes.mapped_values.ValueWrapper")
        self.assertEqual(len(valuew), 1, "Should be only one value wrapper")
        return ABCAdapter.load_entity_by_gid(valuew[0][2])

    def _store_entity(self, entity, type_, module):
        """Launch adapter to store a create a persistent DataType."""
        entity.type = type_
        entity.module = module
        entity.subject = "John Doe"
        entity.state = "RAW_STATE"
        entity.set_operation_id(self.operation.id)
        adapter_instance = StoreAdapter([entity])
        OperationService().initiate_prelaunch(self.operation, adapter_instance,
                                              {})