class TestSimulationResource(RestResourceTest): def transactional_setup_method(self): self.test_user = TestFactory.create_user('Rest_User') self.test_project = TestFactory.create_project( self.test_user, 'Rest_Project', users=[self.test_user.id]) self.simulation_resource = FireSimulationResource() self.storage_interface = StorageInterface() def test_server_fire_simulation_inexistent_gid(self, mocker): self._mock_user(mocker) project_gid = "inexistent-gid" dummy_file = FileStorage(BytesIO(b"test"), 'test.zip') # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = { RequestFileKey.SIMULATION_FILE_KEY.value: dummy_file } with pytest.raises(InvalidIdentifierException): self.simulation_resource.post(project_gid=project_gid) def test_server_fire_simulation_no_file(self, mocker): self._mock_user(mocker) # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = {} with pytest.raises(InvalidIdentifierException): self.simulation_resource.post(project_gid='') def test_server_fire_simulation_bad_extension(self, mocker): self._mock_user(mocker) dummy_file = FileStorage(BytesIO(b"test"), 'test.txt') # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = { RequestFileKey.SIMULATION_FILE_KEY.value: dummy_file } with pytest.raises(InvalidIdentifierException): self.simulation_resource.post(project_gid='') def test_server_fire_simulation(self, mocker, connectivity_factory): self._mock_user(mocker) input_folder = self.storage_interface.get_project_folder( self.test_project.name) sim_dir = os.path.join(input_folder, 'test_sim') if not os.path.isdir(sim_dir): os.makedirs(sim_dir) simulator = SimulatorAdapterModel() simulator.connectivity = connectivity_factory().gid h5.store_view_model(simulator, sim_dir) zip_filename = os.path.join(input_folder, RequestFileKey.SIMULATION_FILE_NAME.value) self.storage_interface.zip_folder(zip_filename, sim_dir) # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') fp = open(zip_filename, 'rb') request_mock.files = { RequestFileKey.SIMULATION_FILE_KEY.value: FileStorage(fp, os.path.basename(zip_filename)) } def launch_sim(self, user_id, project, algorithm, zip_folder_path, simulator_file): return Operation('', '', '', {}) # Mock simulation launch and current user mocker.patch.object(SimulatorService, 'prepare_simulation_on_server', launch_sim) operation_gid, status = self.simulation_resource.post( project_gid=self.test_project.gid) fp.close() assert type(operation_gid) is str assert status == 201 def transactional_teardown_method(self): self.storage_interface.remove_project_structure(self.test_project.name)
def transactional_teardown_method(self): storage_interface = StorageInterface() storage_interface.remove_project_structure( self.test_project_with_data.name) storage_interface.remove_project_structure( self.test_project_without_data.name)
class TestProjectService(TransactionalTestCase): """ This class contains tests for the tvb.core.services.project_service module. """ def transactional_setup_method(self): """ Reset the database before each test. """ self.project_service = ProjectService() self.storage_interface = StorageInterface() self.test_user = TestFactory.create_user() def transactional_teardown_method(self): """ Remove project folders and clean up database. """ created_projects = dao.get_projects_for_user(self.test_user.id) for project in created_projects: self.storage_interface.remove_project_structure(project.name) self.delete_project_folders() def test_create_project_happy_flow(self): user1 = TestFactory.create_user('test_user1') user2 = TestFactory.create_user('test_user2') initial_projects = dao.get_projects_for_user(self.test_user.id) assert len(initial_projects) == 0, "Database reset probably failed!" TestFactory.create_project(self.test_user, 'test_project', "description", users=[user1.id, user2.id]) resulting_projects = dao.get_projects_for_user(self.test_user.id) assert len( resulting_projects) == 1, "Project with valid data not inserted!" project = resulting_projects[0] assert project.name == "test_project", "Invalid retrieved project name" assert project.description == "description", "Description do no match" users_for_project = dao.get_members_of_project(project.id) for user in users_for_project: assert user.id in [user1.id, user2.id, self.test_user.id], "Users not stored properly." assert os.path.exists( os.path.join(TvbProfile.current.TVB_STORAGE, StorageInterface.PROJECTS_FOLDER, "test_project")), "Folder for project was not created" def test_create_project_empty_name(self): """ Creating a project with an empty name. """ data = dict(name="", description="test_description", users=[]) initial_projects = dao.get_projects_for_user(self.test_user.id) assert len(initial_projects) == 0, "Database reset probably failed!" with pytest.raises(ProjectServiceException): self.project_service.store_project(self.test_user, True, None, **data) def test_edit_project_happy_flow(self): """ Standard flow for editing an existing project. """ selected_project = TestFactory.create_project(self.test_user, 'test_proj') proj_root = self.storage_interface.get_project_folder( selected_project.name) initial_projects = dao.get_projects_for_user(self.test_user.id) assert len( initial_projects) == 1, "Database initialization probably failed!" edited_data = dict(name="test_project", description="test_description", users=[]) edited_project = self.project_service.store_project( self.test_user, False, selected_project.id, **edited_data) assert not os.path.exists(proj_root), "Previous folder not deleted" proj_root = self.storage_interface.get_project_folder( edited_project.name) assert os.path.exists(proj_root), "New folder not created!" assert selected_project.name != edited_project.name, "Project was no changed!" def test_edit_project_unexisting(self): """ Trying to edit an un-existing project. """ selected_project = TestFactory.create_project(self.test_user, 'test_proj') self.storage_interface.get_project_folder(selected_project.name) initial_projects = dao.get_projects_for_user(self.test_user.id) assert len( initial_projects) == 1, "Database initialization probably failed!" data = dict(name="test_project", description="test_description", users=[]) with pytest.raises(ProjectServiceException): self.project_service.store_project(self.test_user, False, 99, **data) def test_find_project_happy_flow(self): """ Standard flow for finding a project by it's id. """ initial_projects = dao.get_projects_for_user(self.test_user.id) assert len(initial_projects) == 0, "Database reset probably failed!" inserted_project = TestFactory.create_project(self.test_user, 'test_project') assert self.project_service.find_project( inserted_project.id) is not None, "Project not found !" dao_returned_project = dao.get_project_by_id(inserted_project.id) service_returned_project = self.project_service.find_project( inserted_project.id) assert dao_returned_project.id == service_returned_project.id, \ "Data returned from service is different from data returned by DAO." assert dao_returned_project.name == service_returned_project.name, \ "Data returned from service is different than data returned by DAO." assert dao_returned_project.description == service_returned_project.description, \ "Data returned from service is different from data returned by DAO." assert dao_returned_project.members == service_returned_project.members, \ "Data returned from service is different from data returned by DAO." def test_find_project_unexisting(self): """ Searching for an un-existing project. """ data = dict(name="test_project", description="test_description", users=[]) initial_projects = dao.get_projects_for_user(self.test_user.id) assert len(initial_projects) == 0, "Database reset probably failed!" project = self.project_service.store_project(self.test_user, True, None, **data) # fetch a likely non-existing project. Previous project id plus a 'big' offset with pytest.raises(ProjectServiceException): self.project_service.find_project(project.id + 1033) def test_retrieve_projects_for_user(self): """ Test for retrieving the projects for a given user. One page only. """ initial_projects = self.project_service.retrieve_projects_for_user( self.test_user.id)[0] assert len(initial_projects) == 0, "Database was not reset properly!" TestFactory.create_project(self.test_user, 'test_proj') TestFactory.create_project(self.test_user, 'test_proj1') TestFactory.create_project(self.test_user, 'test_proj2') user1 = TestFactory.create_user('another_user') TestFactory.create_project(user1, 'test_proj3') projects = self.project_service.retrieve_projects_for_user( self.test_user.id)[0] assert len(projects) == 3, "Projects not retrieved properly!" for project in projects: assert project.name != "test_project3", "This project should not have been retrieved" def test_retrieve_1project_3usr(self): """ One user as admin, two users as members, getting projects for admin and for any of the members should return one. """ member1 = TestFactory.create_user("member1") member2 = TestFactory.create_user("member2") TestFactory.create_project(self.test_user, 'Testproject', users=[member1.id, member2.id]) projects = self.project_service.retrieve_projects_for_user( self.test_user.id, 1)[0] assert len(projects) == 1, "Projects not retrieved properly!" projects = self.project_service.retrieve_projects_for_user( member1.id, 1)[0] assert len(projects) == 1, "Projects not retrieved properly!" projects = self.project_service.retrieve_projects_for_user( member2.id, 1)[0] assert len(projects) == 1, "Projects not retrieved properly!" def test_retrieve_3projects_3usr(self): """ Three users, 3 projects. Structure of db: proj1: {admin: user1, members: [user2, user3]} proj2: {admin: user2, members: [user1]} proj3: {admin: user3, members: [user1, user2]} Check valid project returns for all the users. """ member1 = TestFactory.create_user("member1") member2 = TestFactory.create_user("member2") member3 = TestFactory.create_user("member3") TestFactory.create_project(member1, 'TestProject1', users=[member2.id, member3.id]) TestFactory.create_project(member2, 'TestProject2', users=[member1.id]) TestFactory.create_project(member3, 'TestProject3', users=[member1.id, member2.id]) projects = self.project_service.retrieve_projects_for_user( member1.id, 1)[0] assert len(projects) == 3, "Projects not retrieved properly!" projects = self.project_service.retrieve_projects_for_user( member2.id, 1)[0] assert len(projects) == 3, "Projects not retrieved properly!" projects = self.project_service.retrieve_projects_for_user( member3.id, 1)[0] assert len(projects) == 2, "Projects not retrieved properly!" def test_retrieve_projects_random(self): """ Generate a large number of users/projects, and validate the results. """ ExtremeTestFactory.generate_users(NR_USERS, MAX_PROJ_PER_USER) for i in range(NR_USERS): current_user = dao.get_user_by_name("gen" + str(i)) expected_projects = ExtremeTestFactory.VALIDATION_DICT[ current_user.id] if expected_projects % PROJECTS_PAGE_SIZE == 0: expected_pages = expected_projects / PROJECTS_PAGE_SIZE exp_proj_per_page = PROJECTS_PAGE_SIZE else: expected_pages = expected_projects // PROJECTS_PAGE_SIZE + 1 exp_proj_per_page = expected_projects % PROJECTS_PAGE_SIZE if expected_projects == 0: expected_pages = 0 exp_proj_per_page = 0 projects, pages = self.project_service.retrieve_projects_for_user( current_user.id, expected_pages) assert len(projects) == exp_proj_per_page, "Projects not retrieved properly! Expected:" + \ str(exp_proj_per_page) + "but got:" + str(len(projects)) assert pages == expected_pages, "Pages not retrieved properly!" for folder in os.listdir(TvbProfile.current.TVB_STORAGE): full_path = os.path.join(TvbProfile.current.TVB_STORAGE, folder) if os.path.isdir(full_path) and folder.startswith('Generated'): shutil.rmtree(full_path) def test_retrieve_projects_page2(self): """ Test for retrieving the second page projects for a given user. """ for i in range(PROJECTS_PAGE_SIZE + 3): TestFactory.create_project(self.test_user, 'test_proj' + str(i)) projects, pages = self.project_service.retrieve_projects_for_user( self.test_user.id, 2) assert len(projects) == (PROJECTS_PAGE_SIZE + 3 ) % PROJECTS_PAGE_SIZE, "Pagination inproper." assert pages == 2, 'Wrong number of pages retrieved.' def test_retrieve_projects_and_del(self): """ Test for retrieving the second page projects for a given user. """ created_projects = [] for i in range(PROJECTS_PAGE_SIZE + 1): created_projects.append( TestFactory.create_project(self.test_user, 'test_proj' + str(i))) projects, pages = self.project_service.retrieve_projects_for_user( self.test_user.id, 2) assert len(projects) == (PROJECTS_PAGE_SIZE + 1 ) % PROJECTS_PAGE_SIZE, "Pagination improper." assert pages == (PROJECTS_PAGE_SIZE + 1) // PROJECTS_PAGE_SIZE + 1, 'Wrong number of pages' self.project_service.remove_project(created_projects[1].id) projects, pages = self.project_service.retrieve_projects_for_user( self.test_user.id, 2) assert len(projects) == 0, "Pagination improper." assert pages == 1, 'Wrong number of pages retrieved.' projects, pages = self.project_service.retrieve_projects_for_user( self.test_user.id, 1) assert len(projects) == PROJECTS_PAGE_SIZE, "Pagination improper." assert pages == 1, 'Wrong number of pages retrieved.' def test_empty_project_has_zero_disk_size(self): TestFactory.create_project(self.test_user, 'test_proj') projects, pages = self.project_service.retrieve_projects_for_user( self.test_user.id) assert 0 == projects[0].disk_size assert '0.0 KiB' == projects[0].disk_size_human def test_project_disk_size(self): project1 = TestFactory.create_project(self.test_user, 'test_proj1') zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_66.zip') TestFactory.import_zip_connectivity(self.test_user, project1, zip_path, 'testSubject') project2 = TestFactory.create_project(self.test_user, 'test_proj2') zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_76.zip') TestFactory.import_zip_connectivity(self.test_user, project2, zip_path, 'testSubject') projects = self.project_service.retrieve_projects_for_user( self.test_user.id)[0] assert projects[0].disk_size != projects[ 1].disk_size, "projects should have different size" for project in projects: assert 0 != project.disk_size assert '0.0 KiB' != project.disk_size_human prj_folder = self.storage_interface.get_project_folder( project.name) actual_disk_size = self.storage_interface.compute_recursive_h5_disk_usage( prj_folder) ratio = float(actual_disk_size) / project.disk_size msg = "Real disk usage: %s The one recorded in the db : %s" % ( actual_disk_size, project.disk_size) assert ratio < 1.1, msg def test_get_linkable_projects(self): """ Test for retrieving the projects for a given user. """ initial_projects = self.project_service.retrieve_projects_for_user( self.test_user.id)[0] assert len(initial_projects) == 0, "Database was not reset!" test_proj = [] user1 = TestFactory.create_user("another_user") for i in range(4): test_proj.append( TestFactory.create_project(self.test_user if i < 3 else user1, 'test_proj' + str(i))) operation = TestFactory.create_operation(test_user=self.test_user, test_project=test_proj[0]) datatype = dao.store_entity( model_datatype.DataType(module="test_data", subject="subj1", state="test_state", operation_id=operation.id)) linkable = self.project_service.get_linkable_projects_for_user( self.test_user.id, str(datatype.id))[0] assert len(linkable) == 2, "Wrong count of link-able projects!" proj_names = [project.name for project in linkable] assert test_proj[1].name in proj_names assert test_proj[2].name in proj_names assert not test_proj[3].name in proj_names def test_remove_project_happy_flow(self): """ Standard flow for deleting a project. """ inserted_project = TestFactory.create_project(self.test_user, 'test_proj') project_root = self.storage_interface.get_project_folder( inserted_project.name) projects = dao.get_projects_for_user(self.test_user.id) assert len(projects) == 1, "Initializations failed!" assert os.path.exists(project_root), "Something failed at insert time!" self.project_service.remove_project(inserted_project.id) projects = dao.get_projects_for_user(self.test_user.id) assert len(projects) == 0, "Project was not deleted!" assert not os.path.exists(project_root), "Root folder not deleted!" def test_remove_project_wrong_id(self): """ Flow for deleting a project giving an un-existing id. """ TestFactory.create_project(self.test_user, 'test_proj') projects = dao.get_projects_for_user(self.test_user.id) assert len(projects) == 1, "Initializations failed!" with pytest.raises(ProjectServiceException): self.project_service.remove_project(99) def __check_meta_data(self, expected_meta_data, new_datatype): """Validate Meta-Data""" mapp_keys = { DataTypeOverlayDetails.DATA_SUBJECT: "subject", DataTypeOverlayDetails.DATA_STATE: "state" } for key, value in expected_meta_data.items(): if key in mapp_keys: assert value == getattr(new_datatype, mapp_keys[key]) elif key == DataTypeMetaData.KEY_OPERATION_TAG: if DataTypeMetaData.KEY_OP_GROUP_ID in expected_meta_data: # We have a Group to check op_group = new_datatype.parent_operation.fk_operation_group op_group = dao.get_generic_entity( model_operation.OperationGroup, op_group)[0] assert value == op_group.name else: assert value == new_datatype.parent_operation.user_group def test_remove_project_node(self): """ Test removing of a node from a project. """ inserted_project, gid, op = TestFactory.create_value_wrapper( self.test_user) project_to_link = model_project.Project("Link", self.test_user.id, "descript") project_to_link = dao.store_entity(project_to_link) exact_data = dao.get_datatype_by_gid(gid) assert exact_data is not None, "Initialization problem!" dao.store_entity( model_datatype.Links(exact_data.id, project_to_link.id)) vw_h5_path = h5.path_for_stored_index(exact_data) assert os.path.exists(vw_h5_path) if dao.get_system_user() is None: dao.store_entity( model_operation.User( TvbProfile.current.web.admin.SYSTEM_USER_NAME, TvbProfile.current.web.admin.SYSTEM_USER_NAME, None, None, True, None)) self.project_service._remove_project_node_files( inserted_project.id, gid) assert not os.path.exists(vw_h5_path) exact_data = dao.get_datatype_by_gid(gid) assert exact_data is not None, "Data should still be in DB, because of links" vw_h5_path_new = h5.path_for_stored_index(exact_data) assert os.path.exists(vw_h5_path_new) assert vw_h5_path_new != vw_h5_path self.project_service._remove_project_node_files( project_to_link.id, gid) assert dao.get_datatype_by_gid(gid) is None def test_update_meta_data_simple(self): """ Test the new update metaData for a simple data that is not part of a group. """ inserted_project, gid, _ = TestFactory.create_value_wrapper( self.test_user) new_meta_data = { DataTypeOverlayDetails.DATA_SUBJECT: "new subject", DataTypeOverlayDetails.DATA_STATE: "second_state", DataTypeOverlayDetails.CODE_GID: gid, DataTypeOverlayDetails.CODE_OPERATION_TAG: 'new user group' } self.project_service.update_metadata(new_meta_data) new_datatype = dao.get_datatype_by_gid(gid) self.__check_meta_data(new_meta_data, new_datatype) new_datatype_h5 = h5.h5_file_for_index(new_datatype) assert new_datatype_h5.subject.load( ) == 'new subject', 'UserGroup not updated!' def test_update_meta_data_group(self, test_adapter_factory, datatype_group_factory): """ Test the new update metaData for a group of dataTypes. """ test_adapter_factory(adapter_class=DummyAdapter3) group = datatype_group_factory() op_group_id = group.fk_operation_group new_meta_data = { DataTypeOverlayDetails.DATA_SUBJECT: "new subject", DataTypeOverlayDetails.DATA_STATE: "updated_state", DataTypeOverlayDetails.CODE_OPERATION_GROUP_ID: op_group_id, DataTypeOverlayDetails.CODE_OPERATION_TAG: 'newGroupName' } self.project_service.update_metadata(new_meta_data) datatypes = dao.get_datatype_in_group(op_group_id) for datatype in datatypes: new_datatype = dao.get_datatype_by_id(datatype.id) assert op_group_id == new_datatype.parent_operation.fk_operation_group new_group = dao.get_generic_entity(model_operation.OperationGroup, op_group_id)[0] assert new_group.name == "newGroupName" self.__check_meta_data(new_meta_data, new_datatype) def test_retrieve_project_full(self, dummy_datatype_index_factory): """ Tests full project information is retrieved by method `ProjectService.retrieve_project_full(...)` """ project = TestFactory.create_project(self.test_user) operation = TestFactory.create_operation(test_user=self.test_user, test_project=project) dummy_datatype_index_factory(project=project, operation=operation) dummy_datatype_index_factory(project=project, operation=operation) dummy_datatype_index_factory(project=project, operation=operation) _, ops_nr, operations, pages_no = self.project_service.retrieve_project_full( project.id) assert ops_nr == 1, "DataType Factory should only use one operation to store all it's datatypes." assert pages_no == 1, "DataType Factory should only use one operation to store all it's datatypes." resulted_dts = operations[0]['results'] assert len(resulted_dts) == 3, "3 datatypes should be created." def test_get_project_structure(self, datatype_group_factory, dummy_datatype_index_factory, project_factory, user_factory): """ Tests project structure is as expected and contains all datatypes and created links """ user = user_factory() project1 = project_factory(user, name="TestPS1") project2 = project_factory(user, name="TestPS2") dt_group = datatype_group_factory(project=project1) dt_simple = dummy_datatype_index_factory(state="RAW_DATA", project=project1) # Create 3 DTs directly in Project 2 dummy_datatype_index_factory(state="RAW_DATA", project=project2) dummy_datatype_index_factory(state="RAW_DATA", project=project2) dummy_datatype_index_factory(state="RAW_DATA", project=project2) # Create Links from Project 1 into Project 2 link_ids, expected_links = [], [] link_ids.append(dt_simple.id) expected_links.append(dt_simple.gid) # Prepare links towards a full DT Group, but expecting only the DT_Group in the final tree dts = dao.get_datatype_in_group(datatype_group_id=dt_group.id) link_ids.extend([dt_to_link.id for dt_to_link in dts]) link_ids.append(dt_group.id) expected_links.append(dt_group.gid) # Actually create the links from Prj1 into Prj2 AlgorithmService().create_link(link_ids, project2.id) # Retrieve the raw data used to compose the tree (for easy parsing) dts_in_tree = dao.get_data_in_project(project2.id) dts_in_tree = [dt.gid for dt in dts_in_tree] # Retrieve the tree json (for trivial validations only, as we can not decode) node_json = self.project_service.get_project_structure( project2, None, DataTypeMetaData.KEY_STATE, DataTypeMetaData.KEY_SUBJECT, None) assert len(expected_links) + 3 == len( dts_in_tree), "invalid number of nodes in tree" assert dt_group.gid in dts_in_tree, "DT_Group should be in the Project Tree!" assert dt_group.gid in node_json, "DT_Group should be in the Project Tree JSON!" project_dts = dao.get_datatypes_in_project(project2.id) for dt in project_dts: if dt.fk_datatype_group is not None: assert not dt.gid in node_json, "DTs part of a group should not be" assert not dt.gid in dts_in_tree, "DTs part of a group should not be" else: assert dt.gid in node_json, "Simple DTs and DT_Groups should be" assert dt.gid in dts_in_tree, "Simple DTs and DT_Groups should be" for link_gid in expected_links: assert link_gid in node_json, "Expected Link not present" assert link_gid in dts_in_tree, "Expected Link not present"
class TestHPCSchedulerClient(BaseTestCase): def setup_method(self): self.storage_interface = StorageInterface() self.dir_gid = '123' self.clean_database() self.test_user = TestFactory.create_user() self.test_project = TestFactory.create_project(self.test_user) def _prepare_dummy_files(self, tmpdir): dummy_file1 = os.path.join(str(tmpdir), 'dummy1.txt') open(dummy_file1, 'a').close() dummy_file2 = os.path.join(str(tmpdir), 'dummy2.txt') open(dummy_file2, 'a').close() job_inputs = [dummy_file1, dummy_file2] return job_inputs def test_encrypt_inputs(self, tmpdir): job_inputs = self._prepare_dummy_files(tmpdir) job_encrypted_inputs = self.storage_interface.encrypt_inputs( self.dir_gid, job_inputs) # Encrypted folder has 2 more files are more then plain folder assert len(job_encrypted_inputs) == len(job_inputs) def test_decrypt_results(self, tmpdir): # Prepare encrypted dir job_inputs = self._prepare_dummy_files(tmpdir) self.storage_interface.encrypt_inputs(self.dir_gid, job_inputs) encrypted_dir = self.storage_interface.get_encrypted_dir(self.dir_gid) # Unencrypt data out_dir = os.path.join(str(tmpdir), 'output') os.mkdir(out_dir) self.storage_interface.decrypt_results_to_dir(self.dir_gid, out_dir) list_plain_dir = os.listdir(out_dir) assert len(list_plain_dir) == len(os.listdir(encrypted_dir)) assert 'dummy1.txt' in list_plain_dir assert 'dummy2.txt' in list_plain_dir def test_decrypt_files(self, tmpdir): # Prepare encrypted dir job_inputs = self._prepare_dummy_files(tmpdir) enc_files = self.storage_interface.encrypt_inputs( self.dir_gid, job_inputs) # Unencrypt data out_dir = os.path.join(str(tmpdir), 'output') os.mkdir(out_dir) self.storage_interface.decrypt_files_to_dir(self.dir_gid, [enc_files[1]], out_dir) list_plain_dir = os.listdir(out_dir) assert len(list_plain_dir) == 1 assert os.path.basename(enc_files[0]).replace('.aes', '') not in list_plain_dir assert os.path.basename(enc_files[1]).replace('.aes', '') in list_plain_dir def test_do_operation_launch(self, simulator_factory, operation_factory, mocker): # Prepare encrypted dir op = operation_factory(test_user=self.test_user, test_project=self.test_project) sim_folder, sim_gid = simulator_factory(op=op) self._do_operation_launch(op, sim_gid, mocker) def _do_operation_launch(self, op, sim_gid, mocker, is_pse=False): # Prepare encrypted dir self.storage_interface = StorageInterface() self.dir_gid = sim_gid job_encrypted_inputs = HPCSchedulerClient()._prepare_input(op, sim_gid) self.storage_interface.encrypt_inputs(sim_gid, job_encrypted_inputs) encrypted_dir = self.storage_interface.get_encrypted_dir(self.dir_gid) mocker.patch('tvb.core.operation_hpc_launcher._request_passfile', _request_passfile_dummy) mocker.patch( 'tvb.core.operation_hpc_launcher._update_operation_status', _update_operation_status) # Call do_operation_launch similarly to CSCS env plain_dir = self.storage_interface.get_project_folder( self.test_project.name, 'plain') do_operation_launch(sim_gid.hex, 1000, is_pse, '', op.id, plain_dir) assert len(os.listdir(encrypted_dir)) == 7 output_path = os.path.join(encrypted_dir, HPCSchedulerClient.OUTPUT_FOLDER) assert os.path.exists(output_path) expected_files = 2 if is_pse: expected_files = 3 assert len(os.listdir(output_path)) == expected_files return output_path def test_do_operation_launch_pse(self, simulator_factory, operation_factory, mocker): op = operation_factory(test_user=self.test_user, test_project=self.test_project) sim_folder, sim_gid = simulator_factory(op=op) self._do_operation_launch(op, sim_gid, mocker, is_pse=True) def test_prepare_inputs(self, operation_factory, simulator_factory): op = operation_factory(test_user=self.test_user, test_project=self.test_project) sim_folder, sim_gid = simulator_factory(op=op) hpc_client = HPCSchedulerClient() input_files = hpc_client._prepare_input(op, sim_gid) assert len(input_files) == 6 def test_prepare_inputs_with_surface(self, operation_factory, simulator_factory): op = operation_factory(test_user=self.test_user, test_project=self.test_project) sim_folder, sim_gid = simulator_factory(op=op, with_surface=True) hpc_client = HPCSchedulerClient() input_files = hpc_client._prepare_input(op, sim_gid) assert len(input_files) == 9 def test_prepare_inputs_with_eeg_monitor(self, operation_factory, simulator_factory, surface_index_factory, sensors_index_factory, region_mapping_index_factory, connectivity_index_factory): surface_idx, surface = surface_index_factory(cortical=True) sensors_idx, sensors = sensors_index_factory() proj = ProjectionSurfaceEEG(sensors=sensors, sources=surface, projection_data=numpy.ones(3)) op = operation_factory() storage_path = StorageInterface().get_project_folder( op.project.name, str(op.id)) prj_db_db = h5.store_complete(proj, storage_path) prj_db_db.fk_from_operation = op.id dao.store_entity(prj_db_db) connectivity = connectivity_index_factory(76, op) rm_index = region_mapping_index_factory(conn_gid=connectivity.gid, surface_gid=surface_idx.gid) eeg_monitor = EEGViewModel(projection=proj.gid, sensors=sensors.gid) eeg_monitor.region_mapping = rm_index.gid sim_folder, sim_gid = simulator_factory(op=op, monitor=eeg_monitor, conn_gid=connectivity.gid) hpc_client = HPCSchedulerClient() input_files = hpc_client._prepare_input(op, sim_gid) assert len(input_files) == 11 def test_stage_out_to_operation_folder(self, mocker, operation_factory, simulator_factory, pse_burst_configuration_factory): burst = pse_burst_configuration_factory(self.test_project) op = operation_factory(test_user=self.test_user, test_project=self.test_project) op.fk_operation_group = burst.fk_operation_group dao.store_entity(op) sim_folder, sim_gid = simulator_factory(op=op) burst.simulator_gid = sim_gid.hex dao.store_entity(burst) output_path = self._do_operation_launch(op, sim_gid, mocker, is_pse=True) def _stage_out_dummy(dir, sim_gid): return [ os.path.join(output_path, enc_file) for enc_file in os.listdir(output_path) ] mocker.patch.object(HPCSchedulerClient, '_stage_out_results', _stage_out_dummy) sim_results_files, metric_op, metric_file = HPCSchedulerClient.stage_out_to_operation_folder( None, op, sim_gid) assert op.id != metric_op.id assert os.path.exists(metric_file) assert len(sim_results_files) == 1 assert os.path.exists(sim_results_files[0]) def teardown_method(self): encrypted_dir = self.storage_interface.get_encrypted_dir(self.dir_gid) if os.path.isdir(encrypted_dir): shutil.rmtree(encrypted_dir) passfile = self.storage_interface.get_password_file(self.dir_gid) if os.path.exists(passfile): os.remove(passfile) self.storage_interface.remove_project_structure(self.test_project.name) self.clean_database()
class TestCSVConnectivityImporter(BaseTestCase): """ Unit-tests for csv connectivity importer. """ def setup_method(self): self.test_user = TestFactory.create_user() self.test_project = TestFactory.create_project(self.test_user) self.storage_interface = StorageInterface() def teardown_method(self): """ Clean-up tests data """ self.clean_database() self.storage_interface.remove_project_structure(self.test_project.name) def _import_csv_test_connectivity(self, reference_connectivity_gid, subject): ### First prepare input data: data_dir = path.abspath(path.dirname(tvb_data.__file__)) toronto_dir = path.join(data_dir, 'dti_pipeline_toronto') weights = path.join(toronto_dir, 'output_ConnectionCapacityMatrix.csv') tracts = path.join(toronto_dir, 'output_ConnectionDistanceMatrix.csv') weights_tmp = weights + '.tmp' tracts_tmp = tracts + '.tmp' self.storage_interface.copy_file(weights, weights_tmp) self.storage_interface.copy_file(tracts, tracts_tmp) view_model = CSVConnectivityImporterModel() view_model.weights = weights_tmp view_model.tracts = tracts_tmp view_model.data_subject = subject view_model.input_data = reference_connectivity_gid TestFactory.launch_importer(CSVConnectivityImporter, view_model, self.test_user, self.test_project, False) def test_happy_flow_import(self): """ Test that importing a CFF generates at least one DataType in DB. """ zip_path = path.join(path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_96.zip') TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path, subject=TEST_SUBJECT_A) field = FilterChain.datatype + '.subject' filters = FilterChain('', [field], [TEST_SUBJECT_A], ['==']) reference_connectivity_index = TestFactory.get_entity( self.test_project, ConnectivityIndex, filters) dt_count_before = TestFactory.get_entity_count(self.test_project, ConnectivityIndex) self._import_csv_test_connectivity(reference_connectivity_index.gid, TEST_SUBJECT_B) dt_count_after = TestFactory.get_entity_count(self.test_project, ConnectivityIndex) assert dt_count_before + 1 == dt_count_after filters = FilterChain('', [field], [TEST_SUBJECT_B], ['like']) imported_connectivity_index = TestFactory.get_entity( self.test_project, ConnectivityIndex, filters) # check relationship between the imported connectivity and the reference assert reference_connectivity_index.number_of_regions == imported_connectivity_index.number_of_regions assert not reference_connectivity_index.number_of_connections == imported_connectivity_index.number_of_connections reference_connectivity = h5.load_from_index( reference_connectivity_index) imported_connectivity = h5.load_from_index(imported_connectivity_index) assert not (reference_connectivity.weights == imported_connectivity.weights).all() assert not (reference_connectivity.tract_lengths == imported_connectivity.tract_lengths).all() assert (reference_connectivity.centres == imported_connectivity.centres ).all() assert (reference_connectivity.orientations == imported_connectivity.orientations).all() assert (reference_connectivity.region_labels == imported_connectivity.region_labels).all() def test_bad_reference(self): zip_path = path.join(path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_66.zip') TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path) field = FilterChain.datatype + '.subject' filters = FilterChain('', [field], [TEST_SUBJECT_A], ['!=']) bad_reference_connectivity = TestFactory.get_entity( self.test_project, ConnectivityIndex, filters) with pytest.raises(OperationException): self._import_csv_test_connectivity(bad_reference_connectivity.gid, TEST_SUBJECT_A)
class TestStimulusCreator(TransactionalTestCase): def transactional_setup_method(self): """ Reset the database before each test. """ self.test_user = TestFactory.create_user('Stim_User') self.test_project = TestFactory.create_project(self.test_user, "Stim_Project") self.storage_interface = StorageInterface() zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_66.zip') TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path) self.connectivity = TestFactory.get_entity(self.test_project, ConnectivityIndex) cortex = os.path.join(os.path.dirname(tvb_data.surfaceData.__file__), 'cortex_16384.zip') self.surface = TestFactory.import_surface_zip(self.test_user, self.test_project, cortex, CORTICAL) def transactional_teardown_method(self): """ Remove project folders and clean up database. """ self.storage_interface.remove_project_structure(self.test_project.name) def test_create_stimulus_region(self): weight_array = numpy.zeros(self.connectivity.number_of_regions) region_stimulus_creator = RegionStimulusCreator() region_stimulus_creator.storage_path = self.storage_interface.get_project_folder( self.test_project.name, "42") view_model = region_stimulus_creator.get_view_model_class()() view_model.connectivity = self.connectivity.gid view_model.weight = weight_array view_model.temporal = TemporalApplicableEquation() view_model.temporal.parameters['a'] = 1.0 view_model.temporal.parameters['b'] = 2.0 region_stimulus_index = region_stimulus_creator.launch(view_model) assert region_stimulus_index.temporal_equation == 'TemporalApplicableEquation' assert json.loads(region_stimulus_index.temporal_parameters) == { 'a': 1.0, 'b': 2.0 } assert region_stimulus_index.fk_connectivity_gid == self.connectivity.gid def test_create_stimulus_region_with_operation(self): weight_array = numpy.zeros(self.connectivity.number_of_regions) region_stimulus_creator = RegionStimulusCreator() view_model = region_stimulus_creator.get_view_model_class()() view_model.connectivity = self.connectivity.gid view_model.weight = weight_array view_model.temporal = TemporalApplicableEquation() view_model.temporal.parameters['a'] = 1.0 view_model.temporal.parameters['b'] = 2.0 OperationService().fire_operation(region_stimulus_creator, self.test_user, self.test_project.id, view_model=view_model) region_stimulus_index = TestFactory.get_entity(self.test_project, StimuliRegionIndex) assert region_stimulus_index.temporal_equation == 'TemporalApplicableEquation' assert json.loads(region_stimulus_index.temporal_parameters) == { 'a': 1.0, 'b': 2.0 } assert region_stimulus_index.fk_connectivity_gid == self.connectivity.gid def test_create_stimulus_surface(self): surface_stimulus_creator = SurfaceStimulusCreator() surface_stimulus_creator.storage_path = self.storage_interface.get_project_folder( self.test_project.name, "42") view_model = surface_stimulus_creator.get_view_model_class()() view_model.surface = self.surface.gid view_model.focal_points_triangles = numpy.array([1, 2, 3]) view_model.spatial = FiniteSupportEquation() view_model.spatial_amp = 1.0 view_model.spatial_sigma = 1.0 view_model.spatial_offset = 0.0 view_model.temporal = TemporalApplicableEquation() view_model.temporal.parameters['a'] = 1.0 view_model.temporal.parameters['b'] = 0.0 surface_stimulus_index = surface_stimulus_creator.launch(view_model) assert surface_stimulus_index.spatial_equation == 'FiniteSupportEquation' assert surface_stimulus_index.temporal_equation == 'TemporalApplicableEquation' assert surface_stimulus_index.fk_surface_gid == self.surface.gid def test_create_stimulus_surface_with_operation(self): surface_stimulus_creator = SurfaceStimulusCreator() view_model = surface_stimulus_creator.get_view_model_class()() view_model.surface = self.surface.gid view_model.focal_points_triangles = numpy.array([1, 2, 3]) view_model.spatial = FiniteSupportEquation() view_model.spatial_amp = 1.0 view_model.spatial_sigma = 1.0 view_model.spatial_offset = 0.0 view_model.temporal = TemporalApplicableEquation() view_model.temporal.parameters['a'] = 1.0 view_model.temporal.parameters['b'] = 0.0 OperationService().fire_operation(surface_stimulus_creator, self.test_user, self.test_project.id, view_model=view_model) surface_stimulus_index = TestFactory.get_entity( self.test_project, StimuliSurfaceIndex) assert surface_stimulus_index.spatial_equation == 'FiniteSupportEquation' assert surface_stimulus_index.temporal_equation == 'TemporalApplicableEquation' assert surface_stimulus_index.fk_surface_gid == self.surface.gid
class TestGIFTISurfaceImporter(BaseTestCase): """ Unit-tests for GIFTI Surface importer. """ GIFTI_SURFACE_FILE = os.path.join(os.path.dirname(demo_data.__file__), 'sample.cortex.gii') GIFTI_TIME_SERIES_FILE = os.path.join(os.path.dirname(demo_data.__file__), 'sample.time_series.gii') WRONG_GII_FILE = os.path.abspath(__file__) def setup_method(self): self.test_user = TestFactory.create_user('Gifti_User') self.test_project = TestFactory.create_project(self.test_user, "Gifti_Project") self.storage_interface = StorageInterface() def teardown_method(self): """ Clean-up tests data """ self.clean_database() self.storage_interface.remove_project_structure(self.test_project.name) def test_import_surface_gifti_data(self, operation_factory): """ This method tests import of a surface from GIFTI file. !!! Important: We changed this test to execute only GIFTI parse because storing surface it takes too long (~ 9min) since normals needs to be calculated. """ operation_id = operation_factory().id storage_path = self.storage_interface.get_project_folder( self.test_project.name, str(operation_id)) parser = GIFTIParser(storage_path, operation_id) surface = parser.parse(self.GIFTI_SURFACE_FILE) assert 131342 == len(surface.vertices) assert 262680 == len(surface.triangles) def test_import_timeseries_gifti_data(self, operation_factory): """ This method tests import of a time series from GIFTI file. !!! Important: We changed this test to execute only GIFTI parse because storing surface it takes too long (~ 9min) since normals needs to be calculated. """ operation_id = operation_factory().id storage_path = self.storage_interface.get_project_folder( self.test_project.name, str(operation_id)) parser = GIFTIParser(storage_path, operation_id) time_series = parser.parse(self.GIFTI_TIME_SERIES_FILE) data_shape = time_series[1] assert 135 == len(data_shape) assert 143479 == data_shape[0].dims[0] def test_import_wrong_gii_file(self): """ This method tests import of a file in a wrong format """ try: TestFactory.import_surface_gifti(self.test_user, self.test_project, self.WRONG_GII_FILE) raise AssertionError( "Import should fail in case of a wrong GIFTI format.") except OperationException: # Expected exception pass
class TestOperationResource(RestResourceTest): def transactional_setup_method(self): self.test_user = TestFactory.create_user('Rest_User') self.test_project = TestFactory.create_project(self.test_user, 'Rest_Project', users=[self.test_user.id]) self.operations_resource = GetOperationsInProjectResource() self.status_resource = GetOperationStatusResource() self.results_resource = GetOperationResultsResource() self.launch_resource = LaunchOperationResource() self.storage_interface = StorageInterface() def test_server_get_operation_status_inexistent_gid(self, mocker): self._mock_user(mocker) operation_gid = "inexistent-gid" with pytest.raises(InvalidIdentifierException): self.status_resource.get(operation_gid=operation_gid) def test_server_get_operation_status(self, mocker): self._mock_user(mocker) zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_96.zip') TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path) request_mock = mocker.patch.object(flask, 'request') request_mock.args = {Strings.PAGE_NUMBER: '1'} operations_and_pages = self.operations_resource.get(project_gid=self.test_project.gid) result = self.status_resource.get(operation_gid=operations_and_pages['operations'][0].gid) assert type(result) is str assert result in OperationPossibleStatus def test_server_get_operation_results_inexistent_gid(self, mocker): self._mock_user(mocker) operation_gid = "inexistent-gid" with pytest.raises(InvalidIdentifierException): self.results_resource.get(operation_gid=operation_gid) def test_server_get_operation_results(self, mocker): self._mock_user(mocker) zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_96.zip') TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path) request_mock = mocker.patch.object(flask, 'request') request_mock.args = {Strings.PAGE_NUMBER: '1'} operations_and_pages = self.operations_resource.get(project_gid=self.test_project.gid) result = self.results_resource.get(operation_gid=operations_and_pages['operations'][0].gid) assert type(result) is list assert len(result) == 1 def test_server_get_operation_results_failed_operation(self, mocker): self._mock_user(mocker) zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_90.zip') with pytest.raises(TVBException): TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path) request_mock = mocker.patch.object(flask, 'request') request_mock.args = {Strings.PAGE_NUMBER: '1'} operations_and_pages = self.operations_resource.get(project_gid=self.test_project.gid) result = self.results_resource.get(operation_gid=operations_and_pages['operations'][0].gid) assert type(result) is list assert len(result) == 0 def test_server_launch_operation_no_file(self, mocker): self._mock_user(mocker) # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = {} with pytest.raises(InvalidIdentifierException): self.launch_resource.post(project_gid='', algorithm_module='', algorithm_classname='') def test_server_launch_operation_wrong_file_extension(self, mocker): self._mock_user(mocker) dummy_file = FileStorage(BytesIO(b"test"), 'test.txt') # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = {RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: dummy_file} with pytest.raises(InvalidIdentifierException): self.launch_resource.post(project_gid='', algorithm_module='', algorithm_classname='') def test_server_launch_operation_inexistent_gid(self, mocker): self._mock_user(mocker) project_gid = "inexistent-gid" dummy_file = FileStorage(BytesIO(b"test"), 'test.h5') # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = {RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: dummy_file} with pytest.raises(InvalidIdentifierException): self.launch_resource.post(project_gid=project_gid, algorithm_module='', algorithm_classname='') def test_server_launch_operation_inexistent_algorithm(self, mocker): self._mock_user(mocker) inexistent_algorithm = "inexistent-algorithm" dummy_file = FileStorage(BytesIO(b"test"), 'test.h5') # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') request_mock.files = {RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: dummy_file} with pytest.raises(ServiceException): self.launch_resource.post(project_gid=self.test_project.gid, algorithm_module=inexistent_algorithm, algorithm_classname='') def test_server_launch_operation(self, mocker, time_series_index_factory): self._mock_user(mocker) algorithm_module = "tvb.adapters.analyzers.fourier_adapter" algorithm_class = "FourierAdapter" input_ts_index = time_series_index_factory() fft_model = FFTAdapterModel() fft_model.time_series = UUID(input_ts_index.gid) fft_model.window_function = list(SUPPORTED_WINDOWING_FUNCTIONS)[0] input_folder = self.storage_interface.get_project_folder(self.test_project.name) view_model_h5_path = h5.store_view_model(fft_model, input_folder) # Mock flask.request.files to return a dictionary request_mock = mocker.patch.object(flask, 'request') fp = open(view_model_h5_path, 'rb') request_mock.files = { RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: FileStorage(fp, os.path.basename(view_model_h5_path))} # Mock launch_operation() call and current_user mocker.patch.object(OperationService, 'launch_operation') operation_gid, status = self.launch_resource.post(project_gid=self.test_project.gid, algorithm_module=algorithm_module, algorithm_classname=algorithm_class) fp.close() assert type(operation_gid) is str assert len(operation_gid) > 0 def transactional_teardown_method(self): self.storage_interface.remove_project_structure(self.test_project.name)