def matrix_funptr(form, state): from firedrake.tsfc_interface import compile_form test, trial = map(operator.methodcaller("function_space"), form.arguments()) if test != trial: raise NotImplementedError("Only for matching test and trial spaces") if state is not None: interface = make_builder(dont_split=(state, )) else: interface = None kernels = compile_form(form, "subspace_form", split=False, interface=interface) cell_kernels = [] int_facet_kernels = [] for kernel in kernels: kinfo = kernel.kinfo if kinfo.subdomain_id != "otherwise": raise NotImplementedError("Only for full domain integrals") if kinfo.integral_type not in {"cell", "interior_facet"}: raise NotImplementedError( "Only for cell or interior facet integrals") # OK, now we've validated the kernel, let's build the callback args = [] if kinfo.integral_type == "cell": get_map = operator.methodcaller("cell_node_map") kernels = cell_kernels elif kinfo.integral_type == "interior_facet": get_map = operator.methodcaller("interior_facet_node_map") kernels = int_facet_kernels else: get_map = None toset = op2.Set(1, comm=test.comm) dofset = op2.DataSet(toset, 1) arity = sum(m.arity * s.cdim for m, s in zip(get_map(test), test.dof_dset)) iterset = get_map(test).iterset entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size * arity, dtype=IntType)) mat = LocalMat(dofset) arg = mat(op2.INC, (entity_node_map, entity_node_map)) arg.position = 0 args.append(arg) statedat = LocalDat(dofset) state_entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size * arity, dtype=IntType)) statearg = statedat(op2.READ, state_entity_node_map) mesh = form.ufl_domains()[kinfo.domain_number] arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates)) arg.position = 1 args.append(arg) if kinfo.oriented: c = form.ufl_domain().cell_orientations() arg = c.dat(op2.READ, get_map(c)) arg.position = len(args) args.append(arg) if kinfo.needs_cell_sizes: c = form.ufl_domain().cell_sizes arg = c.dat(op2.READ, get_map(c)) arg.position = len(args) args.append(arg) for n in kinfo.coefficient_map: c = form.coefficients()[n] if c is state: statearg.position = len(args) args.append(statearg) continue for (i, c_) in enumerate(c.split()): map_ = get_map(c_) arg = c_.dat(op2.READ, map_) arg.position = len(args) args.append(arg) if kinfo.integral_type == "interior_facet": arg = test.ufl_domain().interior_facets.local_facet_dat(op2.READ) arg.position = len(args) args.append(arg) iterset = op2.Subset(iterset, [0]) mod = seq.JITModule(kinfo.kernel, iterset, *args) kernels.append(CompiledKernel(mod._fun, kinfo)) return cell_kernels, int_facet_kernels
def matrix_funptr(form, state): from firedrake.tsfc_interface import compile_form test, trial = map(operator.methodcaller("function_space"), form.arguments()) if test != trial: raise NotImplementedError("Only for matching test and trial spaces") if state is not None: interface = make_builder(dont_split=(state, )) else: interface = None kernels = compile_form(form, "subspace_form", split=False, interface=interface) cell_kernels = [] int_facet_kernels = [] for kernel in kernels: kinfo = kernel.kinfo if kinfo.subdomain_id != "otherwise": raise NotImplementedError("Only for full domain integrals") if kinfo.integral_type not in {"cell", "interior_facet"}: raise NotImplementedError("Only for cell or interior facet integrals") # OK, now we've validated the kernel, let's build the callback args = [] if kinfo.integral_type == "cell": get_map = operator.methodcaller("cell_node_map") kernels = cell_kernels elif kinfo.integral_type == "interior_facet": get_map = operator.methodcaller("interior_facet_node_map") kernels = int_facet_kernels else: get_map = None toset = op2.Set(1, comm=test.comm) dofset = op2.DataSet(toset, 1) arity = sum(m.arity*s.cdim for m, s in zip(get_map(test), test.dof_dset)) iterset = get_map(test).iterset entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size*arity, dtype=IntType)) mat = LocalMat(dofset) arg = mat(op2.INC, (entity_node_map, entity_node_map)) arg.position = 0 args.append(arg) statedat = LocalDat(dofset) state_entity_node_map = op2.Map(iterset, toset, arity, values=numpy.zeros(iterset.total_size*arity, dtype=IntType)) statearg = statedat(op2.READ, state_entity_node_map) mesh = form.ufl_domains()[kinfo.domain_number] arg = mesh.coordinates.dat(op2.READ, get_map(mesh.coordinates)) arg.position = 1 args.append(arg) if kinfo.oriented: c = form.ufl_domain().cell_orientations() arg = c.dat(op2.READ, get_map(c)) arg.position = len(args) args.append(arg) for n in kinfo.coefficient_map: c = form.coefficients()[n] if c is state: statearg.position = len(args) args.append(statearg) continue for (i, c_) in enumerate(c.split()): map_ = get_map(c_) arg = c_.dat(op2.READ, map_) arg.position = len(args) args.append(arg) if kinfo.integral_type == "interior_facet": arg = test.ufl_domain().interior_facets.local_facet_dat(op2.READ) arg.position = len(args) args.append(arg) iterset = op2.Subset(iterset, [0]) mod = seq.JITModule(kinfo.kernel, iterset, *args) kernels.append(CompiledKernel(mod._fun, kinfo)) return cell_kernels, int_facet_kernels