Exemple #1
0
    def _get_proposal_callbacks(self, template, parameter_list):
        """Get a list of proposal callback functions.

        These functions are (indirectly) called by a MCMC sample routine to finalize the new proposal vector.

        Returns:
            List[(Tuple, mot.lib.cl_function.CLFunction)]: a list of proposal callback functions coupled with
                references to the compartment parameters used in the function.
        """
        callbacks = []

        def existing_parameters(param_list):
            param_names = [p.name for p in parameter_list]
            return all(p in param_names for p in param_list)

        def get_corresponding_param(param_name):
            for p in parameter_list:
                if p.name == param_name:
                    return p

        if template.spherical_parameters is not None and existing_parameters(
                template.spherical_parameters):
            corresponding_params = [
                get_corresponding_param(p)
                for p in template.spherical_parameters
            ]

            func = SimpleCLFunction(
                'void', 'proposal_callback_spherical_{}'.format(template.name),
                ['mot_float_type* theta', 'mot_float_type* phi'], '''
                    if(*phi > M_PI_F){
                        *phi -= M_PI_F;
                        *theta = M_PI_F - *theta;
                    }
                    else if(*phi < 0){
                        *phi += M_PI_F;
                        *theta = M_PI_F - *theta;
                    }
            ''')
            callbacks.append((corresponding_params, func))

        for p in parameter_list:
            if hasattr(p, 'sampling_proposal_modulus'
                       ) and p.sampling_proposal_modulus is not None:
                func = SimpleCLFunction(
                    'void',
                    'proposal_callback_{}_{}'.format(template.name, p.name),
                    ['mot_float_type* ' + p.name],
                    '*{0} = *{0} - floor(*{0} / {1}) * {1};'.format(
                        p.name, p.sampling_proposal_modulus))

                callbacks.append(([p], func))

        return callbacks
Exemple #2
0
def _resolve_prior(prior, compartment_name, compartment_parameters):
    """Create a proper prior out of the given prior information.

    Args:
        prior (str or mot.lib.cl_function.CLFunction or None):
            The prior from which to construct a prior.
        compartment_name (str): the name of the compartment
        compartment_parameters (list of str): the list of parameters of this compartment, used
            for looking up the used parameters in a string prior

    Returns:
        List[mdt.models.compartments.CompartmentPrior]: the list of extra priors for this compartment
    """
    if prior is None:
        return []

    if isinstance(prior, CLFunction):
        return [prior]

    parameters = [
        'mot_float_type ' + p for p in compartment_parameters if p in prior
    ]
    return [
        SimpleCLFunction('mot_float_type', 'prior_' + compartment_name,
                         parameters, prior)
    ]
Exemple #3
0
def _resolve_model_prior(prior, model_parameters):
    """Resolve the model priors.

    Args:
        prior (None or str or mot.lib.cl_function.CLFunction): the prior defined in the composite model template.
        model_parameters (str): the (model, parameter) tuple for all the parameters in the model

    Returns:
        list of mdt.model_building.utils.ModelPrior: list of model priors
    """
    if prior is None:
        return []

    if isinstance(prior, CLFunction):
        return [prior]

    dotted_names = ['{}.{}'.format(m.name, p.name) for m, p in model_parameters]
    dotted_names.sort(key=len, reverse=True)

    parameters = []
    remaining_prior = prior
    for dotted_name in dotted_names:
        bar_name = dotted_name.replace('.', '_')

        if dotted_name in remaining_prior:
            prior = prior.replace(dotted_name, bar_name)
            remaining_prior = remaining_prior.replace(dotted_name, '')
            parameters.append('mot_float_type ' + dotted_name)
        elif bar_name in remaining_prior:
            remaining_prior = remaining_prior.replace(bar_name, '')
            parameters.append('mot_float_type ' + dotted_name)

    return [SimpleCLFunction('mot_float_type', 'model_prior', parameters, prior)]
Exemple #4
0
    def evaluate(self, *args, **kwargs):
        if not any(
                isinstance(p, DataCacheParameter)
                for p in self._parameter_list):
            return super().evaluate(*args, **kwargs)

        cache_struct = self.get_cache_struct('private')[self.name]
        cache_param_name = self._get_cache_parameter().name

        if isinstance(args[0], (tuple, list)):
            args = list(args)
            args[0] = tuple(args[0]) + (cache_struct, )
        else:
            args[0][cache_param_name] = cache_struct

        cache_init_func = self.get_cache_init_function()

        with_cache_func = SimpleCLFunction(
            self._return_type,
            '_{}'.format(self._function_name),
            self._parameter_list,
            '''
                {cache_init_func_name}({cache_params});
                return {parent_func_name}({parent_func_params});
            '''.format(
                cache_init_func_name=cache_init_func.get_cl_function_name(),
                parent_func_name=self.get_cl_function_name(),
                cache_params=', '.join(
                    [p.name for p in cache_init_func.get_parameters()]),
                parent_func_params=', '.join(
                    [p.name for p in self.get_parameters()])),
            dependencies=[cache_init_func, self])

        return with_cache_func.evaluate(*args, **kwargs)
Exemple #5
0
    def get_cache_init_function(self):
        if not self.get_cache_struct('private'):
            return None

        dependency_calls = []
        cache_init_funcs = []
        for dependency in self.get_dependencies():
            if isinstance(dependency, CompartmentModel):
                cache_init_func = dependency.get_cache_init_function()
                if cache_init_func:
                    params = []
                    for p in cache_init_func.get_parameters():
                        if isinstance(p, DataCacheParameter):
                            params.append('{}->{}'.format(
                                p.name, dependency.get_cl_function_name()))
                        else:
                            params.append(p.name)
                    dependency_calls.append('{}({});'.format(
                        cache_init_func.get_cl_function_name(),
                        ', '.join(params)))
                    cache_init_funcs.append(cache_init_func)

        return SimpleCLFunction(
            'void',
            '{}_init_cache'.format(self._function_name), [
                p for p in self.get_parameters()
                if isinstance(p, (FreeParameter, DataCacheParameter,
                                  NoiseStdInputParameter))
            ],
            self._cache_info.cl_code + '\n'.join(dependency_calls),
            dependencies=cache_init_funcs + self.get_dependencies())
Exemple #6
0
    def get_cache_init_function(self):
        """Get the CL function for initializing the cache struct of this compartment.

        Returns:
            mot.lib.cl_function.CLFunction: the CL function for initializing the cache. This has the same
                signature as the compartment model function.
        """
        dependency_calls = []
        for dependency in self.get_dependencies():
            if isinstance(dependency, DMRICompartmentModelFunction):
                params = []
                for p in self._parameter_list:
                    if isinstance(p, DataCacheParameter):
                        params.append('{}->{}'.format(
                            p.name, dependency.get_cl_function_name()))
                    else:
                        params.append(p.name)
                dependency_calls.append('{}({});'.format(
                    dependency.get_cl_function_name(), ', '.join(params)))

        return SimpleCLFunction(
            self._return_type, '{}_init_cache'.format(self._function_name),
            self._parameter_list,
            self.cache_info.cl_code + '\n'.join(dependency_calls))