Exemplo n.º 1
0
def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
                                 axis_env, platform):
  # Workaround for AllToAll not being implemented on CPU.
  replica_groups = _replica_groups(axis_env, axis_name, None)
  if len(replica_groups[0]) == 1:
    return x
  elif platform != 'tpu':
    warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other "
                  "backends emulate it using a very slow and memory intensive algorithm, so expect "
                  "significant slowdowns.")
    lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True)
    return lowering(c, x,
                    split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name,
                    axis_env=axis_env, platform=platform)
  else:
    split_count = len(replica_groups[0])
    if not all(split_count == len(g) for g in replica_groups):
      raise ValueError('Replica groups must be equally sized')
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    if concat_axis == split_axis:
      return xops.AllToAll(x, split_axis, concat_axis, split_count,
                           replica_groups_protos)
    else:
      if concat_axis < split_axis:
        split_axis += 1
      elif split_axis < concat_axis:
        concat_axis += 1
      x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x)
      x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos)
      x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x)
      return x
Exemplo n.º 2
0
def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
  if platform in ("cpu", "tpu"):
    return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups)

  # XLA's tuple all-reduce doesn't support different dtypes in the same
  # allreduce. Instead, we perform once all-reduce for each argument input type.
  args_by_type = collections.defaultdict(lambda: ([], []))
  for i, arg in enumerate(args):
    indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
    indices.append(i)
    dtype_args.append(arg)

  # The outputs, in the original argument order.
  out = [None] * len(args)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
    is_complex = dtypes.issubdtype(dtype, onp.complexfloating)
    n = len(dtype_args)
    if is_complex:
      dtype_args = ([xops.Real(x) for x in dtype_args] +
                    [xops.Imag(x) for x in dtype_args])
    scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
    computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar)
    all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                replica_groups_protos, None, None)
    if is_complex:
      xs = [xops.Complex(xops.GetTupleElement(all_reduce, i),
                         xops.GetTupleElement(all_reduce, n + i)) for i in range(n)]
    else:
      xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
    for i, x in zip(indices, xs):
      out[i] = x
  return xops.Tuple(c, out)
Exemplo n.º 3
0
 def all_reduce(x):
     replica_groups_protos = xc.make_replica_groups(
         _replica_groups(axis_env, axis_name, axis_index_groups))
     scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
     computation = xla.primitive_subcomputation(prim, scalar, scalar)
     return xops.AllReduce(x, computation, replica_groups_protos, None,
                           None)
Exemplo n.º 4
0
def _all_gather_translation_rule(c, x, *, all_gather_dimension, axis_name,
                                 axis_index_groups, axis_size, axis_env,
                                 platform):
    # TODO(cjfj): Enable this for TPU also?
    if (platform == 'gpu') and (all_gather_dimension == 0):
        new_shape = list(c.get_shape(x).dimensions())
        new_shape.insert(all_gather_dimension, 1)
        broadcast_dimensions = [
            i for i in range(len(new_shape)) if i != all_gather_dimension
        ]
        x = xops.BroadcastInDim(x, new_shape, broadcast_dimensions)
        replica_groups = _replica_groups(axis_env, axis_name,
                                         axis_index_groups)
        return xops.AllGather(
            x,
            all_gather_dimension=all_gather_dimension,
            shard_count=axis_size,
            replica_groups=xc.make_replica_groups(replica_groups))
    else:
        lowering = xla.lower_fun(_all_gather_via_psum,
                                 multiple_results=False,
                                 parallel=True)
        return lowering(c,
                        x,
                        all_gather_dimension=all_gather_dimension,
                        axis_name=axis_name,
                        axis_index_groups=axis_index_groups,
                        axis_size=axis_size,
                        axis_env=axis_env,
                        platform=platform)
Exemplo n.º 5
0
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups,
                                axis_env, platform):
  replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
  dtype = c.get_shape(val).numpy_dtype()
  scalar = ShapedArray((), dtype)
  computation = xla.primitive_subcomputation(prim, scalar, scalar)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  return xops.AllReduce(val, computation, replica_groups_protos, None, None)
Exemplo n.º 6
0
def _all_to_all_translation_rule(c, x, split_axis, concat_axis, replica_groups,
                                 platform=None):
  # Workaround for AllToAll not being implemented on CPU.
  if len(replica_groups[0]) == 1:
    return x
  else:
    split_count = len(replica_groups[0])
    if not all(split_count == len(g) for g in replica_groups):
      raise ValueError('Replica groups must be equally sized')
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    return xops.AllToAll(x, split_axis, concat_axis, split_count,
                         replica_groups_protos)
Exemplo n.º 7
0
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
                                axis_env, platform):
    if platform in ("cpu", "tpu"):
        return _notuple_allreduce_translation_rule(
            prim,
            c,
            *args,
            axis_name=axis_name,
            axis_index_groups=axis_index_groups,
            axis_env=axis_env,
            platform=platform)

    # XLA's tuple all-reduce doesn't support different dtypes in the same
    # allreduce. Instead, we perform once all-reduce for each argument input type.
    args_by_type = collections.defaultdict(lambda: ([], []))
    for i, arg in enumerate(args):
        indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
        indices.append(i)
        dtype_args.append(arg)

    # The outputs, in the original argument order.
    out = [None] * len(args)
    replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        n = len(dtype_args)
        if is_complex and prim is lax.add_p:
            # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
            # special case because it's not currently handled by XLA:GPU
            dtype_args = ([xops.Real(x) for x in dtype_args] +
                          [xops.Imag(x) for x in dtype_args])
        scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
        computation = xla.primitive_subcomputation(prim, scalar, scalar)
        all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                    replica_groups_protos, None, None)
        if is_complex and prim is lax.add_p:
            xs = [
                xops.Complex(xops.GetTupleElement(all_reduce, i),
                             xops.GetTupleElement(all_reduce, n + i))
                for i in range(n)
            ]
        else:
            xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
        for i, x in zip(indices, xs):
            out[i] = x
    return xops.Tuple(c, out)
Exemplo n.º 8
0
def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None):
    dtype = c.GetShape(val).numpy_dtype()
    scalar = ShapedArray((), dtype)
    computation = xla.primitive_subcomputation(prim, scalar, scalar)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    return xops.AllReduce(val, computation, replica_groups_protos, None, None)