示例#1
0
def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
  if platform in ("cpu", "tpu"):
    return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups)

  # XLA's tuple all-reduce doesn't support different dtypes in the same
  # allreduce. Instead, we perform once all-reduce for each argument input type.
  args_by_type = collections.defaultdict(lambda: ([], []))
  for i, arg in enumerate(args):
    indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
    indices.append(i)
    dtype_args.append(arg)

  # The outputs, in the original argument order.
  out = [None] * len(args)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
    is_complex = dtypes.issubdtype(dtype, onp.complexfloating)
    n = len(dtype_args)
    if is_complex:
      dtype_args = ([xops.Real(x) for x in dtype_args] +
                    [xops.Imag(x) for x in dtype_args])
    scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
    computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar)
    all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                replica_groups_protos, None, None)
    if is_complex:
      xs = [xops.Complex(xops.GetTupleElement(all_reduce, i),
                         xops.GetTupleElement(all_reduce, n + i)) for i in range(n)]
    else:
      xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
    for i, x in zip(indices, xs):
      out[i] = x
  return xops.Tuple(c, out)
示例#2
0
文件: mlir.py 项目: rsepassi/jax
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in,
                          avals_out, *args, **params):
    xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                                   prim, *avals_in, **params)
    submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
        xla_computation)
    submodule = ir.Module.parse(submodule_str)
    callee_name = None
    for op in submodule.body.operations:
        ctx.module.body.append(op)
        if op.name.value == "main":
            callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
            op.attributes["sym_visibility"] = ir.StringAttr.get("private")
        else:
            ctx.symbol_table.insert(op)

    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    output_type = (ir.TupleType.get_tuple(flat_output_types)
                   if prim.multiple_results else flat_output_types[0])

    call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
                      flatten_lowering_ir_args(args)).result
    if not prim.multiple_results:
        return [call]
    flat_results = [
        mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
        for i, typ in enumerate(flat_output_types)
    ]
    return util.unflatten(flat_results, map(len, output_types))
示例#3
0
 def all_reduce(x):
     replica_groups_protos = xc.make_replica_groups(
         _replica_groups(axis_env, axis_name, axis_index_groups))
     scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
     computation = xla.primitive_subcomputation(prim, scalar, scalar)
     return xops.AllReduce(x, computation, replica_groups_protos, None,
                           None)
示例#4
0
def _nonzero_translation_rule(c, dims, avals, operands):
  (vals,), = operands
  shape = c.get_shape(vals)
  last_axis = len(shape.dimensions()) - 1
  zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())),
                         shape.dimensions())
  s32_etype = xc.dtype_to_etype(np.dtype('int32'))
  nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype)
  i = core.ShapedArray((), np.dtype('int32'))
  out_dim = xops.Reduce(c, [nonzero_indicators],
                        [xb.constant(c, np.array(0, np.dtype('int32')))],
                        xla.primitive_subcomputation(lax.add_p, i, i),
                        (last_axis,))
  c.get_shape(out_dim)  # xla type checking

  subc = xb.make_computation_builder("sort_gt_comparator")
  params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ()))
            for i in range(4)]
  comparator = subc.build(xops.Gt(params[0], params[1]))
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)],
                  is_stable=True, comparator=comparator)
  _, out_val = xla.xla_destructure(c, ans)
  c.get_shape(out_val)  # xla type checking
  return [[out_dim], [out_val]]
示例#5
0
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups,
                                axis_env, platform):
  replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
  dtype = c.get_shape(val).numpy_dtype()
  scalar = ShapedArray((), dtype)
  computation = xla.primitive_subcomputation(prim, scalar, scalar)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  return xops.AllReduce(val, computation, replica_groups_protos, None, None)
示例#6
0
文件: lax_parallel.py 项目: wig-l/jax
def _allreduce_translation_rule(prim, c, val, replica_groups, backend=None):
    dtype = c.GetShape(val).numpy_dtype()
    scalar = ShapedArray((), dtype)
    computation = xla.primitive_subcomputation(prim,
                                               scalar,
                                               scalar,
                                               backend=backend)
    return c.AllReduce(val, computation, replica_groups=replica_groups)
示例#7
0
def _select_and_scatter_add_translation(ctx, avals_in, avals_out, source,
                                        operand, *, select_prim,
                                        window_dimensions, window_strides,
                                        padding, expand_padding):
    source_aval, operand_aval = avals_in
    c = ctx.builder
    dtype = operand_aval.dtype
    scalar = ShapedArray((), dtype)
    select = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                          select_prim, scalar, scalar)
    scatter = xla.primitive_subcomputation(
        ctx.platform, ctx.axis_env,
        lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar)
    zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
    # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
    expand_padding = (expand_padding
                      and not all(lo == 0 and hi == 0 for (lo, hi) in padding))
    if expand_padding:
        original_padding = padding
        identity = (lax._get_max_identity
                    if select_prim is lax.ge_p else lax._get_min_identity)
        pads = [(lo, hi, 0) for (lo, hi) in padding]
        operand = xops.Pad(operand,
                           xla.pyval_to_ir_constant(c, identity(dtype)),
                           xc.make_padding_config(pads))
        padding = [(0, 0) for _ in padding]
    output = xops.SelectAndScatterWithGeneralPadding(operand, select,
                                                     window_dimensions,
                                                     window_strides, padding,
                                                     source, zero, scatter)
    if expand_padding:
        start_indices = [lo for (lo, hi) in original_padding]
        stop_indices = [
            lo + d
            for ((lo, hi), d) in zip(original_padding, operand_aval.shape)
        ]
        output = xops.Slice(output, start_indices, stop_indices,
                            [1] * len(start_indices))
    return [output]
示例#8
0
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
                                        window_dimensions, window_strides,
                                        padding, base_dilation,
                                        window_dilation):
    operand_aval, = avals_in
    scalar = ShapedArray((), operand_aval.dtype)
    return [
        xops.ReduceWindowWithGeneralPadding(
            operand,
            xla.pyval_to_ir_constant(ctx.builder,
                                     np.array(0, operand_aval.dtype)),
            xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.add_p,
                                         scalar, scalar), window_dimensions,
            window_strides, base_dilation, window_dilation, padding)
    ]
示例#9
0
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
                                axis_env, platform):
    if platform in ("cpu", "tpu"):
        return _notuple_allreduce_translation_rule(
            prim,
            c,
            *args,
            axis_name=axis_name,
            axis_index_groups=axis_index_groups,
            axis_env=axis_env,
            platform=platform)

    # XLA's tuple all-reduce doesn't support different dtypes in the same
    # allreduce. Instead, we perform once all-reduce for each argument input type.
    args_by_type = collections.defaultdict(lambda: ([], []))
    for i, arg in enumerate(args):
        indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
        indices.append(i)
        dtype_args.append(arg)

    # The outputs, in the original argument order.
    out = [None] * len(args)
    replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        n = len(dtype_args)
        if is_complex and prim is lax.add_p:
            # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
            # special case because it's not currently handled by XLA:GPU
            dtype_args = ([xops.Real(x) for x in dtype_args] +
                          [xops.Imag(x) for x in dtype_args])
        scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
        computation = xla.primitive_subcomputation(prim, scalar, scalar)
        all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                    replica_groups_protos, None, None)
        if is_complex and prim is lax.add_p:
            xs = [
                xops.Complex(xops.GetTupleElement(all_reduce, i),
                             xops.GetTupleElement(all_reduce, n + i))
                for i in range(n)
            ]
        else:
            xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
        for i, x in zip(indices, xs):
            out[i] = x
    return xops.Tuple(c, out)
示例#10
0
def _reduce_sum_translation_rule(c, dims, avals, operands, *, axes):
  (x,), = operands
  shape = c.get_shape(x)
  dtype = shape.numpy_dtype()
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  if dims:
    aval, = avals
    masks = [xops.Lt(xops.Iota(c, iota_shape, i), dims[v][0])
             for i, v in enumerate(aval.shape) if isinstance(v, Var)
             and i in axes]
    map(c.get_shape, masks)
    x = xops.Select(reduce(xops.And, masks), x,
                    xops.Broadcast(xb.constant(c, np.zeros((), dtype)),
                                   shape.dimensions()))
  scalar = core.ShapedArray((), dtype)
  out = xops.Reduce(c, [x], [xb.constant(c, np.array(0, dtype))],
                    xla.primitive_subcomputation(lax.add_p, scalar, scalar), axes)
  return [[out]]
示例#11
0
    def fallback(ctx: LoweringRuleContext, *args, **params):
        module_ctx = ctx.module_context
        xla_computation = xla.primitive_subcomputation(module_ctx.platform,
                                                       module_ctx.axis_env,
                                                       prim, *ctx.avals_in,
                                                       **params)
        submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
            xla_computation)
        submodule = ir.Module.parse(submodule_str)
        callee_name = None
        for op in submodule.body.operations:
            op = typing.cast(FuncOpType, op)
            module_ctx.module.body.append(op)
            if op.name.value == "main":
                op.attributes["sym_name"] = ir.StringAttr.get(
                    f"xla_fallback_{prim.name}")
                callee_name = ir.StringAttr(
                    module_ctx.symbol_table.insert(op)).value
                op.attributes["sym_visibility"] = ir.StringAttr.get("private")
            else:
                module_ctx.symbol_table.insert(op)

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        output_type = (ir.TupleType.get_tuple(flat_output_types)
                       if prim.multiple_results else flat_output_types[0])

        call = func_dialect.CallOp([output_type],
                                   ir.FlatSymbolRefAttr.get(callee_name),
                                   flatten_lowering_ir_args(args)).result
        if not prim.multiple_results:
            return [call]
        if jax._src.lib.mlir_api_version < 6:
            flat_results = [
                mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
                for i, typ in enumerate(flat_output_types)
            ]
        else:
            flat_results = [
                mhlo.GetTupleElementOp(call, i32_attr(i)).result
                for i in range(len(flat_output_types))
            ]

        return util.unflatten(flat_results, map(len, output_types))
示例#12
0
def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None):
    dtype = c.GetShape(val).numpy_dtype()
    scalar = ShapedArray((), dtype)
    computation = xla.primitive_subcomputation(prim, scalar, scalar)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    return xops.AllReduce(val, computation, replica_groups_protos, None, None)