def _standard_gamma_grad(sample, alpha): samples = np.reshape(sample, -1) alphas = np.reshape(alpha, -1) grads = vmap(_standard_gamma_grad_one)(samples, alphas) return grads.reshape(alpha.shape) @custom_transforms def _standard_gamma_p(key, alpha): return _standard_gamma_impl(key, alpha) defjvp( _standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha)) @partial(jit, static_argnums=(2, 3)) def _standard_gamma(key, alpha, shape, dtype): shape = shape or np.shape(alpha) alpha = lax.convert_element_type(alpha, dtype) if np.shape(alpha) != shape: alpha = np.broadcast_to(alpha, shape) return _standard_gamma_p(key, alpha) def standard_gamma(key, alpha, shape=(), dtype=np.float64): dtype = xla_bridge.canonicalize_dtype(dtype) return _standard_gamma(key, alpha, shape, dtype)
return solve_triangular(tril_inv, identity, lower=True) # TODO: move upstream to jax.nn def binary_cross_entropy_with_logits(x, y): # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x)) # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y @custom_transforms def cumsum(x): return np.cumsum(x, axis=-1) defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1)) @custom_transforms def cumprod(x): return np.cumprod(x, axis=-1) # XXX this implementation does not address the case x=0, hence the result in that case will be nan # Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely defjvp(cumprod, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans) def promote_shapes(*args, shape=()): # adapted from lax.lax_numpy if len(args) < 2 and not shape:
The problem with using np.clip() is that if it modifies any value, the gradient of that value turns into zero. If we want to take a second derivative, then that will be zero too, *even though the second derivative of the squared-distance operation is a constant positive! This is a substantial problem in practice and needs a workaround. In PyTorch, you could do this using detach, e.g. >>> y = x - (torch.clamp(x, max=0.0)).detach() But we need something for JAX. """ return np.clip(x, 0.0, None) jax.defjvp(clip_up, lambda g, ans, x: g) jax.defvjp(clip_up, lambda g, ans, x: g) class Dataloader(object): def __init__(self, *tensors, batch_size=None): self.tensors = tensors self.batch_size = self.num_data if batch_size is None else batch_size self._shuffled_tensors = None self._batches_served = None @property def num_data(self): return len(self.tensors[0]) @property
# activations @custom_transforms def relu(x): r"""Rectified linear unit activation function. Computes the element-wise function: .. math:: \mathrm{relu}(x) = \max(x, 0) """ return np.maximum(x, 0) defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) def softplus(x): r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) """ return np.logaddexp(x, 0) def soft_sign(x): r"""Soft-sign activation function.
def _xlogy_jvp_rhs(g, ans, x, y): shape = lax.broadcast_shapes(np.shape(g), np.shape(x)) g = np.broadcast_to(g, shape) x = np.broadcast_to(x, shape) x, y = _promote_args_like(osp_special.xlogy, x, y) return g * lax._safe_mul(x, np.reciprocal(y)) @custom_transforms def xlogy(x, y): x, y = _promote_args_like(osp_special.xlogy, x, y) return lax._safe_mul(x, np.log(y)) defjvp(xlogy, _xlogy_jvp_lhs, _xlogy_jvp_rhs) def _xlog1py_jvp_lhs(g, ans, x, y): shape = lax.broadcast_shapes(np.shape(g), np.shape(y)) g = np.broadcast_to(g, shape) y = np.broadcast_to(y, shape) g, y = _promote_args_like(osp_special.xlog1py, g, y) return lax._safe_mul(g, np.log1p(y)) def _xlog1py_jvp_rhs(g, ans, x, y): shape = lax.broadcast_shapes(np.shape(g), np.shape(x)) g = np.broadcast_to(g, shape) x = np.broadcast_to(x, shape) x, y = _promote_args_like(osp_special.xlog1py, x, y)
Transform is written such that it acts as the identity during gradient backpropagation. Args: T: Transformation; ndarray(shape=[spatial_dim, spatial_dim]). v: Collection of vectors; ndarray(shape=[..., spatial_dim]). Returns: Transformed vectors; ndarray(shape=[..., spatial_dim]). """ _check_transform_shapes(T, v) return np.dot(v, T) jax.defjvp(_transform, None, lambda g, ans, T, v: g) def pairwise_displacement(Ra, Rb): """Compute a matrix of pairwise displacements given two sets of positions. Args: Ra: Vector of positions; ndarray(shape=[spatial_dim]). Rb: Vector of positions; ndarray(shape=[spatial_dim]). Returns: Matrix of displacements; ndarray(shape=[spatial_dim]). """ if len(Ra.shape) != 1: msg = ('Can only compute displacements between vectors. To compute ' 'displacements between sets of vectors use vmap or TODO.')
return 1 / x + 1 / (2 * x**2) + 1 / (6 * x**3) - 1 / (30 * x**5) + 1 / ( 42 * x**7) - 1 / (30 * x**9) + 5 / (66 * x**11) - 691 / ( 2730 * x**13) + 7 / (6 * x**15) @jax.custom_transforms def digamma(x): return spec.digamma(x) @jax.custom_transforms def gammaln(x): return spec.gammaln(x) jax.defjvp(digamma, lambda g, y, x: lax.mul(g, trigamma(x))) jax.defjvp(gammaln, lambda g, y, x: lax.mul(g, digamma(x))) def sigmoid(x): return 1 / (1 + np.exp(-x)) # link functions links = { 'identity': lambda x: x, 'exponential': lambda x: np.exp(x), 'logit': lambda x: 1 / (1 + np.exp(-x)) } # loss functions