Пример #1
0
            def __init__(self, nickname=None):
                parameters = _resolve_parameters(template.parameters,
                                                 template.name)
                dependencies = _resolve_dependencies(template.dependencies)

                if template.cl_extra:
                    extra_code = '''
                        #ifndef {inclusion_guard_name}
                        #define {inclusion_guard_name}
                        {cl_extra}
                        #endif // {inclusion_guard_name}
                    '''.format(
                        inclusion_guard_name='INCLUDE_GUARD_CL_EXTRA_{}'.
                        format(template.name),
                        cl_extra=template.cl_extra)
                    dependencies.append(SimpleCLCodeObject(extra_code))

                super().__init__(
                    template.return_type,
                    template.name,
                    parameters,
                    template.cl_code,
                    dependencies=dependencies,
                    constraints_func=_resolve_constraints(
                        template.constraints, template.name, parameters,
                        dependencies),
                    model_function_priors=_resolve_prior(
                        template.extra_prior, template.name,
                        [p.name for p in parameters]),
                    extra_optimization_maps_funcs=builder.
                    _get_extra_optimization_map_funcs(template, parameters),
                    extra_sampling_maps_funcs=builder.
                    _get_extra_sampling_map_funcs(template, parameters),
                    nickname=nickname,
                    cache_info=builder._get_cache_info(template))
Пример #2
0
                def __init__(self):
                    dependencies = _resolve_dependencies(template.dependencies)

                    if template.cl_extra:
                        extra_code = '''
                            #ifndef {inclusion_guard_name}
                            #define {inclusion_guard_name}
                            {cl_extra}
                            #endif // {inclusion_guard_name}
                        '''.format(inclusion_guard_name='INCLUDE_GUARD_{}_EXTRA'.format(template.name),
                                   cl_extra=template.cl_extra)
                        dependencies.append(SimpleCLCodeObject(extra_code))

                    super().__init__(
                        template.return_type, template.name,
                        _resolve_parameters(template.parameters), template.cl_code,
                        dependencies=dependencies)
Пример #3
0
Файл: base.py Проект: 42n4/MOT
    def _get_compute_func(self, nmr_samples, thinning, return_output):
        """Get the MCMC algorithm as a computable function.

        Args:
            nmr_samples (int): the number of samples we will draw
            thinning (int): the thinning factor we want to use
            return_output (boolean): if the kernel should return output

        Returns:
            mot.lib.cl_function.CLFunction: the compute function
        """
        cl_func = '''
            void compute(global uint* rng_state, 
                         global mot_float_type* current_chain_position,
                         global mot_float_type* current_log_likelihood,
                         global mot_float_type* current_log_prior,
                         ulong iteration_offset, 
                         ulong nmr_iterations, 
                         ''' + ('''global mot_float_type* samples, 
                                   global mot_float_type* log_likelihoods,
                                   global mot_float_type* log_priors,'''
                                if return_output else '') + '''
                         void* method_data, 
                         void* data){
                         
                bool is_first_work_item = get_local_id(0) == 0;
    
                rand123_data rand123_rng_data = rand123_initialize_data((uint[]){
                    rng_state[0], rng_state[1], rng_state[2], rng_state[3], 
                    rng_state[4], rng_state[5], 0, 0});
                void* rng_data = (void*)&rand123_rng_data;
        
                for(ulong i = 0; i < nmr_iterations; i++){
        '''
        if return_output:
            cl_func += '''
                    if(is_first_work_item){
                        if(i % ''' + str(thinning) + ''' == 0){
                            log_likelihoods[i / ''' + str(
                thinning) + '''] = *current_log_likelihood;
                            log_priors[i / ''' + str(
                    thinning) + '''] = *current_log_prior;
    
                            for(uint j = 0; j < ''' + str(
                        self._nmr_params) + '''; j++){
                                samples[(ulong)(i / ''' + str(
                            thinning) + ''') // remove the interval
                                        + j * ''' + str(
                                nmr_samples) + '''  // parameter index
                                ] = current_chain_position[j];
                            }
                        }
                    }
        '''
        cl_func += '''
                    _advanceSampler(method_data, data, i + iteration_offset, rng_data, 
                                    current_chain_position, current_log_likelihood, current_log_prior);
                }

                if(is_first_work_item){
                    uint state[8];
                    rand123_data_to_array(rand123_rng_data, state);
                    for(uint i = 0; i < 6; i++){
                        rng_state[i] = state[i];
                    }
                }
            }
        '''
        return SimpleCLFunction.from_string(
            cl_func,
            dependencies=[
                Rand123(),
                self._get_log_prior_cl_func(),
                self._get_log_likelihood_cl_func(),
                SimpleCLCodeObject(
                    self._get_state_update_cl_func(nmr_samples, thinning,
                                                   return_output))
            ])