示例#1
0
    return wrapped


def _sharding_constraint_impl(x, partitions):
    # TODO(skye): can we also prevent this from being called in other
    # non-sharded_jit contexts? (e.g. pmap, control flow)
    raise NotImplementedError(
        "with_sharding_constraint() should only be called inside sharded_jit()"
    )


sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
sharding_constraint_p.def_abstract_eval(lambda x, partitions: x)
ad.deflinear2(
    sharding_constraint_p, lambda ct, _, partitions:
    (with_sharding_constraint(ct, partitions), ))


def _sharding_constraint_lowering(ctx, x_node, partitions):
    return [
        mlir.wrap_with_sharding_op(x_node, xla.sharding_to_proto(partitions))
    ]


mlir.register_lowering(sharding_constraint_p, _sharding_constraint_lowering)


def with_sharding_constraint(x, partitions: Optional[pxla.PartitionSpec]):
    """Identity-like function that specifies how ``x`` should be sharded.
示例#2
0
  scale = 1 / prod(fft_lengths)
  out = scale * mask * x
  assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
  # Use JAX's convention for complex gradients
  # https://github.com/google/jax/issues/6223#issuecomment-807740707
  return lax.conj(out)

def fft_transpose_rule(t, operand, fft_type, fft_lengths):
  if fft_type == xla_client.FftType.RFFT:
    result = _rfft_transpose(t, fft_lengths)
  elif fft_type == xla_client.FftType.IRFFT:
    result = _irfft_transpose(t, fft_lengths)
  else:
    result = fft(t, fft_type, fft_lengths)
  return result,

def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
  x, = batched_args
  bd, = batch_dims
  x = batching.moveaxis(x, bd, 0)
  return fft(x, fft_type, fft_lengths), 0

fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
ad.deflinear2(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
  xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
示例#3
0
def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
    nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
    nonzero_in_cts = psum_p.bind(*nonzero_out_cts,
                                 axis_name=axis_name,
                                 axis_index_groups=axis_index_groups)
    return tree_util.tree_unflatten(treedef, nonzero_in_cts)


psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
    partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule,
                                            lax.add_p)  # type: ignore
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.collective_rules[psum_p] = \
  partial(_batched_reduction_collective,
          psum_p,
          lambda v, d: v.sum(d),
          lambda v, axis_size: axis_size * v)


# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
def psum_bind(*args, axis_name, axis_index_groups):
    if all(not isinstance(x, core.Tracer) for x in args):
        if axis_index_groups is not None:
示例#4
0
文件: dispatch.py 项目: jbampton/jax
    # buffers from different XLA backends are passed through the host.
    backend = xb.get_device_backend(device)
    moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
  return device_array.make_device_array(x.aval, device, moved_buf)


def _device_put_impl(x, device: Optional[Device] = None):
  if device_array.type_is_device_array(x):
    return _copy_device_array_to_device(x, device)

  try:
    a = xla.abstractify(x)
  except TypeError as err:
    raise TypeError(
        f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
  return aval_to_result_handler(device, a)(*device_put(x, device))

device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None: x)
xla.translations[device_put_p] = lambda c, x, device=None: x
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)

def _device_put_lowering(ctx, x, *, device):
  return [x]


mlir.register_lowering(device_put_p, _device_put_lowering)
示例#5
0
            window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
        window_strides = window_strides[:bdim] + (1, ) + window_strides[bdim:]
        padding = padding[:bdim] + ((0, 0), ) + padding[bdim:]
        base_dilation = base_dilation[:bdim] + (1, ) + base_dilation[bdim:]
        window_dilation = window_dilation[:bdim] + (
            1, ) + window_dilation[bdim:]

    operand = reduce_window(operand, window_dimensions, window_strides,
                            padding, base_dilation, window_dilation)
    return operand, bdim


reduce_window_sum_p = lax.standard_primitive(
    _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum',
    _reduce_window_sum_translation_rule)
ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
batching.primitive_batchers[reduce_window_sum_p] = partial(
    _reduce_window_batch_rule, _reduce_window_sum)


def _reduce_window_chooser_translation_rule(prim, identity, ctx, avals_in,
                                            avals_out, operand, *,
                                            window_dimensions, window_strides,
                                            padding, base_dilation,
                                            window_dilation):
    operand_aval, = avals_in
    scalar = ShapedArray((), operand_aval.dtype)
    return [
        xops.ReduceWindowWithGeneralPadding(
            operand,
            xla.pyval_to_ir_constant(ctx.builder,