def initialise(self, solver, stream=None): slvr = solver ntime, na, npolchan = slvr.dim_local_size('ntime', 'na', 'npolchan') # Get a property dictionary off the solver D = slvr.template_dict() # Include our kernel parameters D.update(FLOAT_PARAMS if slvr.is_float() else DOUBLE_PARAMS) D['rime_const_data_struct'] = slvr.const_data().string_def() D['BLOCKDIMX'], D['BLOCKDIMY'], D['BLOCKDIMZ'] = \ mbu.redistribute_threads( D['BLOCKDIMX'], D['BLOCKDIMY'], D['BLOCKDIMZ'], npolchan, na, ntime) regs = str(FLOAT_PARAMS['maxregs'] \ if slvr.is_float() else DOUBLE_PARAMS['maxregs']) kname = 'rime_jones_EKBSqrt_float' \ if slvr.is_float() is True else \ 'rime_jones_EKBSqrt_double' kernel_string = KERNEL_TEMPLATE.substitute(**D) self.mod = SourceModule(kernel_string, options=['-lineinfo','-maxrregcount', regs], include_dirs=[montblanc.get_source_path()], no_extern_c=True) self.rime_const_data = self.mod.get_global('C') self.kernel = self.mod.get_function(kname) self.launch_params = self.get_launch_params(slvr, D)
def initialise(self, solver, stream=None): slvr = solver ntime, nbl, npolchan = slvr.dim_local_size('ntime', 'nbl', 'npolchan') # Get a property dictionary off the solver D = slvr.template_dict() # Include our kernel parameters D.update(FLOAT_PARAMS if slvr.is_float() else DOUBLE_PARAMS) D['rime_const_data_struct'] = slvr.const_data().string_def() D['BLOCKDIMX'], D['BLOCKDIMY'], D['BLOCKDIMZ'] = \ mbu.redistribute_threads( D['BLOCKDIMX'], D['BLOCKDIMY'], D['BLOCKDIMZ'], npolchan, nbl, ntime) regs = str(FLOAT_PARAMS['maxregs'] \ if slvr.is_float() else DOUBLE_PARAMS['maxregs']) # Create the signature of the call to the function stamping macro stamp_args = ', '.join([ 'float' if slvr.is_float() else 'double', 'float2' if slvr.is_float() else 'double2', 'float3' if slvr.is_float() else 'double3', 'true' if slvr.use_weight_vector() else 'false', '1' if slvr.outputs_residuals() else '0' ]) stamp_fn = ''.join(['stamp_sum_coherencies_fn(', stamp_args, ')']) D['stamp_function'] = stamp_fn kname = 'rime_sum_coherencies' self.mod = SourceModule(KERNEL_TEMPLATE.substitute(**D), options=['-lineinfo', '-maxrregcount', regs], include_dirs=[montblanc.get_source_path()], no_extern_c=True) self.rime_const_data = self.mod.get_global('C') self.kernel = self.mod.get_function(kname) self.launch_params = self.get_launch_params(slvr, D)
def initialise(self, solver, stream=None): slvr = solver ntime, nbl, npolchan = slvr.dim_local_size('ntime', 'nbl', 'npolchan') # Get a property dictionary off the solver D = slvr.template_dict() # Include our kernel parameters D.update(FLOAT_PARAMS if slvr.is_float() else DOUBLE_PARAMS) D['rime_const_data_struct'] = slvr.const_data().string_def() D['BLOCKDIMX'], D['BLOCKDIMY'], D['BLOCKDIMZ'] = \ mbu.redistribute_threads( D['BLOCKDIMX'], D['BLOCKDIMY'], D['BLOCKDIMZ'], npolchan, nbl, ntime) regs = str(FLOAT_PARAMS['maxregs'] \ if slvr.is_float() else DOUBLE_PARAMS['maxregs']) # Create the signature of the call to the function stamping macro stamp_args = ', '.join([ 'float' if slvr.is_float() else 'double', 'float2' if slvr.is_float() else 'double2', 'float3' if slvr.is_float() else 'double3', 'true' if slvr.use_weight_vector() else 'false', '1' if slvr.outputs_residuals() else '0']) stamp_fn = ''.join(['stamp_sum_coherencies_fn(', stamp_args, ')']) D['stamp_function'] = stamp_fn kname = 'rime_sum_coherencies' self.mod = SourceModule( KERNEL_TEMPLATE.substitute(**D), options=['-lineinfo','-maxrregcount', regs], include_dirs=[montblanc.get_source_path()], no_extern_c=True) self.rime_const_data = self.mod.get_global('C') self.kernel = self.mod.get_function(kname) self.launch_params = self.get_launch_params(slvr, D)