lax.add_p, c, replica_groups=replica_groups) dtype = c.GetShape(val).numpy_dtype() if dtypes.issubdtype(dtype, onp.complexfloating): return c.Complex(psum(c.Real(val)), psum(c.Imag(val))) else: return psum(val) psum_p = standard_pmap_primitive('psum') pxla.split_axis_rules[psum_p] = \ partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda x, shape: x * prod(shape) ad.deflinear(psum_p, lambda t, axis_name: [psum(t, axis_name)]) pxla.multi_host_supported_collectives.add(psum_p) pmax_p = standard_pmap_primitive('pmax') xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pxla.split_axis_rules[pmax_p] = \ partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max) pmin_p = standard_pmap_primitive('pmin') xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p) pxla.split_axis_rules[pmin_p] = \ partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
tie_all_p = jax_core.Primitive('tie_all') tie_all_p.multiple_results = True tie_all_p.def_impl(lambda *args: args) tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda abstract_arrays.raise_to_shaped, args)) xla.translations[tie_all_p] = lambda c, *args: xc.ops.Tuple(c, args) def _tie_all_batch_rule(batched_args, batch_dims): return batched_args, batch_dims def _tie_all_transpose(cts_in, *args, **params): del args, params return cts_in ad.deflinear(tie_all_p, _tie_all_transpose) batching.primitive_batchers[tie_all_p] = _tie_all_batch_rule def tie_all(*args): """An identity function that ties arguments together in a JAX trace.""" flat_args, in_tree = tree_util.tree_flatten(args) if len(flat_args) <= 1: return args out = tie_all_p.bind(*flat_args) return tree_util.tree_unflatten(in_tree, out) def tie_in(x, y): """A reimplementation of `jax.tie_in` that handles pytrees.""" return tie_all(x, y)[1]
return psum(val) return xops.Tuple(c, list(map(_translate, args))) def _psum_transpose_rule(cts, axis_name, axis_index_groups): nonzero_out_cts, treedef = tree_util.tree_flatten(cts) nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) psum_p = core.Primitive('psum') psum_p.multiple_results = True psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.soft_pmap_rules[psum_p] = \ partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p) batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p) batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p) batching.collective_rules[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, d: v.sum(d), lambda v, axis_size: axis_size * v) # We set a special bind rule for psum so that psum(1, 'i') can be evaluated at # tracing time. @psum_p.def_custom_bind def psum_bind(*args, axis_name, axis_index_groups): if all(not isinstance(x, core.Tracer) for x in args): if axis_index_groups is not None:
return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val) return xops.Tuple(c, list(map(_translate, args))) psum_p = standard_pmap_primitive('psum', multiple_results=True) psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.split_axis_rules[psum_p] = \ partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear( psum_p, lambda ts, axis_name, axis_index_groups: psum_p.bind( *ts, axis_name=axis_name, axis_index_groups=axis_index_groups)) pxla.multi_host_supported_collectives.add(psum_p) pmax_p = standard_pmap_primitive('pmax') xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pxla.split_axis_rules[pmax_p] = \ partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max) pmin_p = standard_pmap_primitive('pmin') xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p) pxla.split_axis_rules[pmin_p] = \ partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
if dtypes.issubdtype(dtype, onp.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val) return xops.Tuple(c, list(map(_translate, args))) psum_p = standard_pmap_primitive('psum', multiple_results=True) psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.split_axis_rules[psum_p] = \ partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear(psum_p, lambda ts, axis_name: psum(ts, axis_name=axis_name)) pxla.multi_host_supported_collectives.add(psum_p) pmax_p = standard_pmap_primitive('pmax') xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pxla.split_axis_rules[pmax_p] = \ partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max) pmin_p = standard_pmap_primitive('pmin') xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p) pxla.split_axis_rules[pmin_p] = \ partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
sow_p.def_impl(_sow_impl) def _sow_abstract_eval(*avals, **_): return avals sow_p.def_abstract_eval(_sow_abstract_eval) def _sow_transpose(cts_in, *_, **__): return cts_in ad.deflinear(sow_p, _sow_transpose) def _sow_batch_rule(batched_args, batch_dims, **params): outs = sow_p.bind(*batched_args, **params) return outs, batch_dims batching.primitive_batchers[sow_p] = _sow_batch_rule xla.translations[sow_p] = lambda c, *args, **params: xc.ops.Tuple(c, args) nest_p = jax_core.CallPrimitive('nest') def _nest_impl(f, *args, **_): return f.call_wrapped(*args)
out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) return out def fft_transpose_rule(t, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) xla.translations[fft_p] = fft_translation_rule ad.deflinear(fft_p, fft_transpose_rule) batching.primitive_batchers[fft_p] = fft_batching_rule if pocketfft: xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
return [t] def _psum_parallel_translation_rule(c, val, device_groups): if len(device_groups) > 1: return c.CrossReplicaSum(val, device_groups) else: return c.CrossReplicaSum(val) psum_p = PmapPrimitive('psum') psum_p.def_impl(partial(_unbound_name_error, 'psum')) psum_p.def_abstract_eval(lambda x, *args, **kwargs: x) parallel.serial_pmap_primitive_rules[psum_p] = _psum_serial_pmap_rule pxla.parallel_translation_rules[psum_p] = _psum_parallel_translation_rule ad.deflinear(psum_p, _psum_transpose_rule) parallel.defreducer(lax.reduce_sum_p, psum_p) def _pmax_serial_pmap_rule(vals, axes): val, = vals axis, = axes return lax._reduce_max(val, [axis]), None pmax_p = PmapPrimitive('pmax') pmax_p.def_impl(partial(_unbound_name_error, 'pmax')) pmax_p.def_abstract_eval(lambda x, *args, **kwargs: x) parallel.serial_pmap_primitive_rules[pmax_p] = _pmax_serial_pmap_rule parallel.defreducer(lax.reduce_max_p, pmax_p)
dtype = c.GetShape(val).numpy_dtype() scalar = xla_bridge.Shape.array_shape(dtype, ()) computation = xla.primitive_computation(prim, scalar, scalar) return c.AllReduce(val, computation, replica_groups=device_groups) psum_p = PmapPrimitive('psum') parallel.defreducer(lax.reduce_sum_p, psum_p) parallel.serial_pmap_primitive_rules[psum_p] = \ partial(_allreduce_serial_pmap_rule, lax._reduce_sum) # TODO(mattjj): replace translation rule when we update jaxlib # pxla.parallel_translation_rules[psum_p] = \ # partial(_allreduce_translation_rule, lax.add_p) pxla.parallel_translation_rules[psum_p] = \ lambda c, val, device_groups: c.CrossReplicaSum(val, device_groups) ad.deflinear(psum_p, lambda t, axis_name: [t]) pmax_p = PmapPrimitive('pmax') parallel.defreducer(lax.reduce_max_p, pmax_p) parallel.serial_pmap_primitive_rules[pmax_p] = \ partial(_allreduce_serial_pmap_rule, lax._reduce_max) pxla.parallel_translation_rules[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pmin_p = PmapPrimitive('pmin') parallel.defreducer(lax.reduce_min_p, pmin_p) parallel.serial_pmap_primitive_rules[pmin_p] = \ partial(_allreduce_serial_pmap_rule, lax._reduce_min) pxla.parallel_translation_rules[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p)