Beispiel #1
0
    def setup_method(self):
        """
        Prepare the usage of a different config file for this class only.
        """
        StorageInterface.remove_files([TEST_CONFIG_FILE, TestSQLiteProfile.DEFAULT_STORAGE])

        self.old_config_file = TvbProfile.current.TVB_CONFIG_FILE
        TvbProfile.current.__class__.TVB_CONFIG_FILE = TEST_CONFIG_FILE
        TvbProfile._build_profile_class(TvbProfile.CURRENT_PROFILE_NAME)
        self.settings_service = SettingsService()
Beispiel #2
0
    def test_update_settings(self):
        """
        Test update of settings: correct flags should be returned, and check storage folder renamed
        """
        # 1. save on empty config-file:
        to_store_data = {key: value['value'] for key, value in self.settings_service.configurable_keys.items()}
        for key, value in self.TEST_SETTINGS.items():
            to_store_data[key] = value

        is_changed, shoud_reset = self.settings_service.save_settings(**to_store_data)
        assert shoud_reset and is_changed

        # 2. Reload and save with the same values (is_changed expected to be False)
        TvbProfile._build_profile_class(TvbProfile.CURRENT_PROFILE_NAME)
        self.settings_service = SettingsService()
        to_store_data = {key: value['value'] for key, value in self.settings_service.configurable_keys.items()}

        is_changed, shoud_reset = self.settings_service.save_settings(**to_store_data)
        assert not is_changed
        assert not shoud_reset

        # 3. Reload and check that changing TVB_STORAGE is done correctly
        TvbProfile._build_profile_class(TvbProfile.CURRENT_PROFILE_NAME)
        self.settings_service = SettingsService()
        to_store_data = {key: value['value'] for key, value in self.settings_service.configurable_keys.items()}
        to_store_data[SettingsService.KEY_STORAGE] = os.path.join(TvbProfile.current.TVB_STORAGE, 'RENAMED')

        # Write a test-file and check that it is moved
        file_writer = open(os.path.join(TvbProfile.current.TVB_STORAGE, "test_rename-xxx43"), 'w')
        file_writer.write('test-content')
        file_writer.close()

        is_changed, shoud_reset = self.settings_service.save_settings(**to_store_data)
        assert is_changed
        assert not shoud_reset
        # Check that the file was correctly moved:
        data = open(os.path.join(TvbProfile.current.TVB_STORAGE, 'RENAMED', "test_rename-xxx43"), 'r').read()
        assert data == 'test-content'

        StorageInterface.remove_files([os.path.join(TvbProfile.current.TVB_STORAGE, 'RENAMED'),
                                       os.path.join(TvbProfile.current.TVB_STORAGE, "test_rename-xxx43")])
Beispiel #3
0
class TestHPCSchedulerClient(BaseTestCase):
    def setup_method(self):
        self.storage_interface = StorageInterface()
        self.dir_gid = '123'
        self.encryption_handler = self.storage_interface.get_encryption_handler(
            self.dir_gid)
        self.clean_database()
        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(self.test_user)

    def _prepare_dummy_files(self, tmpdir):
        dummy_file1 = os.path.join(str(tmpdir), 'dummy1.txt')
        open(dummy_file1, 'a').close()
        dummy_file2 = os.path.join(str(tmpdir), 'dummy2.txt')
        open(dummy_file2, 'a').close()
        job_inputs = [dummy_file1, dummy_file2]
        return job_inputs

    def test_encrypt_inputs(self, tmpdir):
        job_inputs = self._prepare_dummy_files(tmpdir)
        job_encrypted_inputs = self.encryption_handler.encrypt_inputs(
            job_inputs)
        # Encrypted folder has 2 more files are more then plain folder
        assert len(job_encrypted_inputs) == len(job_inputs)

    def test_decrypt_results(self, tmpdir):
        # Prepare encrypted dir
        job_inputs = self._prepare_dummy_files(tmpdir)
        self.encryption_handler.encrypt_inputs(job_inputs)
        encrypted_dir = self.encryption_handler.get_encrypted_dir()

        # Unencrypt data
        out_dir = os.path.join(str(tmpdir), 'output')
        os.mkdir(out_dir)
        self.encryption_handler.decrypt_results_to_dir(out_dir)
        list_plain_dir = os.listdir(out_dir)
        assert len(list_plain_dir) == len(os.listdir(encrypted_dir))
        assert 'dummy1.txt' in list_plain_dir
        assert 'dummy2.txt' in list_plain_dir

    def test_decrypt_files(self, tmpdir):
        # Prepare encrypted dir
        job_inputs = self._prepare_dummy_files(tmpdir)
        enc_files = self.encryption_handler.encrypt_inputs(job_inputs)

        # Unencrypt data
        out_dir = os.path.join(str(tmpdir), 'output')
        os.mkdir(out_dir)
        self.encryption_handler.decrypt_files_to_dir([enc_files[1]], out_dir)
        list_plain_dir = os.listdir(out_dir)
        assert len(list_plain_dir) == 1
        assert os.path.basename(enc_files[0]).replace('.aes',
                                                      '') not in list_plain_dir
        assert os.path.basename(enc_files[1]).replace('.aes',
                                                      '') in list_plain_dir

    def test_do_operation_launch(self, simulator_factory, operation_factory,
                                 mocker):
        # Prepare encrypted dir
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op)

        self._do_operation_launch(op, sim_gid, mocker)

    def _do_operation_launch(self, op, sim_gid, mocker, is_pse=False):
        # Prepare encrypted dir
        self.dir_gid = sim_gid
        self.encryption_handler = StorageInterface.get_encryption_handler(
            self.dir_gid)
        job_encrypted_inputs = HPCSchedulerClient()._prepare_input(
            op, self.dir_gid)
        self.encryption_handler.encrypt_inputs(job_encrypted_inputs)
        encrypted_dir = self.encryption_handler.get_encrypted_dir()

        mocker.patch('tvb.core.operation_hpc_launcher._request_passfile',
                     _request_passfile_dummy)
        mocker.patch(
            'tvb.core.operation_hpc_launcher._update_operation_status',
            _update_operation_status)

        # Call do_operation_launch similarly to CSCS env
        plain_dir = self.storage_interface.get_project_folder(
            self.test_project.name, 'plain')
        do_operation_launch(self.dir_gid, 1000, is_pse, '', op.id, plain_dir)
        assert len(os.listdir(encrypted_dir)) == 7
        output_path = os.path.join(encrypted_dir,
                                   HPCSchedulerClient.OUTPUT_FOLDER)
        assert os.path.exists(output_path)
        expected_files = 2
        if is_pse:
            expected_files = 3
        assert len(os.listdir(output_path)) == expected_files
        return output_path

    def test_do_operation_launch_pse(self, simulator_factory,
                                     operation_factory, mocker):
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op)
        self._do_operation_launch(op, sim_gid, mocker, is_pse=True)

    def test_prepare_inputs(self, operation_factory, simulator_factory):
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op)
        hpc_client = HPCSchedulerClient()
        input_files = hpc_client._prepare_input(op, sim_gid)
        assert len(input_files) == 6

    def test_prepare_inputs_with_surface(self, operation_factory,
                                         simulator_factory):
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op, with_surface=True)
        hpc_client = HPCSchedulerClient()
        input_files = hpc_client._prepare_input(op, sim_gid)
        assert len(input_files) == 9

    def test_prepare_inputs_with_eeg_monitor(self, operation_factory,
                                             simulator_factory,
                                             surface_index_factory,
                                             sensors_index_factory,
                                             region_mapping_index_factory,
                                             connectivity_index_factory):
        surface_idx, surface = surface_index_factory(cortical=True)
        sensors_idx, sensors = sensors_index_factory()
        proj = ProjectionSurfaceEEG(sensors=sensors,
                                    sources=surface,
                                    projection_data=numpy.ones(3))

        op = operation_factory()
        prj_db_db = h5.store_complete(proj, op.id, op.project.name)
        prj_db_db.fk_from_operation = op.id
        dao.store_entity(prj_db_db)

        connectivity = connectivity_index_factory(76, op)
        rm_index = region_mapping_index_factory(conn_gid=connectivity.gid,
                                                surface_gid=surface_idx.gid)

        eeg_monitor = EEGViewModel(projection=proj.gid, sensors=sensors.gid)
        eeg_monitor.region_mapping = rm_index.gid

        sim_folder, sim_gid = simulator_factory(op=op,
                                                monitor=eeg_monitor,
                                                conn_gid=connectivity.gid)
        hpc_client = HPCSchedulerClient()
        input_files = hpc_client._prepare_input(op, sim_gid)
        assert len(input_files) == 11

    def test_stage_out_to_operation_folder(self, mocker, operation_factory,
                                           simulator_factory,
                                           pse_burst_configuration_factory):
        burst = pse_burst_configuration_factory(self.test_project)
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        op.fk_operation_group = burst.fk_operation_group
        dao.store_entity(op)

        sim_folder, sim_gid = simulator_factory(op=op)
        burst.simulator_gid = sim_gid.hex
        dao.store_entity(burst)

        output_path = self._do_operation_launch(op,
                                                sim_gid,
                                                mocker,
                                                is_pse=True)

        def _stage_out_dummy(dir, sim_gid):
            return [
                os.path.join(output_path, enc_file)
                for enc_file in os.listdir(output_path)
            ]

        mocker.patch.object(HPCSchedulerClient, '_stage_out_results',
                            _stage_out_dummy)
        sim_results_files, metric_op, metric_file = HPCSchedulerClient.stage_out_to_operation_folder(
            None, op, sim_gid)
        assert op.id != metric_op.id
        assert os.path.exists(metric_file)
        assert len(sim_results_files) == 1
        assert os.path.exists(sim_results_files[0])

    def teardown_method(self):
        encrypted_dir = self.encryption_handler.get_encrypted_dir()
        passfile = self.encryption_handler.get_password_file()
        self.storage_interface.remove_files([encrypted_dir, passfile])
        self.clean_database()
Beispiel #4
0
class AlgorithmService(object):
    """
    Service Layer for Algorithms manipulation (e.g. find all Uploaders, Filter algo by category, etc)
    """
    def __init__(self):
        self.logger = get_logger(self.__class__.__module__)
        self.storage_interface = StorageInterface()

    @staticmethod
    def get_category_by_id(identifier):
        """ Pass to DAO the retrieve of category by ID operation."""
        return dao.get_category_by_id(identifier)

    @staticmethod
    def get_raw_categories():
        """:returns: AlgorithmCategory list of entities that have results in RAW state (Creators/Uploaders)"""
        return dao.get_raw_categories()

    @staticmethod
    def get_visualisers_category():
        """Retrieve all Algorithm categories, with display capability"""
        result = dao.get_visualisers_categories()
        if not result:
            raise ValueError("View Category not found!!!")
        return result[0]

    @staticmethod
    def get_algorithm_by_identifier(ident):
        """
        Retrieve Algorithm entity by ID.
        Return None, if ID is not found in DB.
        """
        return dao.get_algorithm_by_id(ident)

    @staticmethod
    def get_operation_numbers(proj_id):
        """ Count total number of operations started for current project. """
        return dao.get_operation_numbers(proj_id)

    def _prepare_dt_display_name(self, dt_index, dt):
        # dt is a result of the get_values_of_datatype function
        db_dt = dao.get_generic_entity(dt_index, dt[2], "gid")
        display_name = db_dt[0].display_name
        display_name += ' - ' + (dt[3] or "None ")  # Subject
        if dt[5]:
            display_name += ' - From: ' + str(dt[5])
        else:
            display_name += date2string(dt[4])
        if dt[6]:
            display_name += ' - ' + str(dt[6])
        display_name += ' - ID:' + str(dt[0])

        return display_name

    def fill_selectfield_with_datatypes(self,
                                        field,
                                        project_id,
                                        extra_conditions=None):
        # type: (TraitDataTypeSelectField, int, list) -> None
        filtering_conditions = FilterChain()
        filtering_conditions += field.conditions
        filtering_conditions += extra_conditions
        datatypes, _ = dao.get_values_of_datatype(project_id,
                                                  field.datatype_index,
                                                  filtering_conditions)
        datatype_options = []
        for datatype in datatypes:
            display_name = self._prepare_dt_display_name(
                field.datatype_index, datatype)
            datatype_options.append((datatype, display_name))
        field.datatype_options = datatype_options

    def _fill_form_with_datatypes(self,
                                  form,
                                  project_id,
                                  extra_conditions=None):
        for form_field in form.trait_fields:
            if isinstance(form_field, TraitDataTypeSelectField):
                self.fill_selectfield_with_datatypes(form_field, project_id,
                                                     extra_conditions)
        return form

    def prepare_adapter_form(self,
                             adapter_instance=None,
                             form_instance=None,
                             project_id=None,
                             extra_conditions=None):
        # type: (ABCAdapter, ABCAdapterForm, int, []) -> ABCAdapterForm
        form = None
        if form_instance is not None:
            form = form_instance
        elif adapter_instance is not None:
            form = adapter_instance.get_form()()

        if form is None:
            raise OperationException("Cannot prepare None form")

        form = self._fill_form_with_datatypes(form, project_id,
                                              extra_conditions)
        return form

    def _prepare_upload_post_data(self, form, post_data, project_id):
        for form_field in form.trait_fields:
            if isinstance(form_field,
                          TraitUploadField) and form_field.name in post_data:
                field = post_data[form_field.name]
                file_name = None
                if hasattr(field, 'file') and field.file is not None:
                    project = dao.get_project_by_id(project_id)
                    temporary_storage = self.storage_interface.get_temp_folder(
                        project.name)
                    try:
                        uq_name = date2string(datetime.now(),
                                              True) + '_' + str(0)
                        file_name = TEMPORARY_PREFIX + uq_name + '_' + field.filename
                        file_name = os.path.join(temporary_storage, file_name)

                        with open(file_name, 'wb') as file_obj:
                            file_obj.write(field.file.read())
                    except Exception as excep:
                        # TODO: is this handled properly?
                        self.storage_interface.remove_files([file_name])
                        excep.message = 'Could not continue: Invalid input files'
                        raise excep
                post_data[form_field.name] = file_name

    def fill_adapter_form(self, adapter_instance, post_data, project_id):
        # type: (ABCAdapter, dict, int) -> ABCAdapterForm
        form = self.prepare_adapter_form(adapter_instance=adapter_instance,
                                         project_id=project_id)
        if isinstance(form, ABCUploaderForm):
            self._prepare_upload_post_data(form, post_data, project_id)

        if 'fill_defaults' in post_data:
            form.fill_from_post_plus_defaults(post_data)
        else:
            form.fill_from_post(post_data)

        return form

    def prepare_adapter(self, stored_adapter):

        adapter_module = stored_adapter.module
        adapter_name = stored_adapter.classname
        try:
            # Prepare Adapter Interface, by populating with existent data,
            # in case of a parameter of type DataType.
            adapter_instance = ABCAdapter.build_adapter(stored_adapter)
            return adapter_instance
        except Exception:
            self.logger.exception('Not found:' + adapter_name + ' in:' +
                                  adapter_module)
            raise OperationException("Could not prepare " + adapter_name)

    @staticmethod
    def get_algorithm_by_module_and_class(module, classname):
        """
        Get the db entry from the algorithm table for the given module and 
        class.
        """
        return dao.get_algorithm_by_module(module, classname)

    @staticmethod
    def create_link(data_id, project_id):
        """
        For a list of dataType IDs and a project id create all the required links.
        """
        link = Links(data_id, project_id)
        dao.store_entity(link)

    @staticmethod
    def remove_link(dt_id, project_id):
        """
        Remove the link from the datatype given by dt_id to project given by project_id.
        """
        link = dao.get_link(dt_id, project_id)
        if link is not None:
            dao.remove_entity(Links, link.id)

    @staticmethod
    def get_upload_algorithms():
        """
        :return: List of StoredAdapter entities
        """
        categories = dao.get_uploader_categories()
        categories_ids = [categ.id for categ in categories]
        return dao.get_adapters_from_categories(categories_ids)

    @staticmethod
    def get_analyze_groups():
        """
        :return: list of AlgorithmTransientGroup entities
        """
        categories = dao.get_launchable_categories(elimin_viewers=True)
        categories_ids = [categ.id for categ in categories]
        stored_adapters = dao.get_adapters_from_categories(categories_ids)

        groups_list = []
        for adapter in stored_adapters:
            # For empty groups, this time, we fill the actual adapter
            group = AlgorithmTransientGroup(
                adapter.group_name or adapter.displayname,
                adapter.group_description or adapter.description)
            group = AlgorithmService._find_group(groups_list, group)
            group.children.append(adapter)
        return categories[0], groups_list

    @staticmethod
    def _find_group(groups_list, new_group):
        for i in range(len(groups_list) - 1, -1, -1):
            current_group = groups_list[i]
            if current_group.name == new_group.name and current_group.description == new_group.description:
                return current_group
        # Not found in list
        groups_list.append(new_group)
        return new_group

    def get_visualizers_for_group(self, dt_group_gid):

        categories = dao.get_visualisers_categories()
        return self._get_launchable_algorithms(dt_group_gid, categories)[1]

    def get_launchable_algorithms(self, datatype_gid):
        """
        :param datatype_gid: Filter only algorithms compatible with this GUID
        :return: dict(category_name: List AlgorithmTransientGroup)
        """
        categories = dao.get_launchable_categories()
        datatype_instance, filtered_adapters, has_operations_warning = self._get_launchable_algorithms(
            datatype_gid, categories)

        categories_dict = dict()
        for c in categories:
            categories_dict[c.id] = c.displayname

        return self._group_adapters_by_category(
            filtered_adapters, categories_dict), has_operations_warning

    def _get_launchable_algorithms(self, datatype_gid, categories):
        datatype_instance = dao.get_datatype_by_gid(datatype_gid)
        return self.get_launchable_algorithms_for_datatype(
            datatype_instance, categories)

    def get_launchable_algorithms_for_datatype(self, datatype, categories):
        data_class = datatype.__class__
        all_compatible_classes = [data_class.__name__]
        for one_class in getmro(data_class):
            # from tvb.basic.traits.types_mapped import MappedType

            if issubclass(
                    one_class, DataType
            ) and one_class.__name__ not in all_compatible_classes:
                all_compatible_classes.append(one_class.__name__)

        self.logger.debug("Searching in categories: " + str(categories) +
                          " for classes " + str(all_compatible_classes))
        categories_ids = [categ.id for categ in categories]
        launchable_adapters = dao.get_applicable_adapters(
            all_compatible_classes, categories_ids)

        filtered_adapters = []
        has_operations_warning = False
        for stored_adapter in launchable_adapters:
            filter_chain = FilterChain.from_json(
                stored_adapter.datatype_filter)
            try:
                if not filter_chain or filter_chain.get_python_filter_equivalent(
                        datatype):
                    filtered_adapters.append(stored_adapter)
            except (TypeError, InvalidFilterChainInput):
                self.logger.exception("Could not evaluate filter on " +
                                      str(stored_adapter))
                has_operations_warning = True

        return datatype, filtered_adapters, has_operations_warning

    def _group_adapters_by_category(self, stored_adapters, categories):
        """
        :param stored_adapters: list StoredAdapter
        :return: dict(category_name: List AlgorithmTransientGroup), empty groups all in the same AlgorithmTransientGroup
        """
        categories_dict = dict()
        for adapter in stored_adapters:
            category_name = categories.get(adapter.fk_category)
            if category_name in categories_dict:
                groups_list = categories_dict.get(category_name)
            else:
                groups_list = []
                categories_dict[category_name] = groups_list
            group = AlgorithmTransientGroup(adapter.group_name,
                                            adapter.group_description)
            group = self._find_group(groups_list, group)
            group.children.append(adapter)
        return categories_dict

    @staticmethod
    def get_generic_entity(entity_type, filter_value, select_field):
        return dao.get_generic_entity(entity_type, filter_value, select_field)

    ##########################################################################
    ######## Methods below are for MeasurePoint selections ###################
    ##########################################################################

    @staticmethod
    def get_selections_for_project(project_id, datatype_gid):
        """
        Retrieved from DB saved selections for current project. If a certain selection
        doesn't have all the labels between the labels of the given connectivity than
        this selection will not be returned.
        :returns: List of ConnectivitySelection entities.
        """
        return dao.get_selections_for_project(project_id, datatype_gid)

    @staticmethod
    def save_measure_points_selection(ui_name, selected_nodes, datatype_gid,
                                      project_id):
        """
        Store in DB a ConnectivitySelection.
        """
        select_entities = dao.get_selections_for_project(
            project_id, datatype_gid, ui_name)

        if select_entities:
            # when the name of the new selection is within the available selections then update that selection:
            select_entity = select_entities[0]
            select_entity.selected_nodes = selected_nodes
        else:
            select_entity = MeasurePointsSelection(ui_name, selected_nodes,
                                                   datatype_gid, project_id)

        dao.store_entity(select_entity)

    ##########################################################################
    ##########    Bellow are PSE Filters specific methods   ##################
    ##########################################################################

    @staticmethod
    def get_stored_pse_filters(datatype_group_gid):
        return dao.get_stored_pse_filters(datatype_group_gid)

    @staticmethod
    def save_pse_filter(ui_name, datatype_group_gid, threshold_value,
                        applied_on):
        """
        Store in DB a PSE filter.
        """
        select_entities = dao.get_stored_pse_filters(datatype_group_gid,
                                                     ui_name)

        if select_entities:
            # when the UI name is already in DB, update the existing entity
            select_entity = select_entities[0]
            select_entity.threshold_value = threshold_value
            select_entity.applied_on = applied_on  # this is the type, as in applied on size or color
        else:
            select_entity = StoredPSEFilter(ui_name, datatype_group_gid,
                                            threshold_value, applied_on)

        dao.store_entity(select_entity)
    def test_encrypt_decrypt(self, dir_name, file_name):

        # Generate a private key and public key
        private_key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend()
        )

        public_key = private_key.public_key()

        # Convert private key to bytes and save it so ABCUploader can use it
        pem = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption()
        )

        private_key_path = os.path.join(TvbProfile.current.TVB_TEMP_FOLDER, 'private_key.pem')
        with open(private_key_path, 'wb') as f:
            f.write(pem)

        path_to_file = os.path.join(os.path.dirname(tvb_data.__file__), dir_name, file_name)

        # Create model for ABCUploader
        connectivity_model = ZIPConnectivityImporterModel()

        # Generate password
        pass_size = TvbProfile.current.hpc.CRYPT_PASS_SIZE
        password = StorageInterface.generate_random_password(pass_size)

        # Encrypt files using an AES symmetric key
        encrypted_file_path = ABCUploader.get_path_to_encrypt(path_to_file)
        buffer_size = TvbProfile.current.hpc.CRYPT_BUFFER_SIZE
        pyAesCrypt.encryptFile(path_to_file, encrypted_file_path, password, buffer_size)

        # Asynchronously encrypt the password used at the previous step for the symmetric encryption
        password = str.encode(password)
        encrypted_password = ABCUploader.encrypt_password(public_key, password)

        # Save encrypted password
        ABCUploader.save_encrypted_password(encrypted_password, TvbProfile.current.TVB_TEMP_FOLDER)
        path_to_encrypted_password = os.path.join(TvbProfile.current.TVB_TEMP_FOLDER, ENCRYPTED_PASSWORD_NAME)

        # Prepare model for decrypting
        connectivity_model.uploaded = encrypted_file_path
        connectivity_model.encrypted_aes_key = path_to_encrypted_password
        TvbProfile.current.UPLOAD_KEY_PATH = os.path.join(TvbProfile.current.TVB_TEMP_FOLDER, 'private_key.pem')

        # Decrypting
        ABCUploader._decrypt_content(connectivity_model, 'uploaded')

        decrypted_file_path = connectivity_model.uploaded.replace(ENCRYPTED_DATA_SUFFIX, DECRYPTED_DATA_SUFFIX)
        with open(path_to_file, 'rb') as f_original:
            with open(decrypted_file_path, 'rb') as f_decrypted:
                while True:
                    original_content_chunk = f_original.read(buffer_size)
                    decrypted_content_chunk = f_decrypted.read(buffer_size)

                    assert original_content_chunk == decrypted_content_chunk, \
                        "Original and Decrypted chunks are not equal, so decryption hasn't been done correctly!"

                    # check if EOF was reached
                    if len(original_content_chunk) < buffer_size:
                        break

        # Clean-up
        StorageInterface.remove_files(
            [encrypted_file_path, decrypted_file_path, private_key_path, path_to_encrypted_password])
Beispiel #6
0
class ImportService(object):
    """
    Service for importing TVB entities into system.
    It supports TVB exported H5 files as input, but it should also handle H5 files
    generated outside of TVB, as long as they respect the same structure.
    """
    def __init__(self):
        self.logger = get_logger(__name__)
        self.user_id = None
        self.storage_interface = StorageInterface()
        self.created_projects = []
        self.view_model2adapter = self._populate_view_model2adapter()

    def _download_and_unpack_project_zip(self, uploaded, uq_file_name,
                                         temp_folder):

        if isinstance(uploaded, (FieldStorage, Part)):
            if not uploaded.file:
                raise ImportException(
                    "Please select the archive which contains the project structure."
                )
            with open(uq_file_name, 'wb') as file_obj:
                self.storage_interface.copy_file(uploaded.file, file_obj)
        else:
            shutil.copy2(uploaded, uq_file_name)

        try:
            self.storage_interface.unpack_zip(uq_file_name, temp_folder)
        except FileStructureException as excep:
            self.logger.exception(excep)
            raise ImportException(
                "Bad ZIP archive provided. A TVB exported project is expected!"
            )

    @staticmethod
    def _compute_unpack_path():
        """
        :return: the name of the folder where to expand uploaded zip
        """
        now = datetime.now()
        date_str = "%d-%d-%d_%d-%d-%d_%d" % (now.year, now.month, now.day,
                                             now.hour, now.minute, now.second,
                                             now.microsecond)
        uq_name = "%s-ImportProject" % date_str
        return os.path.join(TvbProfile.current.TVB_TEMP_FOLDER, uq_name)

    @transactional
    def import_project_structure(self, uploaded, user_id):
        """
        Execute import operations:

        1. check if ZIP or folder
        2. find all project nodes
        3. for each project node:
            - create project
            - create all operations and groups
            - import all images
            - create all dataTypes
        """

        self.user_id = user_id
        self.created_projects = []

        # Now compute the name of the folder where to explode uploaded ZIP file
        temp_folder = self._compute_unpack_path()
        uq_file_name = temp_folder + ".zip"

        try:
            self._download_and_unpack_project_zip(uploaded, uq_file_name,
                                                  temp_folder)
            self._import_project_from_folder(temp_folder)

        except Exception as excep:
            self.logger.exception(
                "Error encountered during import. Deleting projects created during this operation."
            )
            # Remove project folders created so far.
            # Note that using the project service to remove the projects will not work,
            # because we do not have support for nested transaction.
            # Removing from DB is not necessary because in transactional env a simple exception throw
            # will erase everything to be inserted.
            for project in self.created_projects:
                self.storage_interface.remove_project(project)
            raise ImportException(str(excep))

        finally:
            # Now delete uploaded file and temporary folder where uploaded ZIP was exploded.
            self.storage_interface.remove_files([uq_file_name, temp_folder])

    def _import_project_from_folder(self, temp_folder):
        """
        Process each project from the uploaded pack, to extract names.
        """
        temp_project_path = None
        for root, _, files in os.walk(temp_folder):
            if StorageInterface.TVB_PROJECT_FILE in files:
                temp_project_path = root
                break

        if temp_project_path is not None:
            update_manager = ProjectUpdateManager(temp_project_path)

            if update_manager.checked_version < 3:
                raise ImportException(
                    'Importing projects with versions older than 3 is not supported in TVB 2! '
                    'Please import the project in TVB 1.5.8 and then launch the current version of '
                    'TVB in order to upgrade this project!')

            update_manager.run_all_updates()
            project = self.__populate_project(temp_project_path)
            # Populate the internal list of create projects so far, for cleaning up folders, in case of failure
            self.created_projects.append(project)
            # Ensure project final folder exists on disk
            project_path = self.storage_interface.get_project_folder(
                project.name)
            shutil.move(
                os.path.join(temp_project_path,
                             StorageInterface.TVB_PROJECT_FILE), project_path)
            # Now import project operations with their results
            self.import_list_of_operations(project, temp_project_path)
            # Import images and move them from temp into target
            self._store_imported_images(project, temp_project_path,
                                        project.name)
            if StorageInterface.encryption_enabled():
                self.storage_interface.remove_project(project, True)

    def _load_datatypes_from_operation_folder(self, src_op_path,
                                              operation_entity,
                                              datatype_group):
        """
        Loads datatypes from operation folder
        :returns: Datatype entities as dict {original_path: Dt instance}
        """
        all_datatypes = {}
        for file_name in os.listdir(src_op_path):
            if self.storage_interface.ends_with_tvb_storage_file_extension(
                    file_name):
                h5_file = os.path.join(src_op_path, file_name)
                try:
                    file_update_manager = FilesUpdateManager()
                    file_update_manager.upgrade_file(h5_file)
                    datatype = self.load_datatype_from_file(
                        h5_file, operation_entity.id, datatype_group,
                        operation_entity.fk_launched_in)
                    all_datatypes[h5_file] = datatype

                except IncompatibleFileManagerException:
                    os.remove(h5_file)
                    self.logger.warning(
                        "Incompatible H5 file will be ignored: %s" % h5_file)
                    self.logger.exception("Incompatibility details ...")
        return all_datatypes

    @staticmethod
    def check_import_references(file_path, datatype):
        h5_class = H5File.h5_class_from_file(file_path)
        reference_list = h5_class(file_path).gather_references()

        for _, reference_gid in reference_list:
            if not reference_gid:
                continue

            ref_index = load.load_entity_by_gid(reference_gid)
            if ref_index is None:
                os.remove(file_path)
                dao.remove_entity(datatype.__class__, datatype.id)
                raise MissingReferenceException(
                    'Imported file depends on datatypes that do not exist. Please upload '
                    'those first!')

    def _store_or_link_burst_config(self, burst_config, bc_path, project_id):
        bc_already_in_tvb = dao.get_generic_entity(BurstConfiguration,
                                                   burst_config.gid, 'gid')
        if len(bc_already_in_tvb) == 0:
            self.store_datatype(burst_config, bc_path)
            return 1
        return 0

    def store_or_link_datatype(self, datatype, dt_path, project_id):
        self.check_import_references(dt_path, datatype)
        stored_dt_count = 0
        datatype_already_in_tvb = load.load_entity_by_gid(datatype.gid)
        if not datatype_already_in_tvb:
            self.store_datatype(datatype, dt_path)
            stored_dt_count = 1
        elif datatype_already_in_tvb.parent_operation.project.id != project_id:
            AlgorithmService.create_link(datatype_already_in_tvb.id,
                                         project_id)
            if datatype_already_in_tvb.fk_datatype_group:
                AlgorithmService.create_link(
                    datatype_already_in_tvb.fk_datatype_group, project_id)
        return stored_dt_count

    def _store_imported_datatypes_in_db(self, project, all_datatypes):
        # type: (Project, dict) -> int
        sorted_dts = sorted(
            all_datatypes.items(),
            key=lambda dt_item: dt_item[1].create_date or datetime.now())
        count = 0
        for dt_path, datatype in sorted_dts:
            count += self.store_or_link_datatype(datatype, dt_path, project.id)
        return count

    def _store_imported_images(self, project, temp_project_path, project_name):
        """
        Import all images from project
        """
        images_root = os.path.join(temp_project_path,
                                   StorageInterface.IMAGES_FOLDER)
        target_images_path = self.storage_interface.get_images_folder(
            project_name)
        for root, _, files in os.walk(images_root):
            for metadata_file in files:
                if self.storage_interface.ends_with_tvb_file_extension(
                        metadata_file):
                    self._import_image(root, metadata_file, project.id,
                                       target_images_path)

    @staticmethod
    def _populate_view_model2adapter():
        if len(VIEW_MODEL2ADAPTER) > 0:
            return VIEW_MODEL2ADAPTER
        view_model2adapter = {}
        algos = dao.get_all_algorithms()
        for algo in algos:
            adapter = ABCAdapter.build_adapter(algo)
            view_model_class = adapter.get_view_model_class()
            view_model2adapter[view_model_class] = algo
        return view_model2adapter

    def _retrieve_operations_in_order(self,
                                      project,
                                      import_path,
                                      importer_operation_id=None):
        # type: (Project, str, int) -> list[Operation2ImportData]
        retrieved_operations = []

        for root, _, files in os.walk(import_path):
            if OPERATION_XML in files:
                # Previous Operation format for uploading previous versions of projects
                operation_file_path = os.path.join(root, OPERATION_XML)
                operation, operation_xml_parameters, _ = self.build_operation_from_file(
                    project, operation_file_path)
                operation.import_file = operation_file_path
                self.logger.debug("Found operation in old XML format: " +
                                  str(operation))
                retrieved_operations.append(
                    Operation2ImportData(
                        operation,
                        root,
                        info_from_xml=operation_xml_parameters))

            else:
                # We strive for the new format with ViewModelH5
                main_view_model = None
                dt_paths = []
                all_view_model_files = []
                for file in files:
                    if self.storage_interface.ends_with_tvb_storage_file_extension(
                            file):
                        h5_file = os.path.join(root, file)
                        try:
                            h5_class = H5File.h5_class_from_file(h5_file)
                            if h5_class is ViewModelH5:
                                all_view_model_files.append(h5_file)
                                if not main_view_model:
                                    view_model = h5.load_view_model_from_file(
                                        h5_file)
                                    if type(
                                            view_model
                                    ) in self.view_model2adapter.keys():
                                        main_view_model = view_model
                            else:
                                file_update_manager = FilesUpdateManager()
                                file_update_manager.upgrade_file(h5_file)
                                dt_paths.append(h5_file)
                        except Exception:
                            self.logger.warning(
                                "Unreadable H5 file will be ignored: %s" %
                                h5_file)

                if main_view_model is not None:
                    alg = self.view_model2adapter[type(main_view_model)]
                    op_group_id = None
                    if main_view_model.operation_group_gid:
                        op_group = dao.get_operationgroup_by_gid(
                            main_view_model.operation_group_gid.hex)
                        if not op_group:
                            op_group = OperationGroup(
                                project.id,
                                ranges=json.loads(main_view_model.ranges),
                                gid=main_view_model.operation_group_gid.hex)
                            op_group = dao.store_entity(op_group)
                        op_group_id = op_group.id
                    operation = Operation(
                        main_view_model.gid.hex,
                        project.fk_admin,
                        project.id,
                        alg.id,
                        status=STATUS_FINISHED,
                        user_group=main_view_model.generic_attributes.
                        operation_tag,
                        start_date=datetime.now(),
                        completion_date=datetime.now(),
                        op_group_id=op_group_id,
                        range_values=main_view_model.range_values)
                    operation.create_date = main_view_model.create_date
                    operation.visible = main_view_model.generic_attributes.visible
                    self.logger.debug(
                        "Found main ViewModel to create operation for it: " +
                        str(operation))

                    retrieved_operations.append(
                        Operation2ImportData(operation, root, main_view_model,
                                             dt_paths, all_view_model_files))

                elif len(dt_paths) > 0:
                    alg = dao.get_algorithm_by_module(TVB_IMPORTER_MODULE,
                                                      TVB_IMPORTER_CLASS)
                    default_adapter = ABCAdapter.build_adapter(alg)
                    view_model = default_adapter.get_view_model_class()()
                    view_model.data_file = dt_paths[0]
                    vm_path = h5.store_view_model(view_model, root)
                    all_view_model_files.append(vm_path)
                    operation = Operation(view_model.gid.hex,
                                          project.fk_admin,
                                          project.id,
                                          alg.id,
                                          status=STATUS_FINISHED,
                                          start_date=datetime.now(),
                                          completion_date=datetime.now())
                    operation.create_date = datetime.min
                    self.logger.debug(
                        "Found no ViewModel in folder, so we default to " +
                        str(operation))

                    if importer_operation_id:
                        operation.id = importer_operation_id

                    retrieved_operations.append(
                        Operation2ImportData(operation, root, view_model,
                                             dt_paths, all_view_model_files,
                                             True))

        return sorted(retrieved_operations,
                      key=lambda op_data: op_data.order_field)

    def create_view_model(self,
                          operation_entity,
                          operation_data,
                          new_op_folder,
                          generic_attributes=None,
                          add_params=None):
        view_model = self._get_new_form_view_model(
            operation_entity, operation_data.info_from_xml)
        if add_params is not None:
            for element in add_params:
                key_attr = getattr(view_model, element[0])
                setattr(key_attr, element[1], element[2])

        view_model.range_values = operation_entity.range_values
        op_group = dao.get_operationgroup_by_id(
            operation_entity.fk_operation_group)
        if op_group:
            view_model.operation_group_gid = uuid.UUID(op_group.gid)
            view_model.ranges = json.dumps(op_group.range_references)
            view_model.is_metric_operation = 'DatatypeMeasure' in op_group.name

        if generic_attributes is not None:
            view_model.generic_attributes = generic_attributes
        view_model.generic_attributes.operation_tag = operation_entity.user_group

        h5.store_view_model(view_model, new_op_folder)
        view_model_disk_size = StorageInterface.compute_recursive_h5_disk_usage(
            new_op_folder)
        operation_entity.view_model_disk_size = view_model_disk_size
        operation_entity.view_model_gid = view_model.gid.hex
        dao.store_entity(operation_entity)
        return view_model

    def import_list_of_operations(self,
                                  project,
                                  import_path,
                                  is_group=False,
                                  importer_operation_id=None):
        """
        This method scans provided folder and identify all operations that needs to be imported
        """
        all_dts_count = 0
        all_stored_dts_count = 0
        imported_operations = []
        ordered_operations = self._retrieve_operations_in_order(
            project, import_path, None if is_group else importer_operation_id)

        if is_group and len(ordered_operations) > 0:
            first_op = dao.get_operation_by_id(importer_operation_id)
            vm_path = h5.determine_filepath(first_op.view_model_gid,
                                            os.path.dirname(import_path))
            os.remove(vm_path)

            ordered_operations[0].operation.id = importer_operation_id

        for operation_data in ordered_operations:
            if operation_data.is_old_form:
                operation_entity, datatype_group = self.import_operation(
                    operation_data.operation)
                new_op_folder = self.storage_interface.get_project_folder(
                    project.name, str(operation_entity.id))

                try:
                    operation_datatypes = self._load_datatypes_from_operation_folder(
                        operation_data.operation_folder, operation_entity,
                        datatype_group)
                    # Create and store view_model from operation
                    self.create_view_model(operation_entity, operation_data,
                                           new_op_folder)

                    self._store_imported_datatypes_in_db(
                        project, operation_datatypes)
                    imported_operations.append(operation_entity)
                except MissingReferenceException:
                    operation_entity.status = STATUS_ERROR
                    dao.store_entity(operation_entity)

            elif operation_data.main_view_model is not None:
                operation_data.operation.create_date = datetime.now()
                operation_data.operation.start_date = datetime.now()
                operation_data.operation.completion_date = datetime.now()

                do_merge = False
                if importer_operation_id:
                    do_merge = True
                operation_entity = dao.store_entity(operation_data.operation,
                                                    merge=do_merge)
                dt_group = None
                op_group = dao.get_operationgroup_by_id(
                    operation_entity.fk_operation_group)
                if op_group:
                    dt_group = dao.get_datatypegroup_by_op_group_id(
                        op_group.id)
                    if not dt_group:
                        first_op = dao.get_operations_in_group(
                            op_group.id, only_first_operation=True)
                        dt_group = DataTypeGroup(
                            op_group,
                            operation_id=first_op.id,
                            state=DEFAULTDATASTATE_INTERMEDIATE)
                        dt_group = dao.store_entity(dt_group)
                # Store the DataTypes in db
                dts = {}
                all_dts_count += len(operation_data.dt_paths)
                for dt_path in operation_data.dt_paths:
                    dt = self.load_datatype_from_file(dt_path,
                                                      operation_entity.id,
                                                      dt_group, project.id)
                    if isinstance(dt, BurstConfiguration):
                        if op_group:
                            dt.fk_operation_group = op_group.id
                        all_stored_dts_count += self._store_or_link_burst_config(
                            dt, dt_path, project.id)
                    else:
                        dts[dt_path] = dt
                        if op_group:
                            op_group.fill_operationgroup_name(dt.type)
                            dao.store_entity(op_group)
                try:
                    stored_dts_count = self._store_imported_datatypes_in_db(
                        project, dts)
                    all_stored_dts_count += stored_dts_count

                    if operation_data.main_view_model.is_metric_operation:
                        self._update_burst_metric(operation_entity)

                    imported_operations.append(operation_entity)
                    new_op_folder = self.storage_interface.get_project_folder(
                        project.name, str(operation_entity.id))
                    view_model_disk_size = 0
                    for h5_file in operation_data.all_view_model_files:
                        view_model_disk_size += StorageInterface.compute_size_on_disk(
                            h5_file)
                        shutil.move(h5_file, new_op_folder)
                    operation_entity.view_model_disk_size = view_model_disk_size
                    dao.store_entity(operation_entity)
                except MissingReferenceException as excep:
                    self.storage_interface.remove_operation_data(
                        project.name, operation_entity.id)
                    operation_entity.fk_operation_group = None
                    dao.store_entity(operation_entity)
                    dao.remove_entity(DataTypeGroup, dt_group.id)
                    raise excep
            else:
                self.logger.warning(
                    "Folder %s will be ignored, as we could not find a serialized "
                    "operation or DTs inside!" %
                    operation_data.operation_folder)

            # We want importer_operation_id to be kept just for the first operation (the first iteration)
            if is_group:
                importer_operation_id = None

        self._update_dt_groups(project.id)
        self._update_burst_configurations(project.id)
        return imported_operations, all_dts_count, all_stored_dts_count

    @staticmethod
    def _get_new_form_view_model(operation, xml_parameters):
        # type (Operation) -> ViewModel
        algo = dao.get_algorithm_by_id(operation.fk_from_algo)
        ad = ABCAdapter.build_adapter(algo)
        view_model = ad.get_view_model_class()()

        if xml_parameters:
            declarative_attrs = type(view_model).declarative_attrs

            if isinstance(xml_parameters, str):
                xml_parameters = json.loads(xml_parameters)
            for param in xml_parameters:
                new_param_name = param
                if param != '' and param[0] == "_":
                    new_param_name = param[1:]
                new_param_name = new_param_name.lower()
                if new_param_name in declarative_attrs:
                    try:
                        setattr(view_model, new_param_name,
                                xml_parameters[param])
                    except (TraitTypeError, TraitAttributeError):
                        pass
        return view_model

    def _import_image(self, src_folder, metadata_file, project_id,
                      target_images_path):
        """
        Create and store a image entity.
        """
        figure_dict = StorageInterface().read_metadata_from_xml(
            os.path.join(src_folder, metadata_file))
        actual_figure = os.path.join(
            src_folder,
            os.path.split(figure_dict['file_path'])[1])
        if not os.path.exists(actual_figure):
            self.logger.warning("Expected to find image path %s .Skipping" %
                                actual_figure)
            return
        figure_dict['fk_user_id'] = self.user_id
        figure_dict['fk_project_id'] = project_id
        figure_entity = manager_of_class(ResultFigure).new_instance()
        figure_entity = figure_entity.from_dict(figure_dict)
        stored_entity = dao.store_entity(figure_entity)

        # Update image meta-data with the new details after import
        figure = dao.load_figure(stored_entity.id)
        shutil.move(actual_figure, target_images_path)
        self.logger.debug("Store imported figure")
        _, meta_data = figure.to_dict()
        self.storage_interface.write_image_metadata(figure, meta_data)

    def load_datatype_from_file(self,
                                current_file,
                                op_id,
                                datatype_group=None,
                                current_project_id=None):
        # type: (str, int, DataTypeGroup, int) -> HasTraitsIndex
        """
        Creates an instance of datatype from storage / H5 file
        :returns: DatatypeIndex
        """
        self.logger.debug("Loading DataType from file: %s" % current_file)
        h5_class = H5File.h5_class_from_file(current_file)

        if h5_class is BurstConfigurationH5:
            if current_project_id is None:
                op_entity = dao.get_operationgroup_by_id(op_id)
                current_project_id = op_entity.fk_launched_in
            h5_file = BurstConfigurationH5(current_file)
            burst = BurstConfiguration(current_project_id)
            burst.fk_simulation = op_id
            h5_file.load_into(burst)
            result = burst
        else:
            datatype, generic_attributes = h5.load_with_links(current_file)

            already_existing_datatype = h5.load_entity_by_gid(datatype.gid)
            if datatype_group is not None and already_existing_datatype is not None:
                raise DatatypeGroupImportException(
                    "The datatype group that you are trying to import"
                    " already exists!")
            index_class = h5.REGISTRY.get_index_for_datatype(
                datatype.__class__)
            datatype_index = index_class()
            datatype_index.fill_from_has_traits(datatype)
            datatype_index.fill_from_generic_attributes(generic_attributes)

            if datatype_group is not None and hasattr(datatype_index, 'fk_source_gid') and \
                    datatype_index.fk_source_gid is not None:
                ts = h5.load_entity_by_gid(datatype_index.fk_source_gid)

                if ts is None:
                    op = dao.get_operations_in_group(
                        datatype_group.fk_operation_group,
                        only_first_operation=True)
                    op.fk_operation_group = None
                    dao.store_entity(op)
                    dao.remove_entity(OperationGroup,
                                      datatype_group.fk_operation_group)
                    dao.remove_entity(DataTypeGroup, datatype_group.id)
                    raise DatatypeGroupImportException(
                        "Please import the time series group before importing the"
                        " datatype measure group!")

            # Add all the required attributes
            if datatype_group:
                datatype_index.fk_datatype_group = datatype_group.id
                if len(datatype_group.subject) == 0:
                    datatype_group.subject = datatype_index.subject
                    dao.store_entity(datatype_group)
            datatype_index.fk_from_operation = op_id

            associated_file = h5.path_for_stored_index(datatype_index)
            if os.path.exists(associated_file):
                datatype_index.disk_size = StorageInterface.compute_size_on_disk(
                    associated_file)
            result = datatype_index

        return result

    def store_datatype(self, datatype, current_file=None):
        """This method stores data type into DB"""
        try:
            self.logger.debug("Store datatype: %s with Gid: %s" %
                              (datatype.__class__.__name__, datatype.gid))
            # Now move storage file into correct folder if necessary
            if current_file is not None:
                final_path = h5.path_for_stored_index(datatype)
                if final_path != current_file:
                    shutil.move(current_file, final_path)
            stored_entry = load.load_entity_by_gid(datatype.gid)
            if not stored_entry:
                stored_entry = dao.store_entity(datatype)

            return stored_entry
        except MissingDataSetException as e:
            self.logger.exception(e)
            error_msg = "Datatype %s has missing data and could not be imported properly." % (
                datatype, )
            raise ImportException(error_msg)
        except IntegrityError as excep:
            self.logger.exception(excep)
            error_msg = "Could not import data with gid: %s. There is already a one with " \
                        "the same name or gid." % datatype.gid
            raise ImportException(error_msg)

    def __populate_project(self, project_path):
        """
        Create and store a Project entity.
        """
        self.logger.debug("Creating project from path: %s" % project_path)
        project_dict = self.storage_interface.read_project_metadata(
            project_path)

        project_entity = manager_of_class(Project).new_instance()
        project_entity = project_entity.from_dict(project_dict, self.user_id)

        try:
            self.logger.debug("Storing imported project")
            return dao.store_entity(project_entity)
        except IntegrityError as excep:
            self.logger.exception(excep)
            error_msg = (
                "Could not import project: %s with gid: %s. There is already a "
                "project with the same name or gid.") % (project_entity.name,
                                                         project_entity.gid)
            raise ImportException(error_msg)

    def build_operation_from_file(self, project, operation_file):
        """
        Create Operation entity from metadata file.
        """
        operation_dict = StorageInterface().read_metadata_from_xml(
            operation_file)
        operation_entity = manager_of_class(Operation).new_instance()
        return operation_entity.from_dict(operation_dict, dao, self.user_id,
                                          project.gid)

    @staticmethod
    def import_operation(operation_entity, migration=False):
        """
        Store a Operation entity.
        """
        do_merge = False
        if operation_entity.id:
            do_merge = True
        operation_entity = dao.store_entity(operation_entity, merge=do_merge)
        operation_group_id = operation_entity.fk_operation_group
        datatype_group = None

        if operation_group_id is not None:
            datatype_group = dao.get_datatypegroup_by_op_group_id(
                operation_group_id)

            if datatype_group is None and migration is False:
                # If no dataType group present for current op. group, create it.
                operation_group = dao.get_operationgroup_by_id(
                    operation_group_id)
                datatype_group = DataTypeGroup(
                    operation_group, operation_id=operation_entity.id)
                datatype_group.state = UploadAlgorithmCategoryConfig.defaultdatastate
                datatype_group = dao.store_entity(datatype_group)

        return operation_entity, datatype_group

    def import_simulator_configuration_zip(self, zip_file):
        # Now compute the name of the folder where to explode uploaded ZIP file
        temp_folder = self._compute_unpack_path()
        uq_file_name = temp_folder + ".zip"

        if isinstance(zip_file, (FieldStorage, Part)):
            if not zip_file.file:
                raise ServicesBaseException(
                    "Could not process the given ZIP file...")

            with open(uq_file_name, 'wb') as file_obj:
                self.storage_interface.copy_file(zip_file.file, file_obj)
        else:
            shutil.copy2(zip_file, uq_file_name)

        try:
            self.storage_interface.unpack_zip(uq_file_name, temp_folder)
            return temp_folder
        except FileStructureException as excep:
            raise ServicesBaseException(
                "Could not process the given ZIP file..." + str(excep))

    @staticmethod
    def _update_burst_metric(operation_entity):
        burst_config = dao.get_burst_for_operation_id(operation_entity.id)
        if burst_config and burst_config.ranges:
            if burst_config.fk_metric_operation_group is None:
                burst_config.fk_metric_operation_group = operation_entity.fk_operation_group
            dao.store_entity(burst_config)

    @staticmethod
    def _update_dt_groups(project_id):
        dt_groups = dao.get_datatypegroup_for_project(project_id)
        for dt_group in dt_groups:
            dt_group.count_results = dao.count_datatypes_in_group(dt_group.id)
            dts_in_group = dao.get_datatypes_from_datatype_group(dt_group.id)
            if dts_in_group:
                dt_group.fk_parent_burst = dts_in_group[0].fk_parent_burst
            dao.store_entity(dt_group)

    @staticmethod
    def _update_burst_configurations(project_id):
        burst_configs = dao.get_bursts_for_project(project_id)
        for burst_config in burst_configs:
            burst_config.datatypes_number = dao.count_datatypes_in_burst(
                burst_config.gid)
            dao.store_entity(burst_config)
Beispiel #7
0
    def test_encrypt_decrypt(self, dir_name, file_name):
        import_export_encryption_handler = StorageInterface.get_import_export_encryption_handler(
        )

        # Generate a private key and public key
        private_key = rsa.generate_private_key(public_exponent=65537,
                                               key_size=2048,
                                               backend=default_backend())

        public_key = private_key.public_key()

        # Convert private key to bytes and save it so ABCUploader can use it
        pem = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption())

        private_key_path = os.path.join(
            TvbProfile.current.TVB_TEMP_FOLDER,
            import_export_encryption_handler.PRIVATE_KEY_NAME)
        with open(private_key_path, 'wb') as f:
            f.write(pem)

        path_to_file = os.path.join(os.path.dirname(tvb_data.__file__),
                                    dir_name, file_name)

        # Create model for ABCUploader
        connectivity_model = ZIPConnectivityImporterModel()

        # Generate password
        pass_size = TvbProfile.current.hpc.CRYPT_PASS_SIZE
        password = EncryptionHandler.generate_random_password(pass_size)

        # Encrypt files using an AES symmetric key
        encrypted_file_path = import_export_encryption_handler.get_path_to_encrypt(
            path_to_file)
        buffer_size = TvbProfile.current.hpc.CRYPT_BUFFER_SIZE
        pyAesCrypt.encryptFile(path_to_file, encrypted_file_path, password,
                               buffer_size)

        # Asynchronously encrypt the password used at the previous step for the symmetric encryption
        password = str.encode(password)
        encrypted_password = import_export_encryption_handler.encrypt_password(
            public_key, password)

        # Save encrypted password
        import_export_encryption_handler.save_encrypted_password(
            encrypted_password, TvbProfile.current.TVB_TEMP_FOLDER)
        path_to_encrypted_password = os.path.join(
            TvbProfile.current.TVB_TEMP_FOLDER,
            import_export_encryption_handler.ENCRYPTED_PASSWORD_NAME)

        # Prepare model for decrypting
        connectivity_model.uploaded = encrypted_file_path
        connectivity_model.encrypted_aes_key = path_to_encrypted_password
        TvbProfile.current.UPLOAD_KEY_PATH = os.path.join(
            TvbProfile.current.TVB_TEMP_FOLDER,
            import_export_encryption_handler.PRIVATE_KEY_NAME)

        # Decrypting
        decrypted_download_path = import_export_encryption_handler.decrypt_content(
            connectivity_model.encrypted_aes_key,
            [connectivity_model.uploaded],
            TvbProfile.current.UPLOAD_KEY_PATH)[0]
        connectivity_model.uploaded = decrypted_download_path

        storage_interface = StorageInterface()
        decrypted_file_path = connectivity_model.uploaded.replace(
            import_export_encryption_handler.ENCRYPTED_DATA_SUFFIX,
            import_export_encryption_handler.DECRYPTED_DATA_SUFFIX)

        TestExporters.compare_files(path_to_file, decrypted_file_path)

        # Clean-up
        storage_interface.remove_files([
            encrypted_file_path, decrypted_file_path, private_key_path,
            path_to_encrypted_password
        ])