Ejemplo n.º 1
0
def _select_and_scatter_lower(ctx, operand, source, init_value, *,
                              select_jaxpr, select_consts, scatter_jaxpr,
                              scatter_consts, window_dimensions,
                              window_strides, padding):
    operand_aval, source_aval, init_value_aval = ctx.avals_in
    aval_out, = ctx.avals_out
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    op = mhlo.SelectAndScatterOp(
        mlir.aval_to_ir_type(aval_out), operand, source, init_value,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    select = op.select.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(select):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
                                       select_consts,
                                       *([a] for a in select.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    scatter = op.scatter.blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(scatter):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
                                       scatter_consts,
                                       *([a] for a in scatter.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return op.results
Ejemplo n.º 2
0
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
    out_aval, = ctx.avals_out
    return [
        mhlo.FftOp(mlir.aval_to_ir_type(out_aval), x,
                   mhlo.FftTypeAttr.get(fft_type.name),
                   mlir.dense_int_elements(fft_lengths)).result
    ]
Ejemplo n.º 3
0
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, ctx.avals_out),
        operands,
        init_values,
        mlir.dense_int_elements(window_dimensions),
        window_strides=mlir.dense_int_elements(window_strides),
        base_dilations=mlir.dense_int_elements(base_dilation),
        window_dilations=mlir.dense_int_elements(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=(len(padding), 2)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        if jaxpr.effects:
            raise NotImplementedError(
                'Cannot lower effectful `reduce_window`.')
        out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
                                          mlir.TokenSet(), consts,
                                          *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Ejemplo n.º 4
0
  def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value]
              ) -> Sequence[ir.Value]:
    captured_ops = tuple(mlir.ir_constant(np.asarray(inp),
                                          canonicalize_types=False)
                         for inp in captured_inputs)
    submodule = mlir.xla_computation_to_mhlo_module(xla_comp)
    symtab = ir.SymbolTable(submodule.operation)
    callee_result_types = symtab["main"].type.results
    fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
                                 submodule)
    call = func_dialect.CallOp(callee_result_types,
                               ir.FlatSymbolRefAttr.get(fn),
                               tuple(args_op) + captured_ops)
    if result_shape.is_tuple():
      flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
                      for i in range(len(result_shapes))]
    else:
      flat_results = call.results

    outputs = []
    for op, res_aval, res_shape in zip(flat_results, result_avals,
                                       result_shapes):
      if res_aval.dtype != res_shape.numpy_dtype():
        op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
      outputs.append(op)
    return outputs
Ejemplo n.º 5
0
def _conv_general_dilated_lower(
    ctx, lhs, rhs, *, window_strides, padding,
    lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
    batch_group_count, precision, preferred_element_type,
    expand_complex_convolutions=False, **unused_kwargs):
  lhs_aval, rhs_aval = ctx.avals_in
  aval_out, = ctx.avals_out
  assert isinstance(dimension_numbers, ConvDimensionNumbers)
  dtype = lhs_aval.dtype
  if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
    if preferred_element_type is not None:
      # Convert complex dtype to types used for real and imaginary parts
      assert np.issubdtype(preferred_element_type, np.complexfloating)
      preferred_element_type = _real_dtype(preferred_element_type)
    complex_conv = mlir.lower_fun(
      partial(
        _complex_mul,
        partial(conv_general_dilated, window_strides=window_strides,
                padding=padding, lhs_dilation=lhs_dilation,
                rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
                feature_group_count=feature_group_count,
                batch_group_count=batch_group_count, precision=precision,
                preferred_element_type=preferred_element_type)),
      multiple_results=False)
    return complex_conv(ctx, lhs, rhs)

  lhs_spec, rhs_spec, out_spec = dimension_numbers
  dnums = mhlo.ConvDimensionNumbers.get(
    input_batch_dimension=lhs_spec[0],
    input_feature_dimension=lhs_spec[1],
    input_spatial_dimensions=list(lhs_spec[2:]),
    kernel_output_feature_dimension=rhs_spec[0],
    kernel_input_feature_dimension=rhs_spec[1],
    kernel_spatial_dimensions=list(rhs_spec[2:]),
    output_batch_dimension=out_spec[0],
    output_feature_dimension=out_spec[1],
    output_spatial_dimensions=list(out_spec[2:]))
  num_spatial_dims = len(rhs_spec) - 2
  window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
  return [
      mhlo.ConvOp(
          mlir.aval_to_ir_type(aval_out),
          lhs,
          rhs,
          dimension_numbers=dnums,
          feature_group_count=mlir.i64_attr(feature_group_count),
          batch_group_count=mlir.i64_attr(batch_group_count),
          window_strides=mlir.dense_int_elements(window_strides),
          padding=mlir.dense_int_elements(padding),
          lhs_dilation=mlir.dense_int_elements(lhs_dilation),
          rhs_dilation=mlir.dense_int_elements(rhs_dilation),
          window_reversal=window_reversal,
          precision_config=lax.precision_attr(precision)).result
  ]
Ejemplo n.º 6
0
def _reduce_window_lower(reduce_op, init_value, ctx, operand, *,
                         window_dimensions, window_strides, padding,
                         base_dilation, window_dilation):
    aval_out, = ctx.avals_out
    operand_aval, = ctx.avals_in
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    rw = mhlo.ReduceWindowOp(
        mlir.aval_to_ir_types(aval_out), [operand],
        [mlir.full_like_aval(init_value(scalar_aval.dtype), scalar_aval)],
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(reducer):
        mhlo.ReturnOp(reduce_op(*reducer.arguments))
    return rw.results
Ejemplo n.º 7
0
def _generic_reduce_window_lower(ctx, avals_in, avals_out, *args, jaxpr,
                                 consts, window_dimensions, window_strides,
                                 padding, base_dilation, window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, avals_out), operands, init_values,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        out_nodes = mlir.jaxpr_subcomp(ctx, jaxpr, consts,
                                       *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Ejemplo n.º 8
0
 def _broadcast(x, aval):
     return mhlo.BroadcastInDimOp(
         mlir.aval_to_ir_type(aval_out), x,
         mlir.dense_int_elements(range(rank - len(aval.shape),
                                       rank))).result