Exemple #1
0
                   lax.add_p,
                   c,
                   replica_groups=replica_groups)
    dtype = c.GetShape(val).numpy_dtype()
    if dtypes.issubdtype(dtype, onp.complexfloating):
        return c.Complex(psum(c.Real(val)), psum(c.Imag(val)))
    else:
        return psum(val)


psum_p = standard_pmap_primitive('psum')
pxla.split_axis_rules[psum_p] = \
    partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda x, shape: x * prod(shape)
ad.deflinear(psum_p, lambda t, axis_name: [psum(t, axis_name)])
pxla.multi_host_supported_collectives.add(psum_p)

pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
    partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)

pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
    partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)

Exemple #2
0
tie_all_p = jax_core.Primitive('tie_all')
tie_all_p.multiple_results = True
tie_all_p.def_impl(lambda *args: args)
tie_all_p.def_abstract_eval(lambda *args: safe_map(  # pylint: disable=g-long-lambda
    abstract_arrays.raise_to_shaped, args))
xla.translations[tie_all_p] = lambda c, *args: xc.ops.Tuple(c, args)


def _tie_all_batch_rule(batched_args, batch_dims):
  return batched_args, batch_dims


def _tie_all_transpose(cts_in, *args, **params):
  del args, params
  return cts_in
ad.deflinear(tie_all_p, _tie_all_transpose)
batching.primitive_batchers[tie_all_p] = _tie_all_batch_rule


def tie_all(*args):
  """An identity function that ties arguments together in a JAX trace."""
  flat_args, in_tree = tree_util.tree_flatten(args)
  if len(flat_args) <= 1:
    return args
  out = tie_all_p.bind(*flat_args)
  return tree_util.tree_unflatten(in_tree, out)


def tie_in(x, y):
  """A reimplementation of `jax.tie_in` that handles pytrees."""
  return tie_all(x, y)[1]
Exemple #3
0
      return psum(val)
  return xops.Tuple(c, list(map(_translate, args)))

def _psum_transpose_rule(cts, axis_name, axis_index_groups):
  nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
  nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
                               axis_index_groups=axis_index_groups)
  return tree_util.tree_unflatten(treedef, nonzero_in_cts)

psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
    partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
ad.deflinear(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.collective_rules[psum_p] = \
  partial(_batched_reduction_collective,
          psum_p,
          lambda v, d: v.sum(d),
          lambda v, axis_size: axis_size * v)

# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
def psum_bind(*args, axis_name, axis_index_groups):
  if all(not isinstance(x, core.Tracer) for x in args):
    if axis_index_groups is not None:
Exemple #4
0
            return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
        else:
            return psum(val)

    return xops.Tuple(c, list(map(_translate, args)))


psum_p = standard_pmap_primitive('psum', multiple_results=True)
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.split_axis_rules[psum_p] = \
    partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape)
                                                         for x in args)
ad.deflinear(
    psum_p, lambda ts, axis_name, axis_index_groups: psum_p.bind(
        *ts, axis_name=axis_name, axis_index_groups=axis_index_groups))
pxla.multi_host_supported_collectives.add(psum_p)

pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
    partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)

pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
    partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
Exemple #5
0
        if dtypes.issubdtype(dtype, onp.complexfloating):
            return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
        else:
            return psum(val)

    return xops.Tuple(c, list(map(_translate, args)))


psum_p = standard_pmap_primitive('psum', multiple_results=True)
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.split_axis_rules[psum_p] = \
    partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape)
                                                         for x in args)
ad.deflinear(psum_p, lambda ts, axis_name: psum(ts, axis_name=axis_name))
pxla.multi_host_supported_collectives.add(psum_p)

pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
    partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)

pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
    partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)

Exemple #6
0
sow_p.def_impl(_sow_impl)


def _sow_abstract_eval(*avals, **_):
    return avals


sow_p.def_abstract_eval(_sow_abstract_eval)


def _sow_transpose(cts_in, *_, **__):
    return cts_in


ad.deflinear(sow_p, _sow_transpose)


def _sow_batch_rule(batched_args, batch_dims, **params):
    outs = sow_p.bind(*batched_args, **params)
    return outs, batch_dims


batching.primitive_batchers[sow_p] = _sow_batch_rule
xla.translations[sow_p] = lambda c, *args, **params: xc.ops.Tuple(c, args)

nest_p = jax_core.CallPrimitive('nest')


def _nest_impl(f, *args, **_):
    return f.call_wrapped(*args)
Exemple #7
0
    out = scale * mask * x
    assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
    return out


def fft_transpose_rule(t, fft_type, fft_lengths):
    if fft_type == xla_client.FftType.RFFT:
        result = _rfft_transpose(t, fft_lengths)
    elif fft_type == xla_client.FftType.IRFFT:
        result = _irfft_transpose(t, fft_lengths)
    else:
        result = fft(t, fft_type, fft_lengths)
    return result,


def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return fft(x, fft_type, fft_lengths), 0


fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
ad.deflinear(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
    xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
Exemple #8
0
    return [t]


def _psum_parallel_translation_rule(c, val, device_groups):
    if len(device_groups) > 1:
        return c.CrossReplicaSum(val, device_groups)
    else:
        return c.CrossReplicaSum(val)


psum_p = PmapPrimitive('psum')
psum_p.def_impl(partial(_unbound_name_error, 'psum'))
psum_p.def_abstract_eval(lambda x, *args, **kwargs: x)
parallel.serial_pmap_primitive_rules[psum_p] = _psum_serial_pmap_rule
pxla.parallel_translation_rules[psum_p] = _psum_parallel_translation_rule
ad.deflinear(psum_p, _psum_transpose_rule)
parallel.defreducer(lax.reduce_sum_p, psum_p)


def _pmax_serial_pmap_rule(vals, axes):
    val, = vals
    axis, = axes
    return lax._reduce_max(val, [axis]), None


pmax_p = PmapPrimitive('pmax')
pmax_p.def_impl(partial(_unbound_name_error, 'pmax'))
pmax_p.def_abstract_eval(lambda x, *args, **kwargs: x)
parallel.serial_pmap_primitive_rules[pmax_p] = _pmax_serial_pmap_rule
parallel.defreducer(lax.reduce_max_p, pmax_p)
Exemple #9
0
    dtype = c.GetShape(val).numpy_dtype()
    scalar = xla_bridge.Shape.array_shape(dtype, ())
    computation = xla.primitive_computation(prim, scalar, scalar)
    return c.AllReduce(val, computation, replica_groups=device_groups)


psum_p = PmapPrimitive('psum')
parallel.defreducer(lax.reduce_sum_p, psum_p)
parallel.serial_pmap_primitive_rules[psum_p] = \
    partial(_allreduce_serial_pmap_rule, lax._reduce_sum)
# TODO(mattjj): replace translation rule when we update jaxlib
# pxla.parallel_translation_rules[psum_p] = \
#     partial(_allreduce_translation_rule, lax.add_p)
pxla.parallel_translation_rules[psum_p] = \
    lambda c, val, device_groups: c.CrossReplicaSum(val, device_groups)
ad.deflinear(psum_p, lambda t, axis_name: [t])

pmax_p = PmapPrimitive('pmax')
parallel.defreducer(lax.reduce_max_p, pmax_p)
parallel.serial_pmap_primitive_rules[pmax_p] = \
    partial(_allreduce_serial_pmap_rule, lax._reduce_max)
pxla.parallel_translation_rules[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)

pmin_p = PmapPrimitive('pmin')
parallel.defreducer(lax.reduce_min_p, pmin_p)
parallel.serial_pmap_primitive_rules[pmin_p] = \
    partial(_allreduce_serial_pmap_rule, lax._reduce_min)
pxla.parallel_translation_rules[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)