class AutoProtocolParameter( method_binding_meta(template, ProtocolParameter)): def __init__(self, nickname=None): super(AutoProtocolParameter, self).__init__(data_type, nickname or template.name, value=template.value)
class AutoCreatedCascadeModel( method_binding_meta(template, SimpleCascadeModel)): def __init__(self, *args): models = [] for model_def in template.models: if isinstance(model_def, six.string_types): models.append(mdt.get_model(model_def)()) else: models.append( mdt.get_model(model_def[0])(model_def[1])) new_args = [deepcopy(template.name), models] for ind, arg in args: new_args[ind] = arg super(AutoCreatedCascadeModel, self).__init__(*new_args) def _prepare_model(self, iteration_position, model, output_previous, output_all_previous): super(AutoCreatedCascadeModel, self)._prepare_model(iteration_position, model, output_previous, output_all_previous) def parse_value(v): if isinstance(v, six.string_types): return output_previous[v] elif hasattr(v, '__call__'): return v(output_previous, output_all_previous) return v def apply_func(template_element, cb): items_to_apply = dict(template_element.get(model.name, {})) items_to_apply.update( dict(template_element.get(iteration_position, {}))) for key, value in items_to_apply.items(): cb(key, parse_value(value)) apply_func(template.inits, lambda name, value: model.init(name, value)) apply_func(template.fixes, lambda name, value: model.fix(name, value)) apply_func( template.lower_bounds, lambda name, value: model.set_lower_bound(name, value)) apply_func( template.upper_bounds, lambda name, value: model.set_upper_bound(name, value)) self._prepare_model_cb(iteration_position, model, output_previous, output_all_previous)
class AutoCreatedBatchProfile(method_binding_meta(template, SimpleBatchProfile)): def _get_subjects(self, data_folder): dirs = sorted([os.path.basename(f) for f in glob.glob(os.path.join(data_folder, '*'))]) subjects = [] for subject_id in dirs: subject_base_folder = os.path.join( data_folder, template.subject_base_folder.format(subject_id=subject_id)) def _prepare_path(template_path): if template_path is None: return None return template_path.format(data_folder=data_folder, subject_id=subject_id, subject_base_folder=subject_base_folder) data_glob = glob.glob(_prepare_path(template.data_fname)) if not list(data_glob): break noise_std = self._autoload_noise_std(data_folder, subject_id, file_pattern=template.noise_std_fname) protocol_loader = BatchFitProtocolLoader( _prepare_path(template.protocol_auto_dir), protocol_fname=_prepare_path(template.protocol_fname), bvec_fname=_prepare_path(template.bvec_fname), bval_fname=_prepare_path(template.bval_fname), protocol_columns=template.protocol_columns) mask_fname = None if list(glob.glob(_prepare_path(template.mask_fname))): mask_fname = glob.glob(_prepare_path(template.mask_fname))[0] data_glob = list(filterfalse(lambda v: v == mask_fname, data_glob)) grad_dev = None if list(glob.glob(_prepare_path(template.gradient_deviations_fname))): grad_dev = glob.glob(_prepare_path(template.gradient_deviations_fname))[0] data_glob = list(filterfalse(lambda v: v == grad_dev, data_glob)) subjects.append(SimpleSubjectInfo( subject_base_folder, subject_id, data_glob[0], protocol_loader, mask_fname, noise_std=noise_std, gradient_deviations=grad_dev)) return subjects def __str__(self): return template.name
class AutoFreeParameter( method_binding_meta(template, FreeParameter)): def __init__(self, nickname=None): super(AutoFreeParameter, self).__init__( data_type, nickname or template.name, template.fixed, template.init_value, template.lower_bound, template.upper_bound, parameter_transform=_resolve_parameter_transform( template.parameter_transform), sampling_proposal_std=template.sampling_proposal_std, sampling_prior=template.sampling_prior, numdiff_info=numdiff_info) self.sampling_proposal_modulus = template.sampling_proposal_modulus
class AutoCreatedDMRICompartmentModel(method_binding_meta(template, DMRICompartmentModelFunction)): def __init__(self, *args, **kwargs): parameters = [] if len(template.parameters): parameters = _resolve_parameters(template.parameters) elif len(template.parameter_list): # todo remove the parameter_list attribute in future versions warnings.warn('"parameter_list" is deprecated and will be removed in future versions, ' 'please replace with "parameters".') parameters = _resolve_parameters(template.parameter_list) dependencies = [] if len(template.dependencies): dependencies = _resolve_dependencies(template.dependencies) elif len(template.dependency_list): # todo remove the dependency_list attribute in future versions warnings.warn('"dependency_list" is deprecated and will be removed in future versions, ' 'please replace with "dependencies".') dependencies = _resolve_dependencies(template.dependency_list) new_args = [template.name, template.name, parameters, template.cl_code, dependencies, template.return_type] for ind, already_set_arg in enumerate(args): new_args[ind] = already_set_arg new_kwargs = { 'model_function_priors': (_resolve_prior(template.extra_prior, template.name, [p.name for p in parameters],)), 'post_optimization_modifiers': template.post_optimization_modifiers, 'extra_optimization_maps_funcs': builder._get_extra_optimization_map_funcs( template, parameters), 'extra_sampling_maps_funcs': copy(template.extra_sampling_maps), 'cl_extra': template.cl_extra} new_kwargs.update(kwargs) super(AutoCreatedDMRICompartmentModel, self).__init__(*new_args, **new_kwargs) if hasattr(template, 'init'): template.init(self)
class AutoCreatedLibraryFunction(method_binding_meta(template, SimpleCLLibrary)): def __init__(self, *args, **kwargs): new_args = [template.return_type, template.name, _resolve_parameters(template.parameters), template.cl_code, ] for ind, already_set_arg in enumerate(args): new_args[ind] = already_set_arg new_kwargs = dict(dependencies=_resolve_dependencies(template.dependencies), cl_extra=template.cl_extra) new_kwargs.update(kwargs) super(AutoCreatedLibraryFunction, self).__init__(*new_args, **new_kwargs) if hasattr(template, 'init'): template.init(self)
class AutoCreatedWeightModel(method_binding_meta(template, WeightType)): def __init__(self, *args, **kwargs): parameters = [] if len(template.parameters): parameters = _resolve_parameters(template.parameters) elif len(template.parameter_list): # todo remove the parameter_list attribute in future versions warnings.warn('"parameter_list" is deprecated and will be removed in future versions, ' 'please replace with "parameters".') parameters = _resolve_parameters(template.parameter_list) dependencies = [] if len(template.dependencies): dependencies = _resolve_dependencies(template.dependencies) elif len(template.dependency_list): # todo remove the dependency_list attribute in future versions warnings.warn('"dependency_list" is deprecated and will be removed in future versions, ' 'please replace with "dependencies".') dependencies = _resolve_dependencies(template.dependency_list) new_args = [template.name, template.name, parameters, template.cl_code, ] for ind, already_set_arg in enumerate(args): new_args[ind] = already_set_arg new_kwargs = { 'dependencies': dependencies, 'cl_extra': template.cl_extra} new_kwargs.update(kwargs) super(AutoCreatedWeightModel, self).__init__(template.return_type, *new_args, **new_kwargs) if hasattr(template, 'init'): template.init(self)
class AutoCreatedDMRICompositeModel( method_binding_meta(template, DMRICompositeModel)): def __init__(self, model_name=None): model_name = model_name or deepcopy(template.name) super(AutoCreatedDMRICompositeModel, self).__init__( model_name, CompartmentModelTree(parse(template.model_expression)), deepcopy( _resolve_likelihood_function( template.likelihood_function)), signal_noise_model=deepcopy(template.signal_noise_model), enforce_weights_sum_to_one=template. enforce_weights_sum_to_one, ) if template.sort_maps: self._post_optimization_modifiers.append( _get_map_sorting_modifier( template.sort_maps, self._model_functions_info. get_model_parameter_list())) self._post_optimization_modifiers.extend( _get_model_post_optimization_modifiers( self._model_functions_info.get_model_list())) self._post_optimization_modifiers.extend( deepcopy(template.post_optimization_modifiers)) self._extra_optimization_maps_funcs.extend( _get_model_extra_optimization_maps_funcs( self._model_functions_info.get_model_list())) self._extra_optimization_maps_funcs.extend( deepcopy(template.extra_optimization_maps)) self._extra_sampling_maps_funcs.extend( _get_model_extra_sampling_maps_funcs( self._model_functions_info.get_model_list())) self._extra_sampling_maps_funcs.extend( deepcopy(template.extra_sampling_maps)) for full_param_name, value in template.inits.items(): self.init(full_param_name, deepcopy(value)) for full_param_name, value in template.fixes.items(): self.fix(full_param_name, deepcopy(value)) for full_param_name, value in template.lower_bounds.items(): self.set_lower_bound(full_param_name, deepcopy(value)) for full_param_name, value in template.upper_bounds.items(): self.set_upper_bound(full_param_name, deepcopy(value)) self.nmr_parameters_for_bic_calculation = self.get_nmr_parameters( ) self._model_priors.extend( _resolve_model_prior( template.extra_prior, self._model_functions_info.get_model_parameter_list())) def _get_suitable_volume_indices(self, input_data): volume_selection = template.volume_selection if not volume_selection: return super(AutoCreatedDMRICompositeModel, self)._get_suitable_volume_indices(input_data) use_unweighted = volume_selection.get('use_unweighted', True) use_weighted = volume_selection.get('use_weighted', True) unweighted_threshold = volume_selection.get( 'unweighted_threshold', 25e6) protocol = input_data.protocol if protocol.has_column('g') and protocol.has_column('b'): if use_weighted: if 'min_bval' in volume_selection and 'max_bval' in volume_selection: protocol_indices = protocol.get_indices_bval_in_range( start=volume_selection['min_bval'], end=volume_selection['max_bval']) else: protocol_indices = protocol.get_weighted_indices( unweighted_threshold) else: protocol_indices = [] if use_unweighted: protocol_indices = list(protocol_indices) + \ list(protocol.get_unweighted_indices(unweighted_threshold)) else: return list(range(protocol.length)) return np.unique(protocol_indices)