def generate_loopy_kernel(slate_expr, tsfc_parameters=None): cpu_time = time.time() if len(slate_expr.ufl_domains()) > 1: raise NotImplementedError("Multiple domains not implemented.") Citations().register("Gibson2018") # Create a loopy builder for the Slate expression, # e.g. contains the loopy kernels coming from TSFC gem_expr, var2terminal = slate_to_gem(slate_expr) scalar_type = tsfc_parameters["scalar_type"] slate_loopy, output_arg = gem_to_loopy(gem_expr, var2terminal, scalar_type) builder = LocalLoopyKernelBuilder(expression=slate_expr, tsfc_parameters=tsfc_parameters) loopy_merged = merge_loopy(slate_loopy, output_arg, builder, var2terminal) loopy_merged = loopy.register_function_id_to_in_knl_callable_mapper( loopy_merged, inv_fn_lookup) loopy_merged = loopy.register_function_id_to_in_knl_callable_mapper( loopy_merged, solve_fn_lookup) # WORKAROUND: Generate code directly from the loopy kernel here, # then attach code as a c-string to the op2kernel code = loopy.generate_code_v2(loopy_merged).device_code() code = code.replace(f'void {loopy_merged.name}', f'static void {loopy_merged.name}') loopykernel = op2.Kernel(code, loopy_merged.name, include_dirs=BLASLAPACK_INCLUDE.split(), ldargs=BLASLAPACK_LIB.split()) kinfo = KernelInfo( kernel=loopykernel, integral_type= "cell", # slate can only do things as contributions to the cell integrals oriented=builder.bag.needs_cell_orientations, subdomain_id="otherwise", domain_number=0, coefficient_map=tuple(range(len(slate_expr.coefficients()))), needs_cell_facets=builder.bag.needs_cell_facets, pass_layer_arg=builder.bag.needs_mesh_layers, needs_cell_sizes=builder.bag.needs_cell_sizes) # Cache the resulting kernel # Slate kernels are never split, so indicate that with None in the index slot. idx = tuple([None] * slate_expr.rank) logger.info(GREEN % "compile_slate_expression finished in %g seconds.", time.time() - cpu_time) return (SplitKernel(idx, kinfo), )
def generate_loopy_kernel(slate_expr, tsfc_parameters=None): cpu_time = time.time() if len(slate_expr.ufl_domains()) > 1: raise NotImplementedError("Multiple domains not implemented.") Citations().register("Gibson2018") # Create a loopy builder for the Slate expression, # e.g. contains the loopy kernels coming from TSFC gem_expr, var2terminal = slate_to_gem(slate_expr) scalar_type = tsfc_parameters["scalar_type"] slate_loopy, output_arg = gem_to_loopy(gem_expr, var2terminal, scalar_type) builder = LocalLoopyKernelBuilder(expression=slate_expr, tsfc_parameters=tsfc_parameters) name = "slate_wrapper" loopy_merged = merge_loopy(slate_loopy, output_arg, builder, var2terminal, name) loopy_merged = loopy.register_callable(loopy_merged, INVCallable.name, INVCallable()) loopy_merged = loopy.register_callable(loopy_merged, SolveCallable.name, SolveCallable()) loopykernel = op2.Kernel(loopy_merged, name, include_dirs=BLASLAPACK_INCLUDE.split(), ldargs=BLASLAPACK_LIB.split()) kinfo = KernelInfo( kernel=loopykernel, integral_type= "cell", # slate can only do things as contributions to the cell integrals oriented=builder.bag.needs_cell_orientations, subdomain_id="otherwise", domain_number=0, coefficient_map=tuple(range(len(slate_expr.coefficients()))), needs_cell_facets=builder.bag.needs_cell_facets, pass_layer_arg=builder.bag.needs_mesh_layers, needs_cell_sizes=builder.bag.needs_cell_sizes) # Cache the resulting kernel # Slate kernels are never split, so indicate that with None in the index slot. idx = tuple([None] * slate_expr.rank) logger.info(GREEN % "compile_slate_expression finished in %g seconds.", time.time() - cpu_time) return (SplitKernel(idx, kinfo), )
def generate_kernel_ast(builder, statements, declared_temps): """Glues together the complete AST for the Slate expression contained in the :class:`LocalKernelBuilder`. :arg builder: The :class:`LocalKernelBuilder` containing all relevant expression information. :arg statements: A list of COFFEE objects containing all assembly calls and temporary declarations. :arg declared_temps: A `dict` containing all previously declared temporaries. Return: A `KernelInfo` object describing the complete AST. """ slate_expr = builder.expression if slate_expr.rank == 0: # Scalars are treated as 1x1 MatrixBase objects shape = (1, ) else: shape = slate_expr.shape # Now we create the result statement by declaring its eigen type and # using Eigen::Map to move between Eigen and C data structs. statements.append(ast.FlatBlock("/* Map eigen tensor into C struct */\n")) result_sym = ast.Symbol("T%d" % len(declared_temps)) result_data_sym = ast.Symbol("A%d" % len(declared_temps)) result_type = "Eigen::Map<%s >" % eigen_matrixbase_type(shape) result = ast.Decl(ScalarType_c, ast.Symbol(result_data_sym), pointers=[("restrict", )]) result_statement = ast.FlatBlock( "%s %s((%s *)%s);\n" % (result_type, result_sym, ScalarType_c, result_data_sym)) statements.append(result_statement) # Generate the complete c++ string performing the linear algebra operations # on Eigen matrices/vectors statements.append(ast.FlatBlock("/* Linear algebra expression */\n")) cpp_string = ast.FlatBlock(slate_to_cpp(slate_expr, declared_temps)) statements.append(ast.Incr(result_sym, cpp_string)) # Generate arguments for the macro kernel args = [ result, ast.Decl(ScalarType_c, builder.coord_sym, pointers=[("restrict", )], qualifiers=["const"]) ] # Orientation information if builder.oriented: args.append( ast.Decl("int", builder.cell_orientations_sym, pointers=[("restrict", )], qualifiers=["const"])) # Coefficient information expr_coeffs = slate_expr.coefficients() for c in expr_coeffs: args.extend([ ast.Decl(ScalarType_c, csym, pointers=[("restrict", )], qualifiers=["const"]) for csym in builder.coefficient(c) ]) # Facet information if builder.needs_cell_facets: f_sym = builder.cell_facet_sym f_arg = ast.Symbol("arg_cell_facets") f_dtype = as_cstr(cell_to_facets_dtype) # cell_facets is locally a flattened 2-D array. We typecast here so we # can access its entries using standard array notation. cast = "%s (*%s)[2] = (%s (*)[2])%s;\n" % (f_dtype, f_sym, f_dtype, f_arg) statements.insert(0, ast.FlatBlock(cast)) args.append( ast.Decl(f_dtype, f_arg, pointers=[("restrict", )], qualifiers=["const"])) # NOTE: We need to be careful about the ordering here. Mesh layers are # added as the final argument to the kernel # and the amount of layers before that. if builder.needs_mesh_layers: args.append( ast.Decl("int", builder.mesh_layer_count_sym, pointers=[("restrict", )], qualifiers=["const"])) args.append(ast.Decl("int", builder.mesh_layer_sym)) # Cell size information if builder.needs_cell_sizes: args.append( ast.Decl(ScalarType_c, builder.cell_size_sym, pointers=[("restrict", )], qualifiers=["const"])) # Macro kernel macro_kernel_name = "pyop2_kernel_compile_slate" stmts = ast.Block(statements) macro_kernel = ast.FunDecl("void", macro_kernel_name, args, stmts, pred=["static", "inline"]) # Construct the final ast kernel_ast = ast.Node(builder.templated_subkernels + [macro_kernel]) # Now we wrap up the kernel ast as a PyOP2 kernel and include the # Eigen header files include_dirs = list(builder.include_dirs) include_dirs.append(EIGEN_INCLUDE_DIR) op2kernel = op2.Kernel( kernel_ast, macro_kernel_name, cpp=True, include_dirs=include_dirs, headers=['#include <Eigen/Dense>', '#define restrict __restrict']) op2kernel.num_flops = builder.expression_flops + builder.terminal_flops # Send back a "TSFC-like" SplitKernel object with an # index and KernelInfo kinfo = KernelInfo(kernel=op2kernel, integral_type=builder.integral_type, oriented=builder.oriented, subdomain_id="otherwise", domain_number=0, coefficient_map=slate_expr.coeff_map, needs_cell_facets=builder.needs_cell_facets, pass_layer_arg=builder.needs_mesh_layers, needs_cell_sizes=builder.needs_cell_sizes) return kinfo
def generate_kernel_ast(builder, statements, declared_temps): """Glues together the complete AST for the Slate expression contained in the :class:`LocalKernelBuilder`. :arg builder: The :class:`LocalKernelBuilder` containing all relevant expression information. :arg statements: A list of COFFEE objects containing all assembly calls and temporary declarations. :arg declared_temps: A `dict` containing all previously declared temporaries. Return: A `KernelInfo` object describing the complete AST. """ slate_expr = builder.expression if slate_expr.rank == 0: # Scalars are treated as 1x1 MatrixBase objects shape = (1, ) else: shape = slate_expr.shape # Now we create the result statement by declaring its eigen type and # using Eigen::Map to move between Eigen and C data structs. statements.append(ast.FlatBlock("/* Map eigen tensor into C struct */\n")) result_sym = ast.Symbol("T%d" % len(declared_temps)) result_data_sym = ast.Symbol("A%d" % len(declared_temps)) result_type = "Eigen::Map<%s >" % eigen_matrixbase_type(shape) result = ast.Decl(SCALAR_TYPE, ast.Symbol(result_data_sym, shape)) result_statement = ast.FlatBlock( "%s %s((%s *)%s);\n" % (result_type, result_sym, SCALAR_TYPE, result_data_sym)) statements.append(result_statement) # Generate the complete c++ string performing the linear algebra operations # on Eigen matrices/vectors statements.append(ast.FlatBlock("/* Linear algebra expression */\n")) cpp_string = ast.FlatBlock( metaphrase_slate_to_cpp(slate_expr, declared_temps)) statements.append(ast.Incr(result_sym, cpp_string)) # Generate arguments for the macro kernel args = [result, ast.Decl("%s **" % SCALAR_TYPE, builder.coord_sym)] # Orientation information if builder.oriented: args.append(ast.Decl("int **", builder.cell_orientations_sym)) # Coefficient information expr_coeffs = slate_expr.coefficients() for c in expr_coeffs: if isinstance(c, Constant): ctype = "%s *" % SCALAR_TYPE else: ctype = "%s **" % SCALAR_TYPE args.extend([ast.Decl(ctype, csym) for csym in builder.coefficient(c)]) # Facet information if builder.needs_cell_facets: args.append( ast.Decl("%s *" % as_cstr(cell_to_facets_dtype), builder.cell_facet_sym)) # NOTE: We need to be careful about the ordering here. Mesh layers are # added as the final argument to the kernel. if builder.needs_mesh_layers: args.append(ast.Decl("int", builder.mesh_layer_sym)) # Macro kernel macro_kernel_name = "compile_slate" stmts = ast.Block(statements) macro_kernel = ast.FunDecl("void", macro_kernel_name, args, stmts, pred=["static", "inline"]) # Construct the final ast kernel_ast = ast.Node(builder.templated_subkernels + [macro_kernel]) # Now we wrap up the kernel ast as a PyOP2 kernel and include the # Eigen header files include_dirs = builder.include_dirs include_dirs.extend(["%s/include/eigen3/" % d for d in PETSC_DIR]) op2kernel = op2.Kernel( kernel_ast, macro_kernel_name, cpp=True, include_dirs=include_dirs, headers=['#include <Eigen/Dense>', '#define restrict __restrict']) # Send back a "TSFC-like" SplitKernel object with an # index and KernelInfo kinfo = KernelInfo(kernel=op2kernel, integral_type=builder.integral_type, oriented=builder.oriented, subdomain_id="otherwise", domain_number=0, coefficient_map=tuple(range(len(expr_coeffs))), needs_cell_facets=builder.needs_cell_facets, pass_layer_arg=builder.needs_mesh_layers) return kinfo
def compile_expression(slate_expr, tsfc_parameters=None): """Takes a SLATE expression `slate_expr` and returns the appropriate :class:`firedrake.op2.Kernel` object representing the SLATE expression. :arg slate_expr: a :class:'TensorBase' expression. :arg tsfc_parameters: an optional `dict` of form compiler parameters to be passed onto TSFC during the compilation of ufl forms. """ if not isinstance(slate_expr, TensorBase): raise ValueError( "Expecting a `slate.TensorBase` expression, not a %r" % slate_expr) # TODO: Get PyOP2 to write into mixed dats if any(len(a.function_space()) > 1 for a in slate_expr.arguments()): raise NotImplementedError("Compiling mixed slate expressions") # Initialize shape and statements list shape = slate_expr.shape statements = [] # Create a builder for the SLATE expression builder = KernelBuilder(expression=slate_expr, tsfc_parameters=tsfc_parameters) # Initialize coordinate and facet symbols coordsym = ast.Symbol("coords") coords = None cellfacetsym = ast.Symbol("cell_facets") inc = [] # Now we construct the list of statements to provide to the builder context_temps = builder.temps.copy() for exp, t in context_temps.items(): statements.append(ast.Decl(eigen_matrixbase_type(exp.shape), t)) statements.append(ast.FlatBlock("%s.setZero();\n" % t)) for splitkernel in builder.kernel_exprs[exp]: clist = [] index = splitkernel.indices kinfo = splitkernel.kinfo integral_type = kinfo.integral_type if integral_type not in [ "cell", "interior_facet", "exterior_facet" ]: raise NotImplementedError( "Integral type %s not currently supported." % integral_type) coordinates = exp.ufl_domain().coordinates if coords is not None: assert coordinates == coords else: coords = coordinates for cindex in kinfo.coefficient_map: c = exp.coefficients()[cindex] # Handles both mixed and non-mixed coefficient cases clist.extend(builder.extract_coefficient(c)) inc.extend(kinfo.kernel._include_dirs) tensor = eigen_tensor(exp, t, index) if integral_type in ["interior_facet", "exterior_facet"]: builder.require_cell_facets() itsym = ast.Symbol("i0") clist.append(ast.FlatBlock("&%s" % itsym)) loop_body = [] nfacet = exp.ufl_domain().ufl_cell().num_facets() if integral_type == "exterior_facet": checker = 1 else: checker = 0 loop_body.append( ast.If( ast.Eq(ast.Symbol(cellfacetsym, rank=(itsym, )), checker), [ ast.Block([ ast.FunCall(kinfo.kernel.name, tensor, coordsym, *clist) ], open_scope=True) ])) loop = ast.For(ast.Decl("unsigned int", itsym, init=0), ast.Less(itsym, nfacet), ast.Incr(itsym, 1), loop_body) statements.append(loop) else: statements.append( ast.FunCall(kinfo.kernel.name, tensor, coordsym, *clist)) # Now we handle any terms that require auxiliary data (if any) if bool(builder.aux_exprs): aux_temps, aux_statements = auxiliary_information(builder) context_temps.update(aux_temps) statements.extend(aux_statements) result_sym = ast.Symbol("T%d" % len(builder.temps)) result_data_sym = ast.Symbol("A%d" % len(builder.temps)) result_type = "Eigen::Map<%s >" % eigen_matrixbase_type(shape) result = ast.Decl(SCALAR_TYPE, ast.Symbol(result_data_sym, shape)) result_statement = ast.FlatBlock( "%s %s((%s *)%s);\n" % (result_type, result_sym, SCALAR_TYPE, result_data_sym)) statements.append(result_statement) cpp_string = ast.FlatBlock( metaphrase_slate_to_cpp(slate_expr, context_temps)) statements.append(ast.Assign(result_sym, cpp_string)) # Generate arguments for the macro kernel args = [result, ast.Decl("%s **" % SCALAR_TYPE, coordsym)] for c in slate_expr.coefficients(): if isinstance(c, Constant): ctype = "%s *" % SCALAR_TYPE else: ctype = "%s **" % SCALAR_TYPE args.extend([ ast.Decl(ctype, sym_c) for sym_c in builder.extract_coefficient(c) ]) if builder.needs_cell_facets: args.append(ast.Decl("char *", cellfacetsym)) macro_kernel_name = "compile_slate" kernel_ast, oriented = builder.construct_ast( name=macro_kernel_name, args=args, statements=ast.Block(statements)) inc.extend(["%s/include/eigen3/" % d for d in PETSC_DIR]) op2kernel = op2.Kernel( kernel_ast, macro_kernel_name, cpp=True, include_dirs=inc, headers=['#include <Eigen/Dense>', '#define restrict __restrict']) assert len(slate_expr.ufl_domains()) == 1 kinfo = KernelInfo(kernel=op2kernel, integral_type="cell", oriented=oriented, subdomain_id="otherwise", domain_number=0, coefficient_map=range(len(slate_expr.coefficients())), needs_cell_facets=builder.needs_cell_facets) idx = tuple([0] * slate_expr.rank) return (SplitKernel(idx, kinfo), )
def compile_expression(slate_expr, tsfc_parameters=None): """Takes a Slate expression `slate_expr` and returns the appropriate :class:`firedrake.op2.Kernel` object representing the Slate expression. :arg slate_expr: a :class:'TensorBase' expression. :arg tsfc_parameters: an optional `dict` of form compiler parameters to be passed onto TSFC during the compilation of ufl forms. Returns: A `tuple` containing a `SplitKernel(idx, kinfo)` """ if not isinstance(slate_expr, TensorBase): raise ValueError("Expecting a `TensorBase` expression, not %s" % type(slate_expr)) # TODO: Get PyOP2 to write into mixed dats if any(len(a.function_space()) > 1 for a in slate_expr.arguments()): raise NotImplementedError("Compiling mixed slate expressions") # If the expression has already been symbolically compiled, then # simply reuse the produced kernel. if slate_expr._metakernel_cache is not None: return slate_expr._metakernel_cache # Initialize coefficients, shape and statements list expr_coeffs = slate_expr.coefficients() # We treat scalars as 1x1 MatrixBase objects, so we give # the right shape to do so and everything just falls out. # This bit here ensures the return result has the right # shape if slate_expr.rank == 0: shape = (1, ) else: shape = slate_expr.shape statements = [] # Create a builder for the Slate expression builder = KernelBuilder(expression=slate_expr, tsfc_parameters=tsfc_parameters) # Initialize coordinate, cell orientations and facet/layer # symbols coordsym = ast.Symbol("coords") coords = None cell_orientations = ast.Symbol("cell_orientations") cellfacetsym = ast.Symbol("cell_facets") mesh_layer_sym = ast.Symbol("layer") inc = [] # We keep track of temporaries that have been declared declared_temps = {} for cxt_kernel in builder.context_kernels: exp = cxt_kernel.tensor t = builder.temps[exp] if exp not in declared_temps: # Declare and initialize the temporary statements.append(ast.Decl(eigen_matrixbase_type(exp.shape), t)) statements.append(ast.FlatBlock("%s.setZero();\n" % t)) declared_temps[exp] = t it_type = cxt_kernel.original_integral_type if it_type not in supported_integral_types: raise NotImplementedError("Type %s not supported." % it_type) # Explicit checking of coordinates coordinates = exp.ufl_domain().coordinates if coords is not None: assert coordinates == coords else: coords = coordinates if it_type == "cell": # Nothing difficult about cellwise integrals. Just need # to get coefficient info, include_dirs and append # function calls to the appropriate subkernels. # If tensor is mixed, there will be more than one SplitKernel incl = [] for splitkernel in cxt_kernel.tsfc_kernels: index = splitkernel.indices kinfo = splitkernel.kinfo # Generate an iterable of coefficients to pass to the subkernel # if any are required clist = [ c for ci in kinfo.coefficient_map for c in builder.coefficient(exp.coefficients()[ci]) ] if kinfo.oriented: clist.insert(0, cell_orientations) incl.extend(kinfo.kernel._include_dirs) tensor = eigen_tensor(exp, t, index) statements.append( ast.FunCall(kinfo.kernel.name, tensor, coordsym, *clist)) elif it_type in [ "interior_facet", "exterior_facet", "interior_facet_vert", "exterior_facet_vert" ]: # These integral types will require accessing local facet # information and looping over facet indices. builder.require_cell_facets() loop_stmt, incl = facet_integral_loop(cxt_kernel, builder, coordsym, cellfacetsym, cell_orientations) statements.append(loop_stmt) elif it_type == "interior_facet_horiz": # The infamous interior horizontal facet # will have two SplitKernels: one top, # one bottom. The mesh layer will determine # which kernels we call. builder.require_mesh_layers() top_sks = [ k for k in cxt_kernel.tsfc_kernels if k.kinfo.integral_type == "exterior_facet_top" ] bottom_sks = [ k for k in cxt_kernel.tsfc_kernels if k.kinfo.integral_type == "exterior_facet_bottom" ] assert len(top_sks) == len(bottom_sks), ( "Number of top and bottom kernels should be equal") # Top and bottom kernels need to be sorted by kinfo.indices # if the space is mixed to ensure indices match. top_sks = sorted(top_sks, key=lambda x: x.indices) bottom_sks = sorted(bottom_sks, key=lambda x: x.indices) stmt, incl = extruded_int_horiz_facet(exp, builder, top_sks, bottom_sks, coordsym, mesh_layer_sym, cell_orientations) statements.append(stmt) elif it_type in ["exterior_facet_bottom", "exterior_facet_top"]: # These kernels will only be called if we are on # the top or bottom layers of the extruded mesh. builder.require_mesh_layers() stmt, incl = extruded_top_bottom_facet(cxt_kernel, builder, coordsym, mesh_layer_sym, cell_orientations) statements.append(stmt) else: raise ValueError("Kernel type not recognized: %s" % it_type) # Don't duplicate include lines inc_dir = list(set(incl) - set(inc)) inc.extend(inc_dir) # Now we handle any terms that require auxiliary temporaries, # such as inverses, transposes and actions of a tensor on a # coefficient if builder.aux_exprs: # The declared temps will be updated within this method aux_statements = auxiliary_temporaries(builder, declared_temps) statements.extend(aux_statements) # Now we create the result statement by declaring its eigen type and # using Eigen::Map to move between Eigen and C data structs. result_sym = ast.Symbol("T%d" % len(builder.temps)) result_data_sym = ast.Symbol("A%d" % len(builder.temps)) result_type = "Eigen::Map<%s >" % eigen_matrixbase_type(shape) result = ast.Decl(SCALAR_TYPE, ast.Symbol(result_data_sym, shape)) result_statement = ast.FlatBlock( "%s %s((%s *)%s);\n" % (result_type, result_sym, SCALAR_TYPE, result_data_sym)) statements.append(result_statement) # Generate the complete c++ string performing the linear algebra operations # on Eigen matrices/vectors cpp_string = ast.FlatBlock( metaphrase_slate_to_cpp(slate_expr, declared_temps)) statements.append(ast.Incr(result_sym, cpp_string)) # Finalize AST for macro kernel construction builder._finalize_kernels_and_update() # Generate arguments for the macro kernel args = [result, ast.Decl("%s **" % SCALAR_TYPE, coordsym)] # Orientation information if builder.oriented: args.append(ast.Decl("int **", cell_orientations)) # Coefficient information for c in expr_coeffs: if isinstance(c, Constant): ctype = "%s *" % SCALAR_TYPE else: ctype = "%s **" % SCALAR_TYPE args.extend([ast.Decl(ctype, csym) for csym in builder.coefficient(c)]) # Facet information if builder.needs_cell_facets: args.append( ast.Decl("%s *" % as_cstr(cell_to_facets_dtype), cellfacetsym)) # NOTE: We need to be careful about the ordering here. Mesh layers are # added as the final argument to the kernel. if builder.needs_mesh_layers: args.append(ast.Decl("int", mesh_layer_sym)) # NOTE: In the future we may want to have more than one "macro_kernel" macro_kernel_name = "compile_slate" stmt = ast.Block(statements) macro_kernel = builder.construct_macro_kernel(name=macro_kernel_name, args=args, statements=stmt) # Tell the builder to construct the final ast kernel_ast = builder.construct_ast([macro_kernel]) # Now we wrap up the kernel ast as a PyOP2 kernel. # Include the Eigen header files inc.extend(["%s/include/eigen3/" % d for d in PETSC_DIR]) op2kernel = op2.Kernel( kernel_ast, macro_kernel_name, cpp=True, include_dirs=inc, headers=['#include <Eigen/Dense>', '#define restrict __restrict']) assert len(slate_expr.ufl_domains()) == 1, ( "No support for multiple domains yet!") # Send back a "TSFC-like" SplitKernel object with an # index and KernelInfo kinfo = KernelInfo(kernel=op2kernel, integral_type=builder.integral_type, oriented=builder.oriented, subdomain_id="otherwise", domain_number=0, coefficient_map=tuple(range(len(expr_coeffs))), needs_cell_facets=builder.needs_cell_facets, pass_layer_arg=builder.needs_mesh_layers) idx = tuple([0] * slate_expr.rank) kernels = (SplitKernel(idx, kinfo), ) # Store the resulting kernel for reuse slate_expr._metakernel_cache = kernels return kernels