Пример #1
0
    # buffers from different XLA backends are passed through the host.
    backend = xb.get_device_backend(device)
    moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
  return device_array.make_device_array(x.aval, device, moved_buf)


def _device_put_impl(x, device: Optional[Device] = None):
  if device_array.type_is_device_array(x):
    return _copy_device_array_to_device(x, device)

  try:
    a = xla.abstractify(x)
  except TypeError as err:
    raise TypeError(
        f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
  return aval_to_result_handler(device, a)(*device_put(x, device))

device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None: x)
xla.translations[device_put_p] = lambda c, x, device=None: x
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)

def _device_put_lowering(ctx, x, *, device):
  return [x]


mlir.register_lowering(device_put_p, _device_put_lowering)
Пример #2
0
def _standard_gamma_grad(sample, alpha):
    samples = np.reshape(sample, -1)
    alphas = np.reshape(alpha, -1)
    grads = vmap(_standard_gamma_grad_one)(samples, alphas)
    return grads.reshape(alpha.shape)


@custom_transforms
def _standard_gamma_p(key, alpha):
    return _standard_gamma_impl(key, alpha)


ad.defjvp2(
    _standard_gamma_p.primitive, None, lambda tangent, sample, key, alpha, **
    kwargs: tangent * _standard_gamma_grad(sample, alpha))
batching.defvectorized(_standard_gamma_p.primitive)


@partial(jit, static_argnums=(2, 3))
def _standard_gamma(key, alpha, shape=(), dtype=np.float32):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    return _standard_gamma_p(key, alpha)


def standard_gamma(key, alpha, shape=(), dtype=np.float32):
    return _standard_gamma(key, alpha, shape, dtype)

Пример #3
0
# XXX work around the issue: batching rule for 'reduce_window' not implemented
# when using @custom_transforms decorator
def _cumprod_impl(x):
    return np.cumprod(x, axis=-1)


cumprod_p = core.Primitive('cumprod')
cumprod_p.def_impl(_cumprod_impl)
cumprod_p.def_abstract_eval(
    partial(partial_eval.abstract_eval_fun, _cumprod_impl))
xla.translations[cumprod_p] = partial(xla.lower_fun, _cumprod_impl)
# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
ad.defjvp2(cumprod_p, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
batching.defvectorized(cumprod_p)


def cumprod(x):
    return cumprod_p.bind(x)


def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [np.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        return [
            lax.reshape(arg, (1, ) * (num_dims - len(s)) +
Пример #4
0
def _standard_gamma_grad(sample, alpha):
    samples = np.reshape(sample, -1)
    alphas = np.reshape(alpha, -1)
    grads = vmap(_standard_gamma_grad_one)(samples, alphas)
    return grads.reshape(alpha.shape)


@custom_transforms
def _standard_gamma_p(key, alpha):
    return _standard_gamma_impl(key, alpha)


ad.defjvp2(_standard_gamma_p.primitive, None,
           lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha))
batching.defvectorized(_standard_gamma_p.primitive)


@partial(jit, static_argnums=(2, 3))
def _standard_gamma(key, alpha, shape, dtype):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    return _standard_gamma_p(key, alpha)


def standard_gamma(key, alpha, shape=(), dtype=np.float64):
    dtype = xla_bridge.canonicalize_dtype(dtype)
    return _standard_gamma(key, alpha, shape, dtype)