def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name, axis_env, platform): # Workaround for AllToAll not being implemented on CPU. replica_groups = _replica_groups(axis_env, axis_name, None) if len(replica_groups[0]) == 1: return x elif platform != 'tpu': warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other " "backends emulate it using a very slow and memory intensive algorithm, so expect " "significant slowdowns.") lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True) return lowering(c, x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_env=axis_env, platform=platform) else: split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') replica_groups_protos = xc.make_replica_groups(replica_groups) if concat_axis == split_axis: return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) else: if concat_axis < split_axis: split_axis += 1 elif split_axis < concat_axis: concat_axis += 1 x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x) x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x) return x
def _psum_translation_rule(c, *args, replica_groups=None, platform=None): if platform in ("cpu", "tpu"): return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, onp.complexfloating) n = len(dtype_args) if is_complex: dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex: xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def all_reduce(x): replica_groups_protos = xc.make_replica_groups( _replica_groups(axis_env, axis_name, axis_index_groups)) scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) return xops.AllReduce(x, computation, replica_groups_protos, None, None)
def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, axis_env, platform): # TODO(cjfj): Enable this for TPU also? if (platform == 'gpu') and (all_gather_dimension == 0): new_shape = list(c.get_shape(x).dimensions()) new_shape.insert(all_gather_dimension, 1) broadcast_dimensions = [ i for i in range(len(new_shape)) if i != all_gather_dimension ] x = xops.BroadcastInDim(x, new_shape, broadcast_dimensions) replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) return xops.AllGather( x, all_gather_dimension=all_gather_dimension, shard_count=axis_size, replica_groups=xc.make_replica_groups(replica_groups)) else: lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True) return lowering(c, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_size=axis_size, axis_env=axis_env, platform=platform)
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups, axis_env, platform): replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) dtype = c.get_shape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None)
def _all_to_all_translation_rule(c, x, split_axis, concat_axis, replica_groups, platform=None): # Workaround for AllToAll not being implemented on CPU. if len(replica_groups[0]) == 1: return x else: split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos)
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups, axis_env, platform): if platform in ("cpu", "tpu"): return _notuple_allreduce_translation_rule( prim, c, *args, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex and prim is lax.add_p: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex and prim is lax.add_p: xs = [ xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n) ] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None): dtype = c.GetShape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None)