Exemple #1
0
 class AutoProtocolParameter(
         method_binding_meta(template, ProtocolParameter)):
     def __init__(self, nickname=None):
         super(AutoProtocolParameter,
               self).__init__(data_type,
                              nickname or template.name,
                              value=template.value)
Exemple #2
0
        class AutoCreatedCascadeModel(
                method_binding_meta(template, SimpleCascadeModel)):
            def __init__(self, *args):
                models = []
                for model_def in template.models:
                    if isinstance(model_def, six.string_types):
                        models.append(mdt.get_model(model_def)())
                    else:
                        models.append(
                            mdt.get_model(model_def[0])(model_def[1]))

                new_args = [deepcopy(template.name), models]
                for ind, arg in args:
                    new_args[ind] = arg
                super(AutoCreatedCascadeModel, self).__init__(*new_args)

            def _prepare_model(self, iteration_position, model,
                               output_previous, output_all_previous):
                super(AutoCreatedCascadeModel,
                      self)._prepare_model(iteration_position, model,
                                           output_previous,
                                           output_all_previous)

                def parse_value(v):
                    if isinstance(v, six.string_types):
                        return output_previous[v]
                    elif hasattr(v, '__call__'):
                        return v(output_previous, output_all_previous)
                    return v

                def apply_func(template_element, cb):
                    items_to_apply = dict(template_element.get(model.name, {}))
                    items_to_apply.update(
                        dict(template_element.get(iteration_position, {})))

                    for key, value in items_to_apply.items():
                        cb(key, parse_value(value))

                apply_func(template.inits,
                           lambda name, value: model.init(name, value))
                apply_func(template.fixes,
                           lambda name, value: model.fix(name, value))
                apply_func(
                    template.lower_bounds,
                    lambda name, value: model.set_lower_bound(name, value))
                apply_func(
                    template.upper_bounds,
                    lambda name, value: model.set_upper_bound(name, value))

                self._prepare_model_cb(iteration_position, model,
                                       output_previous, output_all_previous)
Exemple #3
0
        class AutoCreatedBatchProfile(method_binding_meta(template, SimpleBatchProfile)):
            def _get_subjects(self, data_folder):
                dirs = sorted([os.path.basename(f) for f in glob.glob(os.path.join(data_folder, '*'))])

                subjects = []
                for subject_id in dirs:
                    subject_base_folder = os.path.join(
                        data_folder, template.subject_base_folder.format(subject_id=subject_id))

                    def _prepare_path(template_path):
                        if template_path is None:
                            return None
                        return template_path.format(data_folder=data_folder,
                                                    subject_id=subject_id,
                                                    subject_base_folder=subject_base_folder)

                    data_glob = glob.glob(_prepare_path(template.data_fname))
                    if not list(data_glob):
                        break

                    noise_std = self._autoload_noise_std(data_folder, subject_id, file_pattern=template.noise_std_fname)

                    protocol_loader = BatchFitProtocolLoader(
                        _prepare_path(template.protocol_auto_dir),
                        protocol_fname=_prepare_path(template.protocol_fname),
                        bvec_fname=_prepare_path(template.bvec_fname),
                        bval_fname=_prepare_path(template.bval_fname),
                        protocol_columns=template.protocol_columns)

                    mask_fname = None
                    if list(glob.glob(_prepare_path(template.mask_fname))):
                        mask_fname = glob.glob(_prepare_path(template.mask_fname))[0]
                        data_glob = list(filterfalse(lambda v: v == mask_fname, data_glob))

                    grad_dev = None
                    if list(glob.glob(_prepare_path(template.gradient_deviations_fname))):
                        grad_dev = glob.glob(_prepare_path(template.gradient_deviations_fname))[0]
                        data_glob = list(filterfalse(lambda v: v == grad_dev, data_glob))

                    subjects.append(SimpleSubjectInfo(
                        subject_base_folder, subject_id, data_glob[0],
                        protocol_loader, mask_fname, noise_std=noise_std,
                        gradient_deviations=grad_dev))

                return subjects

            def __str__(self):
                return template.name
Exemple #4
0
 class AutoFreeParameter(
         method_binding_meta(template, FreeParameter)):
     def __init__(self, nickname=None):
         super(AutoFreeParameter, self).__init__(
             data_type,
             nickname or template.name,
             template.fixed,
             template.init_value,
             template.lower_bound,
             template.upper_bound,
             parameter_transform=_resolve_parameter_transform(
                 template.parameter_transform),
             sampling_proposal_std=template.sampling_proposal_std,
             sampling_prior=template.sampling_prior,
             numdiff_info=numdiff_info)
         self.sampling_proposal_modulus = template.sampling_proposal_modulus
Exemple #5
0
        class AutoCreatedDMRICompartmentModel(method_binding_meta(template, DMRICompartmentModelFunction)):

            def __init__(self, *args, **kwargs):
                parameters = []
                if len(template.parameters):
                    parameters = _resolve_parameters(template.parameters)
                elif len(template.parameter_list):
                    # todo remove the parameter_list attribute in future versions
                    warnings.warn('"parameter_list" is deprecated and will be removed in future versions, '
                                  'please replace with "parameters".')
                    parameters = _resolve_parameters(template.parameter_list)

                dependencies = []
                if len(template.dependencies):
                    dependencies = _resolve_dependencies(template.dependencies)
                elif len(template.dependency_list):
                    # todo remove the dependency_list attribute in future versions
                    warnings.warn('"dependency_list" is deprecated and will be removed in future versions, '
                                  'please replace with "dependencies".')
                    dependencies = _resolve_dependencies(template.dependency_list)

                new_args = [template.name,
                            template.name,
                            parameters,
                            template.cl_code,
                            dependencies,
                            template.return_type]

                for ind, already_set_arg in enumerate(args):
                    new_args[ind] = already_set_arg

                new_kwargs = {
                    'model_function_priors': (_resolve_prior(template.extra_prior, template.name,
                                                             [p.name for p in parameters],)),
                    'post_optimization_modifiers': template.post_optimization_modifiers,
                    'extra_optimization_maps_funcs': builder._get_extra_optimization_map_funcs(
                        template, parameters),
                    'extra_sampling_maps_funcs': copy(template.extra_sampling_maps),
                    'cl_extra': template.cl_extra}
                new_kwargs.update(kwargs)

                super(AutoCreatedDMRICompartmentModel, self).__init__(*new_args, **new_kwargs)

                if hasattr(template, 'init'):
                    template.init(self)
Exemple #6
0
        class AutoCreatedLibraryFunction(method_binding_meta(template, SimpleCLLibrary)):

            def __init__(self, *args, **kwargs):
                new_args = [template.return_type,
                            template.name,
                            _resolve_parameters(template.parameters),
                            template.cl_code,
                            ]

                for ind, already_set_arg in enumerate(args):
                    new_args[ind] = already_set_arg

                new_kwargs = dict(dependencies=_resolve_dependencies(template.dependencies),
                                  cl_extra=template.cl_extra)
                new_kwargs.update(kwargs)

                super(AutoCreatedLibraryFunction, self).__init__(*new_args, **new_kwargs)

                if hasattr(template, 'init'):
                    template.init(self)
Exemple #7
0
        class AutoCreatedWeightModel(method_binding_meta(template, WeightType)):

            def __init__(self, *args, **kwargs):
                parameters = []
                if len(template.parameters):
                    parameters = _resolve_parameters(template.parameters)
                elif len(template.parameter_list):
                    # todo remove the parameter_list attribute in future versions
                    warnings.warn('"parameter_list" is deprecated and will be removed in future versions, '
                                  'please replace with "parameters".')
                    parameters = _resolve_parameters(template.parameter_list)

                dependencies = []
                if len(template.dependencies):
                    dependencies = _resolve_dependencies(template.dependencies)
                elif len(template.dependency_list):
                    # todo remove the dependency_list attribute in future versions
                    warnings.warn('"dependency_list" is deprecated and will be removed in future versions, '
                                  'please replace with "dependencies".')
                    dependencies = _resolve_dependencies(template.dependency_list)

                new_args = [template.name,
                            template.name,
                            parameters,
                            template.cl_code,
                            ]

                for ind, already_set_arg in enumerate(args):
                    new_args[ind] = already_set_arg

                new_kwargs = {
                    'dependencies': dependencies,
                    'cl_extra': template.cl_extra}
                new_kwargs.update(kwargs)

                super(AutoCreatedWeightModel, self).__init__(template.return_type, *new_args, **new_kwargs)

                if hasattr(template, 'init'):
                    template.init(self)
Exemple #8
0
        class AutoCreatedDMRICompositeModel(
                method_binding_meta(template, DMRICompositeModel)):
            def __init__(self, model_name=None):
                model_name = model_name or deepcopy(template.name)

                super(AutoCreatedDMRICompositeModel, self).__init__(
                    model_name,
                    CompartmentModelTree(parse(template.model_expression)),
                    deepcopy(
                        _resolve_likelihood_function(
                            template.likelihood_function)),
                    signal_noise_model=deepcopy(template.signal_noise_model),
                    enforce_weights_sum_to_one=template.
                    enforce_weights_sum_to_one,
                )

                if template.sort_maps:
                    self._post_optimization_modifiers.append(
                        _get_map_sorting_modifier(
                            template.sort_maps,
                            self._model_functions_info.
                            get_model_parameter_list()))

                self._post_optimization_modifiers.extend(
                    _get_model_post_optimization_modifiers(
                        self._model_functions_info.get_model_list()))
                self._post_optimization_modifiers.extend(
                    deepcopy(template.post_optimization_modifiers))

                self._extra_optimization_maps_funcs.extend(
                    _get_model_extra_optimization_maps_funcs(
                        self._model_functions_info.get_model_list()))
                self._extra_optimization_maps_funcs.extend(
                    deepcopy(template.extra_optimization_maps))

                self._extra_sampling_maps_funcs.extend(
                    _get_model_extra_sampling_maps_funcs(
                        self._model_functions_info.get_model_list()))
                self._extra_sampling_maps_funcs.extend(
                    deepcopy(template.extra_sampling_maps))

                for full_param_name, value in template.inits.items():
                    self.init(full_param_name, deepcopy(value))

                for full_param_name, value in template.fixes.items():
                    self.fix(full_param_name, deepcopy(value))

                for full_param_name, value in template.lower_bounds.items():
                    self.set_lower_bound(full_param_name, deepcopy(value))

                for full_param_name, value in template.upper_bounds.items():
                    self.set_upper_bound(full_param_name, deepcopy(value))

                self.nmr_parameters_for_bic_calculation = self.get_nmr_parameters(
                )

                self._model_priors.extend(
                    _resolve_model_prior(
                        template.extra_prior,
                        self._model_functions_info.get_model_parameter_list()))

            def _get_suitable_volume_indices(self, input_data):
                volume_selection = template.volume_selection

                if not volume_selection:
                    return super(AutoCreatedDMRICompositeModel,
                                 self)._get_suitable_volume_indices(input_data)

                use_unweighted = volume_selection.get('use_unweighted', True)
                use_weighted = volume_selection.get('use_weighted', True)
                unweighted_threshold = volume_selection.get(
                    'unweighted_threshold', 25e6)

                protocol = input_data.protocol

                if protocol.has_column('g') and protocol.has_column('b'):
                    if use_weighted:
                        if 'min_bval' in volume_selection and 'max_bval' in volume_selection:
                            protocol_indices = protocol.get_indices_bval_in_range(
                                start=volume_selection['min_bval'],
                                end=volume_selection['max_bval'])
                        else:
                            protocol_indices = protocol.get_weighted_indices(
                                unweighted_threshold)
                    else:
                        protocol_indices = []

                    if use_unweighted:
                        protocol_indices = list(protocol_indices) + \
                                           list(protocol.get_unweighted_indices(unweighted_threshold))
                else:
                    return list(range(protocol.length))

                return np.unique(protocol_indices)