Пример #1
0
    def run(self):
        """
        Get the required data from the operation queue and launch the operation.
        """
        operation_id = self.operation_id
        run_params = [TvbProfile.current.PYTHON_INTERPRETER_PATH, '-m', 'tvb.core.operation_async_launcher',
                      str(operation_id), TvbProfile.CURRENT_PROFILE_NAME]

        current_operation = dao.get_operation_by_id(operation_id)
        storage_interface = StorageInterface()
        project_folder = storage_interface.get_project_folder(current_operation.project.name)
        in_usage = storage_interface.is_in_usage(project_folder)
        storage_interface.inc_running_op_count(project_folder)
        if not in_usage:
            storage_interface.sync_folders(project_folder)
        # In the exceptional case where the user pressed stop while the Thread startup is done,
        # We should no longer launch the operation.
        if self.stopped() is False:

            env = os.environ.copy()
            env['PYTHONPATH'] = os.pathsep.join(sys.path)
            # anything that was already in $PYTHONPATH should have been reproduced in sys.path

            launched_process = Popen(run_params, stdout=PIPE, stderr=PIPE, env=env)

            LOGGER.debug("Storing pid=%s for operation id=%s launched on local machine." % (operation_id,
                                                                                            launched_process.pid))
            op_ident = OperationProcessIdentifier(operation_id, pid=launched_process.pid)
            dao.store_entity(op_ident)

            if self.stopped():
                # In the exceptional case where the user pressed stop while the Thread startup is done.
                # and stop_operation is concurrently asking about OperationProcessIdentity.
                self.stop_pid(launched_process.pid)

            subprocess_result = launched_process.communicate()
            LOGGER.info("Finished with launch of operation %s" % operation_id)
            returned = launched_process.wait()

            LOGGER.info("Return code: {}. Stopped: {}".format(returned, self.stopped()))
            LOGGER.info("Thread: {}".format(self))
            if returned != 0 and not self.stopped():
                # Process did not end as expected. (e.g. Segmentation fault)
                burst_service = BurstService()
                operation = dao.get_operation_by_id(self.operation_id)
                LOGGER.error("Operation suffered fatal failure! Exit code: %s Exit message: %s" % (returned,
                                                                                                   subprocess_result))
                burst_service.persist_operation_state(operation, STATUS_ERROR,
                                                      "Operation failed unexpectedly! Please check the log files.")

            del launched_process

        storage_interface.check_and_delete(project_folder)

        # Give back empty spot now that you finished your operation
        CURRENT_ACTIVE_THREADS.remove(self)
        LOCKS_QUEUE.put(1)
Пример #2
0
class SimulatorService(object):
    def __init__(self):
        self.logger = get_logger(self.__class__.__module__)
        self.burst_service = BurstService()
        self.operation_service = OperationService()
        self.algorithm_service = AlgorithmService()
        self.storage_interface = StorageInterface()

    @staticmethod
    def _reset_model(session_stored_simulator):
        session_stored_simulator.model = type(session_stored_simulator.model)()
        vi_indexes = session_stored_simulator.determine_indexes_for_chosen_vars_of_interest(
        )
        vi_indexes = numpy.array(list(vi_indexes.values()))
        for monitor in session_stored_simulator.monitors:
            monitor.variables_of_interest = vi_indexes

    def reset_at_connectivity_change(self, is_simulator_copy, form,
                                     session_stored_simulator):
        """
        In case the user copies a simulation and changes the Connectivity, we want to reset the Model and Noise
        parameters because they might not fit to the new Connectivity's nr of regions.
        """
        if is_simulator_copy and form.connectivity.value != session_stored_simulator.connectivity:
            self._reset_model(session_stored_simulator)
            if issubclass(type(session_stored_simulator.integrator),
                          IntegratorStochastic):
                session_stored_simulator.integrator.noise = type(
                    session_stored_simulator.integrator.noise)()

    def reset_at_surface_change(self, is_simulator_copy, form,
                                session_stored_simulator):
        """
        In case the user copies a surface-simulation and changes the Surface, we want to reset the Model
        parameters because they might not fit to the new Surface's nr of vertices.
        """
        if is_simulator_copy and (
                session_stored_simulator.surface is None and form.surface.value
                or session_stored_simulator.surface and form.surface.value !=
                session_stored_simulator.surface.surface_gid):
            self._reset_model(session_stored_simulator)

    @staticmethod
    def _set_simulator_range_parameter(simulator, range_parameter_name,
                                       range_parameter_value):
        range_param_name_list = range_parameter_name.split('.')
        current_attr = simulator
        for param_name in range_param_name_list[:len(range_param_name_list) -
                                                1]:
            current_attr = getattr(current_attr, param_name)
        setattr(current_attr, range_param_name_list[-1], range_parameter_value)

    def async_launch_and_prepare_simulation(self, burst_config, user, project,
                                            simulator_algo, simulator):
        try:
            operation = self.operation_service.prepare_operation(
                user.id,
                project,
                simulator_algo,
                view_model=simulator,
                burst_gid=burst_config.gid,
                op_group_id=burst_config.fk_operation_group)
            burst_config = self.burst_service.update_simulation_fields(
                burst_config, operation.id, simulator.gid)
            storage_path = self.storage_interface.get_project_folder(
                project.name, str(operation.id))
            self.burst_service.store_burst_configuration(
                burst_config, storage_path)

            wf_errs = 0
            try:
                OperationService().launch_operation(operation.id, True)
                return operation
            except Exception as excep:
                self.logger.error(excep)
                wf_errs += 1
                if burst_config:
                    self.burst_service.mark_burst_finished(
                        burst_config, error_message=str(excep))

            self.logger.debug(
                "Finished launching workflow. The operation was launched successfully, "
                + str(wf_errs) + " had error on pre-launch steps")

        except Exception as excep:
            self.logger.error(excep)
            if burst_config:
                self.burst_service.mark_burst_finished(
                    burst_config, error_message=str(excep))

    def prepare_simulation_on_server(self, user_id, project, algorithm,
                                     zip_folder_path, simulator_file):
        simulator_vm = h5.load_view_model_from_file(simulator_file)
        operation = self.operation_service.prepare_operation(
            user_id, project, algorithm, view_model=simulator_vm)
        self.async_launch_simulation_on_server(operation, zip_folder_path)

        return operation

    def async_launch_simulation_on_server(self, operation, zip_folder_path):
        try:
            OperationService().launch_operation(operation.id, True)
            return operation
        except Exception as excep:
            self.logger.error(excep)
        finally:
            shutil.rmtree(zip_folder_path)

    @staticmethod
    def _set_range_param_in_dict(param_value):
        if type(param_value) is numpy.ndarray:
            return param_value[0]
        elif isinstance(param_value, uuid.UUID):
            return param_value.hex
        else:
            return param_value

    def async_launch_and_prepare_pse(self, burst_config, user, project,
                                     simulator_algo, range_param1,
                                     range_param2, session_stored_simulator):
        try:
            algo_category = simulator_algo.algorithm_category
            operation_group = burst_config.operation_group
            metric_operation_group = burst_config.metric_operation_group
            range_param2_values = [None]
            if range_param2:
                range_param2_values = range_param2.get_range_values()
            GROUP_BURST_PENDING[burst_config.id] = True
            operations, pse_canceled = self._prepare_operations(
                algo_category, burst_config, metric_operation_group,
                operation_group, project, range_param1, range_param2,
                range_param2_values, session_stored_simulator, simulator_algo,
                user)

            GROUP_BURST_PENDING[burst_config.id] = False
            if pse_canceled:
                return

            wf_errs = self._launch_operations(operations, burst_config)
            self.logger.debug("Finished launching workflows. " +
                              str(len(operations) - wf_errs) +
                              " were launched successfully, " + str(wf_errs) +
                              " had error on pre-launch steps")
            return operations[0] if len(operations) > 0 else None

        except Exception as excep:
            self.logger.error(excep)
            self.burst_service.mark_burst_finished(burst_config,
                                                   error_message=str(excep))

    def _launch_operations(self, operations, burst_config):
        wf_errs = 0
        for operation in operations:
            try:
                burst_config = dao.get_burst_by_id(burst_config.id)
                if burst_config is None or burst_config.status in [
                        BurstConfiguration.BURST_CANCELED,
                        BurstConfiguration.BURST_ERROR
                ]:
                    self.logger.debug(
                        "Preparing operations cannot continue. Burst config {}"
                        .format(burst_config))
                    return
                OperationService().launch_operation(operation.id, True)
            except Exception as excep:
                self.logger.error(excep)
                wf_errs += 1
                self.burst_service.mark_burst_finished(
                    burst_config, error_message=str(excep))
        return wf_errs

    def _prepare_operations(self, algo_category, burst_config,
                            metric_operation_group, operation_group, project,
                            range_param1, range_param2, range_param2_values,
                            session_stored_simulator, simulator_algo, user):
        first_simulator = None
        pse_canceled = False
        operations = []
        for param1_value in range_param1.get_range_values():
            for param2_value in range_param2_values:
                burst_config = dao.get_burst_by_id(burst_config.id)
                if burst_config is None:
                    self.logger.debug("Burst config was deleted")
                    pse_canceled = True
                    break

                if burst_config.status in [
                        BurstConfiguration.BURST_CANCELED,
                        BurstConfiguration.BURST_ERROR
                ]:
                    self.logger.debug(
                        "Current burst status is {}. Preparing operations cannot continue."
                        .format(burst_config.status))
                    pse_canceled = True
                    break
                # Copy, but generate a new GUID for every Simulator in PSE
                simulator = copy.deepcopy(session_stored_simulator)
                simulator.gid = uuid.uuid4()
                self._set_simulator_range_parameter(simulator,
                                                    range_param1.name,
                                                    param1_value)

                ranges = {
                    range_param1.name:
                    self._set_range_param_in_dict(param1_value)
                }

                if param2_value is not None:
                    self._set_simulator_range_parameter(
                        simulator, range_param2.name, param2_value)
                    ranges[range_param2.name] = self._set_range_param_in_dict(
                        param2_value)

                ranges = json.dumps(ranges)

                operation = self.operation_service.prepare_operation(
                    user.id,
                    project,
                    simulator_algo,
                    view_model=simulator,
                    ranges=ranges,
                    burst_gid=burst_config.gid,
                    op_group_id=burst_config.fk_operation_group)
                simulator.range_values = ranges
                operations.append(operation)
                if first_simulator is None:
                    first_simulator = simulator
                    storage_path = self.storage_interface.get_project_folder(
                        project.name, str(operation.id))
                    burst_config = self.burst_service.update_simulation_fields(
                        burst_config, operation.id, first_simulator.gid)
                    self.burst_service.store_burst_configuration(
                        burst_config, storage_path)
                    datatype_group = DataTypeGroup(
                        operation_group,
                        operation_id=operation.id,
                        fk_parent_burst=burst_config.gid,
                        state=algo_category.defaultdatastate)
                    dao.store_entity(datatype_group)

                    metrics_datatype_group = DataTypeGroup(
                        metric_operation_group,
                        fk_parent_burst=burst_config.gid,
                        state=algo_category.defaultdatastate)
                    dao.store_entity(metrics_datatype_group)
        return operations, pse_canceled

    @staticmethod
    def compute_conn_branch_conditions(is_branch, simulator):
        if not is_branch:
            return None

        conn = load.load_entity_by_gid(simulator.connectivity)
        if conn.number_of_regions:
            return FilterChain(
                fields=[FilterChain.datatype + '.number_of_regions'],
                operations=["=="],
                values=[conn.number_of_regions])

    @staticmethod
    def validate_first_fragment(form, project_id, conn_idx):
        conn_count = dao.count_datatypes(project_id, conn_idx)
        if conn_count == 0:
            form.connectivity.errors.append(
                "No connectivity in the project! Simulation cannot be started without "
                "a connectivity!")

    def get_simulation_state_index(self, burst_config,
                                   simulation_history_class):
        parent_burst = burst_config.parent_burst_object
        simulation_state_index = dao.get_generic_entity(
            simulation_history_class, parent_burst.gid, "fk_parent_burst")

        if simulation_state_index is None or len(simulation_state_index) < 1:
            exc = BurstServiceException(
                "Simulation State not found for %s, thus we are unable to branch from "
                "it!" % burst_config.name)
            self.logger.error(exc)
            raise exc

        return simulation_state_index
Пример #3
0
class ExportManager(object):
    """
    This class provides basic methods for exporting data types of projects in different formats.
    """
    all_exporters = {}  # Dictionary containing all available exporters
    export_folder = None
    EXPORT_FOLDER_NAME = "EXPORT_TMP"
    EXPORTED_SIMULATION_NAME = "exported_simulation"
    EXPORTED_SIMULATION_DTS_DIR = "datatypes"
    logger = get_logger(__name__)

    def __init__(self):
        # Here we register all available data type exporters
        # If new exporters supported, they should be added here
        self._register_exporter(TVBExporter())
        self._register_exporter(TVBLinkedExporter())
        self.export_folder = os.path.join(TvbProfile.current.TVB_STORAGE, self.EXPORT_FOLDER_NAME)
        self.storage_interface = StorageInterface()

    def _register_exporter(self, exporter):
        """
        This method register into an internal format available exporters.
        :param exporter: Instance of a data type exporter (extends ABCExporter)
        """
        if exporter is not None:
            self.all_exporters[exporter.__class__.__name__] = exporter

    def get_exporters_for_data(self, data):
        """
        Get available exporters for current data type.
        :returns: a dictionary with the {exporter_id : label}
        """
        if data is None:
            raise InvalidExportDataException("Could not detect exporters for null data")

        self.logger.debug("Trying to determine exporters valid for %s" % data.type)
        results = {}

        # No exporter for None data
        if data is None:
            return results

        for exporterId in self.all_exporters.keys():
            exporter = self.all_exporters[exporterId]
            if exporter.accepts(data):
                results[exporterId] = exporter.get_label()

        return results

    def export_data(self, data, exporter_id, project):
        """
        Export provided data using given exporter
        :param data: data type to be exported
        :param exporter_id: identifier of the exporter to be used
        :param project: project that contains data to be exported

        :returns: a tuple with the following elements
            1. name of the file to be shown to user
            2. full path of the export file (available for download)
            3. boolean which specify if file can be deleted after download
        """
        if data is None:
            raise InvalidExportDataException("Could not export null data. Please select data to be exported")

        if exporter_id is None:
            raise ExportException("Please select the exporter to be used for this operation")

        if exporter_id not in self.all_exporters:
            raise ExportException("Provided exporter identifier is not a valid one")

        exporter = self.all_exporters[exporter_id]

        if project is None:
            raise ExportException("Please provide the project where data files are stored")

        # Now we start the real export
        if not exporter.accepts(data):
            raise InvalidExportDataException("Current data can not be exported by specified exporter")

        # Now compute and create folder where to store exported data
        # This will imply to generate a folder which is unique for each export
        data_export_folder = None
        try:
            data_export_folder = self.storage_interface.build_data_export_folder(data, self.export_folder)
            self.logger.debug("Start export of data: %s" % data.type)
            export_data = exporter.export(data, data_export_folder, project)
        finally:
            # In case export did not generated any file delete folder
            if data_export_folder is not None and len(os.listdir(data_export_folder)) == 0:
                os.rmdir(data_export_folder)

        return export_data

    def _export_linked_datatypes(self, project):
        linked_paths = ProjectService().get_linked_datatypes_storage_path(project)

        if not linked_paths:
            # do not export an empty operation
            return None, None

        # Make an import operation which will contain links to other projects
        algo = dao.get_algorithm_by_module(TVB_IMPORTER_MODULE, TVB_IMPORTER_CLASS)
        op = model_operation.Operation(None, None, project.id, algo.id)
        op.project = project
        op.algorithm = algo
        op.id = 'links-to-external-projects'
        op.start_now()
        op.mark_complete(model_operation.STATUS_FINISHED)

        return linked_paths, op

    def export_project(self, project):
        """
        Given a project root and the TVB storage_path, create a ZIP
        ready for export.
        :param project: project object which identifies project to be exported
        """
        if project is None:
            raise ExportException("Please provide project to be exported")

        folders_to_exclude = self._get_op_with_errors(project.id)
        linked_paths, op = self._export_linked_datatypes(project)

        result_path = self.storage_interface.export_project(project, folders_to_exclude,
                                                            self.export_folder, linked_paths, op)

        return result_path

    @staticmethod
    def _get_op_with_errors(project_id):
        """
        Get the operation folders with error base name as list.
        """
        operations = dao.get_operations_with_error_in_project(project_id)
        op_with_errors = []
        for op in operations:
            op_with_errors.append(op.id)
        return op_with_errors

    def export_simulator_configuration(self, burst_id):
        burst = dao.get_burst_by_id(burst_id)
        if burst is None:
            raise InvalidExportDataException("Could not find burst with ID " + str(burst_id))

        op_folder = self.storage_interface.get_project_folder(burst.project.name, str(burst.fk_simulation))
        tmp_export_folder = self.storage_interface.build_data_export_folder(burst, self.export_folder)
        tmp_sim_folder = os.path.join(tmp_export_folder, self.EXPORTED_SIMULATION_NAME)

        if not os.path.exists(tmp_sim_folder):
            os.makedirs(tmp_sim_folder)

        all_view_model_paths, all_datatype_paths = h5.gather_references_of_view_model(burst.simulator_gid, op_folder)

        burst_path = h5.determine_filepath(burst.gid, op_folder)
        all_view_model_paths.append(burst_path)

        for vm_path in all_view_model_paths:
            dest = os.path.join(tmp_sim_folder, os.path.basename(vm_path))
            self.storage_interface.copy_file(vm_path, dest)

        for dt_path in all_datatype_paths:
            dest = os.path.join(tmp_sim_folder, self.EXPORTED_SIMULATION_DTS_DIR, os.path.basename(dt_path))
            self.storage_interface.copy_file(dt_path, dest)

        main_vm_path = h5.determine_filepath(burst.simulator_gid, tmp_sim_folder)
        H5File.remove_metadata_param(main_vm_path, 'history_gid')

        now = datetime.now()
        date_str = now.strftime("%Y-%m-%d_%H-%M")
        zip_file_name = "%s_%s.%s" % (date_str, str(burst_id), StorageInterface.ZIP_FILE_EXTENSION)

        result_path = os.path.join(tmp_export_folder, zip_file_name)
        self.storage_interface.write_zip_folder(result_path, tmp_sim_folder)

        self.storage_interface.remove_folder(tmp_sim_folder)
        return result_path
Пример #4
0
class TestSimulationResource(RestResourceTest):
    def transactional_setup_method(self):
        self.test_user = TestFactory.create_user('Rest_User')
        self.test_project = TestFactory.create_project(
            self.test_user, 'Rest_Project', users=[self.test_user.id])
        self.simulation_resource = FireSimulationResource()
        self.storage_interface = StorageInterface()

    def test_server_fire_simulation_inexistent_gid(self, mocker):
        self._mock_user(mocker)
        project_gid = "inexistent-gid"
        dummy_file = FileStorage(BytesIO(b"test"), 'test.zip')
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request')
        request_mock.files = {
            RequestFileKey.SIMULATION_FILE_KEY.value: dummy_file
        }

        with pytest.raises(InvalidIdentifierException):
            self.simulation_resource.post(project_gid=project_gid)

    def test_server_fire_simulation_no_file(self, mocker):
        self._mock_user(mocker)
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request')
        request_mock.files = {}

        with pytest.raises(InvalidIdentifierException):
            self.simulation_resource.post(project_gid='')

    def test_server_fire_simulation_bad_extension(self, mocker):
        self._mock_user(mocker)
        dummy_file = FileStorage(BytesIO(b"test"), 'test.txt')
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request')
        request_mock.files = {
            RequestFileKey.SIMULATION_FILE_KEY.value: dummy_file
        }

        with pytest.raises(InvalidIdentifierException):
            self.simulation_resource.post(project_gid='')

    def test_server_fire_simulation(self, mocker, connectivity_factory):
        self._mock_user(mocker)
        input_folder = self.storage_interface.get_project_folder(
            self.test_project.name)
        sim_dir = os.path.join(input_folder, 'test_sim')
        if not os.path.isdir(sim_dir):
            os.makedirs(sim_dir)

        simulator = SimulatorAdapterModel()
        simulator.connectivity = connectivity_factory().gid
        h5.store_view_model(simulator, sim_dir)

        zip_filename = os.path.join(input_folder,
                                    RequestFileKey.SIMULATION_FILE_NAME.value)
        self.storage_interface.zip_folder(zip_filename, sim_dir)

        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request')
        fp = open(zip_filename, 'rb')
        request_mock.files = {
            RequestFileKey.SIMULATION_FILE_KEY.value:
            FileStorage(fp, os.path.basename(zip_filename))
        }

        def launch_sim(self, user_id, project, algorithm, zip_folder_path,
                       simulator_file):
            return Operation('', '', '', {})

        # Mock simulation launch and current user
        mocker.patch.object(SimulatorService, 'prepare_simulation_on_server',
                            launch_sim)

        operation_gid, status = self.simulation_resource.post(
            project_gid=self.test_project.gid)
        fp.close()

        assert type(operation_gid) is str
        assert status == 201

    def transactional_teardown_method(self):
        self.storage_interface.remove_project_structure(self.test_project.name)
Пример #5
0
class TestProjectService(TransactionalTestCase):
    """
    This class contains tests for the tvb.core.services.project_service module.
    """

    def transactional_setup_method(self):
        """
        Reset the database before each test.
        """
        self.project_service = ProjectService()
        self.storage_interface = StorageInterface()
        self.test_user = TestFactory.create_user()

    def transactional_teardown_method(self):
        """
        Remove project folders.
        """
        self.delete_project_folders()

    def test_create_project_happy_flow(self):

        user1 = TestFactory.create_user('test_user1')
        user2 = TestFactory.create_user('test_user2')
        initial_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(initial_projects) == 0, "Database reset probably failed!"

        TestFactory.create_project(self.test_user, 'test_project', "description", users=[user1.id, user2.id])

        resulting_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(resulting_projects) == 1, "Project with valid data not inserted!"
        project = resulting_projects[0]
        assert project.name == "test_project", "Invalid retrieved project name"
        assert project.description == "description", "Description do no match"

        users_for_project = dao.get_members_of_project(project.id)
        for user in users_for_project:
            assert user.id in [user1.id, user2.id, self.test_user.id], "Users not stored properly."
        assert os.path.exists(os.path.join(TvbProfile.current.TVB_STORAGE,
                                           StorageInterface.PROJECTS_FOLDER,
                                           "test_project")), "Folder for project was not created"

    def test_create_project_empty_name(self):
        """
        Creating a project with an empty name.
        """
        data = dict(name="", description="test_description", users=[])
        initial_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(initial_projects) == 0, "Database reset probably failed!"
        with pytest.raises(ProjectServiceException):
            self.project_service.store_project(self.test_user, True, None, **data)

    def test_edit_project_happy_flow(self):
        """
        Standard flow for editing an existing project.
        """
        selected_project = TestFactory.create_project(self.test_user, 'test_proj')
        proj_root = self.storage_interface.get_project_folder(selected_project.name)
        initial_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(initial_projects) == 1, "Database initialization probably failed!"

        edited_data = dict(name="test_project", description="test_description", users=[])
        edited_project = self.project_service.store_project(self.test_user, False, selected_project.id, **edited_data)
        assert not os.path.exists(proj_root), "Previous folder not deleted"
        proj_root = self.storage_interface.get_project_folder(edited_project.name)
        assert os.path.exists(proj_root), "New folder not created!"
        assert selected_project.name != edited_project.name, "Project was no changed!"

    def test_edit_project_unexisting(self):
        """
        Trying to edit an un-existing project.
        """
        selected_project = TestFactory.create_project(self.test_user, 'test_proj')
        self.storage_interface.get_project_folder(selected_project.name)
        initial_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(initial_projects) == 1, "Database initialization probably failed!"
        data = dict(name="test_project", description="test_description", users=[])
        with pytest.raises(ProjectServiceException):
            self.project_service.store_project(self.test_user, False, 99, **data)

    def test_find_project_happy_flow(self):
        """
        Standard flow for finding a project by it's id.
        """
        initial_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(initial_projects) == 0, "Database reset probably failed!"
        inserted_project = TestFactory.create_project(self.test_user, 'test_project')
        assert self.project_service.find_project(inserted_project.id) is not None, "Project not found !"
        dao_returned_project = dao.get_project_by_id(inserted_project.id)
        service_returned_project = self.project_service.find_project(inserted_project.id)
        assert dao_returned_project.id == service_returned_project.id, \
            "Data returned from service is different from data returned by DAO."
        assert dao_returned_project.name == service_returned_project.name, \
            "Data returned from service is different than  data returned by DAO."
        assert dao_returned_project.description == service_returned_project.description, \
            "Data returned from service is different from data returned by DAO."
        assert dao_returned_project.members == service_returned_project.members, \
            "Data returned from service is different from data returned by DAO."

    def test_find_project_unexisting(self):
        """
        Searching for an un-existing project.
        """
        data = dict(name="test_project", description="test_description", users=[])
        initial_projects = dao.get_projects_for_user(self.test_user.id)
        assert len(initial_projects) == 0, "Database reset probably failed!"
        project = self.project_service.store_project(self.test_user, True, None, **data)
        # fetch a likely non-existing project. Previous project id plus a 'big' offset
        with pytest.raises(ProjectServiceException):
            self.project_service.find_project(project.id + 1033)

    def test_retrieve_projects_for_user(self):
        """
        Test for retrieving the projects for a given user. One page only.
        """
        initial_projects = self.project_service.retrieve_projects_for_user(self.test_user.id)[0]
        assert len(initial_projects) == 0, "Database was not reset properly!"
        TestFactory.create_project(self.test_user, 'test_proj')
        TestFactory.create_project(self.test_user, 'test_proj1')
        TestFactory.create_project(self.test_user, 'test_proj2')
        user1 = TestFactory.create_user('another_user')
        TestFactory.create_project(user1, 'test_proj3')
        projects = self.project_service.retrieve_projects_for_user(self.test_user.id)[0]
        assert len(projects) == 3, "Projects not retrieved properly!"
        for project in projects:
            assert project.name != "test_project3", "This project should not have been retrieved"

    def test_retrieve_1project_3usr(self):
        """
        One user as admin, two users as members, getting projects for admin and for any of
        the members should return one.
        """
        member1 = TestFactory.create_user("member1")
        member2 = TestFactory.create_user("member2")
        TestFactory.create_project(self.test_user, 'Testproject', users=[member1.id, member2.id])
        projects = self.project_service.retrieve_projects_for_user(self.test_user.id, 1)[0]
        assert len(projects) == 1, "Projects not retrieved properly!"
        projects = self.project_service.retrieve_projects_for_user(member1.id, 1)[0]
        assert len(projects) == 1, "Projects not retrieved properly!"
        projects = self.project_service.retrieve_projects_for_user(member2.id, 1)[0]
        assert len(projects) == 1, "Projects not retrieved properly!"

    def test_retrieve_3projects_3usr(self):
        """
        Three users, 3 projects. Structure of db:
        proj1: {admin: user1, members: [user2, user3]}
        proj2: {admin: user2, members: [user1]}
        proj3: {admin: user3, members: [user1, user2]}
        Check valid project returns for all the users.
        """
        member1 = TestFactory.create_user("member1")
        member2 = TestFactory.create_user("member2")
        member3 = TestFactory.create_user("member3")
        TestFactory.create_project(member1, 'TestProject1', users=[member2.id, member3.id])
        TestFactory.create_project(member2, 'TestProject2', users=[member1.id])
        TestFactory.create_project(member3, 'TestProject3', users=[member1.id, member2.id])
        projects = self.project_service.retrieve_projects_for_user(member1.id, 1)[0]
        assert len(projects) == 3, "Projects not retrieved properly!"
        projects = self.project_service.retrieve_projects_for_user(member2.id, 1)[0]
        assert len(projects) == 3, "Projects not retrieved properly!"
        projects = self.project_service.retrieve_projects_for_user(member3.id, 1)[0]
        assert len(projects) == 2, "Projects not retrieved properly!"

    def test_retrieve_projects_random(self):
        """
        Generate a large number of users/projects, and validate the results.
        """
        ExtremeTestFactory.generate_users(NR_USERS, MAX_PROJ_PER_USER)
        for i in range(NR_USERS):
            current_user = dao.get_user_by_name("gen" + str(i))
            expected_projects = ExtremeTestFactory.VALIDATION_DICT[current_user.id]
            if expected_projects % PROJECTS_PAGE_SIZE == 0:
                expected_pages = expected_projects / PROJECTS_PAGE_SIZE
                exp_proj_per_page = PROJECTS_PAGE_SIZE
            else:
                expected_pages = expected_projects // PROJECTS_PAGE_SIZE + 1
                exp_proj_per_page = expected_projects % PROJECTS_PAGE_SIZE
            if expected_projects == 0:
                expected_pages = 0
                exp_proj_per_page = 0
            projects, pages = self.project_service.retrieve_projects_for_user(current_user.id, expected_pages)
            assert len(projects) == exp_proj_per_page, "Projects not retrieved properly! Expected:" + \
                                                       str(exp_proj_per_page) + "but got:" + str(len(projects))
            assert pages == expected_pages, "Pages not retrieved properly!"

        for folder in os.listdir(TvbProfile.current.TVB_STORAGE):
            full_path = os.path.join(TvbProfile.current.TVB_STORAGE, folder)
            if folder.startswith('Generated'):
                self.storage_interface.remove_folder(full_path)

    def test_retrieve_projects_page2(self):
        """
        Test for retrieving the second page projects for a given user.
        """
        for i in range(PROJECTS_PAGE_SIZE + 3):
            TestFactory.create_project(self.test_user, 'test_proj' + str(i))
        projects, pages = self.project_service.retrieve_projects_for_user(self.test_user.id, 2)
        assert len(projects) == (PROJECTS_PAGE_SIZE + 3) % PROJECTS_PAGE_SIZE, "Pagination inproper."
        assert pages == 2, 'Wrong number of pages retrieved.'

    def test_retrieve_projects_and_del(self):
        """
        Test for retrieving the second page projects for a given user.
        """
        created_projects = []
        for i in range(PROJECTS_PAGE_SIZE + 1):
            created_projects.append(TestFactory.create_project(self.test_user, 'test_proj' + str(i)))
        projects, pages = self.project_service.retrieve_projects_for_user(self.test_user.id, 2)
        assert len(projects) == (PROJECTS_PAGE_SIZE + 1) % PROJECTS_PAGE_SIZE, "Pagination improper."
        assert pages == (PROJECTS_PAGE_SIZE + 1) // PROJECTS_PAGE_SIZE + 1, 'Wrong number of pages'
        self.project_service.remove_project(created_projects[1].id)
        projects, pages = self.project_service.retrieve_projects_for_user(self.test_user.id, 2)
        assert len(projects) == 0, "Pagination improper."
        assert pages == 1, 'Wrong number of pages retrieved.'
        projects, pages = self.project_service.retrieve_projects_for_user(self.test_user.id, 1)
        assert len(projects) == PROJECTS_PAGE_SIZE, "Pagination improper."
        assert pages == 1, 'Wrong number of pages retrieved.'

    def test_empty_project_has_zero_disk_size(self):
        TestFactory.create_project(self.test_user, 'test_proj')
        projects, pages = self.project_service.retrieve_projects_for_user(self.test_user.id)
        assert 0 == projects[0].disk_size
        assert '0.0 KiB' == projects[0].disk_size_human

    def test_project_disk_size(self):
        project1 = TestFactory.create_project(self.test_user, 'test_proj1')
        zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_66.zip')
        TestFactory.import_zip_connectivity(self.test_user, project1, zip_path, 'testSubject')

        project2 = TestFactory.create_project(self.test_user, 'test_proj2')
        zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_76.zip')
        TestFactory.import_zip_connectivity(self.test_user, project2, zip_path, 'testSubject')

        projects = self.project_service.retrieve_projects_for_user(self.test_user.id)[0]
        assert projects[0].disk_size != projects[1].disk_size, "projects should have different size"

        for project in projects:
            assert 0 != project.disk_size
            assert '0.0 KiB' != project.disk_size_human

            prj_folder = self.storage_interface.get_project_folder(project.name)
            actual_disk_size = self.storage_interface.compute_recursive_h5_disk_usage(prj_folder)

            ratio = float(actual_disk_size) / project.disk_size
            msg = "Real disk usage: %s The one recorded in the db : %s" % (actual_disk_size, project.disk_size)
            assert ratio < 1.1, msg

    def test_get_linkable_projects(self):
        """
        Test for retrieving the projects for a given user.
        """
        initial_projects = self.project_service.retrieve_projects_for_user(self.test_user.id)[0]
        assert len(initial_projects) == 0, "Database was not reset!"
        test_proj = []
        user1 = TestFactory.create_user("another_user")
        for i in range(4):
            test_proj.append(TestFactory.create_project(self.test_user if i < 3 else user1, 'test_proj' + str(i)))
        operation = TestFactory.create_operation(test_user=self.test_user, test_project=test_proj[0])
        datatype = dao.store_entity(model_datatype.DataType(module="test_data", subject="subj1",
                                                            state="test_state", operation_id=operation.id))

        linkable = self.project_service.get_linkable_projects_for_user(self.test_user.id, str(datatype.id))[0]

        assert len(linkable) == 2, "Wrong count of link-able projects!"
        proj_names = [project.name for project in linkable]
        assert test_proj[1].name in proj_names
        assert test_proj[2].name in proj_names
        assert not test_proj[3].name in proj_names

    def test_remove_project_happy_flow(self):
        """
        Standard flow for deleting a project.
        """
        inserted_project = TestFactory.create_project(self.test_user, 'test_proj')
        project_root = self.storage_interface.get_project_folder(inserted_project.name)
        projects = dao.get_projects_for_user(self.test_user.id)
        assert len(projects) == 1, "Initializations failed!"
        assert os.path.exists(project_root), "Something failed at insert time!"
        self.project_service.remove_project(inserted_project.id)
        projects = dao.get_projects_for_user(self.test_user.id)
        assert len(projects) == 0, "Project was not deleted!"
        assert not os.path.exists(project_root), "Root folder not deleted!"

    def test_remove_project_wrong_id(self):
        """
        Flow for deleting a project giving an un-existing id.
        """
        TestFactory.create_project(self.test_user, 'test_proj')
        projects = dao.get_projects_for_user(self.test_user.id)
        assert len(projects) == 1, "Initializations failed!"
        with pytest.raises(ProjectServiceException):
            self.project_service.remove_project(99)

    def __check_meta_data(self, expected_meta_data, new_datatype):
        """Validate Meta-Data"""
        mapp_keys = {DataTypeOverlayDetails.DATA_SUBJECT: "subject", DataTypeOverlayDetails.DATA_STATE: "state"}
        for key, value in expected_meta_data.items():
            if key in mapp_keys:
                assert value == getattr(new_datatype, mapp_keys[key])
            elif key == DataTypeMetaData.KEY_OPERATION_TAG:
                if DataTypeMetaData.KEY_OP_GROUP_ID in expected_meta_data:
                    # We have a Group to check
                    op_group = new_datatype.parent_operation.fk_operation_group
                    op_group = dao.get_generic_entity(model_operation.OperationGroup, op_group)[0]
                    assert value == op_group.name
                else:
                    assert value == new_datatype.parent_operation.user_group

    def test_remove_project_node(self):
        """
        Test removing of a node from a project.
        """
        inserted_project, gid, op = TestFactory.create_value_wrapper(self.test_user)
        project_to_link = model_project.Project("Link", self.test_user.id, "descript")
        project_to_link = dao.store_entity(project_to_link)
        exact_data = dao.get_datatype_by_gid(gid)
        assert exact_data is not None, "Initialization problem!"
        link = dao.store_entity(model_datatype.Links(exact_data.id, project_to_link.id))

        vw_h5_path = h5.path_for_stored_index(exact_data)
        assert os.path.exists(vw_h5_path)

        if dao.get_system_user() is None:
            dao.store_entity(model_operation.User(TvbProfile.current.web.admin.SYSTEM_USER_NAME,
                                                  TvbProfile.current.web.admin.SYSTEM_USER_NAME, None, None, True,
                                                  None))

        self.project_service._remove_project_node_files(inserted_project.id, gid, [link])

        assert not os.path.exists(vw_h5_path)
        exact_data = dao.get_datatype_by_gid(gid)
        assert exact_data is not None, "Data should still be in DB, because of links"
        vw_h5_path_new = h5.path_for_stored_index(exact_data)
        assert os.path.exists(vw_h5_path_new)
        assert vw_h5_path_new != vw_h5_path

        self.project_service._remove_project_node_files(project_to_link.id, gid, [])
        assert dao.get_datatype_by_gid(gid) is None

    def test_update_meta_data_simple(self):
        """
        Test the new update metaData for a simple data that is not part of a group.
        """
        inserted_project, gid, _ = TestFactory.create_value_wrapper(self.test_user)
        new_meta_data = {DataTypeOverlayDetails.DATA_SUBJECT: "new subject",
                         DataTypeOverlayDetails.DATA_STATE: "second_state",
                         DataTypeOverlayDetails.CODE_GID: gid,
                         DataTypeOverlayDetails.CODE_OPERATION_TAG: 'new user group'}
        self.project_service.update_metadata(new_meta_data)

        new_datatype = dao.get_datatype_by_gid(gid)
        self.__check_meta_data(new_meta_data, new_datatype)

        new_datatype_h5 = h5.h5_file_for_index(new_datatype)
        assert new_datatype_h5.subject.load() == 'new subject', 'UserGroup not updated!'

    def test_update_meta_data_group(self, test_adapter_factory, datatype_group_factory):
        """
        Test the new update metaData for a group of dataTypes.
        """
        test_adapter_factory(adapter_class=DummyAdapter3)
        group, _ = datatype_group_factory()
        op_group_id = group.fk_operation_group

        new_meta_data = {DataTypeOverlayDetails.DATA_SUBJECT: "new subject",
                         DataTypeOverlayDetails.DATA_STATE: "updated_state",
                         DataTypeOverlayDetails.CODE_OPERATION_GROUP_ID: op_group_id,
                         DataTypeOverlayDetails.CODE_OPERATION_TAG: 'newGroupName'}
        self.project_service.update_metadata(new_meta_data)
        datatypes = dao.get_datatype_in_group(op_group_id)
        for datatype in datatypes:
            new_datatype = dao.get_datatype_by_id(datatype.id)
            assert op_group_id == new_datatype.parent_operation.fk_operation_group
            new_group = dao.get_generic_entity(model_operation.OperationGroup, op_group_id)[0]
            assert new_group.name == "newGroupName"
            self.__check_meta_data(new_meta_data, new_datatype)

    def test_retrieve_project_full(self, dummy_datatype_index_factory):
        """
        Tests full project information is retrieved by method `ProjectService.retrieve_project_full(...)`
        """

        project = TestFactory.create_project(self.test_user)
        operation = TestFactory.create_operation(test_user=self.test_user, test_project=project)

        dummy_datatype_index_factory(project=project, operation=operation)
        dummy_datatype_index_factory(project=project, operation=operation)
        dummy_datatype_index_factory(project=project, operation=operation)

        _, ops_nr, operations, pages_no = self.project_service.retrieve_project_full(project.id)
        assert ops_nr == 1, "DataType Factory should only use one operation to store all it's datatypes."
        assert pages_no == 1, "DataType Factory should only use one operation to store all it's datatypes."
        resulted_dts = operations[0]['results']
        assert len(resulted_dts) == 3, "3 datatypes should be created."

    def test_get_project_structure(self, datatype_group_factory, dummy_datatype_index_factory,
                                   project_factory, user_factory):
        """
        Tests project structure is as expected and contains all datatypes and created links
        """
        user = user_factory()
        project1 = project_factory(user, name="TestPS1")
        project2 = project_factory(user, name="TestPS2")

        dt_group, _ = datatype_group_factory(project=project1)
        dt_simple = dummy_datatype_index_factory(state="RAW_DATA", project=project1)
        # Create 3 DTs directly in Project 2
        dummy_datatype_index_factory(state="RAW_DATA", project=project2)
        dummy_datatype_index_factory(state="RAW_DATA", project=project2)
        dummy_datatype_index_factory(state="RAW_DATA", project=project2)

        # Create Links from Project 1 into Project 2
        link_ids, expected_links = [], []
        link_ids.append(dt_simple.id)
        expected_links.append(dt_simple.gid)

        # Prepare links towards a full DT Group, but expecting only the DT_Group in the final tree
        dts = dao.get_datatype_in_group(datatype_group_id=dt_group.id)
        link_ids.extend([dt_to_link.id for dt_to_link in dts])
        link_ids.append(dt_group.id)
        expected_links.append(dt_group.gid)

        # Actually create the links from Prj1 into Prj2
        for link_id in link_ids:
            AlgorithmService().create_link(link_id, project2.id)

        # Retrieve the raw data used to compose the tree (for easy parsing)
        dts_in_tree = dao.get_data_in_project(project2.id)
        dts_in_tree = [dt.gid for dt in dts_in_tree]
        # Retrieve the tree json (for trivial validations only, as we can not decode)
        node_json = self.project_service.get_project_structure(project2, None, DataTypeMetaData.KEY_STATE,
                                                               DataTypeMetaData.KEY_SUBJECT, None)

        assert len(expected_links) + 3 == len(dts_in_tree), "invalid number of nodes in tree"
        assert dt_group.gid in dts_in_tree, "DT_Group should be in the Project Tree!"
        assert dt_group.gid in node_json, "DT_Group should be in the Project Tree JSON!"

        project_dts = dao.get_datatypes_in_project(project2.id)
        for dt in project_dts:
            if dt.fk_datatype_group is not None:
                assert not dt.gid in node_json, "DTs part of a group should not be"
                assert not dt.gid in dts_in_tree, "DTs part of a group should not be"
            else:
                assert dt.gid in node_json, "Simple DTs and DT_Groups should be"
                assert dt.gid in dts_in_tree, "Simple DTs and DT_Groups should be"

        for link_gid in expected_links:
            assert link_gid in node_json, "Expected Link not present"
            assert link_gid in dts_in_tree, "Expected Link not present"
Пример #6
0
class TestOperationResource(RestResourceTest):

    def transactional_setup_method(self):
        self.test_user = TestFactory.create_user('Rest_User')
        self.test_project = TestFactory.create_project(self.test_user, 'Rest_Project', users=[self.test_user.id])
        self.operations_resource = GetOperationsInProjectResource()
        self.status_resource = GetOperationStatusResource()
        self.results_resource = GetOperationResultsResource()
        self.launch_resource = LaunchOperationResource()
        self.storage_interface = StorageInterface()

    def test_server_get_operation_status_inexistent_gid(self, mocker):
        self._mock_user(mocker)
        operation_gid = "inexistent-gid"
        with pytest.raises(InvalidIdentifierException): self.status_resource.get(operation_gid=operation_gid)

    def test_server_get_operation_status(self, mocker):
        self._mock_user(mocker)
        zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_96.zip')
        TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path)

        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.args = {Strings.PAGE_NUMBER: '1'}

        operations_and_pages = self.operations_resource.get(project_gid=self.test_project.gid)

        result = self.status_resource.get(operation_gid=operations_and_pages['operations'][0].gid)
        assert type(result) is str
        assert result in OperationPossibleStatus

    def test_server_get_operation_results_inexistent_gid(self, mocker):
        self._mock_user(mocker)
        operation_gid = "inexistent-gid"
        with pytest.raises(InvalidIdentifierException): self.results_resource.get(operation_gid=operation_gid)

    def test_server_get_operation_results(self, mocker):
        self._mock_user(mocker)
        zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_96.zip')
        TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path)

        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.args = {Strings.PAGE_NUMBER: '1'}

        operations_and_pages = self.operations_resource.get(project_gid=self.test_project.gid)

        result = self.results_resource.get(operation_gid=operations_and_pages['operations'][0].gid)
        assert type(result) is list
        assert len(result) == 1

    def test_server_get_operation_results_failed_operation(self, mocker):
        self._mock_user(mocker)
        zip_path = os.path.join(os.path.dirname(tvb_data.__file__), 'connectivity', 'connectivity_90.zip')
        with pytest.raises(TVBException):
            TestFactory.import_zip_connectivity(self.test_user, self.test_project, zip_path)

        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.args = {Strings.PAGE_NUMBER: '1'}

        operations_and_pages = self.operations_resource.get(project_gid=self.test_project.gid)

        result = self.results_resource.get(operation_gid=operations_and_pages['operations'][0].gid)
        assert type(result) is list
        assert len(result) == 0

    def test_server_launch_operation_no_file(self, mocker):
        self._mock_user(mocker)
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.files = {}

        with pytest.raises(InvalidIdentifierException): self.launch_resource.post(project_gid='', algorithm_module='',
                                                                           algorithm_classname='')

    def test_server_launch_operation_wrong_file_extension(self, mocker):
        self._mock_user(mocker)
        dummy_file = FileStorage(BytesIO(b"test"), 'test.txt')
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.files = {RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: dummy_file}

        with pytest.raises(InvalidIdentifierException): self.launch_resource.post(project_gid='', algorithm_module='',
                                                                           algorithm_classname='')

    def test_server_launch_operation_inexistent_gid(self, mocker):
        self._mock_user(mocker)
        project_gid = "inexistent-gid"
        dummy_file = FileStorage(BytesIO(b"test"), 'test.h5')
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.files = {RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: dummy_file}

        with pytest.raises(InvalidIdentifierException): self.launch_resource.post(project_gid=project_gid,
                                                                                  algorithm_module='',
                                                                                  algorithm_classname='')

    def test_server_launch_operation_inexistent_algorithm(self, mocker):
        self._mock_user(mocker)
        inexistent_algorithm = "inexistent-algorithm"

        dummy_file = FileStorage(BytesIO(b"test"), 'test.h5')
        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request', spec={})
        request_mock.files = {RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: dummy_file}

        with pytest.raises(ServiceException): self.launch_resource.post(project_gid=self.test_project.gid,
                                                                        algorithm_module=inexistent_algorithm,
                                                                        algorithm_classname='')

    def test_server_launch_operation(self, mocker, time_series_index_factory):
        self._mock_user(mocker)
        algorithm_module = "tvb.adapters.analyzers.fourier_adapter"
        algorithm_class = "FourierAdapter"

        input_ts_index = time_series_index_factory()

        fft_model = FFTAdapterModel()
        fft_model.time_series = UUID(input_ts_index.gid)
        fft_model.window_function = list(SUPPORTED_WINDOWING_FUNCTIONS)[0]

        input_folder = self.storage_interface.get_project_folder(self.test_project.name)
        view_model_h5_path = h5.store_view_model(fft_model, input_folder)

        # Mock flask.request.files to return a dictionary
        request_mock = mocker.patch.object(flask, 'request', spec={})
        fp = open(view_model_h5_path, 'rb')
        request_mock.files = {
            RequestFileKey.LAUNCH_ANALYZERS_MODEL_FILE.value: FileStorage(fp, os.path.basename(view_model_h5_path))}

        # Mock launch_operation() call and current_user
        mocker.patch.object(OperationService, 'launch_operation')

        operation_gid, status = self.launch_resource.post(project_gid=self.test_project.gid,
                                                          algorithm_module=algorithm_module,
                                                          algorithm_classname=algorithm_class)

        fp.close()

        assert type(operation_gid) is str
        assert len(operation_gid) > 0
Пример #7
0
class ABCAdapter(object):
    """
    Root Abstract class for all TVB Adapters. 
    """
    # model.Algorithm instance that will be set for each adapter class created by in build_adapter method
    stored_adapter = None
    launch_mode = AdapterLaunchModeEnum.ASYNC_DIFF_MEM

    def __init__(self):
        self.generic_attributes = GenericAttributes()
        self.generic_attributes.subject = DataTypeMetaData.DEFAULT_SUBJECT
        self.storage_interface = StorageInterface()
        # Will be populate with current running operation's identifier
        self.operation_id = None
        self.user_id = None
        self.submitted_form = None
        self.log = get_logger(self.__class__.__module__)

    @classmethod
    def get_group_name(cls):
        if hasattr(cls, "_ui_group") and hasattr(cls._ui_group, "name"):
            return cls._ui_group.name
        return None

    @classmethod
    def get_group_description(cls):
        if hasattr(cls, "_ui_group") and hasattr(cls._ui_group, "description"):
            return cls._ui_group.description
        return None

    @classmethod
    def get_ui_name(cls):
        if hasattr(cls, "_ui_name"):
            return cls._ui_name
        else:
            return cls.__name__

    @classmethod
    def get_ui_description(cls):
        if hasattr(cls, "_ui_description"):
            return cls._ui_description

    @classmethod
    def get_ui_subsection(cls):
        if hasattr(cls, "_ui_subsection"):
            return cls._ui_subsection

        if hasattr(cls, "_ui_group") and hasattr(cls._ui_group, "subsection"):
            return cls._ui_group.subsection

    @staticmethod
    def can_be_active():
        """
        To be overridden where needed (e.g. Matlab dependent adapters).
        :return: By default True, and False when the current Adapter can not be executed in the current env
        for various reasons (e.g. no Matlab or Octave installed)
        """
        return True

    def submit_form(self, form):
        self.submitted_form = form

    # TODO separate usage of get_form_class (returning a class) and return of a submitted instance
    def get_form(self):
        if self.submitted_form is not None:
            return self.submitted_form
        return self.get_form_class()

    @abstractmethod
    def get_form_class(self):
        return None

    def get_adapter_fragments(self, view_model):
        """
        The result will be used for introspecting and checking operation changed input
        params from the defaults, to show in web gui.
        :return: a list of ABCAdapterForm classes, in case the current Adapter GUI
        will be composed of multiple sub-forms.
        """
        return {}

    def get_view_model_class(self):
        return self.get_form_class().get_view_model()

    @abstractmethod
    def get_output(self):
        """
        Describes inputs and outputs of the launch method.
        """

    def configure(self, view_model):
        """
        To be implemented in each Adapter that requires any specific configurations
        before the actual launch.
        """

    @abstractmethod
    def get_required_memory_size(self, view_model):
        """
        Abstract method to be implemented in each adapter. Should return the required memory
        for launching the adapter.
        """

    @abstractmethod
    def get_required_disk_size(self, view_model):
        """
        Abstract method to be implemented in each adapter. Should return the required memory
        for launching the adapter in kilo-Bytes.
        """

    def get_execution_time_approximation(self, view_model):
        """
        Method should approximate based on input arguments, the time it will take for the operation 
        to finish (in seconds).
        """
        return -1

    @abstractmethod
    def launch(self, view_model):
        """
         To be implemented in each Adapter.
         Will contain the logic of the Adapter.
         Takes a ViewModel with data, dependency direction is: Adapter -> Form -> ViewModel
         Any returned DataType will be stored in DB, by the Framework.
        :param view_model: the data model corresponding to the current adapter
        """

    def add_operation_additional_info(self, message):
        """
        Adds additional info on the operation to be displayed in the UI. Usually a warning message.
        """
        current_op = dao.get_operation_by_id(self.operation_id)
        current_op.additional_info = message
        dao.store_entity(current_op)

    def extract_operation_data(self, operation):
        operation = dao.get_operation_by_id(operation.id)
        self.operation_id = operation.id
        self.current_project_id = operation.project.id
        self.user_id = operation.fk_launched_by

    def _ensure_enough_resources(self, available_disk_space, view_model):
        # Compare the amount of memory the current algorithms states it needs,
        # with the average between the RAM available on the OS and the free memory at the current moment.
        # We do not consider only the free memory, because some OSs are freeing late and on-demand only.
        total_free_memory = psutil.virtual_memory().free + psutil.swap_memory().free
        total_existent_memory = psutil.virtual_memory().total + psutil.swap_memory().total
        memory_reference = (total_free_memory + total_existent_memory) / 2
        adapter_required_memory = self.get_required_memory_size(view_model)

        if adapter_required_memory > memory_reference:
            msg = "Machine does not have enough RAM memory for the operation (expected %.2g GB, but found %.2g GB)."
            raise NoMemoryAvailableException(msg % (adapter_required_memory / 2 ** 30, memory_reference / 2 ** 30))

        # Compare the expected size of the operation results with the HDD space currently available for the user
        # TVB defines a quota per user.
        required_disk_space = self.get_required_disk_size(view_model)
        if available_disk_space < 0:
            msg = "You have exceeded you HDD space quota by %.2f MB Stopping execution."
            raise NoMemoryAvailableException(msg % (- available_disk_space / 2 ** 10))
        if available_disk_space < required_disk_space:
            msg = ("You only have %.2f GB of disk space available but the operation you "
                   "launched might require %.2f Stopping execution...")
            raise NoMemoryAvailableException(msg % (available_disk_space / 2 ** 20, required_disk_space / 2 ** 20))
        return required_disk_space

    def _update_operation_entity(self, operation, required_disk_space):
        operation.start_now()
        operation.estimated_disk_size = required_disk_space
        dao.store_entity(operation)

    @nan_not_allowed()
    def _prelaunch(self, operation, view_model, available_disk_space=0):
        """
        Method to wrap LAUNCH.
        Will prepare data, and store results on return.
        """
        self.extract_operation_data(operation)
        self.generic_attributes.fill_from(view_model.generic_attributes)
        self.configure(view_model)
        required_disk_size = self._ensure_enough_resources(available_disk_space, view_model)
        self._update_operation_entity(operation, required_disk_size)

        result = self.launch(view_model)

        if not isinstance(result, (list, tuple)):
            result = [result, ]
        self.__check_integrity(result)
        return self._capture_operation_results(result)

    def _capture_operation_results(self, result):
        """
        After an operation was finished, make sure the results are stored
        in DB storage and the correct meta-data,IDs are set.
        """
        data_type_group_id = None
        operation = dao.get_operation_by_id(self.operation_id)
        if operation.user_group is None or len(operation.user_group) == 0:
            operation.user_group = date2string(datetime.now(), date_format=LESS_COMPLEX_TIME_FORMAT)
            operation = dao.store_entity(operation)
        if self._is_group_launch():
            data_type_group_id = dao.get_datatypegroup_by_op_group_id(operation.fk_operation_group).id

        count_stored = 0
        if result is None:
            return "", count_stored

        group_type = None  # In case of a group, the first not-none type is sufficient to memorize here
        for res in result:
            if res is None:
                continue
            if not res.fixed_generic_attributes:
                res.fill_from_generic_attributes(self.generic_attributes)
            res.fk_from_operation = self.operation_id
            res.fk_datatype_group = data_type_group_id

            associated_file = h5.path_for_stored_index(res)
            if os.path.exists(associated_file):
                if not res.fixed_generic_attributes:
                    with H5File.from_file(associated_file) as f:
                        f.store_generic_attributes(self.generic_attributes)
                # Compute size-on disk, in case file-storage is used
                res.disk_size = self.storage_interface.compute_size_on_disk(associated_file)

            dao.store_entity(res)
            res.after_store()
            group_type = res.type
            count_stored += 1

        if count_stored > 0 and self._is_group_launch():
            # Update the operation group name
            operation_group = dao.get_operationgroup_by_id(operation.fk_operation_group)
            operation_group.fill_operationgroup_name(group_type)
            dao.store_entity(operation_group)

        return 'Operation ' + str(self.operation_id) + ' has finished.', count_stored

    def __check_integrity(self, result):
        """
        Check that the returned parameters for LAUNCH operation
        are of the type specified in the adapter's interface.
        """
        for result_entity in result:
            if result_entity is None:
                continue
            if not self.__is_data_in_supported_types(result_entity):
                msg = "Unexpected output DataType %s"
                raise InvalidParameterException(msg % type(result_entity))

    def __is_data_in_supported_types(self, data):

        if data is None:
            return True
        for supported_type in self.get_output():
            if isinstance(data, supported_type):
                return True
        # Data can't be mapped on any supported type !!
        return False

    def _is_group_launch(self):
        """
        Return true if this adapter is launched from a group of operations
        """
        operation = dao.get_operation_by_id(self.operation_id)
        return operation.fk_operation_group is not None

    def load_entity_by_gid(self, data_gid):
        # type: (typing.Union[uuid.UUID, str]) -> DataType
        """
        Load a generic DataType, specified by GID.
        """
        idx = load_entity_by_gid(data_gid)
        if idx and self.generic_attributes.parent_burst is None:
            # Only in case the BurstConfiguration references hasn't been set already, take it from the current DT
            self.generic_attributes.parent_burst = idx.fk_parent_burst
        return idx

    def load_traited_by_gid(self, data_gid):
        # type: (typing.Union[uuid.UUID, str]) -> HasTraits
        """
        Load a generic HasTraits instance, specified by GID.
        """
        index = self.load_entity_by_gid(data_gid)
        return h5.load_from_index(index)

    def load_with_references(self, dt_gid):
        # type: (typing.Union[uuid.UUID, str]) -> HasTraits
        dt_index = self.load_entity_by_gid(dt_gid)
        h5_path = h5.path_for_stored_index(dt_index)
        dt, _ = h5.load_with_references(h5_path)
        return dt

    def view_model_to_has_traits(self, view_model):
        # type: (ViewModel) -> HasTraits
        has_traits_class = view_model.linked_has_traits
        has_traits = has_traits_class()
        view_model_class = type(view_model)
        if not has_traits_class:
            raise Exception("There is no linked HasTraits for this ViewModel {}".format(type(view_model)))
        for attr_name in has_traits_class.declarative_attrs:
            view_model_class_attr = getattr(view_model_class, attr_name)
            view_model_attr = getattr(view_model, attr_name)
            if isinstance(view_model_class_attr, DataTypeGidAttr) and view_model_attr:
                attr_value = self.load_with_references(view_model_attr)
            elif isinstance(view_model_class_attr, Attr) and isinstance(view_model_attr, ViewModel):
                attr_value = self.view_model_to_has_traits(view_model_attr)
            elif isinstance(view_model_class_attr, List) and len(view_model_attr) > 0 and isinstance(view_model_attr[0],
                                                                                                     ViewModel):
                attr_value = list()
                for view_model_elem in view_model_attr:
                    elem = self.view_model_to_has_traits(view_model_elem)
                    attr_value.append(elem)
            else:
                attr_value = view_model_attr
            setattr(has_traits, attr_name, attr_value)
        return has_traits

    @staticmethod
    def build_adapter_from_class(adapter_class):
        """
        Having a subclass of ABCAdapter, prepare an instance for launching an operation with it.
        """
        if not issubclass(adapter_class, ABCAdapter):
            raise IntrospectionException("Invalid data type: It should extend adapters.ABCAdapter!")
        try:
            stored_adapter = dao.get_algorithm_by_module(adapter_class.__module__, adapter_class.__name__)

            adapter_instance = adapter_class()
            adapter_instance.stored_adapter = stored_adapter
            return adapter_instance
        except Exception as excep:
            LOGGER.exception(excep)
            raise IntrospectionException(str(excep))

    @staticmethod
    def determine_adapter_class(stored_adapter):
        # type: (Algorithm) -> ABCAdapter
        """
        Determine the class of an adapter based on module and classname strings from stored_adapter
        :param stored_adapter: Algorithm or AlgorithmDTO type
        :return: a subclass of ABCAdapter
        """
        ad_module = importlib.import_module(stored_adapter.module)
        adapter_class = getattr(ad_module, stored_adapter.classname)
        return adapter_class

    @staticmethod
    def build_adapter(stored_adapter):
        # type: (Algorithm) -> ABCAdapter
        """
        Having a module and a class name, create an instance of ABCAdapter.
        """
        try:
            adapter_class = ABCAdapter.determine_adapter_class(stored_adapter)
            adapter_instance = adapter_class()
            adapter_instance.stored_adapter = stored_adapter
            return adapter_instance

        except Exception:
            msg = "Could not load Adapter Instance for Stored row %s" % stored_adapter
            LOGGER.exception(msg)
            raise IntrospectionException(msg)

    def load_view_model(self, operation):
        storage_path = self.storage_interface.get_project_folder(operation.project.name, str(operation.id))
        input_gid = operation.view_model_gid
        return h5.load_view_model(input_gid, storage_path)

    @staticmethod
    def array_size2kb(size):
        """
        :param size: size in bytes
        :return: size in kB
        """
        return size * TvbProfile.current.MAGIC_NUMBER / 8 / 2 ** 10

    @staticmethod
    def fill_index_from_h5(analyzer_index, analyzer_h5):
        """
        Method used only by analyzers that write slices of data.
        As they never have the whole array_data in memory, the metadata related to array_data (min, max, etc.) they
        store on the index is not correct, so we need to update them.
        """
        metadata = analyzer_h5.array_data.get_cached_metadata()

        if not metadata.has_complex:
            analyzer_index.array_data_max = float(metadata.max)
            analyzer_index.array_data_min = float(metadata.min)
            analyzer_index.array_data_mean = float(metadata.mean)

        analyzer_index.aray_has_complex = metadata.has_complex
        analyzer_index.array_is_finite = metadata.is_finite
        analyzer_index.shape = json.dumps(analyzer_h5.array_data.shape)
        analyzer_index.ndim = len(analyzer_h5.array_data.shape)

    def path_for(self, h5_file_class, gid, dt_class=None):
        project = dao.get_project_by_id(self.current_project_id)
        return h5.path_for(self.operation_id, h5_file_class, gid, project.name, dt_class)

    def store_complete(self, datatype, generic_attributes=GenericAttributes()):
        project = dao.get_project_by_id(self.current_project_id)
        return h5.store_complete(datatype, self.operation_id, project.name, generic_attributes)

    def get_storage_path(self):
        project = dao.get_project_by_id(self.current_project_id)
        return self.storage_interface.get_project_folder(project.name, str(self.operation_id))
Пример #8
0
class ProjectService:
    """
    Services layer for Project entities.
    """
    def __init__(self):
        self.logger = get_logger(__name__)
        self.storage_interface = StorageInterface()

    def store_project(self, current_user, is_create, selected_id, **data):
        """
        We want to create/update a project entity.
        """
        # Validate Unique Name
        new_name = data["name"]
        if len(new_name) < 1:
            raise ProjectServiceException("Invalid project name!")
        projects_no = dao.count_projects_for_name(new_name, selected_id)
        if projects_no > 0:
            err = {'name': 'Please choose another name, this one is used!'}
            raise formencode.Invalid("Duplicate Name Error", {},
                                     None,
                                     error_dict=err)
        started_operations = dao.get_operation_numbers(selected_id)[1]
        if started_operations > 0:
            raise ProjectServiceException(
                "A project can not be renamed while operations are still running!"
            )
        if is_create:
            current_proj = Project(new_name, current_user.id,
                                   data["description"])
            self.storage_interface.get_project_folder(current_proj.name)
        else:
            try:
                current_proj = dao.get_project_by_id(selected_id)
            except Exception as excep:
                self.logger.exception("An error has occurred!")
                raise ProjectServiceException(str(excep))
            if current_proj.name != new_name:
                self.storage_interface.rename_project(current_proj.name,
                                                      new_name)
            current_proj.name = new_name
            current_proj.description = data["description"]
        # Commit to make sure we have a valid ID
        current_proj.refresh_update_date()
        _, metadata_proj = current_proj.to_dict()
        self.storage_interface.write_project_metadata(metadata_proj)
        current_proj = dao.store_entity(current_proj)

        # Retrieve, to initialize lazy attributes
        current_proj = dao.get_project_by_id(current_proj.id)
        # Update share settings on current Project entity
        visited_pages = []
        prj_admin = current_proj.administrator.username
        if 'visited_pages' in data and data['visited_pages']:
            visited_pages = data['visited_pages'].split(',')
        for page in visited_pages:
            members = UserService.retrieve_users_except([prj_admin], int(page),
                                                        MEMBERS_PAGE_SIZE)[0]
            members = [m.id for m in members]
            dao.delete_members_for_project(current_proj.id, members)

        selected_user_ids = data["users"]
        if is_create and current_user.id not in selected_user_ids:
            # Make the project admin also member of the current project
            selected_user_ids.append(current_user.id)
        dao.add_members_to_project(current_proj.id, selected_user_ids)
        # Finish operation
        self.logger.debug("Edit/Save OK for project:" + str(current_proj.id) +
                          ' by user:'******'-'
                result["count"] = one_op[2]
                result["gid"] = one_op[13]
                operation_group_id = one_op[3]
                if operation_group_id is not None and operation_group_id:
                    try:
                        operation_group = dao.get_generic_entity(
                            OperationGroup, operation_group_id)[0]
                        result["group"] = operation_group.name
                        result["group"] = result["group"].replace("_", " ")
                        result["operation_group_id"] = operation_group.id
                        datatype_group = dao.get_datatypegroup_by_op_group_id(
                            operation_group_id)
                        result[
                            "datatype_group_gid"] = datatype_group.gid if datatype_group is not None else None
                        result["gid"] = operation_group.gid
                        # Filter only viewers for current DataTypeGroup entity:

                        if datatype_group is None:
                            view_groups = None
                        else:
                            view_groups = AlgorithmService(
                            ).get_visualizers_for_group(datatype_group.gid)
                        result["view_groups"] = view_groups
                    except Exception:
                        self.logger.exception(
                            "We will ignore group on entity:" + str(one_op))
                        result["datatype_group_gid"] = None
                else:
                    result['group'] = None
                    result['datatype_group_gid'] = None
                result["algorithm"] = dao.get_algorithm_by_id(one_op[4])
                result["user"] = dao.get_user_by_id(one_op[5])
                if type(one_op[6]) is str:
                    result["create"] = string2date(str(one_op[6]))
                else:
                    result["create"] = one_op[6]
                if type(one_op[7]) is str:
                    result["start"] = string2date(str(one_op[7]))
                else:
                    result["start"] = one_op[7]
                if type(one_op[8]) is str:
                    result["complete"] = string2date(str(one_op[8]))
                else:
                    result["complete"] = one_op[8]

                if result["complete"] is not None and result[
                        "start"] is not None:
                    result["duration"] = format_timedelta(result["complete"] -
                                                          result["start"])
                result["status"] = one_op[9]
                result["additional"] = one_op[10]
                result["visible"] = True if one_op[11] > 0 else False
                result['operation_tag'] = one_op[12]
                if not result['group']:
                    result['results'] = dao.get_results_for_operation(
                        result['id'])
                else:
                    result['results'] = None
                operations.append(result)
            except Exception:
                # We got an exception when processing one Operation Row. We will continue with the rest of the rows.
                self.logger.exception(
                    "Could not prepare operation for display:" + str(one_op))
        return selected_project, total_ops_nr, operations, pages_no

    def retrieve_projects_for_user(self, user_id, current_page=1):
        """
        Return a list with all Projects visible for current user.
        """
        start_idx = PROJECTS_PAGE_SIZE * (current_page - 1)
        total = dao.get_projects_for_user(user_id, is_count=True)
        available_projects = dao.get_projects_for_user(user_id, start_idx,
                                                       PROJECTS_PAGE_SIZE)
        pages_no = total // PROJECTS_PAGE_SIZE + (1 if total %
                                                  PROJECTS_PAGE_SIZE else 0)
        for prj in available_projects:
            fns, sta, err, canceled, pending = dao.get_operation_numbers(
                prj.id)
            prj.operations_finished = fns
            prj.operations_started = sta
            prj.operations_error = err
            prj.operations_canceled = canceled
            prj.operations_pending = pending
            prj.disk_size = dao.get_project_disk_size(prj.id)
            prj.disk_size_human = format_bytes_human(prj.disk_size)
        self.logger.debug("Displaying " + str(len(available_projects)) +
                          " projects in UI for user " + str(user_id))
        return available_projects, pages_no

    @staticmethod
    def retrieve_all_user_projects(user_id,
                                   page_start=0,
                                   page_size=PROJECTS_PAGE_SIZE):
        """
        Return a list with all projects visible for current user, without pagination.
        """
        return dao.get_projects_for_user(user_id,
                                         page_start=page_start,
                                         page_size=page_size)

    @staticmethod
    def get_linkable_projects_for_user(user_id, data_id):
        """
        Find projects with are visible for current user, and in which current datatype hasn't been linked yet.
        """
        return dao.get_linkable_projects_for_user(user_id, data_id)

    @transactional
    def remove_project(self, project_id):
        """
        Remove Project from DB and File Storage.
        """
        try:
            project2delete = dao.get_project_by_id(project_id)

            self.logger.debug("Deleting project: id=" + str(project_id) +
                              ' name=' + project2delete.name)
            project_datatypes = dao.get_datatypes_in_project(project_id)
            project_datatypes.sort(key=lambda dt: dt.create_date, reverse=True)
            for one_data in project_datatypes:
                self.remove_datatype(project_id, one_data.gid, True)

            links = dao.get_links_for_project(project_id)
            for one_link in links:
                dao.remove_entity(Links, one_link.id)
            project_bursts = dao.get_bursts_for_project(project_id)
            for burst in project_bursts:
                dao.remove_entity(burst.__class__, burst.id)

            self.storage_interface.remove_project(project2delete)
            dao.delete_project(project_id)
            self.logger.debug("Deleted project: id=" + str(project_id) +
                              ' name=' + project2delete.name)

        except RemoveDataTypeException as excep:
            self.logger.exception("Could not execute operation Node Remove!")
            raise ProjectServiceException(str(excep))
        except FileStructureException as excep:
            self.logger.exception("Could not delete because of rights!")
            raise ProjectServiceException(str(excep))
        except Exception as excep:
            self.logger.exception(str(excep))
            raise ProjectServiceException(str(excep))

    # ----------------- Methods for populating Data-Structure Page ---------------

    @staticmethod
    def get_datatype_in_group(group):
        """
        Return all dataTypes that are the result of the same DTgroup.
        """
        return dao.get_datatype_in_group(datatype_group_id=group)

    @staticmethod
    def get_datatypes_from_datatype_group(datatype_group_id):
        """
        Retrieve all dataType which are part from the given dataType group.
        """
        return dao.get_datatypes_from_datatype_group(datatype_group_id)

    @staticmethod
    def load_operation_by_gid(operation_gid):
        """ Retrieve loaded Operation from DB"""
        return dao.get_operation_by_gid(operation_gid)

    @staticmethod
    def load_operation_lazy_by_gid(operation_gid):
        """ Retrieve lazy Operation from DB"""
        return dao.get_operation_lazy_by_gid(operation_gid)

    @staticmethod
    def get_operation_group_by_id(operation_group_id):
        """ Loads OperationGroup from DB"""
        return dao.get_operationgroup_by_id(operation_group_id)

    @staticmethod
    def get_operation_group_by_gid(operation_group_gid):
        """ Loads OperationGroup from DB"""
        return dao.get_operationgroup_by_gid(operation_group_gid)

    @staticmethod
    def get_operations_in_group(operation_group):
        """ Return all the operations from an operation group. """
        return dao.get_operations_in_group(operation_group.id)

    @staticmethod
    def is_upload_operation(operation_gid):
        """ Returns True only if the operation with the given GID is an upload operation. """
        return dao.is_upload_operation(operation_gid)

    @staticmethod
    def get_all_operations_for_uploaders(project_id):
        """ Returns all finished upload operations. """
        return dao.get_all_operations_for_uploaders(project_id)

    def set_operation_and_group_visibility(self,
                                           entity_gid,
                                           is_visible,
                                           is_operation_group=False):
        """
        Sets the operation visibility.

        If 'is_operation_group' is True than this method will change the visibility for all
        the operation from the OperationGroup with the GID field equal to 'entity_gid'.
        """
        def set_visibility(op):
            # workaround:
            # 'reload' the operation so that it has the project property set.
            # get_operations_in_group does not eager load it and now we're out of a sqlalchemy session
            # write_operation_metadata requires that property
            op = dao.get_operation_by_id(op.id)
            # end hack
            op.visible = is_visible
            dao.store_entity(op)

        def set_group_descendants_visibility(operation_group_id):
            ops_in_group = dao.get_operations_in_group(operation_group_id)
            for group_op in ops_in_group:
                set_visibility(group_op)

        if is_operation_group:
            op_group_id = dao.get_operationgroup_by_gid(entity_gid).id
            set_group_descendants_visibility(op_group_id)
        else:
            operation = dao.get_operation_by_gid(entity_gid)
            # we assure that if the operation belongs to a group than the visibility will be changed for the entire group
            if operation.fk_operation_group is not None:
                set_group_descendants_visibility(operation.fk_operation_group)
            else:
                set_visibility(operation)

    def get_operation_details(self, operation_gid, is_group):
        """
        :returns: an entity OperationOverlayDetails filled with all information for current operation details.
        """

        if is_group:
            operation_group = self.get_operation_group_by_gid(operation_gid)
            operation = dao.get_operations_in_group(operation_group.id, False,
                                                    True)
            # Reload, to make sure all attributes lazy are populated as well.
            operation = dao.get_operation_by_gid(operation.gid)
            no_of_op_in_group = dao.get_operations_in_group(operation_group.id,
                                                            is_count=True)
            datatype_group = self.get_datatypegroup_by_op_group_id(
                operation_group.id)
            count_result = dao.count_datatypes_in_group(datatype_group.id)

        else:
            operation = dao.get_operation_by_gid(operation_gid)
            if operation is None:
                return None
            no_of_op_in_group = 1
            count_result = dao.count_resulted_datatypes(operation.id)

        user_display_name = dao.get_user_by_id(
            operation.fk_launched_by).display_name
        burst = dao.get_burst_for_operation_id(operation.id)
        datatypes_param, all_special_params = self._review_operation_inputs(
            operation.gid)

        op_pid = dao.get_operation_process_for_operation(operation.id)
        op_details = OperationOverlayDetails(operation, user_display_name,
                                             len(datatypes_param),
                                             count_result, burst,
                                             no_of_op_in_group, op_pid)

        # Add all parameter which are set differently by the user on this Operation.
        if all_special_params is not None:
            op_details.add_scientific_fields(all_special_params)
        return op_details

    @staticmethod
    def get_filterable_meta():
        """
        Contains all the attributes by which
        the user can structure the tree of DataTypes
        """
        return DataTypeMetaData.get_filterable_meta()

    def get_project_structure(self, project, visibility_filter, first_level,
                              second_level, filter_value):
        """
        Find all DataTypes (including the linked ones and the groups) relevant for the current project.
        In case of a problem, will return an empty list.
        """
        metadata_list = []
        dt_list = dao.get_data_in_project(project.id, visibility_filter,
                                          filter_value)

        for dt in dt_list:
            # Prepare the DT results from DB, for usage in controller, by converting into DataTypeMetaData objects
            data = {}
            is_group = False
            group_op = None

            #  Filter by dt.type, otherwise Links to individual DT inside a group will be mistaken
            if dt.type == "DataTypeGroup" and dt.parent_operation.operation_group is not None:
                is_group = True
                group_op = dt.parent_operation.operation_group

            # All these fields are necessary here for dynamic Tree levels.
            data[DataTypeMetaData.KEY_DATATYPE_ID] = dt.id
            data[DataTypeMetaData.KEY_GID] = dt.gid
            data[DataTypeMetaData.KEY_NODE_TYPE] = dt.display_type
            data[DataTypeMetaData.KEY_STATE] = dt.state
            data[DataTypeMetaData.KEY_SUBJECT] = str(dt.subject)
            data[DataTypeMetaData.KEY_TITLE] = dt.display_name
            data[DataTypeMetaData.KEY_RELEVANCY] = dt.visible
            data[DataTypeMetaData.
                 KEY_LINK] = dt.parent_operation.fk_launched_in != project.id

            data[DataTypeMetaData.
                 KEY_TAG_1] = dt.user_tag_1 if dt.user_tag_1 else ''
            data[DataTypeMetaData.
                 KEY_TAG_2] = dt.user_tag_2 if dt.user_tag_2 else ''
            data[DataTypeMetaData.
                 KEY_TAG_3] = dt.user_tag_3 if dt.user_tag_3 else ''
            data[DataTypeMetaData.
                 KEY_TAG_4] = dt.user_tag_4 if dt.user_tag_4 else ''
            data[DataTypeMetaData.
                 KEY_TAG_5] = dt.user_tag_5 if dt.user_tag_5 else ''

            # Operation related fields:
            operation_name = CommonDetails.compute_operation_name(
                dt.parent_operation.algorithm.algorithm_category.displayname,
                dt.parent_operation.algorithm.displayname)
            data[DataTypeMetaData.KEY_OPERATION_TYPE] = operation_name
            data[
                DataTypeMetaData.
                KEY_OPERATION_ALGORITHM] = dt.parent_operation.algorithm.displayname
            data[DataTypeMetaData.
                 KEY_AUTHOR] = dt.parent_operation.user.username
            data[
                DataTypeMetaData.
                KEY_OPERATION_TAG] = group_op.name if is_group else dt.parent_operation.user_group
            data[DataTypeMetaData.
                 KEY_OP_GROUP_ID] = group_op.id if is_group else None

            completion_date = dt.parent_operation.completion_date
            string_year = completion_date.strftime(
                MONTH_YEAR_FORMAT) if completion_date is not None else ""
            string_month = completion_date.strftime(
                DAY_MONTH_YEAR_FORMAT) if completion_date is not None else ""
            data[DataTypeMetaData.KEY_DATE] = date2string(completion_date) if (
                completion_date is not None) else ''
            data[DataTypeMetaData.KEY_CREATE_DATA_MONTH] = string_year
            data[DataTypeMetaData.KEY_CREATE_DATA_DAY] = string_month

            data[
                DataTypeMetaData.
                KEY_BURST] = dt._parent_burst.name if dt._parent_burst is not None else '-None-'

            metadata_list.append(DataTypeMetaData(data, dt.invalid))

        return StructureNode.metadata2tree(metadata_list, first_level,
                                           second_level, project.id,
                                           project.name)

    @staticmethod
    def get_datatype_details(datatype_gid):
        """
        :returns: an array. First entry in array is an instance of DataTypeOverlayDetails\
            The second one contains all the possible states for the specified dataType.
        """
        meta_atts = DataTypeOverlayDetails()
        states = DataTypeMetaData.STATES
        try:
            datatype_result = dao.get_datatype_details(datatype_gid)
            meta_atts.fill_from_datatype(datatype_result,
                                         datatype_result._parent_burst)
            return meta_atts, states, datatype_result
        except Exception:
            # We ignore exception here (it was logged above, and we want to return no details).
            return meta_atts, states, None

    def _remove_project_node_files(self,
                                   project_id,
                                   gid,
                                   skip_validation=False):
        """
        Delegate removal of a node in the structure of the project.
        In case of a problem will THROW StructureException.
        """
        try:
            project = self.find_project(project_id)
            datatype = dao.get_datatype_by_gid(gid)
            links = dao.get_links_for_datatype(datatype.id)

            op = dao.get_operation_by_id(datatype.fk_from_operation)
            if links:
                was_link = False
                for link in links:
                    # This means it's only a link and we need to remove it
                    if link.fk_from_datatype == datatype.id and link.fk_to_project == project.id:
                        dao.remove_entity(Links, link.id)
                        was_link = True
                if not was_link:
                    # Create a clone of the operation
                    # There is no view_model so the view_model_gid is None

                    new_op = Operation(
                        op.view_model_gid,
                        dao.get_system_user().id, links[0].fk_to_project,
                        datatype.parent_operation.fk_from_algo,
                        datatype.parent_operation.status,
                        datatype.parent_operation.start_date,
                        datatype.parent_operation.completion_date,
                        datatype.parent_operation.fk_operation_group,
                        datatype.parent_operation.additional_info,
                        datatype.parent_operation.user_group,
                        datatype.parent_operation.range_values)
                    new_op = dao.store_entity(new_op)
                    to_project = self.find_project(links[0].fk_to_project)
                    to_project_path = self.storage_interface.get_project_folder(
                        to_project.name)

                    full_path = h5.path_for_stored_index(datatype)
                    old_folder = self.storage_interface.get_project_folder(
                        project.name, str(op.id))
                    vm_full_path = h5.determine_filepath(
                        op.view_model_gid, old_folder)

                    self.storage_interface.move_datatype_with_sync(
                        to_project, to_project_path, new_op.id, full_path,
                        vm_full_path)

                    datatype.fk_from_operation = new_op.id
                    datatype.parent_operation = new_op
                    dao.store_entity(datatype)
                    dao.remove_entity(Links, links[0].id)
            else:
                specific_remover = get_remover(datatype.type)(datatype)
                specific_remover.remove_datatype(skip_validation)

        except RemoveDataTypeException:
            self.logger.exception("Could not execute operation Node Remove!")
            raise
        except FileStructureException:
            self.logger.exception("Remove operation failed")
            raise StructureException(
                "Remove operation failed for unknown reasons.Please contact system administrator."
            )

    def remove_operation(self, operation_id):
        """
        Remove a given operation
        """
        operation = dao.try_get_operation_by_id(operation_id)
        if operation is not None:
            self.logger.debug("Deleting operation %s " % operation)
            datatypes_for_op = dao.get_results_for_operation(operation_id)
            for dt in reversed(datatypes_for_op):
                self.remove_datatype(operation.project.id, dt.gid, False)
            # Here the Operation is mot probably already removed - in case DTs were found inside
            # but we still remove it for the case when no DTs exist
            dao.remove_entity(Operation, operation.id)
            self.storage_interface.remove_operation_data(
                operation.project.name, operation_id)
            self.storage_interface.push_folder_to_sync(operation.project.name)
            self.logger.debug("Finished deleting operation %s " % operation)
        else:
            self.logger.warning(
                "Attempt to delete operation with id=%s which no longer exists."
                % operation_id)

    def remove_datatype(self, project_id, datatype_gid, skip_validation=False):
        """
        Method used for removing a dataType. If the given dataType is a DatatypeGroup
        or a dataType from a DataTypeGroup than this method will remove the entire group.
        The operation(s) used for creating the dataType(s) will also be removed.
        """
        datatype = dao.get_datatype_by_gid(datatype_gid)
        if datatype is None:
            self.logger.warning(
                "Attempt to delete DT[%s] which no longer exists." %
                datatype_gid)
            return

        is_datatype_group = False
        datatype_group = None
        if dao.is_datatype_group(datatype_gid):
            is_datatype_group = True
            datatype_group = datatype
        elif datatype.fk_datatype_group is not None:
            is_datatype_group = True
            datatype_group = dao.get_datatype_by_id(datatype.fk_datatype_group)

        operations_set = [datatype.fk_from_operation]
        correct = True

        if is_datatype_group:
            operations_set = [datatype_group.fk_from_operation]
            self.logger.debug("Removing datatype group %s" % datatype_group)
            if datatype_group.fk_parent_burst:
                burst = dao.get_generic_entity(BurstConfiguration,
                                               datatype_group.fk_parent_burst,
                                               'gid')[0]
                dao.remove_entity(BurstConfiguration, burst.id)
                if burst.fk_metric_operation_group:
                    correct = correct and self._remove_operation_group(
                        burst.fk_metric_operation_group, project_id,
                        skip_validation, operations_set)

                if burst.fk_operation_group:
                    correct = correct and self._remove_operation_group(
                        burst.fk_operation_group, project_id, skip_validation,
                        operations_set)

            else:
                self._remove_datatype_group_dts(project_id, datatype_group.id,
                                                skip_validation,
                                                operations_set)

                datatype_group = dao.get_datatype_group_by_gid(
                    datatype_group.gid)
                dao.remove_entity(DataTypeGroup, datatype.id)
                correct = correct and dao.remove_entity(
                    OperationGroup, datatype_group.fk_operation_group)
        else:
            self.logger.debug("Removing datatype %s" % datatype)
            self._remove_project_node_files(project_id, datatype.gid,
                                            skip_validation)

        # Remove Operation entity in case no other DataType needs them.
        project = dao.get_project_by_id(project_id)
        for operation_id in operations_set:
            dependent_dt = dao.get_generic_entity(DataType, operation_id,
                                                  "fk_from_operation")
            if len(dependent_dt) > 0:
                # Do not remove Operation in case DataType still exist referring it.
                continue
            op_burst = dao.get_burst_for_operation_id(operation_id)
            if op_burst:
                correct = correct and dao.remove_entity(
                    BurstConfiguration, op_burst.id)
            correct = correct and dao.remove_entity(Operation, operation_id)
            # Make sure Operation folder is removed
            self.storage_interface.remove_operation_data(
                project.name, operation_id)

        self.storage_interface.push_folder_to_sync(project.name)
        if not correct:
            raise RemoveDataTypeException("Could not remove DataType " +
                                          str(datatype_gid))

    def _remove_operation_group(self, operation_group_id, project_id,
                                skip_validation, operations_set):
        metrics_groups = dao.get_generic_entity(DataTypeGroup,
                                                operation_group_id,
                                                'fk_operation_group')
        if len(metrics_groups) > 0:
            metric_datatype_group_id = metrics_groups[0].id
            self._remove_datatype_group_dts(project_id,
                                            metric_datatype_group_id,
                                            skip_validation, operations_set)
            dao.remove_entity(DataTypeGroup, metric_datatype_group_id)
        return dao.remove_entity(OperationGroup, operation_group_id)

    def _remove_datatype_group_dts(self, project_id, dt_group_id,
                                   skip_validation, operations_set):
        data_list = dao.get_datatypes_from_datatype_group(dt_group_id)
        for adata in data_list:
            self._remove_project_node_files(project_id, adata.gid,
                                            skip_validation)
            if adata.fk_from_operation not in operations_set:
                operations_set.append(adata.fk_from_operation)

    def update_metadata(self, submit_data):
        """
        Update DataType/ DataTypeGroup metadata
        THROW StructureException when input data is invalid.
        """
        new_data = dict()
        for key in DataTypeOverlayDetails().meta_attributes_list:
            if key in submit_data:
                value = submit_data[key]
                if value == "None":
                    value = None
                if value == "" and key in [
                        CommonDetails.CODE_OPERATION_TAG,
                        CommonDetails.CODE_OPERATION_GROUP_ID
                ]:
                    value = None
                new_data[key] = value

        try:
            if (CommonDetails.CODE_OPERATION_GROUP_ID in new_data
                    and new_data[CommonDetails.CODE_OPERATION_GROUP_ID]):
                # We need to edit a group
                all_data_in_group = dao.get_datatype_in_group(
                    operation_group_id=new_data[
                        CommonDetails.CODE_OPERATION_GROUP_ID])
                if len(all_data_in_group) < 1:
                    raise StructureException(
                        "Inconsistent group, can not be updated!")
                # datatype_group = dao.get_generic_entity(DataTypeGroup, all_data_in_group[0].fk_datatype_group)[0]
                # all_data_in_group.append(datatype_group)
                for datatype in all_data_in_group:
                    self._edit_data(datatype, new_data, True)
            else:
                # Get the required DataType and operation from DB to store changes that will be done in XML.
                gid = new_data[CommonDetails.CODE_GID]
                datatype = dao.get_datatype_by_gid(gid)
                self._edit_data(datatype, new_data)
        except Exception as excep:
            self.logger.exception(excep)
            raise StructureException(str(excep))

    def _edit_data(self, datatype, new_data, from_group=False):
        # type: (DataType, dict, bool) -> None
        """
        Private method, used for editing a meta-data XML file and a DataType row
        for a given custom DataType entity with new dictionary of data from UI.
        """
        # 1. First update Operation fields:
        #    Update group field if possible
        new_group_name = new_data[CommonDetails.CODE_OPERATION_TAG]
        empty_group_value = (new_group_name is None or new_group_name == "")
        if from_group:
            if empty_group_value:
                raise StructureException("Empty group is not allowed!")

            group = dao.get_generic_entity(
                OperationGroup,
                new_data[CommonDetails.CODE_OPERATION_GROUP_ID])
            if group and len(group) > 0 and new_group_name != group[0].name:
                group = group[0]
                exists_group = dao.get_generic_entity(OperationGroup,
                                                      new_group_name, 'name')
                if exists_group:
                    raise StructureException("Group '" + new_group_name +
                                             "' already exists.")
                group.name = new_group_name
                dao.store_entity(group)
        else:
            operation = dao.get_operation_by_id(datatype.fk_from_operation)
            operation.user_group = new_group_name
            dao.store_entity(operation)
            op_folder = self.storage_interface.get_project_folder(
                operation.project.name, str(operation.id))
            vm_gid = operation.view_model_gid
            view_model_file = h5.determine_filepath(vm_gid, op_folder)
            if view_model_file:
                view_model_class = H5File.determine_type(view_model_file)
                view_model = view_model_class()
                with ViewModelH5(view_model_file, view_model) as f:
                    ga = f.load_generic_attributes()
                    ga.operation_tag = new_group_name
                    f.store_generic_attributes(ga, False)
            else:
                self.logger.warning(
                    "Could not find ViewModel H5 file for op: {}".format(
                        operation))

        # 2. Update GenericAttributes in the associated H5 files:
        h5_path = h5.path_for_stored_index(datatype)
        with H5File.from_file(h5_path) as f:
            ga = f.load_generic_attributes()

            ga.subject = new_data[DataTypeOverlayDetails.DATA_SUBJECT]
            ga.state = new_data[DataTypeOverlayDetails.DATA_STATE]
            ga.operation_tag = new_group_name
            if DataTypeOverlayDetails.DATA_TAG_1 in new_data:
                ga.user_tag_1 = new_data[DataTypeOverlayDetails.DATA_TAG_1]
            if DataTypeOverlayDetails.DATA_TAG_2 in new_data:
                ga.user_tag_2 = new_data[DataTypeOverlayDetails.DATA_TAG_2]
            if DataTypeOverlayDetails.DATA_TAG_3 in new_data:
                ga.user_tag_3 = new_data[DataTypeOverlayDetails.DATA_TAG_3]
            if DataTypeOverlayDetails.DATA_TAG_4 in new_data:
                ga.user_tag_4 = new_data[DataTypeOverlayDetails.DATA_TAG_4]
            if DataTypeOverlayDetails.DATA_TAG_5 in new_data:
                ga.user_tag_5 = new_data[DataTypeOverlayDetails.DATA_TAG_5]

            f.store_generic_attributes(ga, False)

        # 3. Update MetaData in DT Index DB as well.
        datatype.fill_from_generic_attributes(ga)
        dao.store_entity(datatype)

    def _review_operation_inputs(self, operation_gid):
        """
        :returns: A list of DataTypes that are used as input parameters for the specified operation.
                 And a dictionary will all operation parameters different then the default ones.
        """
        operation = dao.get_operation_by_gid(operation_gid)
        try:
            adapter = ABCAdapter.build_adapter(operation.algorithm)
            return review_operation_inputs_from_adapter(adapter, operation)

        except Exception:
            self.logger.exception("Could not load details for operation %s" %
                                  operation_gid)
            if operation.view_model_gid:
                changed_parameters = dict(
                    Warning=
                    "Algorithm changed dramatically. We can not offer more details"
                )
            else:
                changed_parameters = dict(
                    Warning=
                    "GID parameter is missing. Old implementation of the operation."
                )
            return [], changed_parameters

    @staticmethod
    def get_results_for_operation(operation_id, selected_filter=None):
        """
        Retrieve the DataTypes entities resulted after the execution of the given operation.
        """
        return dao.get_results_for_operation(operation_id, selected_filter)

    @staticmethod
    def get_datatype_by_id(datatype_id):
        """Retrieve a DataType DB reference by its id."""
        return dao.get_datatype_by_id(datatype_id)

    @staticmethod
    def get_datatypegroup_by_gid(datatypegroup_gid):
        """ Returns the DataTypeGroup with the specified gid. """
        return dao.get_datatype_group_by_gid(datatypegroup_gid)

    @staticmethod
    def get_datatypegroup_by_op_group_id(operation_group_id):
        """ Returns the DataTypeGroup with the specified id. """
        return dao.get_datatypegroup_by_op_group_id(operation_group_id)

    @staticmethod
    def get_datatypes_in_project(project_id, only_visible=False):
        return dao.get_data_in_project(project_id, only_visible)

    @staticmethod
    def set_datatype_visibility(datatype_gid, is_visible):
        """
        Sets the dataType visibility. If the given dataType is a dataType group or it is part of a
        dataType group than this method will set the visibility for each dataType from this group.
        """
        def set_visibility(dt):
            """ set visibility flag, persist in db and h5"""
            dt.visible = is_visible
            dt = dao.store_entity(dt)

            h5_path = h5.path_for_stored_index(dt)
            with H5File.from_file(h5_path) as f:
                f.visible.store(is_visible)

        def set_group_descendants_visibility(datatype_group_id):
            datatypes_in_group = dao.get_datatypes_from_datatype_group(
                datatype_group_id)
            for group_dt in datatypes_in_group:
                set_visibility(group_dt)

        datatype = dao.get_datatype_by_gid(datatype_gid)

        if isinstance(datatype, DataTypeGroup):  # datatype is a group
            set_group_descendants_visibility(datatype.id)
            datatype.visible = is_visible
            dao.store_entity(datatype)
        elif datatype.fk_datatype_group is not None:  # datatype is member of a group
            set_group_descendants_visibility(datatype.fk_datatype_group)
            # the datatype to be updated is the parent datatype group
            parent = dao.get_datatype_by_id(datatype.fk_datatype_group)
            parent.visible = is_visible
            dao.store_entity(parent)
        else:
            # update the single datatype.
            set_visibility(datatype)

    @staticmethod
    def is_datatype_group(datatype_gid):
        """ Used to check if the dataType with the specified GID is a DataTypeGroup. """
        return dao.is_datatype_group(datatype_gid)

    def get_linked_datatypes_storage_path(self, project):
        """
        :return: the file paths to the datatypes that are linked in `project`
        """
        paths = []
        for lnk_dt in dao.get_linked_datatypes_in_project(project.id):
            # get datatype as a mapped type
            path = h5.path_for_stored_index(lnk_dt)
            if path is not None:
                paths.append(path)
            else:
                self.logger.warning(
                    "Problem when trying to retrieve path on %s:%s!" %
                    (lnk_dt.type, lnk_dt.gid))
        return paths
Пример #9
0
class DiagnoseDiskUsage(object):
    FORMAT_DT = '    {:14} {:20} {:>12} {:>12} {:>12} {:>12}'
    HEADER_DT = FORMAT_DT.format('', '', 'disk_size(kib)', 'db_size(kib)',
                                 'delta(kib)', 'ratio(%)')

    def __init__(self, prj_id):
        self.storage_interface = StorageInterface()
        self.expected_files = set()
        self.prj_disk_size, self.prj_db_size = 0, 0

        try:
            dao.session.open_session()
            # We do not fetch the project using dao because dao will detach it from the session.
            # We want to query on the fly on attribute access and this requires attached objects.
            # This code is doing a tree traversal of the db.
            # The query on attribute access style fits better than aggregating queries.
            self.project = dao.session.query(Project).filter(
                Project.id == prj_id).one()
            self.expected_files.add(
                self.storage_interface.get_project_meta_file_path(
                    self.project.name))
            root_path = self.storage_interface.get_project_folder(
                self.project.name)

            print()
            print('Reporting disk for project {} in {}'.format(
                self.project.name, root_path))
            print()
            print(self.HEADER_DT)

            for op in self.project.OPERATIONS:
                self.analyse_operation(op)

            print(self.HEADER_DT)
            self.print_usage_line('Project', 'total', self.prj_disk_size,
                                  self.prj_db_size)

            print()
            self.list_unexpected_project_files(root_path)
            print()
        finally:
            dao.session.close_session()

    @staticmethod
    def get_h5_by_gid(root, gid):
        for f in os.listdir(root):
            fp = os.path.join(root, f)
            if gid in f and os.path.isfile(fp):
                return fp

    @staticmethod
    def get_file_kib_size(fp):
        return int(round((os.path.getsize(fp) / 1024.)))

    @staticmethod
    def print_usage_line(col1, col2, actual, expected):
        if expected != 0:
            ratio = int(100.0 * actual / expected)
            if ratio > 200:
                ratio = "! %s" % ratio
            else:
                ratio = str(ratio)
        else:
            ratio = 'inf'

        delta = actual - expected
        if delta > 100:
            delta = "! %s" % delta
        else:
            delta = str(delta)

        print(
            DiagnoseDiskUsage.FORMAT_DT.format(col1, col2,
                                               '{:,}'.format(actual),
                                               '{:,}'.format(expected), delta,
                                               ratio))

    def analyse_operation(self, op):
        op_disk_size, op_db_size = 0, 0

        print('Operation {} : {}'.format(op.id, op.algorithm.name))

        for dt in op.DATA_TYPES:
            if dt.type == 'DataTypeGroup':
                # these have no h5
                continue
            op_pth = self.storage_interface.get_project_folder(
                self.project.name, str(op.id))
            dt_pth = self.get_h5_by_gid(op_pth, dt.gid)

            dt_actual_disk_size = self.get_file_kib_size(dt_pth)

            db_disk_size = dt.disk_size or 0

            op_disk_size += dt_actual_disk_size
            op_db_size += db_disk_size

            self.print_usage_line(dt.gid[:12], dt.type, dt_actual_disk_size,
                                  db_disk_size)
            self.expected_files.add(dt_pth)

        self.prj_disk_size += op_disk_size
        self.prj_db_size += op_db_size
        self.print_usage_line('', 'total :', op_disk_size, op_db_size)
        print()

    def list_unexpected_project_files(self, root_path):
        unexpected = []
        for r, d, files in os.walk(root_path):
            for f in files:
                pth = os.path.join(r, f)
                if pth not in self.expected_files:
                    unexpected.append(pth)

        print('Unexpected project files :')
        for f in unexpected:
            print(f)

        if not unexpected:
            print('yey! none found')
Пример #10
0
class BurstService(object):
    LAUNCH_NEW = 'new'
    LAUNCH_BRANCH = 'branch'

    def __init__(self):
        self.logger = get_logger(self.__class__.__module__)
        self.storage_interface = StorageInterface()

    def mark_burst_finished(self,
                            burst_entity,
                            burst_status=None,
                            error_message=None,
                            store_h5_file=True):
        """
        Mark Burst status field.
        Also compute 'weight' for current burst: no of operations inside, estimate time on disk...

        :param burst_entity: BurstConfiguration to be updated, at finish time.
        :param burst_status: BurstConfiguration status. By default BURST_FINISHED
        :param error_message: If given, set the status to error and perpetuate the message.
        """
        if burst_status is None:
            burst_status = BurstConfiguration.BURST_FINISHED
        if error_message is not None:
            burst_status = BurstConfiguration.BURST_ERROR

        try:
            # If there are any DataType Groups in current Burst, update their counter.
            burst_dt_groups = dao.get_generic_entity(DataTypeGroup,
                                                     burst_entity.gid,
                                                     "fk_parent_burst")
            for dt_group in burst_dt_groups:
                dt_group.count_results = dao.count_datatypes_in_group(
                    dt_group.id)
                dt_group.disk_size, dt_group.subject = dao.get_summary_for_group(
                    dt_group.id)
                dao.store_entity(dt_group)

            # Update actual Burst entity fields
            burst_entity.datatypes_number = dao.count_datatypes_in_burst(
                burst_entity.gid)

            burst_entity.status = burst_status
            burst_entity.error_message = error_message
            burst_entity.finish_time = datetime.now()
            dao.store_entity(burst_entity)
            if store_h5_file:
                self.store_burst_configuration(burst_entity)
        except Exception:
            self.logger.exception(
                "Could not correctly update Burst status and meta-data!")
            burst_entity.status = burst_status
            burst_entity.error_message = "Error when updating Burst Status"
            burst_entity.finish_time = datetime.now()
            dao.store_entity(burst_entity)
            if store_h5_file:
                self.store_burst_configuration(burst_entity)

    def persist_operation_state(self,
                                operation,
                                operation_status,
                                message=None):
        """
        Update Operation instance state. Store it in DB and on HDD/
        :param operation: Operation instance
        :param operation_status: new status
        :param message: message in case of error
        :return: operation instance changed
        """
        operation.mark_complete(operation_status, message)
        operation.queue_full = False
        operation = dao.store_entity(operation)
        # update burst also
        burst_config = self.get_burst_for_operation_id(operation.id)
        if burst_config is not None:
            burst_status = STATUS_FOR_OPERATION.get(operation_status)
            self.mark_burst_finished(burst_config, burst_status, message)
        return operation

    @staticmethod
    def get_burst_for_operation_id(operation_id, is_group=False):
        return dao.get_burst_for_operation_id(operation_id, is_group)

    def rename_burst(self, burst_id, new_name):
        """
        Rename the burst given by burst_id, setting it's new name to
        burst_name.
        """
        burst = dao.get_burst_by_id(burst_id)
        burst.name = new_name
        dao.store_entity(burst)
        self.store_burst_configuration(burst)

    @staticmethod
    def get_available_bursts(project_id):
        """
        Return all the burst for the current project.
        """
        bursts = dao.get_bursts_for_project(
            project_id, page_size=MAX_BURSTS_DISPLAYED) or []
        return bursts

    @staticmethod
    def populate_burst_disk_usage(bursts):
        """
        Adds a disk_usage field to each burst object.
        The disk usage is computed as the sum of the datatypes generated by a burst
        """
        sizes = dao.compute_bursts_disk_size([b.gid for b in bursts])
        for b in bursts:
            b.disk_size = format_bytes_human(sizes[b.gid])

    def update_history_status(self, id_list):
        """
        For each burst_id received in the id_list read new status from DB and return a list
        [id, new_status, is_group, message, running_time] tuple.
        """
        result = []
        for b_id in id_list:
            burst = dao.get_burst_by_id(b_id)
            if burst is not None:
                if burst.status == burst.BURST_RUNNING:
                    running_time = datetime.now() - burst.start_time
                else:
                    running_time = burst.finish_time - burst.start_time
                running_time = format_timedelta(running_time,
                                                most_significant2=False)

                if burst.status == burst.BURST_ERROR:
                    msg = 'Check Operations page for error Message'
                else:
                    msg = ''
                result.append([
                    burst.id, burst.status, burst.is_group, msg, running_time
                ])
            else:
                self.logger.debug("Could not find burst with id=" + str(b_id) +
                                  ". Might have been deleted by user!!")
        return result

    @staticmethod
    def update_simulation_fields(burst, op_simulation_id, simulation_gid):
        burst.fk_simulation = op_simulation_id
        burst.simulator_gid = simulation_gid.hex
        burst = dao.store_entity(burst)
        return burst

    @staticmethod
    def load_burst_configuration(burst_config_id):
        # type: (int) -> BurstConfiguration
        burst_config = dao.get_burst_by_id(burst_config_id)
        return burst_config

    @staticmethod
    def remove_burst_configuration(burst_config_id):
        # type: (int) -> None
        dao.remove_entity(BurstConfiguration, burst_config_id)

    @staticmethod
    def prepare_burst_for_pse(burst_config):
        # type: (BurstConfiguration) -> (BurstConfiguration)
        operation_group = OperationGroup(burst_config.fk_project,
                                         ranges=burst_config.ranges)
        operation_group = dao.store_entity(operation_group)

        metric_operation_group = OperationGroup(burst_config.fk_project,
                                                ranges=burst_config.ranges)
        metric_operation_group = dao.store_entity(metric_operation_group)

        burst_config.operation_group = operation_group
        burst_config.fk_operation_group = operation_group.id
        burst_config.metric_operation_group = metric_operation_group
        burst_config.fk_metric_operation_group = metric_operation_group.id
        return dao.store_entity(burst_config)

    @staticmethod
    def store_burst_configuration(burst_config):
        project = dao.get_project_by_id(burst_config.fk_project)
        bc_path = h5.path_for(burst_config.fk_simulation, BurstConfigurationH5,
                              burst_config.gid, project.name)
        with BurstConfigurationH5(bc_path) as bc_h5:
            bc_h5.store(burst_config)

    @staticmethod
    def load_burst_configuration_from_folder(simulator_folder, project):
        bc_h5_filename = DirLoader(
            simulator_folder,
            None).find_file_for_has_traits_type(BurstConfiguration)
        burst_config = BurstConfiguration(project.id)
        with BurstConfigurationH5(
                os.path.join(simulator_folder, bc_h5_filename)) as bc_h5:
            bc_h5.load_into(burst_config)
        return burst_config

    @staticmethod
    def prepare_simulation_name(burst, project_id):
        simulation_number = dao.get_number_of_bursts(project_id) + 1

        if burst.name is None:
            simulation_name = 'simulation_' + str(simulation_number)
        else:
            simulation_name = burst.name

        return simulation_name, simulation_number

    def prepare_indexes_for_simulation_results(self, operation,
                                               result_filenames, burst):
        indexes = list()
        self.logger.debug(
            "Preparing indexes for simulation results in operation {}...".
            format(operation.id))
        for filename in result_filenames:
            try:
                self.logger.debug(
                    "Preparing index for filename: {}".format(filename))
                index = h5.index_for_h5_file(filename)()
                h5_class = h5.REGISTRY.get_h5file_for_index(type(index))

                with h5_class(filename) as index_h5:
                    index.fill_from_h5(index_h5)
                    index.fill_from_generic_attributes(
                        index_h5.load_generic_attributes())

                index.fk_parent_burst = burst.gid
                index.fk_from_operation = operation.id
                if operation.fk_operation_group:
                    datatype_group = dao.get_datatypegroup_by_op_group_id(
                        operation.fk_operation_group)
                    self.logger.debug(
                        "Found DatatypeGroup with id {} for operation {}".
                        format(datatype_group.id, operation.id))
                    index.fk_datatype_group = datatype_group.id

                    # Update the operation group name
                    operation_group = dao.get_operationgroup_by_id(
                        operation.fk_operation_group)
                    operation_group.fill_operationgroup_name(
                        "TimeSeriesRegionIndex")
                    dao.store_entity(operation_group)
                self.logger.debug(
                    "Prepared index {} for file {} in operation {}".format(
                        index.summary_info, filename, operation.id))
                indexes.append(index)
            except Exception as e:
                self.logger.debug(
                    "Skip preparing index {} because there was an error.".
                    format(filename))
                self.logger.error(e)
        self.logger.debug(
            "Prepared {} indexes for results in operation {}...".format(
                len(indexes), operation.id))
        return indexes

    def prepare_index_for_metric_result(self, operation, result_filename,
                                        burst):
        self.logger.debug(
            "Preparing index for metric result in operation {}...".format(
                operation.id))
        index = h5.index_for_h5_file(result_filename)()
        with DatatypeMeasureH5(result_filename) as dti_h5:
            index.gid = dti_h5.gid.load().hex
            index.metrics = json.dumps(dti_h5.metrics.load())
            index.fk_source_gid = dti_h5.analyzed_datatype.load().hex
        index.fk_from_operation = operation.id
        index.fk_parent_burst = burst.gid
        datatype_group = dao.get_datatypegroup_by_op_group_id(
            operation.fk_operation_group)
        self.logger.debug(
            "Found DatatypeGroup with id {} for operation {}".format(
                datatype_group.id, operation.id))
        index.fk_datatype_group = datatype_group.id
        self.logger.debug(
            "Prepared index {} for results in operation {}...".format(
                index.summary_info, operation.id))
        return index

    def _update_pse_burst_status(self, burst_config):
        operations_in_group = dao.get_operations_in_group(
            burst_config.fk_operation_group)
        if burst_config.fk_metric_operation_group:
            operations_in_group.extend(
                dao.get_operations_in_group(
                    burst_config.fk_metric_operation_group))
        operation_statuses = list()
        for operation in operations_in_group:
            if not has_finished(operation.status):
                self.logger.debug(
                    'Operation {} in group {} is not finished, burst status will not be updated'
                    .format(operation.id, operation.fk_operation_group))
                return
            operation_statuses.append(operation.status)
        self.logger.debug(
            'All operations in burst {} have finished. Will update burst status'
            .format(burst_config.id))
        if STATUS_ERROR in operation_statuses:
            self.mark_burst_finished(
                burst_config, BurstConfiguration.BURST_ERROR,
                'Some operations in PSE have finished with errors')
        elif STATUS_CANCELED in operation_statuses:
            self.mark_burst_finished(burst_config,
                                     BurstConfiguration.BURST_CANCELED)
        else:
            self.mark_burst_finished(burst_config)

    def update_burst_status(self, burst_config):
        if burst_config.fk_operation_group:
            self._update_pse_burst_status(burst_config)
        else:
            operation = dao.get_operation_by_id(burst_config.fk_simulation)
            message = operation.additional_info
            if len(message) == 0:
                message = None
            self.mark_burst_finished(burst_config,
                                     STATUS_FOR_OPERATION[operation.status],
                                     message)

    @staticmethod
    def prepare_metrics_operation(operation):
        # TODO reuse from OperationService and do not duplicate logic here
        parent_burst = dao.get_generic_entity(BurstConfiguration,
                                              operation.fk_operation_group,
                                              'fk_operation_group')[0]
        metric_operation_group_id = parent_burst.fk_metric_operation_group
        range_values = operation.range_values
        metric_algo = dao.get_algorithm_by_module(MEASURE_METRICS_MODULE,
                                                  MEASURE_METRICS_CLASS)

        metric_operation = Operation(None,
                                     operation.fk_launched_by,
                                     operation.fk_launched_in,
                                     metric_algo.id,
                                     status=STATUS_FINISHED,
                                     op_group_id=metric_operation_group_id,
                                     range_values=range_values)
        metric_operation.visible = False
        metric_operation = dao.store_entity(metric_operation)
        op_dir = StorageInterface().get_project_folder(
            operation.project.name, str(metric_operation.id))
        return op_dir, metric_operation

    @staticmethod
    def get_range_param_by_name(param_name, all_range_parameters):
        for range_param in all_range_parameters:
            if param_name == range_param.name:
                return range_param

        return None

    @staticmethod
    def handle_range_params_at_loading(burst_config, all_range_parameters):
        param1, param2 = None, None
        if burst_config.range1:
            param1 = RangeParameter.from_json(burst_config.range1)
            param1.fill_from_default(
                BurstService.get_range_param_by_name(param1.name,
                                                     all_range_parameters))
            if burst_config.range2 is not None:
                param2 = RangeParameter.from_json(burst_config.range2)
                param2.fill_from_default(
                    BurstService.get_range_param_by_name(
                        param2.name, all_range_parameters))

        return param1, param2

    def prepare_data_for_burst_copy(self, burst_config_id, burst_name_format,
                                    project):
        burst_config = self.load_burst_configuration(burst_config_id)
        burst_config_copy = burst_config.clone()
        count = dao.count_bursts_with_name(burst_config.name,
                                           burst_config.fk_project)
        burst_config_copy.name = burst_name_format.format(
            burst_config.name, count + 1)

        storage_path = self.storage_interface.get_project_folder(
            project.name, str(burst_config.fk_simulation))
        simulator = h5.load_view_model(burst_config.simulator_gid,
                                       storage_path)
        simulator.generic_attributes = GenericAttributes()
        return simulator, burst_config_copy

    @staticmethod
    def store_burst(burst_config):
        return dao.store_entity(burst_config)

    def load_simulation_from_zip(self, zip_file, project):
        import_service = ImportService()
        simulator_folder = import_service.import_simulator_configuration_zip(
            zip_file)

        simulator_h5_filename = DirLoader(
            simulator_folder,
            None).find_file_for_has_traits_type(SimulatorAdapterModel)
        simulator_h5_filepath = os.path.join(simulator_folder,
                                             simulator_h5_filename)
        simulator = h5.load_view_model_from_file(simulator_h5_filepath)

        burst_config = self.load_burst_configuration_from_folder(
            simulator_folder, project)
        burst_config_copy = burst_config.clone()
        simulator.generic_attributes.parent_burst = burst_config_copy.gid

        return simulator, burst_config_copy, simulator_folder
Пример #11
0
class TestHPCSchedulerClient(BaseTestCase):
    def setup_method(self):
        self.storage_interface = StorageInterface()
        self.dir_gid = '123'
        self.encryption_handler = self.storage_interface.get_encryption_handler(
            self.dir_gid)
        self.clean_database()
        self.test_user = TestFactory.create_user()
        self.test_project = TestFactory.create_project(self.test_user)

    def _prepare_dummy_files(self, tmpdir):
        dummy_file1 = os.path.join(str(tmpdir), 'dummy1.txt')
        open(dummy_file1, 'a').close()
        dummy_file2 = os.path.join(str(tmpdir), 'dummy2.txt')
        open(dummy_file2, 'a').close()
        job_inputs = [dummy_file1, dummy_file2]
        return job_inputs

    def test_encrypt_inputs(self, tmpdir):
        job_inputs = self._prepare_dummy_files(tmpdir)
        job_encrypted_inputs = self.encryption_handler.encrypt_inputs(
            job_inputs)
        # Encrypted folder has 2 more files are more then plain folder
        assert len(job_encrypted_inputs) == len(job_inputs)

    def test_decrypt_results(self, tmpdir):
        # Prepare encrypted dir
        job_inputs = self._prepare_dummy_files(tmpdir)
        self.encryption_handler.encrypt_inputs(job_inputs)
        encrypted_dir = self.encryption_handler.get_encrypted_dir()

        # Unencrypt data
        out_dir = os.path.join(str(tmpdir), 'output')
        os.mkdir(out_dir)
        self.encryption_handler.decrypt_results_to_dir(out_dir)
        list_plain_dir = os.listdir(out_dir)
        assert len(list_plain_dir) == len(os.listdir(encrypted_dir))
        assert 'dummy1.txt' in list_plain_dir
        assert 'dummy2.txt' in list_plain_dir

    def test_decrypt_files(self, tmpdir):
        # Prepare encrypted dir
        job_inputs = self._prepare_dummy_files(tmpdir)
        enc_files = self.encryption_handler.encrypt_inputs(job_inputs)

        # Unencrypt data
        out_dir = os.path.join(str(tmpdir), 'output')
        os.mkdir(out_dir)
        self.encryption_handler.decrypt_files_to_dir([enc_files[1]], out_dir)
        list_plain_dir = os.listdir(out_dir)
        assert len(list_plain_dir) == 1
        assert os.path.basename(enc_files[0]).replace('.aes',
                                                      '') not in list_plain_dir
        assert os.path.basename(enc_files[1]).replace('.aes',
                                                      '') in list_plain_dir

    def test_do_operation_launch(self, simulator_factory, operation_factory,
                                 mocker):
        # Prepare encrypted dir
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op)

        self._do_operation_launch(op, sim_gid, mocker)

    def _do_operation_launch(self, op, sim_gid, mocker, is_pse=False):
        # Prepare encrypted dir
        self.dir_gid = sim_gid
        self.encryption_handler = StorageInterface.get_encryption_handler(
            self.dir_gid)
        job_encrypted_inputs = HPCSchedulerClient()._prepare_input(
            op, self.dir_gid)
        self.encryption_handler.encrypt_inputs(job_encrypted_inputs)
        encrypted_dir = self.encryption_handler.get_encrypted_dir()

        mocker.patch('tvb.core.operation_hpc_launcher._request_passfile',
                     _request_passfile_dummy)
        mocker.patch(
            'tvb.core.operation_hpc_launcher._update_operation_status',
            _update_operation_status)

        # Call do_operation_launch similarly to CSCS env
        plain_dir = self.storage_interface.get_project_folder(
            self.test_project.name, 'plain')
        do_operation_launch(self.dir_gid, 1000, is_pse, '', op.id, plain_dir)
        assert len(os.listdir(encrypted_dir)) == 7
        output_path = os.path.join(encrypted_dir,
                                   HPCSchedulerClient.OUTPUT_FOLDER)
        assert os.path.exists(output_path)
        expected_files = 2
        if is_pse:
            expected_files = 3
        assert len(os.listdir(output_path)) == expected_files
        return output_path

    def test_do_operation_launch_pse(self, simulator_factory,
                                     operation_factory, mocker):
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op)
        self._do_operation_launch(op, sim_gid, mocker, is_pse=True)

    def test_prepare_inputs(self, operation_factory, simulator_factory):
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op)
        hpc_client = HPCSchedulerClient()
        input_files = hpc_client._prepare_input(op, sim_gid)
        assert len(input_files) == 6

    def test_prepare_inputs_with_surface(self, operation_factory,
                                         simulator_factory):
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        sim_folder, sim_gid = simulator_factory(op=op, with_surface=True)
        hpc_client = HPCSchedulerClient()
        input_files = hpc_client._prepare_input(op, sim_gid)
        assert len(input_files) == 9

    def test_prepare_inputs_with_eeg_monitor(self, operation_factory,
                                             simulator_factory,
                                             surface_index_factory,
                                             sensors_index_factory,
                                             region_mapping_index_factory,
                                             connectivity_index_factory):
        surface_idx, surface = surface_index_factory(cortical=True)
        sensors_idx, sensors = sensors_index_factory()
        proj = ProjectionSurfaceEEG(sensors=sensors,
                                    sources=surface,
                                    projection_data=numpy.ones(3))

        op = operation_factory()
        prj_db_db = h5.store_complete(proj, op.id, op.project.name)
        prj_db_db.fk_from_operation = op.id
        dao.store_entity(prj_db_db)

        connectivity = connectivity_index_factory(76, op)
        rm_index = region_mapping_index_factory(conn_gid=connectivity.gid,
                                                surface_gid=surface_idx.gid)

        eeg_monitor = EEGViewModel(projection=proj.gid, sensors=sensors.gid)
        eeg_monitor.region_mapping = rm_index.gid

        sim_folder, sim_gid = simulator_factory(op=op,
                                                monitor=eeg_monitor,
                                                conn_gid=connectivity.gid)
        hpc_client = HPCSchedulerClient()
        input_files = hpc_client._prepare_input(op, sim_gid)
        assert len(input_files) == 11

    def test_stage_out_to_operation_folder(self, mocker, operation_factory,
                                           simulator_factory,
                                           pse_burst_configuration_factory):
        burst = pse_burst_configuration_factory(self.test_project)
        op = operation_factory(test_user=self.test_user,
                               test_project=self.test_project)
        op.fk_operation_group = burst.fk_operation_group
        dao.store_entity(op)

        sim_folder, sim_gid = simulator_factory(op=op)
        burst.simulator_gid = sim_gid.hex
        dao.store_entity(burst)

        output_path = self._do_operation_launch(op,
                                                sim_gid,
                                                mocker,
                                                is_pse=True)

        def _stage_out_dummy(dir, sim_gid):
            return [
                os.path.join(output_path, enc_file)
                for enc_file in os.listdir(output_path)
            ]

        mocker.patch.object(HPCSchedulerClient, '_stage_out_results',
                            _stage_out_dummy)
        sim_results_files, metric_op, metric_file = HPCSchedulerClient.stage_out_to_operation_folder(
            None, op, sim_gid)
        assert op.id != metric_op.id
        assert os.path.exists(metric_file)
        assert len(sim_results_files) == 1
        assert os.path.exists(sim_results_files[0])

    def teardown_method(self):
        encrypted_dir = self.encryption_handler.get_encrypted_dir()
        passfile = self.encryption_handler.get_password_file()
        self.storage_interface.remove_files([encrypted_dir, passfile])
        self.clean_database()
Пример #12
0
class TVBLoader(object):

    def __init__(self, registry):
        self.storage_interface = StorageInterface()
        self.registry = registry

    def path_for_stored_index(self, dt_index_instance):
        # type: (DataType) -> str
        """ Given a Datatype(HasTraitsIndex) instance, build where the corresponding H5 should be or is stored"""
        if hasattr(dt_index_instance, 'fk_simulation'):
            # In case of BurstConfiguration the operation id is on fk_simulation
            op_id = dt_index_instance.fk_simulation
        else:
            op_id = dt_index_instance.fk_from_operation
        operation = dao.get_operation_by_id(op_id)
        operation_folder = self.storage_interface.get_project_folder(operation.project.name, str(operation.id))

        gid = uuid.UUID(dt_index_instance.gid)
        h5_file_class = self.registry.get_h5file_for_index(dt_index_instance.__class__)
        fname = self.storage_interface.get_filename(h5_file_class.file_name_base(), gid)

        return os.path.join(operation_folder, fname)

    def path_for(self, op_id, h5_file_class, gid, project_name, dt_class):
        return self.storage_interface.path_for(op_id, h5_file_class, gid, project_name, dt_class)

    def path_by_dir(self, base_dir, h5_file_class, gid, dt_class):
        return self.storage_interface.path_by_dir(base_dir, h5_file_class, gid, dt_class)

    def load_from_index(self, dt_index):
        # type: (DataType) -> HasTraits
        h5_path = self.path_for_stored_index(dt_index)
        h5_file_class = self.registry.get_h5file_for_index(dt_index.__class__)
        traits_class = self.registry.get_datatype_for_index(dt_index)
        with h5_file_class(h5_path) as f:
            result_dt = traits_class()
            f.load_into(result_dt)
        return result_dt

    def load_complete_by_function(self, file_path, load_ht_function):
        # type: (str, callable) -> (HasTraits, GenericAttributes)
        with H5File.from_file(file_path) as f:
            try:
                datatype_cls = self.registry.get_datatype_for_h5file(f)
            except KeyError:
                datatype_cls = f.determine_datatype_from_file()
            datatype = datatype_cls()
            f.load_into(datatype)
            ga = f.load_generic_attributes()
            sub_dt_refs = f.gather_references(datatype_cls)

        for traited_attr, sub_gid in sub_dt_refs:
            if sub_gid is None:
                continue
            is_monitor = False
            if isinstance(sub_gid, list):
                sub_gid = sub_gid[0]
                is_monitor = True
            ref_ht = load_ht_function(sub_gid, traited_attr)
            if is_monitor:
                ref_ht = [ref_ht]
            setattr(datatype, traited_attr.field_name, ref_ht)

        return datatype, ga

    def load_with_references(self, file_path):
        def load_ht_function(sub_gid, traited_attr):
            ref_idx = dao.get_datatype_by_gid(sub_gid.hex, load_lazy=False)
            ref_ht = self.load_from_index(ref_idx)
            return ref_ht

        return self.load_complete_by_function(file_path, load_ht_function)

    def load_with_links(self, file_path):
        def load_ht_function(sub_gid, traited_attr):
            # Used traited_attr.default for cases similar to ProjectionMonitor which has obsnoise of type Noise and
            # it cannot be instantiated due to abstract methods, while the default is Additive()
            ref_ht = traited_attr.default or traited_attr.field_type()
            ref_ht.gid = sub_gid
            return ref_ht

        return self.load_complete_by_function(file_path, load_ht_function)
Пример #13
0
class TestStimulusCreator(TransactionalTestCase):
    def transactional_setup_method(self):
        """
        Reset the database before each test.
        """
        self.test_user = TestFactory.create_user('Stim_User')
        self.test_project = TestFactory.create_project(self.test_user,
                                                       "Stim_Project")
        self.storage_interface = StorageInterface()

        zip_path = os.path.join(os.path.dirname(tvb_data.__file__),
                                'connectivity', 'connectivity_66.zip')
        TestFactory.import_zip_connectivity(self.test_user, self.test_project,
                                            zip_path)
        self.connectivity = TestFactory.get_entity(self.test_project,
                                                   ConnectivityIndex)

        cortex = os.path.join(os.path.dirname(tvb_data.surfaceData.__file__),
                              'cortex_16384.zip')
        self.surface = TestFactory.import_surface_zip(self.test_user,
                                                      self.test_project,
                                                      cortex, CORTICAL)

    def transactional_teardown_method(self):
        """
        Remove project folders and clean up database.
        """
        self.storage_interface.remove_project_structure(self.test_project.name)

    def test_create_stimulus_region(self):
        weight_array = numpy.zeros(self.connectivity.number_of_regions)
        region_stimulus_creator = RegionStimulusCreator()
        region_stimulus_creator.storage_path = self.storage_interface.get_project_folder(
            self.test_project.name, "42")

        view_model = region_stimulus_creator.get_view_model_class()()
        view_model.connectivity = self.connectivity.gid
        view_model.weight = weight_array
        view_model.temporal = TemporalApplicableEquation()
        view_model.temporal.parameters['a'] = 1.0
        view_model.temporal.parameters['b'] = 2.0

        region_stimulus_index = region_stimulus_creator.launch(view_model)

        assert region_stimulus_index.temporal_equation == 'TemporalApplicableEquation'
        assert json.loads(region_stimulus_index.temporal_parameters) == {
            'a': 1.0,
            'b': 2.0
        }
        assert region_stimulus_index.fk_connectivity_gid == self.connectivity.gid

    def test_create_stimulus_region_with_operation(self):
        weight_array = numpy.zeros(self.connectivity.number_of_regions)
        region_stimulus_creator = RegionStimulusCreator()

        view_model = region_stimulus_creator.get_view_model_class()()
        view_model.connectivity = self.connectivity.gid
        view_model.weight = weight_array
        view_model.temporal = TemporalApplicableEquation()
        view_model.temporal.parameters['a'] = 1.0
        view_model.temporal.parameters['b'] = 2.0

        OperationService().fire_operation(region_stimulus_creator,
                                          self.test_user,
                                          self.test_project.id,
                                          view_model=view_model)
        region_stimulus_index = TestFactory.get_entity(self.test_project,
                                                       StimuliRegionIndex)

        assert region_stimulus_index.temporal_equation == 'TemporalApplicableEquation'
        assert json.loads(region_stimulus_index.temporal_parameters) == {
            'a': 1.0,
            'b': 2.0
        }
        assert region_stimulus_index.fk_connectivity_gid == self.connectivity.gid

    def test_create_stimulus_surface(self):
        surface_stimulus_creator = SurfaceStimulusCreator()
        surface_stimulus_creator.storage_path = self.storage_interface.get_project_folder(
            self.test_project.name, "42")

        view_model = surface_stimulus_creator.get_view_model_class()()
        view_model.surface = self.surface.gid
        view_model.focal_points_triangles = numpy.array([1, 2, 3])
        view_model.spatial = FiniteSupportEquation()
        view_model.spatial_amp = 1.0
        view_model.spatial_sigma = 1.0
        view_model.spatial_offset = 0.0
        view_model.temporal = TemporalApplicableEquation()
        view_model.temporal.parameters['a'] = 1.0
        view_model.temporal.parameters['b'] = 0.0

        surface_stimulus_index = surface_stimulus_creator.launch(view_model)

        assert surface_stimulus_index.spatial_equation == 'FiniteSupportEquation'
        assert surface_stimulus_index.temporal_equation == 'TemporalApplicableEquation'
        assert surface_stimulus_index.fk_surface_gid == self.surface.gid

    def test_create_stimulus_surface_with_operation(self):
        surface_stimulus_creator = SurfaceStimulusCreator()

        view_model = surface_stimulus_creator.get_view_model_class()()
        view_model.surface = self.surface.gid
        view_model.focal_points_triangles = numpy.array([1, 2, 3])
        view_model.spatial = FiniteSupportEquation()
        view_model.spatial_amp = 1.0
        view_model.spatial_sigma = 1.0
        view_model.spatial_offset = 0.0
        view_model.temporal = TemporalApplicableEquation()
        view_model.temporal.parameters['a'] = 1.0
        view_model.temporal.parameters['b'] = 0.0

        OperationService().fire_operation(surface_stimulus_creator,
                                          self.test_user,
                                          self.test_project.id,
                                          view_model=view_model)
        surface_stimulus_index = TestFactory.get_entity(
            self.test_project, StimuliSurfaceIndex)

        assert surface_stimulus_index.spatial_equation == 'FiniteSupportEquation'
        assert surface_stimulus_index.temporal_equation == 'TemporalApplicableEquation'
        assert surface_stimulus_index.fk_surface_gid == self.surface.gid
Пример #14
0
class TestGIFTISurfaceImporter(BaseTestCase):
    """
    Unit-tests for GIFTI Surface importer.
    """

    GIFTI_SURFACE_FILE = os.path.join(os.path.dirname(demo_data.__file__),
                                      'sample.cortex.gii')
    GIFTI_TIME_SERIES_FILE = os.path.join(os.path.dirname(demo_data.__file__),
                                          'sample.time_series.gii')
    WRONG_GII_FILE = os.path.abspath(__file__)

    def setup_method(self):
        self.test_user = TestFactory.create_user('Gifti_User')
        self.test_project = TestFactory.create_project(self.test_user,
                                                       "Gifti_Project")
        self.storage_interface = StorageInterface()

    def teardown_method(self):
        """
        Clean-up tests data
        """
        self.clean_database()
        self.storage_interface.remove_project_structure(self.test_project.name)

    def test_import_surface_gifti_data(self, operation_factory):
        """
            This method tests import of a surface from GIFTI file.
            !!! Important: We changed this test to execute only GIFTI parse
                because storing surface it takes too long (~ 9min) since
                normals needs to be calculated.
        """
        operation_id = operation_factory().id
        storage_path = self.storage_interface.get_project_folder(
            self.test_project.name, str(operation_id))

        parser = GIFTIParser(storage_path, operation_id)
        surface = parser.parse(self.GIFTI_SURFACE_FILE)

        assert 131342 == len(surface.vertices)
        assert 262680 == len(surface.triangles)

    def test_import_timeseries_gifti_data(self, operation_factory):
        """
        This method tests import of a time series from GIFTI file.
        !!! Important: We changed this test to execute only GIFTI parse
            because storing surface it takes too long (~ 9min) since
            normals needs to be calculated.
        """
        operation_id = operation_factory().id
        storage_path = self.storage_interface.get_project_folder(
            self.test_project.name, str(operation_id))

        parser = GIFTIParser(storage_path, operation_id)
        time_series = parser.parse(self.GIFTI_TIME_SERIES_FILE)

        data_shape = time_series[1]

        assert 135 == len(data_shape)
        assert 143479 == data_shape[0].dims[0]

    def test_import_wrong_gii_file(self):
        """
        This method tests import of a file in a wrong format
        """
        try:
            TestFactory.import_surface_gifti(self.test_user, self.test_project,
                                             self.WRONG_GII_FILE)
            raise AssertionError(
                "Import should fail in case of a wrong GIFTI format.")
        except OperationException:
            # Expected exception
            pass
Пример #15
0
class ImportService(object):
    """
    Service for importing TVB entities into system.
    It supports TVB exported H5 files as input, but it should also handle H5 files
    generated outside of TVB, as long as they respect the same structure.
    """
    def __init__(self):
        self.logger = get_logger(__name__)
        self.user_id = None
        self.storage_interface = StorageInterface()
        self.created_projects = []
        self.view_model2adapter = self._populate_view_model2adapter()

    def _download_and_unpack_project_zip(self, uploaded, uq_file_name,
                                         temp_folder):

        if isinstance(uploaded, (FieldStorage, Part)):
            if not uploaded.file:
                raise ImportException(
                    "Please select the archive which contains the project structure."
                )
            with open(uq_file_name, 'wb') as file_obj:
                self.storage_interface.copy_file(uploaded.file, file_obj)
        else:
            shutil.copy2(uploaded, uq_file_name)

        try:
            self.storage_interface.unpack_zip(uq_file_name, temp_folder)
        except FileStructureException as excep:
            self.logger.exception(excep)
            raise ImportException(
                "Bad ZIP archive provided. A TVB exported project is expected!"
            )

    @staticmethod
    def _compute_unpack_path():
        """
        :return: the name of the folder where to expand uploaded zip
        """
        now = datetime.now()
        date_str = "%d-%d-%d_%d-%d-%d_%d" % (now.year, now.month, now.day,
                                             now.hour, now.minute, now.second,
                                             now.microsecond)
        uq_name = "%s-ImportProject" % date_str
        return os.path.join(TvbProfile.current.TVB_TEMP_FOLDER, uq_name)

    @transactional
    def import_project_structure(self, uploaded, user_id):
        """
        Execute import operations:

        1. check if ZIP or folder
        2. find all project nodes
        3. for each project node:
            - create project
            - create all operations and groups
            - import all images
            - create all dataTypes
        """

        self.user_id = user_id
        self.created_projects = []

        # Now compute the name of the folder where to explode uploaded ZIP file
        temp_folder = self._compute_unpack_path()
        uq_file_name = temp_folder + ".zip"

        try:
            self._download_and_unpack_project_zip(uploaded, uq_file_name,
                                                  temp_folder)
            self._import_project_from_folder(temp_folder)

        except Exception as excep:
            self.logger.exception(
                "Error encountered during import. Deleting projects created during this operation."
            )
            # Remove project folders created so far.
            # Note that using the project service to remove the projects will not work,
            # because we do not have support for nested transaction.
            # Removing from DB is not necessary because in transactional env a simple exception throw
            # will erase everything to be inserted.
            for project in self.created_projects:
                self.storage_interface.remove_project(project)
            raise ImportException(str(excep))

        finally:
            # Now delete uploaded file and temporary folder where uploaded ZIP was exploded.
            self.storage_interface.remove_files([uq_file_name, temp_folder])

    def _import_project_from_folder(self, temp_folder):
        """
        Process each project from the uploaded pack, to extract names.
        """
        temp_project_path = None
        for root, _, files in os.walk(temp_folder):
            if StorageInterface.TVB_PROJECT_FILE in files:
                temp_project_path = root
                break

        if temp_project_path is not None:
            update_manager = ProjectUpdateManager(temp_project_path)

            if update_manager.checked_version < 3:
                raise ImportException(
                    'Importing projects with versions older than 3 is not supported in TVB 2! '
                    'Please import the project in TVB 1.5.8 and then launch the current version of '
                    'TVB in order to upgrade this project!')

            update_manager.run_all_updates()
            project = self.__populate_project(temp_project_path)
            # Populate the internal list of create projects so far, for cleaning up folders, in case of failure
            self.created_projects.append(project)
            # Ensure project final folder exists on disk
            project_path = self.storage_interface.get_project_folder(
                project.name)
            shutil.move(
                os.path.join(temp_project_path,
                             StorageInterface.TVB_PROJECT_FILE), project_path)
            # Now import project operations with their results
            self.import_list_of_operations(project, temp_project_path)
            # Import images and move them from temp into target
            self._store_imported_images(project, temp_project_path,
                                        project.name)
            if StorageInterface.encryption_enabled():
                self.storage_interface.remove_project(project, True)

    def _load_datatypes_from_operation_folder(self, src_op_path,
                                              operation_entity,
                                              datatype_group):
        """
        Loads datatypes from operation folder
        :returns: Datatype entities as dict {original_path: Dt instance}
        """
        all_datatypes = {}
        for file_name in os.listdir(src_op_path):
            if self.storage_interface.ends_with_tvb_storage_file_extension(
                    file_name):
                h5_file = os.path.join(src_op_path, file_name)
                try:
                    file_update_manager = FilesUpdateManager()
                    file_update_manager.upgrade_file(h5_file)
                    datatype = self.load_datatype_from_file(
                        h5_file, operation_entity.id, datatype_group,
                        operation_entity.fk_launched_in)
                    all_datatypes[h5_file] = datatype

                except IncompatibleFileManagerException:
                    os.remove(h5_file)
                    self.logger.warning(
                        "Incompatible H5 file will be ignored: %s" % h5_file)
                    self.logger.exception("Incompatibility details ...")
        return all_datatypes

    @staticmethod
    def check_import_references(file_path, datatype):
        h5_class = H5File.h5_class_from_file(file_path)
        reference_list = h5_class(file_path).gather_references()

        for _, reference_gid in reference_list:
            if not reference_gid:
                continue

            ref_index = load.load_entity_by_gid(reference_gid)
            if ref_index is None:
                os.remove(file_path)
                dao.remove_entity(datatype.__class__, datatype.id)
                raise MissingReferenceException(
                    'Imported file depends on datatypes that do not exist. Please upload '
                    'those first!')

    def _store_or_link_burst_config(self, burst_config, bc_path, project_id):
        bc_already_in_tvb = dao.get_generic_entity(BurstConfiguration,
                                                   burst_config.gid, 'gid')
        if len(bc_already_in_tvb) == 0:
            self.store_datatype(burst_config, bc_path)
            return 1
        return 0

    def store_or_link_datatype(self, datatype, dt_path, project_id):
        self.check_import_references(dt_path, datatype)
        stored_dt_count = 0
        datatype_already_in_tvb = load.load_entity_by_gid(datatype.gid)
        if not datatype_already_in_tvb:
            self.store_datatype(datatype, dt_path)
            stored_dt_count = 1
        elif datatype_already_in_tvb.parent_operation.project.id != project_id:
            AlgorithmService.create_link(datatype_already_in_tvb.id,
                                         project_id)
            if datatype_already_in_tvb.fk_datatype_group:
                AlgorithmService.create_link(
                    datatype_already_in_tvb.fk_datatype_group, project_id)
        return stored_dt_count

    def _store_imported_datatypes_in_db(self, project, all_datatypes):
        # type: (Project, dict) -> int
        sorted_dts = sorted(
            all_datatypes.items(),
            key=lambda dt_item: dt_item[1].create_date or datetime.now())
        count = 0
        for dt_path, datatype in sorted_dts:
            count += self.store_or_link_datatype(datatype, dt_path, project.id)
        return count

    def _store_imported_images(self, project, temp_project_path, project_name):
        """
        Import all images from project
        """
        images_root = os.path.join(temp_project_path,
                                   StorageInterface.IMAGES_FOLDER)
        target_images_path = self.storage_interface.get_images_folder(
            project_name)
        for root, _, files in os.walk(images_root):
            for metadata_file in files:
                if self.storage_interface.ends_with_tvb_file_extension(
                        metadata_file):
                    self._import_image(root, metadata_file, project.id,
                                       target_images_path)

    @staticmethod
    def _populate_view_model2adapter():
        if len(VIEW_MODEL2ADAPTER) > 0:
            return VIEW_MODEL2ADAPTER
        view_model2adapter = {}
        algos = dao.get_all_algorithms()
        for algo in algos:
            adapter = ABCAdapter.build_adapter(algo)
            view_model_class = adapter.get_view_model_class()
            view_model2adapter[view_model_class] = algo
        return view_model2adapter

    def _retrieve_operations_in_order(self,
                                      project,
                                      import_path,
                                      importer_operation_id=None):
        # type: (Project, str, int) -> list[Operation2ImportData]
        retrieved_operations = []

        for root, _, files in os.walk(import_path):
            if OPERATION_XML in files:
                # Previous Operation format for uploading previous versions of projects
                operation_file_path = os.path.join(root, OPERATION_XML)
                operation, operation_xml_parameters, _ = self.build_operation_from_file(
                    project, operation_file_path)
                operation.import_file = operation_file_path
                self.logger.debug("Found operation in old XML format: " +
                                  str(operation))
                retrieved_operations.append(
                    Operation2ImportData(
                        operation,
                        root,
                        info_from_xml=operation_xml_parameters))

            else:
                # We strive for the new format with ViewModelH5
                main_view_model = None
                dt_paths = []
                all_view_model_files = []
                for file in files:
                    if self.storage_interface.ends_with_tvb_storage_file_extension(
                            file):
                        h5_file = os.path.join(root, file)
                        try:
                            h5_class = H5File.h5_class_from_file(h5_file)
                            if h5_class is ViewModelH5:
                                all_view_model_files.append(h5_file)
                                if not main_view_model:
                                    view_model = h5.load_view_model_from_file(
                                        h5_file)
                                    if type(
                                            view_model
                                    ) in self.view_model2adapter.keys():
                                        main_view_model = view_model
                            else:
                                file_update_manager = FilesUpdateManager()
                                file_update_manager.upgrade_file(h5_file)
                                dt_paths.append(h5_file)
                        except Exception:
                            self.logger.warning(
                                "Unreadable H5 file will be ignored: %s" %
                                h5_file)

                if main_view_model is not None:
                    alg = self.view_model2adapter[type(main_view_model)]
                    op_group_id = None
                    if main_view_model.operation_group_gid:
                        op_group = dao.get_operationgroup_by_gid(
                            main_view_model.operation_group_gid.hex)
                        if not op_group:
                            op_group = OperationGroup(
                                project.id,
                                ranges=json.loads(main_view_model.ranges),
                                gid=main_view_model.operation_group_gid.hex)
                            op_group = dao.store_entity(op_group)
                        op_group_id = op_group.id
                    operation = Operation(
                        main_view_model.gid.hex,
                        project.fk_admin,
                        project.id,
                        alg.id,
                        status=STATUS_FINISHED,
                        user_group=main_view_model.generic_attributes.
                        operation_tag,
                        start_date=datetime.now(),
                        completion_date=datetime.now(),
                        op_group_id=op_group_id,
                        range_values=main_view_model.range_values)
                    operation.create_date = main_view_model.create_date
                    operation.visible = main_view_model.generic_attributes.visible
                    self.logger.debug(
                        "Found main ViewModel to create operation for it: " +
                        str(operation))

                    retrieved_operations.append(
                        Operation2ImportData(operation, root, main_view_model,
                                             dt_paths, all_view_model_files))

                elif len(dt_paths) > 0:
                    alg = dao.get_algorithm_by_module(TVB_IMPORTER_MODULE,
                                                      TVB_IMPORTER_CLASS)
                    default_adapter = ABCAdapter.build_adapter(alg)
                    view_model = default_adapter.get_view_model_class()()
                    view_model.data_file = dt_paths[0]
                    vm_path = h5.store_view_model(view_model, root)
                    all_view_model_files.append(vm_path)
                    operation = Operation(view_model.gid.hex,
                                          project.fk_admin,
                                          project.id,
                                          alg.id,
                                          status=STATUS_FINISHED,
                                          start_date=datetime.now(),
                                          completion_date=datetime.now())
                    operation.create_date = datetime.min
                    self.logger.debug(
                        "Found no ViewModel in folder, so we default to " +
                        str(operation))

                    if importer_operation_id:
                        operation.id = importer_operation_id

                    retrieved_operations.append(
                        Operation2ImportData(operation, root, view_model,
                                             dt_paths, all_view_model_files,
                                             True))

        return sorted(retrieved_operations,
                      key=lambda op_data: op_data.order_field)

    def create_view_model(self,
                          operation_entity,
                          operation_data,
                          new_op_folder,
                          generic_attributes=None,
                          add_params=None):
        view_model = self._get_new_form_view_model(
            operation_entity, operation_data.info_from_xml)
        if add_params is not None:
            for element in add_params:
                key_attr = getattr(view_model, element[0])
                setattr(key_attr, element[1], element[2])

        view_model.range_values = operation_entity.range_values
        op_group = dao.get_operationgroup_by_id(
            operation_entity.fk_operation_group)
        if op_group:
            view_model.operation_group_gid = uuid.UUID(op_group.gid)
            view_model.ranges = json.dumps(op_group.range_references)
            view_model.is_metric_operation = 'DatatypeMeasure' in op_group.name

        if generic_attributes is not None:
            view_model.generic_attributes = generic_attributes
        view_model.generic_attributes.operation_tag = operation_entity.user_group

        h5.store_view_model(view_model, new_op_folder)
        view_model_disk_size = StorageInterface.compute_recursive_h5_disk_usage(
            new_op_folder)
        operation_entity.view_model_disk_size = view_model_disk_size
        operation_entity.view_model_gid = view_model.gid.hex
        dao.store_entity(operation_entity)
        return view_model

    def import_list_of_operations(self,
                                  project,
                                  import_path,
                                  is_group=False,
                                  importer_operation_id=None):
        """
        This method scans provided folder and identify all operations that needs to be imported
        """
        all_dts_count = 0
        all_stored_dts_count = 0
        imported_operations = []
        ordered_operations = self._retrieve_operations_in_order(
            project, import_path, None if is_group else importer_operation_id)

        if is_group and len(ordered_operations) > 0:
            first_op = dao.get_operation_by_id(importer_operation_id)
            vm_path = h5.determine_filepath(first_op.view_model_gid,
                                            os.path.dirname(import_path))
            os.remove(vm_path)

            ordered_operations[0].operation.id = importer_operation_id

        for operation_data in ordered_operations:
            if operation_data.is_old_form:
                operation_entity, datatype_group = self.import_operation(
                    operation_data.operation)
                new_op_folder = self.storage_interface.get_project_folder(
                    project.name, str(operation_entity.id))

                try:
                    operation_datatypes = self._load_datatypes_from_operation_folder(
                        operation_data.operation_folder, operation_entity,
                        datatype_group)
                    # Create and store view_model from operation
                    self.create_view_model(operation_entity, operation_data,
                                           new_op_folder)

                    self._store_imported_datatypes_in_db(
                        project, operation_datatypes)
                    imported_operations.append(operation_entity)
                except MissingReferenceException:
                    operation_entity.status = STATUS_ERROR
                    dao.store_entity(operation_entity)

            elif operation_data.main_view_model is not None:
                operation_data.operation.create_date = datetime.now()
                operation_data.operation.start_date = datetime.now()
                operation_data.operation.completion_date = datetime.now()

                do_merge = False
                if importer_operation_id:
                    do_merge = True
                operation_entity = dao.store_entity(operation_data.operation,
                                                    merge=do_merge)
                dt_group = None
                op_group = dao.get_operationgroup_by_id(
                    operation_entity.fk_operation_group)
                if op_group:
                    dt_group = dao.get_datatypegroup_by_op_group_id(
                        op_group.id)
                    if not dt_group:
                        first_op = dao.get_operations_in_group(
                            op_group.id, only_first_operation=True)
                        dt_group = DataTypeGroup(
                            op_group,
                            operation_id=first_op.id,
                            state=DEFAULTDATASTATE_INTERMEDIATE)
                        dt_group = dao.store_entity(dt_group)
                # Store the DataTypes in db
                dts = {}
                all_dts_count += len(operation_data.dt_paths)
                for dt_path in operation_data.dt_paths:
                    dt = self.load_datatype_from_file(dt_path,
                                                      operation_entity.id,
                                                      dt_group, project.id)
                    if isinstance(dt, BurstConfiguration):
                        if op_group:
                            dt.fk_operation_group = op_group.id
                        all_stored_dts_count += self._store_or_link_burst_config(
                            dt, dt_path, project.id)
                    else:
                        dts[dt_path] = dt
                        if op_group:
                            op_group.fill_operationgroup_name(dt.type)
                            dao.store_entity(op_group)
                try:
                    stored_dts_count = self._store_imported_datatypes_in_db(
                        project, dts)
                    all_stored_dts_count += stored_dts_count

                    if operation_data.main_view_model.is_metric_operation:
                        self._update_burst_metric(operation_entity)

                    imported_operations.append(operation_entity)
                    new_op_folder = self.storage_interface.get_project_folder(
                        project.name, str(operation_entity.id))
                    view_model_disk_size = 0
                    for h5_file in operation_data.all_view_model_files:
                        view_model_disk_size += StorageInterface.compute_size_on_disk(
                            h5_file)
                        shutil.move(h5_file, new_op_folder)
                    operation_entity.view_model_disk_size = view_model_disk_size
                    dao.store_entity(operation_entity)
                except MissingReferenceException as excep:
                    self.storage_interface.remove_operation_data(
                        project.name, operation_entity.id)
                    operation_entity.fk_operation_group = None
                    dao.store_entity(operation_entity)
                    dao.remove_entity(DataTypeGroup, dt_group.id)
                    raise excep
            else:
                self.logger.warning(
                    "Folder %s will be ignored, as we could not find a serialized "
                    "operation or DTs inside!" %
                    operation_data.operation_folder)

            # We want importer_operation_id to be kept just for the first operation (the first iteration)
            if is_group:
                importer_operation_id = None

        self._update_dt_groups(project.id)
        self._update_burst_configurations(project.id)
        return imported_operations, all_dts_count, all_stored_dts_count

    @staticmethod
    def _get_new_form_view_model(operation, xml_parameters):
        # type (Operation) -> ViewModel
        algo = dao.get_algorithm_by_id(operation.fk_from_algo)
        ad = ABCAdapter.build_adapter(algo)
        view_model = ad.get_view_model_class()()

        if xml_parameters:
            declarative_attrs = type(view_model).declarative_attrs

            if isinstance(xml_parameters, str):
                xml_parameters = json.loads(xml_parameters)
            for param in xml_parameters:
                new_param_name = param
                if param != '' and param[0] == "_":
                    new_param_name = param[1:]
                new_param_name = new_param_name.lower()
                if new_param_name in declarative_attrs:
                    try:
                        setattr(view_model, new_param_name,
                                xml_parameters[param])
                    except (TraitTypeError, TraitAttributeError):
                        pass
        return view_model

    def _import_image(self, src_folder, metadata_file, project_id,
                      target_images_path):
        """
        Create and store a image entity.
        """
        figure_dict = StorageInterface().read_metadata_from_xml(
            os.path.join(src_folder, metadata_file))
        actual_figure = os.path.join(
            src_folder,
            os.path.split(figure_dict['file_path'])[1])
        if not os.path.exists(actual_figure):
            self.logger.warning("Expected to find image path %s .Skipping" %
                                actual_figure)
            return
        figure_dict['fk_user_id'] = self.user_id
        figure_dict['fk_project_id'] = project_id
        figure_entity = manager_of_class(ResultFigure).new_instance()
        figure_entity = figure_entity.from_dict(figure_dict)
        stored_entity = dao.store_entity(figure_entity)

        # Update image meta-data with the new details after import
        figure = dao.load_figure(stored_entity.id)
        shutil.move(actual_figure, target_images_path)
        self.logger.debug("Store imported figure")
        _, meta_data = figure.to_dict()
        self.storage_interface.write_image_metadata(figure, meta_data)

    def load_datatype_from_file(self,
                                current_file,
                                op_id,
                                datatype_group=None,
                                current_project_id=None):
        # type: (str, int, DataTypeGroup, int) -> HasTraitsIndex
        """
        Creates an instance of datatype from storage / H5 file
        :returns: DatatypeIndex
        """
        self.logger.debug("Loading DataType from file: %s" % current_file)
        h5_class = H5File.h5_class_from_file(current_file)

        if h5_class is BurstConfigurationH5:
            if current_project_id is None:
                op_entity = dao.get_operationgroup_by_id(op_id)
                current_project_id = op_entity.fk_launched_in
            h5_file = BurstConfigurationH5(current_file)
            burst = BurstConfiguration(current_project_id)
            burst.fk_simulation = op_id
            h5_file.load_into(burst)
            result = burst
        else:
            datatype, generic_attributes = h5.load_with_links(current_file)

            already_existing_datatype = h5.load_entity_by_gid(datatype.gid)
            if datatype_group is not None and already_existing_datatype is not None:
                raise DatatypeGroupImportException(
                    "The datatype group that you are trying to import"
                    " already exists!")
            index_class = h5.REGISTRY.get_index_for_datatype(
                datatype.__class__)
            datatype_index = index_class()
            datatype_index.fill_from_has_traits(datatype)
            datatype_index.fill_from_generic_attributes(generic_attributes)

            if datatype_group is not None and hasattr(datatype_index, 'fk_source_gid') and \
                    datatype_index.fk_source_gid is not None:
                ts = h5.load_entity_by_gid(datatype_index.fk_source_gid)

                if ts is None:
                    op = dao.get_operations_in_group(
                        datatype_group.fk_operation_group,
                        only_first_operation=True)
                    op.fk_operation_group = None
                    dao.store_entity(op)
                    dao.remove_entity(OperationGroup,
                                      datatype_group.fk_operation_group)
                    dao.remove_entity(DataTypeGroup, datatype_group.id)
                    raise DatatypeGroupImportException(
                        "Please import the time series group before importing the"
                        " datatype measure group!")

            # Add all the required attributes
            if datatype_group:
                datatype_index.fk_datatype_group = datatype_group.id
                if len(datatype_group.subject) == 0:
                    datatype_group.subject = datatype_index.subject
                    dao.store_entity(datatype_group)
            datatype_index.fk_from_operation = op_id

            associated_file = h5.path_for_stored_index(datatype_index)
            if os.path.exists(associated_file):
                datatype_index.disk_size = StorageInterface.compute_size_on_disk(
                    associated_file)
            result = datatype_index

        return result

    def store_datatype(self, datatype, current_file=None):
        """This method stores data type into DB"""
        try:
            self.logger.debug("Store datatype: %s with Gid: %s" %
                              (datatype.__class__.__name__, datatype.gid))
            # Now move storage file into correct folder if necessary
            if current_file is not None:
                final_path = h5.path_for_stored_index(datatype)
                if final_path != current_file:
                    shutil.move(current_file, final_path)
            stored_entry = load.load_entity_by_gid(datatype.gid)
            if not stored_entry:
                stored_entry = dao.store_entity(datatype)

            return stored_entry
        except MissingDataSetException as e:
            self.logger.exception(e)
            error_msg = "Datatype %s has missing data and could not be imported properly." % (
                datatype, )
            raise ImportException(error_msg)
        except IntegrityError as excep:
            self.logger.exception(excep)
            error_msg = "Could not import data with gid: %s. There is already a one with " \
                        "the same name or gid." % datatype.gid
            raise ImportException(error_msg)

    def __populate_project(self, project_path):
        """
        Create and store a Project entity.
        """
        self.logger.debug("Creating project from path: %s" % project_path)
        project_dict = self.storage_interface.read_project_metadata(
            project_path)

        project_entity = manager_of_class(Project).new_instance()
        project_entity = project_entity.from_dict(project_dict, self.user_id)

        try:
            self.logger.debug("Storing imported project")
            return dao.store_entity(project_entity)
        except IntegrityError as excep:
            self.logger.exception(excep)
            error_msg = (
                "Could not import project: %s with gid: %s. There is already a "
                "project with the same name or gid.") % (project_entity.name,
                                                         project_entity.gid)
            raise ImportException(error_msg)

    def build_operation_from_file(self, project, operation_file):
        """
        Create Operation entity from metadata file.
        """
        operation_dict = StorageInterface().read_metadata_from_xml(
            operation_file)
        operation_entity = manager_of_class(Operation).new_instance()
        return operation_entity.from_dict(operation_dict, dao, self.user_id,
                                          project.gid)

    @staticmethod
    def import_operation(operation_entity, migration=False):
        """
        Store a Operation entity.
        """
        do_merge = False
        if operation_entity.id:
            do_merge = True
        operation_entity = dao.store_entity(operation_entity, merge=do_merge)
        operation_group_id = operation_entity.fk_operation_group
        datatype_group = None

        if operation_group_id is not None:
            datatype_group = dao.get_datatypegroup_by_op_group_id(
                operation_group_id)

            if datatype_group is None and migration is False:
                # If no dataType group present for current op. group, create it.
                operation_group = dao.get_operationgroup_by_id(
                    operation_group_id)
                datatype_group = DataTypeGroup(
                    operation_group, operation_id=operation_entity.id)
                datatype_group.state = UploadAlgorithmCategoryConfig.defaultdatastate
                datatype_group = dao.store_entity(datatype_group)

        return operation_entity, datatype_group

    def import_simulator_configuration_zip(self, zip_file):
        # Now compute the name of the folder where to explode uploaded ZIP file
        temp_folder = self._compute_unpack_path()
        uq_file_name = temp_folder + ".zip"

        if isinstance(zip_file, (FieldStorage, Part)):
            if not zip_file.file:
                raise ServicesBaseException(
                    "Could not process the given ZIP file...")

            with open(uq_file_name, 'wb') as file_obj:
                self.storage_interface.copy_file(zip_file.file, file_obj)
        else:
            shutil.copy2(zip_file, uq_file_name)

        try:
            self.storage_interface.unpack_zip(uq_file_name, temp_folder)
            return temp_folder
        except FileStructureException as excep:
            raise ServicesBaseException(
                "Could not process the given ZIP file..." + str(excep))

    @staticmethod
    def _update_burst_metric(operation_entity):
        burst_config = dao.get_burst_for_operation_id(operation_entity.id)
        if burst_config and burst_config.ranges:
            if burst_config.fk_metric_operation_group is None:
                burst_config.fk_metric_operation_group = operation_entity.fk_operation_group
            dao.store_entity(burst_config)

    @staticmethod
    def _update_dt_groups(project_id):
        dt_groups = dao.get_datatypegroup_for_project(project_id)
        for dt_group in dt_groups:
            dt_group.count_results = dao.count_datatypes_in_group(dt_group.id)
            dts_in_group = dao.get_datatypes_from_datatype_group(dt_group.id)
            if dts_in_group:
                dt_group.fk_parent_burst = dts_in_group[0].fk_parent_burst
            dao.store_entity(dt_group)

    @staticmethod
    def _update_burst_configurations(project_id):
        burst_configs = dao.get_bursts_for_project(project_id)
        for burst_config in burst_configs:
            burst_config.datatypes_number = dao.count_datatypes_in_burst(
                burst_config.gid)
            dao.store_entity(burst_config)
Пример #16
0
class ExportManager(object):
    """
    This class provides basic methods for exporting data types of projects in different formats.
    """
    all_exporters = {}  # Dictionary containing all available exporters
    logger = get_logger(__name__)

    def __init__(self):
        # Here we register all available data type exporters
        # If new exporters supported, they should be added here
        self._register_exporter(TVBExporter())
        self._register_exporter(TVBLinkedExporter())
        self.storage_interface = StorageInterface()

    def _register_exporter(self, exporter):
        """
        This method register into an internal format available exporters.
        :param exporter: Instance of a data type exporter (extends ABCExporter)
        """
        if exporter is not None:
            self.all_exporters[exporter.__class__.__name__] = exporter

    def get_exporters_for_data(self, data):
        """
        Get available exporters for current data type.
        :returns: a dictionary with the {exporter_id : label}
        """
        if data is None:
            raise InvalidExportDataException("Could not detect exporters for null data")

        self.logger.debug("Trying to determine exporters valid for %s" % data.type)
        results = {}

        # No exporter for None data
        if data is None:
            return results

        for exporterId in self.all_exporters.keys():
            exporter = self.all_exporters[exporterId]
            if exporter.accepts(data):
                results[exporterId] = exporter.get_label()

        return results

    def export_data(self, data, exporter_id, project, user_public_key=None):
        """
        Export provided data using given exporter
        :param data: data type to be exported
        :param exporter_id: identifier of the exporter to be used
        :param project: project that contains data to be exported
        :param user_public_key: public key file used for encrypting data before exporting

        :returns: a tuple with the following elements
            1. name of the file to be shown to user
            2. full path of the export file (available for download)
            3. boolean which specify if file can be deleted after download
        """
        if data is None:
            raise InvalidExportDataException("Could not export null data. Please select data to be exported")

        if exporter_id is None:
            raise ExportException("Please select the exporter to be used for this operation")

        if exporter_id not in self.all_exporters:
            raise ExportException("Provided exporter identifier is not a valid one")

        exporter = self.all_exporters[exporter_id]

        if user_public_key is not None:
            public_key_path, encryption_password = self.storage_interface.prepare_encryption(project.name)
            if isinstance(user_public_key, (FieldStorage, Part)):
                with open(public_key_path, 'wb') as file_obj:
                    self.storage_interface.copy_file(user_public_key.file, file_obj)
            else:
                shutil.copy2(user_public_key, public_key_path)

        else:
            public_key_path, encryption_password = None, None

        if project is None:
            raise ExportException("Please provide the project where data files are stored")

        # Now we start the real export
        if not exporter.accepts(data):
            raise InvalidExportDataException("Current data can not be exported by specified exporter")

        # Now compute and create folder where to store exported data
        # This will imply to generate a folder which is unique for each export
        export_data = None
        try:
            self.logger.debug("Start export of data: %s" % data.type)
            export_data = exporter.export(data, project, public_key_path, encryption_password)
        except Exception:
            pass

        return export_data

    @staticmethod
    def _get_paths_of_linked_datatypes(project):
        linked_paths = ProjectService().get_linked_datatypes_storage_path(project)

        if not linked_paths:
            # do not export an empty operation
            return None, None

        # Make an import operation which will contain links to other projects
        algo = dao.get_algorithm_by_module(TVB_IMPORTER_MODULE, TVB_IMPORTER_CLASS)
        op = model_operation.Operation(None, None, project.id, algo.id)
        op.project = project
        op.algorithm = algo
        op.id = 'links-to-external-projects'
        op.start_now()
        op.mark_complete(model_operation.STATUS_FINISHED)

        return linked_paths, op

    def export_project(self, project):
        """
        Given a project root and the TVB storage_path, create a ZIP
        ready for export.
        :param project: project object which identifies project to be exported
        """
        if project is None:
            raise ExportException("Please provide project to be exported")

        folders_to_exclude = self._get_op_with_errors(project.id)
        linked_paths, op = self._get_paths_of_linked_datatypes(project)

        result_path = self.storage_interface.export_project(project, folders_to_exclude, linked_paths, op)

        return result_path

    @staticmethod
    def _get_op_with_errors(project_id):
        """
        Get the operation folders with error base name as list.
        """
        operations = dao.get_operations_with_error_in_project(project_id)
        op_with_errors = []
        for op in operations:
            op_with_errors.append(op.id)
        return op_with_errors

    def export_simulator_configuration(self, burst_id):
        burst = dao.get_burst_by_id(burst_id)
        if burst is None:
            raise InvalidExportDataException("Could not find burst with ID " + str(burst_id))

        op_folder = self.storage_interface.get_project_folder(burst.project.name, str(burst.fk_simulation))

        all_view_model_paths, all_datatype_paths = h5.gather_references_of_view_model(burst.simulator_gid, op_folder)

        burst_path = h5.determine_filepath(burst.gid, op_folder)
        all_view_model_paths.append(burst_path)

        zip_filename = ABCExporter.get_export_file_name(burst, self.storage_interface.TVB_ZIP_FILE_EXTENSION)
        result_path = self.storage_interface.export_simulator_configuration(burst, all_view_model_paths,
                                                                            all_datatype_paths, zip_filename)
        return result_path