Exemplo n.º 1
0
def compare_mhlo(x, y, direction, type):
    """Creates mhlo.CompareOp."""
    if jax._src.lib.mlir_api_version >= 5:
        return mhlo.CompareOp(x, y,
                              mhlo.ComparisonDirectionAttr.get(direction),
                              mhlo.ComparisonTypeAttr.get(type))
    dims = ir.RankedTensorType(x.type).shape
    bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
    if jax._src.lib.mlir_api_version >= 3:
        return mhlo.CompareOp(bool_shape, x, y,
                              mhlo.ComparisonDirectionAttr.get(direction),
                              mhlo.ComparisonTypeAttr.get(type))
    return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
                          ir.StringAttr.get(type))
Exemplo n.º 2
0
Arquivo: mlir.py Projeto: rsepassi/jax
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr):
    input_types = map(aval_to_ir_types, avals_in)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    int32_scalar_type = aval_to_ir_type(
        core.ShapedArray((), np.dtype(np.int32)))
    loop_carry_types = [(int32_scalar_type, )] + input_types + output_types
    flat_loop_carry_types = util.flatten(loop_carry_types)
    counter_init = ir_constants(np.array(0, np.int32))
    flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple(
        _dummy_like_aval(aval) for aval in avals_out))
    loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
    init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)

    one = ir_constant(np.array(1, np.int32))
    while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

    # Loop condition
    cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(cond_block):
        bool_scalar_type = aval_to_ir_type(
            core.ShapedArray((), np.dtype(np.bool_)))
        two = ir_constant(np.array(2, np.int32))
        shape = ir_constant(np.array((), np.int64), canonicalize_types=False)
        rng = mhlo.RngUniformOp(one, two, shape).result
        i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0],
                                   i32_attr(0))
        cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"),
                             ir.StringAttr.get("SIGNED")).result
        mhlo.ReturnOp([cmp])

    body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(body_block):
        flat_body_args = [
            mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
                                   i32_attr(i)).result
            for i, input_type in enumerate(flat_loop_carry_types)
        ]
        body_args = util.unflatten(flat_body_args, map(len, loop_carry_types))
        ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)])
        body_ctx = ctx.replace(name_stack=xla.extend_name_stack(
            ctx.name_stack, xla.wrap_name(name, 'remat')))
        z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y)
        i_next = mhlo.AddOp(i, one).result
        new_carry = mhlo.TupleOp(loop_carry_tuple_type,
                                 [i_next, *util.flatten(y), *util.flatten(z)])
        mhlo.ReturnOp([new_carry.result])

    outputs = [
        mhlo.GetTupleElementOp(output_type, while_op.result,
                               i32_attr(1 + len(avals_in) + i)).result
        for i, output_type in enumerate(flat_output_types)
    ]
    return util.unflatten(outputs, map(len, output_types))
Exemplo n.º 3
0
Arquivo: mlir.py Projeto: GJBoth/jax
def _minmax_mhlo(op, cmp, x, y):
  """Min/max that compares complex values lexicographically as pairs."""
  tensor_type = ir.RankedTensorType(x.type)
  if ir.ComplexType.isinstance(tensor_type.element_type):
    rx = mhlo.RealOp(x).result
    ry = mhlo.RealOp(y).result
    dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
    bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
    real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
                             ir.StringAttr.get("FLOAT"))
    real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
                              ir.StringAttr.get(cmp),
                              ir.StringAttr.get("FLOAT"))
    imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
                              mhlo.ImagOp(y).result,
                              ir.StringAttr.get(cmp),
                              ir.StringAttr.get("FLOAT"))
    which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
    return mhlo.SelectOp(which, x, y)
  else:
    return op(x, y)
Exemplo n.º 4
0
Arquivo: mlir.py Projeto: GJBoth/jax
def convert_mhlo(x, aval_in, aval_out):
  """Variant of convert that has XLA HLO semantics.

  In particular, treat casts to boolean as x != 0, rather than truncating
  integer values (b/209440332)."""
  if aval_out.dtype == np.dtype(np.bool_):
    if dtypes.issubdtype(aval_in.dtype, np.inexact):
      compare_type = "FLOAT"
    elif dtypes.issubdtype(aval_in.dtype, np.signedinteger):
      compare_type = "SIGNED"
    else:
      compare_type = "UNSIGNED"
    return mhlo.CompareOp(
        aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
        ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result
  return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result