Exemple #1
0
 def log_prob(self, value: Array) -> Array:
     """See `Distribution.log_prob`."""
     total_permutations = lax.lgamma(self._total_count + 1.)
     counts_factorial = lax.lgamma(value + 1.)
     redundant_permutations = jnp.sum(counts_factorial, axis=-1)
     log_combinations = total_permutations - redundant_permutations
     return log_combinations + jnp.sum(
         math.multiply_no_nan(self.log_of_probs, value), axis=-1)
Exemple #2
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
Exemple #3
0
 def body_func(args):
     xi, accumulated_sum = args
     xi_float = jnp.asarray(xi, dtype=dtype)
     log_xi_factorial = lax.lgamma(xi_float + 1.)
     log_comb_n_xi = (log_n_factorial - log_xi_factorial -
                      lax.lgamma(total_count - xi_float + 1.))
     comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
     likelihood1 = math.power_no_nan(probs, xi)
     likelihood2 = math.power_no_nan(1. - probs, total_count - xi)
     likelihood = likelihood1 * likelihood2
     comb_term = comb_n_xi * log_xi_factorial * likelihood  # [K]
     chex.assert_shape(comb_term, (probs.shape[-1], ))
     return xi + 1, accumulated_sum + comb_term
Exemple #4
0
    def _entropy_scalar_with_lax(
            total_count: int, probs: Array,
            log_of_probs: Array) -> Union[jnp.float32, jnp.float64]:
        """Like `_entropy_scalar`, but uses a lax while loop."""

        dtype = probs.dtype
        log_n_factorial = lax.lgamma(jnp.asarray(total_count + 1, dtype=dtype))

        def cond_func(args):
            xi, _ = args
            return jnp.less_equal(xi, total_count)

        def body_func(args):
            xi, accumulated_sum = args
            xi_float = jnp.asarray(xi, dtype=dtype)
            log_xi_factorial = lax.lgamma(xi_float + 1.)
            log_comb_n_xi = (log_n_factorial - log_xi_factorial -
                             lax.lgamma(total_count - xi_float + 1.))
            comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
            likelihood1 = math.power_no_nan(probs, xi)
            likelihood2 = math.power_no_nan(1. - probs, total_count - xi)
            likelihood = likelihood1 * likelihood2
            comb_term = comb_n_xi * log_xi_factorial * likelihood  # [K]
            chex.assert_shape(comb_term, (probs.shape[-1], ))
            return xi + 1, accumulated_sum + comb_term

        comb_term = jnp.sum(lax.while_loop(cond_func, body_func,
                                           (0, jnp.zeros_like(probs)))[1],
                            axis=-1)

        n_probs_factor = jnp.sum(total_count *
                                 math.multiply_no_nan(log_of_probs, probs),
                                 axis=-1)

        return -log_n_factorial - n_probs_factor + comb_term
Exemple #5
0
    def norm_const(self):
        corr = self.correlation.reshape(1, -1) + 1e-8
        conc = jnp.stack((self.phi_concentration, self.psi_concentration),
                         axis=-1).reshape(-1, 2)
        m = jnp.arange(50).reshape(-1, 1)
        num = lax.lgamma(2 * m + 1.0)
        den = lax.lgamma(m + 1.0)
        lbinoms = num - 2 * den

        fs = (lbinoms.reshape(-1, 1) + 2 * m * jnp.log(corr) -
              m * jnp.log(4 * jnp.prod(conc, axis=-1)))
        fs += log_I1(49, conc, terms=51).sum(-1)
        mfs = fs.max()
        norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(
            fs - mfs, 0)
        return norm_const.reshape(self.phi_loc.shape)
Exemple #6
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
def poisson_log_likelihood(x, log_rate):
  """Compute the log likelihood under Poisson distribution.

    log poisson(k, r) = log(r^k * e^(-r) / k!)
                      = k log(r) - r - log k!
    log poisson(k, r=exp(l)) = k * l - exp(l) - lgamma(k + 1)

  Args:
    x: binned spike count data.
    log_rate: The (log) rate that define the likelihood of the data
      under the LFADS model.
  Returns:
    The log-likelihood of the data under the model (up to a constant factor).
  """
  return x * log_rate - np.exp(log_rate) - lax.lgamma(x + 1.0)
Exemple #8
0
def log_I1(orders: int, value, terms=250):
    r"""Compute first n log modified bessel function of first kind
    .. math ::

        \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk)
        - \lgamma(v + k + 1)\right])

    :param orders: orders of the log modified bessel function.
    :param value: values to compute modified bessel function for
    :param terms: truncation of summation
    :return: 0 to orders modified bessel function
    """
    orders = orders + 1
    if value.ndim == 0:
        vshape = jnp.shape([1])
    else:
        vshape = value.shape
    value = value.reshape(-1, 1)
    flat_vshape = _numel(vshape)

    k = jnp.arange(terms)
    lgammas_all = lax.lgamma(jnp.arange(1.0, terms + orders + 1))
    assert lgammas_all.shape == (orders + terms,
                                 )  # lgamma(0) = inf => start from 1

    lvalues = lax.log(value / 2) * k.reshape(1, -1)
    assert lvalues.shape == (flat_vshape, terms)

    lfactorials = lgammas_all[:terms]
    assert lfactorials.shape == (terms, )

    lgammas = lgammas_all.tile(orders).reshape((orders, -1))
    assert lgammas.shape == (orders, terms + orders
                             )  # lgamma(0) = inf => start from 1

    indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1)
    assert indices.shape == (orders, terms)

    seqs = logsumexp(
        2 * lvalues[None, :, :] - lfactorials[None, None, :] -
        jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :],
        -1,
    )
    assert seqs.shape == (orders, flat_vshape)

    i1s = lvalues[..., :orders].T + seqs
    assert i1s.shape == (orders, flat_vshape)
    return i1s.reshape(-1, *vshape)
Exemple #9
0
    def _entropy_scalar(
            total_count: int, probs: Array,
            log_of_probs: Array) -> Union[jnp.float32, jnp.float64]:
        """Calculates the entropy for a Multinomial with integer `total_count`."""
        # Constant factors in the entropy.
        xi = jnp.arange(total_count + 1, dtype=probs.dtype)
        log_xi_factorial = lax.lgamma(xi + 1)
        log_n_minus_xi_factorial = jnp.flip(log_xi_factorial, axis=-1)
        log_n_factorial = log_xi_factorial[..., -1]
        log_comb_n_xi = (log_n_factorial[..., None] - log_xi_factorial -
                         log_n_minus_xi_factorial)
        comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
        chex.assert_shape(comb_n_xi, (total_count + 1, ))

        likelihood1 = math.power_no_nan(probs[..., None], xi)
        likelihood2 = math.power_no_nan(1. - probs[..., None],
                                        total_count - xi)
        chex.assert_shape(likelihood1, (
            probs.shape[-1],
            total_count + 1,
        ))
        chex.assert_shape(likelihood2, (
            probs.shape[-1],
            total_count + 1,
        ))
        likelihood = jnp.sum(likelihood1 * likelihood2, axis=-2)
        chex.assert_shape(likelihood, (total_count + 1, ))
        comb_term = jnp.sum(comb_n_xi * log_xi_factorial * likelihood, axis=-1)
        chex.assert_shape(comb_term, ())

        # Probs factors in the entropy.
        n_probs_factor = jnp.sum(total_count *
                                 math.multiply_no_nan(log_of_probs, probs),
                                 axis=-1)

        return -log_n_factorial - n_probs_factor + comb_term
Exemple #10
0
def betaln(x, y):
    x, y = _promote_args_inexact("betaln", x, y)
    return lax.lgamma(x) + lax.lgamma(y) - lax.lgamma(x + y)
Exemple #11
0
def gammaln(x):
    x, = _promote_args_inexact("gammaln", x)
    return lax.lgamma(x)
Exemple #12
0
def fact(n):
    return lax.exp(lax.lgamma(n + 1.))
Exemple #13
0
def logpdf(x, p):
    x, p = _promote_args_inexact("gennorm.logpdf", x, p)
    return lax.log(.5 * p) - lax.lgamma(1 / p) - lax.abs(x)**p
Exemple #14
0
 def logpdf(self, x):
     x = x * 1.0
     return lax.add(
         xlogy(x, lax.log(self.lmbda)),
         lax.add(lax.neg(self.lmbda), lax.neg(lax.lgamma(x))),
     )
Exemple #15
0
def _gamma(n):
    return jnp.exp(lgamma(n))