示例#1
0
    def run(self):
        """Fits the composite model."""
        with per_model_logging_context(self._output_path):
            self._model.set_problem_data(self._problem_data)

            if self.recalculate:
                if os.path.exists(self._output_path):
                    list(
                        map(
                            os.remove,
                            glob.glob(os.path.join(self._output_path,
                                                   '*.nii*'))))
            else:
                if model_output_exists(self._model, self._output_folder):
                    maps = get_all_image_data(self._output_path)
                    self._logger.info('Not recalculating {} model'.format(
                        self._model.name))
                    return create_roi(maps, self._problem_data.mask)

            if not os.path.exists(self._output_path):
                os.makedirs(self._output_path)

            with self._logging():
                results = self._processing_strategy.run(
                    self._model, self._problem_data, self._output_path,
                    self.recalculate,
                    SimpleModelProcessingWorkerGenerator(
                        lambda *args: FittingProcessingWorker(
                            self._optimizer, *args)))
                self._write_protocol()

        return results
示例#2
0
        def __call__(self, subject_info):
            logger.info('Going to process subject {}, ({} of {}, we are at {:.2%})'.format(
                subject_info.subject_id, self._index_counter + 1, total_nmr_subjects,
                self._index_counter / total_nmr_subjects))
            self._index_counter += 1

            output_dir = os.path.join(output_folder, subject_info.subject_id)

            if all(model_output_exists(model, output_dir) for model in models_to_fit) and not recalculate:
                logger.info('Skipping subject {0}, output exists'.format(subject_info.subject_id))
                return

            logger.info('Loading the data (DWI, mask and protocol) of subject {0}'.format(subject_info.subject_id))
            input_data = subject_info.get_input_data(use_gradient_deviations)

            with timer(subject_info.subject_id):
                for model in models_to_fit:
                    logger.info('Going to fit model {0} on subject {1}'.format(model, subject_info.subject_id))

                    try:
                        model_fit = ModelFit(model,
                                             input_data,
                                             output_dir,
                                             recalculate=recalculate,
                                             only_recalculate_last=True,
                                             cl_device_ind=cl_device_ind,
                                             double_precision=double_precision,
                                             tmp_results_dir=tmp_results_dir)
                        model_fit.run()
                    except InsufficientProtocolError as ex:
                        logger.info('Could not fit model {0} on subject {1} '
                                    'due to protocol problems. {2}'.format(model, subject_info.subject_id, ex))
                    else:
                        logger.info('Done fitting model {0} on subject {1}'.format(model, subject_info.subject_id))
示例#3
0
def fit_composite_model(model, input_data, output_folder, method, tmp_results_dir,
                        recalculate=False, cascade_names=None, optimizer_options=None):
    """Fits the composite model and returns the results as ROI lists per map.

     Args:
        model (:class:`~mdt.models.composite.DMRICompositeModel`): An implementation of an composite model
            that contains the model we want to optimize.
        input_data (:class:`~mdt.utils.MRIInputData`): The input data object for the model.
        output_folder (string): The path to the folder where to place the output.
            The resulting maps are placed in a subdirectory (named after the model name) in this output folder.
        method (str): The optimization routine to use.
        tmp_results_dir (str): the main directory to use for the temporary results
        recalculate (boolean): If we want to recalculate the results if they are already present.
        cascade_names (list): the list of cascade names, meant for logging
        optimizer_options (dict): the additional optimization options
    """
    logger = logging.getLogger(__name__)
    output_path = os.path.join(output_folder, model.name)

    if not model.is_input_data_sufficient(input_data):
        raise InsufficientProtocolError(
            'The given protocol is insufficient for this model. '
            'The reported errors where: {}'.format(model.get_input_data_problems(input_data)))

    if not recalculate and model_output_exists(model, output_folder):
        maps = get_all_nifti_data(output_path)
        logger.info('Not recalculating {} model'.format(model.name))
        return create_roi(maps, input_data.mask)

    with per_model_logging_context(output_path):
        logger.info('Using MDT version {}'.format(__version__))
        logger.info('Preparing for model {0}'.format(model.name))
        logger.info('Current cascade: {0}'.format(cascade_names))

        model.set_input_data(input_data)

        if recalculate:
            if os.path.exists(output_path):
                list(map(os.remove, glob.glob(os.path.join(output_path, '*.nii*'))))
                if os.path.exists(os.path.join(output_path + 'covariances')):
                    shutil.rmtree(os.path.join(output_path + 'covariances'))

        if not os.path.exists(output_path):
            os.makedirs(output_path)

        with _model_fit_logging(logger, model.name, model.get_free_param_names()):
            tmp_dir = get_full_tmp_results_path(output_path, tmp_results_dir)
            logger.info('Saving temporary results in {}.'.format(tmp_dir))

            worker = FittingProcessor(method, model, input_data.mask,
                                      input_data.nifti_header, output_path,
                                      tmp_dir, recalculate, optimizer_options=optimizer_options)

            processing_strategy = get_processing_strategy('optimization')
            return processing_strategy.process(worker)
示例#4
0
    def __call__(self, subject_info):
        """Run the batch fitting on the given subject.

        This is a module level function to allow for python multiprocessing to work.

        Args:
            subject_info (SubjectInfo): the subject information
        """
        output_dir = subject_info.output_dir

        if all(
                model_output_exists(model, output_dir)
                for model in self._models_to_fit) and not self._recalculate:
            self._logger.info('Skipping subject {0}, output exists'.format(
                subject_info.subject_id))
            return

        self._logger.info(
            'Loading the data (DWI, mask and protocol) of subject {0}'.format(
                subject_info.subject_id))
        problem_data = subject_info.get_problem_data()

        with self._timer(subject_info.subject_id):
            for model in self._models_to_fit:
                self._logger.info(
                    'Going to fit model {0} on subject {1}'.format(
                        model, subject_info.subject_id))
                try:
                    model_fit = ModelFit(
                        model,
                        problem_data,
                        output_dir,
                        recalculate=self._recalculate,
                        only_recalculate_last=True,
                        cascade_subdir=self._cascade_subdir,
                        cl_device_ind=self._cl_device_ind,
                        double_precision=self._double_precision,
                        tmp_results_dir=self._tmp_results_dir)
                    model_fit.run()
                except InsufficientProtocolError as ex:
                    self._logger.info('Could not fit model {0} on subject {1} '
                                      'due to protocol problems. {2}'.format(
                                          model, subject_info.subject_id, ex))
                else:
                    self._logger.info(
                        'Done fitting model {0} on subject {1}'.format(
                            model, subject_info.subject_id))
示例#5
0
    def run(self):
        """Fits the composite model and returns the results as ROI lists per map."""
        if not self.recalculate and model_output_exists(
                self._model, self._output_folder):
            maps = get_all_nifti_data(self._output_path)
            self._logger.info('Not recalculating {} model'.format(
                self._model.name))
            return create_roi(maps, self._input_data.mask)

        with per_model_logging_context(self._output_path):
            self._logger.info('Using MDT version {}'.format(__version__))
            self._logger.info('Preparing for model {0}'.format(
                self._model.name))
            self._logger.info('Current cascade: {0}'.format(
                self._cascade_names))

            self._model.set_input_data(self._input_data)

            if self.recalculate:
                if os.path.exists(self._output_path):
                    list(
                        map(
                            os.remove,
                            glob.glob(os.path.join(self._output_path,
                                                   '*.nii*'))))

            if not os.path.exists(self._output_path):
                os.makedirs(self._output_path)

            with self._logging():
                tmp_dir = get_full_tmp_results_path(self._output_path,
                                                    self._tmp_results_dir)
                self._logger.info(
                    'Saving temporary results in {}.'.format(tmp_dir))

                worker = FittingProcessor(self._optimizer, self._model,
                                          self._input_data.mask,
                                          self._input_data.nifti_header,
                                          self._output_path, tmp_dir,
                                          self.recalculate)

                processing_strategy = get_processing_strategy('optimization')
                results = processing_strategy.process(worker)

                self._write_protocol(self._model.get_input_data().protocol)

        return results
示例#6
0
        def __call__(self, subject_info):
            logger.info(
                'Going to process subject {}, ({} of {}, we are at {:.2%})'.
                format(subject_info.subject_id, self._index_counter + 1,
                       total_nmr_subjects,
                       self._index_counter / total_nmr_subjects))
            self._index_counter += 1

            output_dir = os.path.join(output_folder, subject_info.subject_id)

            if all(
                    model_output_exists(model, output_dir)
                    for model in models_to_fit) and not recalculate:
                logger.info('Skipping subject {0}, output exists'.format(
                    subject_info.subject_id))
                return

            logger.info(
                'Loading the data (DWI, mask and protocol) of subject {0}'.
                format(subject_info.subject_id))
            input_data = subject_info.get_input_data(use_gradient_deviations)

            with timer(subject_info.subject_id):
                for model_name in models_to_fit:
                    if isinstance(model_name, str):
                        model_instance = get_model(model_name)()
                    else:
                        model_instance = model_name
                    if isinstance(model_instance, DMRICascadeModelInterface):
                        warnings.warn(
                            dedent('''
                        
                            Fitting cascade models has been deprecated, MDT now by default tries to find a suitable initialization point for your specified model. 
                        
                            As an example, instead of specifying 'NODDI (Cascade)', now just specify 'NODDI' and MDT will do its best to get the best model fit possible.
                        '''), FutureWarning)

                    logger.info('Going to fit model {0} on subject {1}'.format(
                        model_name, subject_info.subject_id))

                    try:
                        if not isinstance(model_instance,
                                          DMRICascadeModelInterface):
                            inits = get_optimization_inits(
                                model_name,
                                input_data,
                                output_dir,
                                cl_device_ind=cl_device_ind)
                        else:
                            inits = {}

                        model_fit = ModelFit(
                            model_name,
                            input_data,
                            output_dir,
                            recalculate=recalculate,
                            cl_device_ind=cl_device_ind,
                            double_precision=double_precision,
                            tmp_results_dir=tmp_results_dir,
                            initialization_data={'inits': inits})
                        model_fit.run()
                    except InsufficientProtocolError as ex:
                        logger.info('Could not fit model {0} on subject {1} '
                                    'due to protocol problems. {2}'.format(
                                        model_name, subject_info.subject_id,
                                        ex))
                    else:
                        logger.info(
                            'Done fitting model {0} on subject {1}'.format(
                                model_name, subject_info.subject_id))