Beispiel #1
0

def _standard_gamma_grad(sample, alpha):
    samples = np.reshape(sample, -1)
    alphas = np.reshape(alpha, -1)
    grads = vmap(_standard_gamma_grad_one)(samples, alphas)
    return grads.reshape(alpha.shape)


@custom_transforms
def _standard_gamma_p(key, alpha):
    return _standard_gamma_impl(key, alpha)


defjvp(
    _standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs:
    tangent * _standard_gamma_grad(sample, alpha))


@partial(jit, static_argnums=(2, 3))
def _standard_gamma(key, alpha, shape, dtype):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    return _standard_gamma_p(key, alpha)


def standard_gamma(key, alpha, shape=(), dtype=np.float64):
    dtype = xla_bridge.canonicalize_dtype(dtype)
    return _standard_gamma(key, alpha, shape, dtype)
Beispiel #2
0
    return solve_triangular(tril_inv, identity, lower=True)


# TODO: move upstream to jax.nn
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y


@custom_transforms
def cumsum(x):
    return np.cumsum(x, axis=-1)


defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1))


@custom_transforms
def cumprod(x):
    return np.cumprod(x, axis=-1)


# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
defjvp(cumprod, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)


def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
Beispiel #3
0
    The problem with using np.clip() is that if it modifies any value, the 
    gradient of that value turns into zero.  If we want to take a second 
    derivative, then that will be zero too, *even though the second derivative
    of the squared-distance operation is a constant positive!  This is a 
    substantial problem in practice and needs a workaround.

    In PyTorch, you could do this using detach, e.g.
    >>> y = x - (torch.clamp(x, max=0.0)).detach()
    But we need something for JAX.
    """

    return np.clip(x, 0.0, None)


jax.defjvp(clip_up, lambda g, ans, x: g)
jax.defvjp(clip_up, lambda g, ans, x: g)


class Dataloader(object):
    def __init__(self, *tensors, batch_size=None):
        self.tensors = tensors
        self.batch_size = self.num_data if batch_size is None else batch_size
        self._shuffled_tensors = None
        self._batches_served = None

    @property
    def num_data(self):
        return len(self.tensors[0])

    @property
Beispiel #4
0
# activations


@custom_transforms
def relu(x):
    r"""Rectified linear unit activation function.

  Computes the element-wise function:

  .. math::
    \mathrm{relu}(x) = \max(x, 0)
  """
    return np.maximum(x, 0)


defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))


def softplus(x):
    r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)
  """
    return np.logaddexp(x, 0)


def soft_sign(x):
    r"""Soft-sign activation function.
Beispiel #5
0
def _xlogy_jvp_rhs(g, ans, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
    g = np.broadcast_to(g, shape)
    x = np.broadcast_to(x, shape)
    x, y = _promote_args_like(osp_special.xlogy, x, y)
    return g * lax._safe_mul(x, np.reciprocal(y))


@custom_transforms
def xlogy(x, y):
    x, y = _promote_args_like(osp_special.xlogy, x, y)
    return lax._safe_mul(x, np.log(y))


defjvp(xlogy, _xlogy_jvp_lhs, _xlogy_jvp_rhs)


def _xlog1py_jvp_lhs(g, ans, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(y))
    g = np.broadcast_to(g, shape)
    y = np.broadcast_to(y, shape)
    g, y = _promote_args_like(osp_special.xlog1py, g, y)
    return lax._safe_mul(g, np.log1p(y))


def _xlog1py_jvp_rhs(g, ans, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
    g = np.broadcast_to(g, shape)
    x = np.broadcast_to(x, shape)
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
Beispiel #6
0
  Transform is written such that it acts as the identity during gradient
  backpropagation.

  Args:
    T: Transformation; ndarray(shape=[spatial_dim, spatial_dim]).
    v: Collection of vectors; ndarray(shape=[..., spatial_dim]).

  Returns:
    Transformed vectors; ndarray(shape=[..., spatial_dim]).
  """
    _check_transform_shapes(T, v)
    return np.dot(v, T)


jax.defjvp(_transform, None, lambda g, ans, T, v: g)


def pairwise_displacement(Ra, Rb):
    """Compute a matrix of pairwise displacements given two sets of positions.

  Args:
    Ra: Vector of positions; ndarray(shape=[spatial_dim]).
    Rb: Vector of positions; ndarray(shape=[spatial_dim]).

  Returns:
    Matrix of displacements; ndarray(shape=[spatial_dim]).
  """
    if len(Ra.shape) != 1:
        msg = ('Can only compute displacements between vectors. To compute '
               'displacements between sets of vectors use vmap or TODO.')
Beispiel #7
0
    return 1 / x + 1 / (2 * x**2) + 1 / (6 * x**3) - 1 / (30 * x**5) + 1 / (
        42 * x**7) - 1 / (30 * x**9) + 5 / (66 * x**11) - 691 / (
            2730 * x**13) + 7 / (6 * x**15)


@jax.custom_transforms
def digamma(x):
    return spec.digamma(x)


@jax.custom_transforms
def gammaln(x):
    return spec.gammaln(x)


jax.defjvp(digamma, lambda g, y, x: lax.mul(g, trigamma(x)))
jax.defjvp(gammaln, lambda g, y, x: lax.mul(g, digamma(x)))


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


# link functions
links = {
    'identity': lambda x: x,
    'exponential': lambda x: np.exp(x),
    'logit': lambda x: 1 / (1 + np.exp(-x))
}

# loss functions