def _select_and_scatter_lower(ctx, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = mhlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), operand, source, init_value, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) select = op.select.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(select): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr, select_consts, *([a] for a in select.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr, scatter_consts, *([a] for a in scatter.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results
def _fft_lowering(ctx, x, *, fft_type, fft_lengths): out_aval, = ctx.avals_out return [ mhlo.FftOp(mlir.aval_to_ir_type(out_aval), x, mhlo.FftTypeAttr.get(fft_type.name), mlir.dense_int_elements(fft_lengths)).result ]
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), window_strides=mlir.dense_int_elements(window_strides), base_dilations=mlir.dense_int_elements(base_dilation), window_dilations=mlir.dense_int_elements(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): if jaxpr.effects: raise NotImplementedError( 'Cannot lower effectful `reduce_window`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, mlir.TokenSet(), consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value] ) -> Sequence[ir.Value]: captured_ops = tuple(mlir.ir_constant(np.asarray(inp), canonicalize_types=False) for inp in captured_inputs) submodule = mlir.xla_computation_to_mhlo_module(xla_comp) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}", submodule) call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) if result_shape.is_tuple(): flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result for i in range(len(result_shapes))] else: flat_results = call.results outputs = [] for op, res_aval, res_shape in zip(flat_results, result_avals, result_shapes): if res_aval.dtype != res_shape.numpy_dtype(): op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result outputs.append(op) return outputs
def _conv_general_dilated_lower( ctx, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision, preferred_element_type, expand_complex_convolutions=False, **unused_kwargs): lhs_aval, rhs_aval = ctx.avals_in aval_out, = ctx.avals_out assert isinstance(dimension_numbers, ConvDimensionNumbers) dtype = lhs_aval.dtype if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating): if preferred_element_type is not None: # Convert complex dtype to types used for real and imaginary parts assert np.issubdtype(preferred_element_type, np.complexfloating) preferred_element_type = _real_dtype(preferred_element_type) complex_conv = mlir.lower_fun( partial( _complex_mul, partial(conv_general_dilated, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, preferred_element_type=preferred_element_type)), multiple_results=False) return complex_conv(ctx, lhs, rhs) lhs_spec, rhs_spec, out_spec = dimension_numbers dnums = mhlo.ConvDimensionNumbers.get( input_batch_dimension=lhs_spec[0], input_feature_dimension=lhs_spec[1], input_spatial_dimensions=list(lhs_spec[2:]), kernel_output_feature_dimension=rhs_spec[0], kernel_input_feature_dimension=rhs_spec[1], kernel_spatial_dimensions=list(rhs_spec[2:]), output_batch_dimension=out_spec[0], output_feature_dimension=out_spec[1], output_spatial_dimensions=list(out_spec[2:])) num_spatial_dims = len(rhs_spec) - 2 window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims) return [ mhlo.ConvOp( mlir.aval_to_ir_type(aval_out), lhs, rhs, dimension_numbers=dnums, feature_group_count=mlir.i64_attr(feature_group_count), batch_group_count=mlir.i64_attr(batch_group_count), window_strides=mlir.dense_int_elements(window_strides), padding=mlir.dense_int_elements(padding), lhs_dilation=mlir.dense_int_elements(lhs_dilation), rhs_dilation=mlir.dense_int_elements(rhs_dilation), window_reversal=window_reversal, precision_config=lax.precision_attr(precision)).result ]
def _reduce_window_lower(reduce_op, init_value, ctx, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): aval_out, = ctx.avals_out operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) rw = mhlo.ReduceWindowOp( mlir.aval_to_ir_types(aval_out), [operand], [mlir.full_like_aval(init_value(scalar_aval.dtype), scalar_aval)], mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): mhlo.ReturnOp(reduce_op(*reducer.arguments)) return rw.results
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def _broadcast(x, aval): return mhlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result