def _psum_translation_rule(c, *args, replica_groups=None, platform=None): if platform in ("cpu", "tpu"): return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, onp.complexfloating) n = len(dtype_args) if is_complex: dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex: xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in, avals_out, *args, **params): xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, prim, *avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: ctx.module.body.append(op) if op.name.value == "main": callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] return util.unflatten(flat_results, map(len, output_types))
def all_reduce(x): replica_groups_protos = xc.make_replica_groups( _replica_groups(axis_env, axis_name, axis_index_groups)) scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) return xops.AllReduce(x, computation, replica_groups_protos, None, None)
def _nonzero_translation_rule(c, dims, avals, operands): (vals,), = operands shape = c.get_shape(vals) last_axis = len(shape.dimensions()) - 1 zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())), shape.dimensions()) s32_etype = xc.dtype_to_etype(np.dtype('int32')) nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype) i = core.ShapedArray((), np.dtype('int32')) out_dim = xops.Reduce(c, [nonzero_indicators], [xb.constant(c, np.array(0, np.dtype('int32')))], xla.primitive_subcomputation(lax.add_p, i, i), (last_axis,)) c.get_shape(out_dim) # xla type checking subc = xb.make_computation_builder("sort_gt_comparator") params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ())) for i in range(4)] comparator = subc.build(xops.Gt(params[0], params[1])) iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions()) ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)], is_stable=True, comparator=comparator) _, out_val = xla.xla_destructure(c, ans) c.get_shape(out_val) # xla type checking return [[out_dim], [out_val]]
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups, axis_env, platform): replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) dtype = c.get_shape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None)
def _allreduce_translation_rule(prim, c, val, replica_groups, backend=None): dtype = c.GetShape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar, backend=backend) return c.AllReduce(val, computation, replica_groups=replica_groups)
def _select_and_scatter_add_translation(ctx, avals_in, avals_out, source, operand, *, select_prim, window_dimensions, window_strides, padding, expand_padding): source_aval, operand_aval = avals_in c = ctx.builder dtype = operand_aval.dtype scalar = ShapedArray((), dtype) select = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, select_prim, scalar, scalar) scatter = xla.primitive_subcomputation( ctx.platform, ctx.axis_env, lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar) zero = xla.pyval_to_ir_constant(c, np.array(0, dtype)) # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed. expand_padding = (expand_padding and not all(lo == 0 and hi == 0 for (lo, hi) in padding)) if expand_padding: original_padding = padding identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] operand = xops.Pad(operand, xla.pyval_to_ir_constant(c, identity(dtype)), xc.make_padding_config(pads)) padding = [(0, 0) for _ in padding] output = xops.SelectAndScatterWithGeneralPadding(operand, select, window_dimensions, window_strides, padding, source, zero, scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [ lo + d for ((lo, hi), d) in zip(original_padding, operand_aval.shape) ] output = xops.Slice(output, start_indices, stop_indices, [1] * len(start_indices)) return [output]
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): operand_aval, = avals_in scalar = ShapedArray((), operand_aval.dtype) return [ xops.ReduceWindowWithGeneralPadding( operand, xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype)), xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.add_p, scalar, scalar), window_dimensions, window_strides, base_dilation, window_dilation, padding) ]
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups, axis_env, platform): if platform in ("cpu", "tpu"): return _notuple_allreduce_translation_rule( prim, c, *args, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex and prim is lax.add_p: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex and prim is lax.add_p: xs = [ xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n) ] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def _reduce_sum_translation_rule(c, dims, avals, operands, *, axes): (x,), = operands shape = c.get_shape(x) dtype = shape.numpy_dtype() iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions()) if dims: aval, = avals masks = [xops.Lt(xops.Iota(c, iota_shape, i), dims[v][0]) for i, v in enumerate(aval.shape) if isinstance(v, Var) and i in axes] map(c.get_shape, masks) x = xops.Select(reduce(xops.And, masks), x, xops.Broadcast(xb.constant(c, np.zeros((), dtype)), shape.dimensions())) scalar = core.ShapedArray((), dtype) out = xops.Reduce(c, [x], [xb.constant(c, np.array(0, dtype))], xla.primitive_subcomputation(lax.add_p, scalar, scalar), axes) return [[out]]
def fallback(ctx: LoweringRuleContext, *args, **params): module_ctx = ctx.module_context xla_computation = xla.primitive_subcomputation(module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: op = typing.cast(FuncOpType, op) module_ctx.module.body.append(op) if op.name.value == "main": op.attributes["sym_name"] = ir.StringAttr.get( f"xla_fallback_{prim.name}") callee_name = ir.StringAttr( module_ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: module_ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = func_dialect.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] if jax._src.lib.mlir_api_version < 6: flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] else: flat_results = [ mhlo.GetTupleElementOp(call, i32_attr(i)).result for i in range(len(flat_output_types)) ] return util.unflatten(flat_results, map(len, output_types))
def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None): dtype = c.GetShape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None)