def _select_and_scatter_lower(ctx, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = mhlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), operand, source, init_value, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) select = op.select.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(select): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr, select_consts, *([a] for a in select.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr, scatter_consts, *([a] for a in scatter.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), window_strides=mlir.dense_int_elements(window_strides), base_dilations=mlir.dense_int_elements(base_dilation), window_dilations=mlir.dense_int_elements(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): if jaxpr.effects: raise NotImplementedError( 'Cannot lower effectful `reduce_window`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, mlir.TokenSet(), consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def _cond_lowering(ctx, index, *args, branches, linear): del linear # Unused. joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = [ eff for eff in joined_effects if eff in core.ordered_effects ] num_tokens = len(ordered_effects) tokens_in = ctx.tokens_in.subset(ordered_effects) output_token_types = [mlir.token_type() for _ in ordered_effects] output_types = [ *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out) ] flat_output_types = util.flatten(output_types) # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks # have no arguments; the computation within the block uses implicit # captures. case_op = mhlo.CaseOp(flat_output_types, index=index, num_branches=len(branches)) name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond') for i, jaxpr in enumerate(branches): branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): sub_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun')) out_vals, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr.jaxpr, tokens_in, _map(mlir.ir_constants, jaxpr.consts), *_map(mlir.wrap_singleton_ir_values, args)) out_tokens = [tokens_out.get(eff) for eff in ordered_effects] out_vals = [*out_tokens, *out_vals] mhlo.ReturnOp(util.flatten(out_vals)) tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types)) tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens]) ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens))) return outputs