def matrix_funptr(form, state): from firedrake.tsfc_interface import compile_form test, trial = map(operator.methodcaller("function_space"), form.arguments()) if test != trial: raise NotImplementedError("Only for matching test and trial spaces") if state is not None: interface = make_builder(dont_split=(state, )) else: interface = None kernels = compile_form(form, "subspace_form", split=False, interface=interface) cell_kernels = [] int_facet_kernels = [] for kernel in kernels: kinfo = kernel.kinfo if kinfo.subdomain_id != "otherwise": raise NotImplementedError("Only for full domain integrals") if kinfo.integral_type not in {"cell", "interior_facet"}: raise NotImplementedError( "Only for cell or interior facet integrals") # OK, now we've validated the kernel, let's build the callback args = [] if kinfo.integral_type == "cell": get_map = operator.methodcaller("cell_node_map") kernels = cell_kernels elif kinfo.integral_type == "interior_facet": get_map = operator.methodcaller("interior_facet_node_map") kernels = int_facet_kernels else: get_map = None toset = op2.Set(1, comm=test.comm) dofset = op2.DataSet(toset, 1) arity = sum(m.arity * s.cdim for m, s in zip(get_map(test), test.dof_dset)) iterset = get_map(test).iterset entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size * arity, dtype=IntType)) mat = LocalMat(dofset) arg = mat(op2.INC, (entity_node_map, entity_node_map)) arg.position = 0 args.append(arg) statedat = LocalDat(dofset) state_entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size * arity, dtype=IntType)) statearg = statedat(op2.READ, state_entity_node_map) mesh = form.ufl_domains()[kinfo.domain_number] arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates)) arg.position = 1 args.append(arg) if kinfo.oriented: c = form.ufl_domain().cell_orientations() arg = c.dat(op2.READ, get_map(c)) arg.position = len(args) args.append(arg) if kinfo.needs_cell_sizes: c = form.ufl_domain().cell_sizes arg = c.dat(op2.READ, get_map(c)) arg.position = len(args) args.append(arg) for n in kinfo.coefficient_map: c = form.coefficients()[n] if c is state: statearg.position = len(args) args.append(statearg) continue for (i, c_) in enumerate(c.split()): map_ = get_map(c_) arg = c_.dat(op2.READ, map_) arg.position = len(args) args.append(arg) if kinfo.integral_type == "interior_facet": arg = test.ufl_domain().interior_facets.local_facet_dat(op2.READ) arg.position = len(args) args.append(arg) iterset = op2.Subset(iterset, [0]) mod = seq.JITModule(kinfo.kernel, iterset, *args) kernels.append(CompiledKernel(mod._fun, kinfo)) return cell_kernels, int_facet_kernels
def generate_code(self): indent = lambda t, i: ('\n' + ' ' * i).join(t.split('\n')) # 1) Construct the wrapper arguments code_dict = {} code_dict['wrapper_name'] = 'wrap_executor' code_dict['executor_arg'] = "%s %s" % (slope.Executor.meta['ctype_exec'], slope.Executor.meta['name_param_exec']) _wrapper_args = ', '.join([arg.c_wrapper_arg() for arg in self._args]) _wrapper_decs = ';\n'.join([arg.c_wrapper_dec() for arg in self._args]) code_dict['wrapper_args'] = _wrapper_args code_dict['wrapper_decs'] = indent(_wrapper_decs, 1) code_dict['rank'] = ", %s %s" % (slope.Executor.meta['ctype_rank'], slope.Executor.meta['rank']) code_dict['region_flag'] = ", %s %s" % (slope.Executor.meta['ctype_region_flag'], slope.Executor.meta['region_flag']) # 2) Construct the kernel invocations _loop_body, _user_code, _ssinds_arg = [], [], [] # For each kernel ... for i, (kernel, it_space, args) in enumerate(zip(self._all_kernels, self._all_itspaces, self._all_args)): # ... bind the Executor's arguments to this kernel's arguments binding = [] for a1 in args: for a2 in self._args: if a1.data is a2.data and a1.map is a2.map: a1.ref_arg = a2 break binding.append(a1.c_arg_bindto()) binding = ";\n".join(binding) # ... obtain the /code_dict/ as if it were not part of an Executor, # since bits of code generation can be reused loop_code_dict = sequential.JITModule(kernel, it_space, *args, delay=True) loop_code_dict = loop_code_dict.generate_code() # ... does the scatter use global or local maps ? if self._use_glb_maps: loop_code_dict['index_expr'] = '%s[n]' % self._executor.gtl_maps[i]['DIRECT'] prefetch_var = 'int p = %s[n + %d]' % (self._executor.gtl_maps[i]['DIRECT'], self._use_prefetch) else: prefetch_var = 'int p = n + %d' % self._use_prefetch # ... add prefetch intrinsics, if requested prefetch_maps, prefetch_vecs = '', '' if self._use_prefetch: prefetch = lambda addr: '_mm_prefetch ((char*)(%s), _MM_HINT_T0)' % addr prefetch_maps = [a.c_map_entry('p') for a in args if a._is_indirect] # can save some instructions since prefetching targets chunks of 32 bytes prefetch_maps = flatten([j for j in pm if pm.index(j) % 2 == 0] for pm in prefetch_maps) prefetch_maps = list(OrderedDict.fromkeys(prefetch_maps)) prefetch_maps = ';\n'.join([prefetch_var] + [prefetch('&(%s)' % pm) for pm in prefetch_maps]) prefetch_vecs = flatten(a.c_vec_entry('p', True) for a in args if a._is_indirect) prefetch_vecs = ';\n'.join([prefetch(pv) for pv in prefetch_vecs]) loop_code_dict['prefetch_maps'] = prefetch_maps loop_code_dict['prefetch_vecs'] = prefetch_vecs # ... build the subset indirection array, if necessary _ssind_arg, _ssind_decl = '', '' if loop_code_dict['ssinds_arg']: _ssind_arg = 'ssinds_%d' % i _ssind_decl = 'int* %s' % _ssind_arg loop_code_dict['index_expr'] = '%s[n]' % _ssind_arg # ... use the proper function name (the function name of the kernel # within *this* specific loop chain) loop_code_dict['kernel_name'] = kernel._function_names[self._kernel.cache_key] # ... finish building up the /code_dict/ loop_code_dict['args_binding'] = binding loop_code_dict['tile_init'] = self._executor.c_loop_init[i] loop_code_dict['tile_finish'] = self._executor.c_loop_end[i] loop_code_dict['tile_start'] = slope.Executor.meta['tile_start'] loop_code_dict['tile_end'] = slope.Executor.meta['tile_end'] loop_code_dict['tile_iter'] = '%s[n]' % self._executor.gtl_maps[i]['DIRECT'] if _ssind_arg: loop_code_dict['tile_iter'] = '%s[%s]' % (_ssind_arg, loop_code_dict['tile_iter']) # ... concatenate the rest, i.e., body, user code, ... _loop_body.append(strip(TilingJITModule._kernel_wrapper % loop_code_dict)) _user_code.append(kernel._user_code) _ssinds_arg.append(_ssind_decl) _loop_chain_body = indent("\n\n".join(_loop_body), 2) code_dict['user_code'] = indent("\n".join(_user_code), 1) code_dict['ssinds_arg'] = "".join(["%s," % s for s in _ssinds_arg if s]) code_dict['executor_code'] = indent(self._executor.c_code(_loop_chain_body), 1) return code_dict