コード例 #1
0
def progressive_biased_sampling(rng_key, proposal, new_proposal):
    """Baised proposal sampling.

    Unlike uniform sampling, biased sampling favors new proposals. It thus
    biases the transition away from the trajectory's initial state.

    """
    p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight),
                        a_max=1)
    do_accept = jax.random.bernoulli(rng_key, p_accept)
    new_weight = jnp.logaddexp(proposal.weight, new_proposal.weight)
    new_sum_log_p_accept = jnp.logaddexp(proposal.sum_log_p_accept,
                                         new_proposal.sum_log_p_accept)

    return jax.lax.cond(
        do_accept,
        lambda _: Proposal(
            new_proposal.state,
            new_proposal.energy,
            new_weight,
            new_sum_log_p_accept,
        ),
        lambda _: Proposal(
            proposal.state,
            proposal.energy,
            new_weight,
            new_sum_log_p_accept,
        ),
        operand=None,
    )
コード例 #2
0
ファイル: ibp_test.py プロジェクト: zeta1999/jax_verify
    def test_softplus_ibp(self):
        def softplus_model(inp):
            return jax.nn.softplus(inp)

        z = jnp.array([[-2., 3.]])

        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            softplus_model, input_bounds)

        self.assertArrayAlmostEqual(jnp.logaddexp(z - 1., 0),
                                    output_bounds.lower)
        self.assertArrayAlmostEqual(jnp.logaddexp(z + 1., 0),
                                    output_bounds.upper)
コード例 #3
0
def update_pn_forward(carry, i):
    logpmf_ytest, logpmf_yn, y_samp, pdiff, y, x, x_test, rho, rho_x, ind_new, vT, logpmf_init = carry

    #Sample new x
    x_new = x[ind_new[i]]

    #Sample new y based on unif rv
    y_new = jnp.where((jnp.log(vT[i]) <= logpmf_yn[ind_new[i]]), x=1, y=0)
    log_vn = logpmf_yn[ind_new[i]]
    y_samp = index_update(y_samp, index[i], y_new)

    #Update pmf_yn
    #compute x rhos/alphas
    logalpha = jnp.log(2. - (1 / (i + 1))) - jnp.log(i + 2)
    logk_xx = mvcc.calc_logkxx(x, x_new, rho_x)
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(-jnp.exp(logalpha))
    logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx)
                                   )  #alpha*k_xx /(1-alpha + alpha*k_xx)
    #clip for numerical stability to prevent NaNs
    eps = 1e-4  #1e-6 causes optimization to fail
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    logpmf_yn = mvcc.update_copula(logpmf_yn, log_vn, y_new, logalpha_x, rho)

    #Compute pdiff
    pdiff = index_update(
        pdiff, index[i, :],
        jnp.abs(jnp.exp(logpmf_yn[:, 0]) - jnp.exp(logpmf_init[:, 0])))

    #Update pmf_ytest
    #compute x rhos/alphas
    logalpha = jnp.log(2. - (1 / (i + 1))) - jnp.log(i + 2)
    logk_xx = mvcc.calc_logkxx(x_test, x_new, rho_x)
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(-jnp.exp(logalpha))
    logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx)
                                   )  #alpha*k_xx /(1-alpha + alpha*k_xx)
    #clip for numerical stability to prevent NaNs
    eps = 1e-4  #1e-6 causes optimization to fail
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    #Update pmf_ytest
    logpmf_ytest = mvcc.update_copula(logpmf_ytest, log_vn, y_new, logalpha_x,
                                      rho)

    carry = logpmf_ytest, logpmf_yn, y_samp, pdiff, y, x, x_test, rho, rho_x, ind_new, vT, logpmf_init
    return carry, i
コード例 #4
0
def update_pn(carry,i):
    vn,logcdf_conditionals_yn,logpdf_joints_yn,preq_loglik,x,rho,rho_x = carry

    #Compute new x
    x_new = x[i]
    logalpha = jnp.log(2.- (1/(i+1)))-jnp.log(i+2)

    #compute x rhos/alphas
    logk_xx = calc_logkxx(x,x_new,rho_x)
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(-jnp.exp(logalpha))
    logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha,logalphak_xx)) #alpha*k_xx /(1-alpha + alpha*k_xx)

    #clip for numerical stability to prevent NaNs
    eps = 1e-5 #1e-6 causes optimization to fail
    logalpha_x = jnp.clip(logalpha_x,jnp.log(eps),jnp.log(1-eps))

    u = jnp.exp(logcdf_conditionals_yn)
    v = jnp.exp(logcdf_conditionals_yn[i])

    vn = index_update(vn,i,v) #remember history of vn
 
    preq_loglik = index_update(preq_loglik,i,logpdf_joints_yn[i,-1])
    logcdf_conditionals_yn,logpdf_joints_yn= update_copula(logcdf_conditionals_yn,logpdf_joints_yn,u,v,logalpha_x,rho)
    carry = vn,logcdf_conditionals_yn,logpdf_joints_yn,preq_loglik,x,rho,rho_x
    return carry,i
コード例 #5
0
def test_cluster_split():
    import pylab as plt
    from jax import disable_jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.zeros(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0], jnp.bool_))
    radii, rotation = ellipsoid_params(C)
    y = mu[:, None] + rotation @ jnp.diag(radii) @ x
    plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)
    with disable_jit():
        cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split = \
            cluster_split(random.PRNGKey(0), points, mask, log_VS, log_ellipsoid_volume(radii), kmeans_init=True)
        print(jnp.logaddexp(log_ellipsoid_volume(radii1), log_ellipsoid_volume(radii2)), log_ellipsoid_volume(radii))
        print(log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split)
        print(cluster_id)

    y = mu1[:, None] + rotation1 @ jnp.diag(radii1) @ x
    plt.plot(y[0, :], y[1, :])

    y = mu2[:, None] + rotation2 @ jnp.diag(radii2) @ x
    plt.plot(y[0, :], y[1, :])

    mask = cluster_id == 0
    plt.scatter(points[mask, 0], points[mask, 1])
    mask = cluster_id == 1
    plt.scatter(points[mask, 0], points[mask, 1])

    plt.show()
コード例 #6
0
 def gibbs_loop(i, rng_particles_log_posteriors):
     rng, particles, log_posteriors = rng_particles_log_posteriors
     i = i % num_patients
     # flip values at index i
     particles_flipped = jax.ops.index_update(
         particles, jax.ops.index[:, i], np.logical_not(particles[:, i]))
     # compute log_posterior of flipped particles
     log_posteriors_flipped_at_i = bayes.tempered_logpos_logbase(
         particles_flipped, log_posterior_params, log_base_measure_params,
         rho)
     # compute acceptance probability, depending on whether we use Liu mod.
     if liu_modification:
         log_proposal_ratio = log_posteriors_flipped_at_i - log_posteriors
     else:
         log_proposal_ratio = log_posteriors_flipped_at_i - np.logaddexp(
             log_posteriors_flipped_at_i, log_posteriors)
     # here the MH thresholding is implicitly done.
     rng, rng_unif = jax.random.split(rng, 2)
     random_values = jax.random.uniform(rng_unif, particles.shape[:1])
     flipped_at_i = np.log(random_values) < log_proposal_ratio
     selected_at_i = np.logical_xor(flipped_at_i, particles[:, i])
     particles = jax.ops.index_update(particles, jax.ops.index[:, i],
                                      selected_at_i)
     log_posteriors = np.where(flipped_at_i, log_posteriors_flipped_at_i,
                               log_posteriors)
     return [rng, particles, log_posteriors]
コード例 #7
0
ファイル: pixelcnn.py プロジェクト: tokusumi/flax
def discretized_logistic_logpmf(images, means, inv_scales):
    """Compute log-probabilities for each mixture component, pixel and channel."""
    # Compute the difference between the logistic cdf half a level above and half
    # a level below the image value.
    centered = images - means

    # Where images == 1 we use log(1 - cdf(images - 1 / 255))
    top = -jnp.logaddexp(0, (centered - 1 / 255) * inv_scales)

    # Where images == -1 we use log(cdf(images + 1 / 255))
    bottom = -jnp.logaddexp(0, -(centered + 1 / 255) * inv_scales)

    # Elsewhere we use log(cdf(images + 1 / 255) - cdf(images - 1 / 255))
    mid = log1mexp(inv_scales / 127.5) + top + bottom

    return jnp.where(images == 1, top, jnp.where(images == -1, bottom, mid))
コード例 #8
0
ファイル: mnist_vae.py プロジェクト: zhouj/jax
def image_sample(rng, params, nrow, ncol):
  """Sample images from the generative model."""
  _, dec_params = params
  code_rng, img_rng = random.split(rng)
  logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
  sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
  return image_grid(nrow, ncol, sampled_images, (28, 28))
コード例 #9
0
def logp_unordered_subset(theta, zs):
    # last dimension of the zs indicates the selected elements
    # sparse index representation
    #
    # Wouter et al use the Gumbel representation to compute p(Sk) in
    # exponential time rather than factorial.
    # We do it in factorial time.
    Sz, Sx, N, K = zs.shape
    # Is there a better syntax for gather
    logp_z = theta[
        np.arange(Sz)[:,None,None,None],
        np.arange(Sx)[:,None,None],
        np.arange(N)[:,None],
        zs,
    ]

    # get denominator orderings
    perms = all_perms(K)
    logp = logp_z[..., perms]

    # cumlogsumexp would be more stable? but there are only two elements here...
    # sum_i p(b_i)
    #a = logp.max(-1, keepdims=True)
    #p = np.exp(logp - a)
    #sbi0 = a + np.log(p.cumsum(-1) - p)

    # slow implementation, the above seems wrong
    sbis = [np.log(np.zeros(logp[..., 0].shape))]
    for i in range(K-1):
        sbis.append(np.logaddexp(sbis[-1], logp[..., i]))
    sbi = np.stack(sbis, -1)

    logp_bs = logp.sum(-1) - log1mexp(sbi).sum(-1)
    logp_b = lse(logp_bs, -1)
    return logp_b
コード例 #10
0
def update_pn(carry, i):
    log_vn, logpmf1_yn, preq_loglik, y, x, rho, rho_x = carry

    #Compute new x
    y_new = y[i]
    x_new = x[i]
    logalpha = jnp.log(2. - (1 / (i + 1))) - jnp.log(i + 2)

    #compute x rhos/alphas
    eps = 5e-5
    logk_xx = calc_logkxx(x, x_new, rho_x)
    logalphak_xx = logalpha + logk_xx
    log1alpha = jnp.log1p(jnp.clip(-jnp.exp(logalpha), -1 + eps, jnp.inf))
    logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx)
                                   )  #alpha*k_xx /(1-alpha + alpha*k_xx)

    #clip for numerical stability to prevent NaNs
    logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps))

    #add p1 or (1-p1) depending on what y_new is
    temp = y_new * logpmf1_yn[i, -1] + (1 - y_new) * jnp.log1p(
        jnp.clip(-jnp.exp(logpmf1_yn[i, -1]), -1 + eps, jnp.inf))
    preq_loglik = index_update(preq_loglik, i, temp)

    #
    log_v = logpmf1_yn[i]
    log_vn = index_update(log_vn, i, log_v)

    logpmf1_yn = update_copula(logpmf1_yn, log_v, y_new, logalpha_x, rho)
    carry = log_vn, logpmf1_yn, preq_loglik, y, x, rho, rho_x
    return carry, i
コード例 #11
0
ファイル: multinomial.py プロジェクト: NeilGirdhar/efax
 def to_exp(self) -> MultinomialEP:
     max_q = jnp.maximum(0.0, jnp.amax(self.log_odds, axis=-1))
     q_minus_max_q = self.log_odds - max_q[..., np.newaxis]
     log_scaled_A = jnp.logaddexp(-max_q,
                                  jss.logsumexp(q_minus_max_q, axis=-1))
     return MultinomialEP(
         jnp.exp(q_minus_max_q - log_scaled_A[..., np.newaxis]))
コード例 #12
0
 def body(state):
     (i, done, old_cluster_id, _, _, _, _, _, _, _, _, min_loss,
      delay) = state
     mask1 = mask & (old_cluster_id == 0)
     mask2 = mask & (old_cluster_id == 1)
     # estimate volumes of current clustering
     n1 = jnp.sum(mask1)
     n2 = jnp.sum(mask2)
     log_VS1 = log_VS + jnp.log(n1) - jnp.log(n_S)
     log_VS2 = log_VS + jnp.log(n2) - jnp.log(n_S)
     # construct E_1, E_2 and compute volumes
     mu1, C1 = bounding_ellipsoid(points, mask1)
     radii1, rotation1 = ellipsoid_params(C1)
     log_VE1 = log_ellipsoid_volume(radii1)
     mu2, C2 = bounding_ellipsoid(points, mask2)
     radii2, rotation2 = ellipsoid_params(C2)
     log_VE2 = log_ellipsoid_volume(radii2)
     # enlarge to at least cover V(S1) and V(S2)
     log_scale1 = log_coverage_scale(log_VE1, log_VS1, D)
     log_scale2 = log_coverage_scale(log_VE2, log_VS2, D)
     C1 = C1 / jnp.exp(log_scale1)
     radii1 = jnp.exp(jnp.log(radii1) + log_scale1)
     C2 = C2 / jnp.exp(log_scale2)
     radii2 = jnp.exp(jnp.log(radii2) + log_scale2)
     log_VE1 = log_VE1 + log_scale1 * D
     log_VE2 = log_VE2 + log_scale2 * D
     # compute reassignment metrics
     maha1 = vmap(lambda point: (point - mu1) @ C1 @ (point - mu1))(points)
     maha2 = vmap(lambda point: (point - mu2) @ C2 @ (point - mu2))(points)
     log_h1 = log_VE1 - log_VS1 + jnp.log(maha1)
     log_h2 = log_VE2 - log_VS2 + jnp.log(maha2)
     # reassign
     delta_F = jnp.exp(log_h1) - jnp.exp(log_h2)
     reassign_idx = jnp.argmax(jnp.abs(delta_F))
     new_cluster_id = dynamic_update_slice(
         cluster_id, (delta_F[reassign_idx, None] > 0).astype(jnp.int_),
         reassign_idx[None])
     # new_cluster_id = jnp.where(log_h1 < log_h2, 0, 1)
     log_V_sum = jnp.logaddexp(log_VE1, log_VE2)
     new_loss = jnp.exp(log_V_sum - log_VS)
     loss_decreased = new_loss < min_loss
     delay = jnp.where(loss_decreased, 0, delay + 1)
     min_loss = jnp.where(loss_decreased, new_loss, min_loss)
     ###
     # i / delay / loss_decreased / new_loss / min_loss
     # 0 / 0 / True / a / a
     # 1 / 1 / False / b / a
     # 2 / 2 / False / a / a
     # 3 / 3 / False / b / a
     # 4 / 4 / False / a / a
     done = jnp.all(new_cluster_id == old_cluster_id) \
            | (delay >= 10) \
            | (n1 < D + 1) \
            | (n2 < D + 1) \
            | jnp.isnan(log_V_sum)
     # print(i, "reassignments", jnp.sum(new_cluster_id != old_cluster_id), 'F', log_V_sum)
     # print(i, done, jnp.abs(delta_F).max())
     return (i + 1, done, new_cluster_id, log_VS1, mu1, radii1, rotation1,
             log_VS2, mu2, radii2, rotation2, min_loss, delay)
コード例 #13
0
 def log_likelihood(theta, **kwargs):
     def log_circ(theta, c, r, w):
         return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2))
     w1=w2=jnp.array(0.1)
     r1=r2=jnp.array(2.)
     c1 = jnp.array([0., -4.])
     c2 = jnp.array([0., 4.])
     return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2))
コード例 #14
0
ファイル: multinomial.py プロジェクト: NeilGirdhar/efax
 def nat_to_probability(self) -> RealArray:
     max_q = jnp.maximum(0.0, jnp.amax(self.log_odds, axis=-1))
     q_minus_max_q = self.log_odds - max_q[..., np.newaxis]
     log_scaled_A = jnp.logaddexp(-max_q,
                                  jss.logsumexp(q_minus_max_q, axis=-1))
     p = jnp.exp(q_minus_max_q - log_scaled_A[..., np.newaxis])
     final_p = 1.0 - jnp.sum(p, axis=-1, keepdims=True)
     return jnp.append(p, final_p, axis=-1)
コード例 #15
0
 def update_sum_log_p_accept(inputs):
     _, proposal, new_proposal = inputs
     return Proposal(
         proposal.state,
         proposal.energy,
         proposal.weight,
         jnp.logaddexp(proposal.sum_log_p_accept,
                       new_proposal.sum_log_p_accept),
     )
コード例 #16
0
ファイル: functions.py プロジェクト: sts-sadr/jax
def softplus(x):
  r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)
  """
  return jnp.logaddexp(x, 0)
コード例 #17
0
ファイル: utils.py プロジェクト: fehiepsi/jaxns
def logaddexp(x1, x2):
    if is_complex(x1) or is_complex(x2):
        select1 = x1.real > x2.real
        amax = jnp.where(select1, x1, x2)
        delta = jnp.where(select1, x2-x1, x1-x2)
        return jnp.where(jnp.isnan(delta),
                          x1+x2,  # NaNs or infinities of the same sign.
                          amax + jnp.log1p(jnp.exp(delta)))
    else:
        return jnp.logaddexp(x1, x2)
コード例 #18
0
ファイル: cubes_utils.py プロジェクト: fehiepsi/jaxns
 def inner_body(inner_state):
     (
         key,
         _done,
         _completed,
         log_t_M,
     ) = inner_state
     completed = jnp.logaddexp(log_t_cum, log_t_M) >= log_T
     # if not completed then increment t_M and test if point in cube
     log_t_M = jnp.where(completed, log_t_M, jnp.logaddexp(log_t_M, 0.))
     key, inner_choose_key = random.split(key, 2)
     j = random.randint(inner_choose_key,
                        shape=(),
                        minval=0,
                        maxval=N + 1)
     # point query
     in_j = points_in_box(x, points_lower[j, :], points_upper[j, :])
     done = in_j | completed
     return key, done, completed, log_t_M
コード例 #19
0
def sigmoid_forward(params, x):
    """Maps x in (-infinity, infinity)^d to (0, 1)^d"""
    beta, log_scale, shift = params
    x = (x + shift) * np.exp(
        log_scale)  # To get closer to unit square when mapped out
    sigmoid_log_prob = np.sum(log_scale + np.log(beta) - beta * x -
                              2 * np.logaddexp(0, -beta * x),
                              axis=1)
    x = 1.0 / (1 + np.exp(-beta * x))
    return x, sigmoid_log_prob
コード例 #20
0
def bernoulli_log_density(b, unnormalized_logprob):
    """
  Args: 
    Unnormalized_logprob: log(mu / (1 - mu)) <- "logit"
    b: 0 or 1 (i.e. binarized digit/image)
  Return: 
    log Ber(b | mu)
  """
    s = b * 2 - 1
    return -np.logaddexp(0., -s * unnormalized_logprob)
コード例 #21
0
    def loop_body(prev, x):
        prev_phi, prev_emit = prev
        # emit-to-phi epsilon transition
        prev_phi = prev_phi.at[:, 1:].set(
            jnp.logaddexp(prev_phi[:, 1:], prev_emit[:, :-1]))

        logprob_emit, logprob_phi, pad = x

        # phi-to-emit transition
        next_emit = jnp.logaddexp(prev_phi + logprob_emit,
                                  prev_emit + logprob_emit)
        # self-loop transition
        next_phi = prev_phi + logprob_phi.reshape((batchsize, 1))

        pad = pad.reshape((batchsize, 1))
        next_emit = pad * prev_emit + (1.0 - pad) * next_emit
        next_phi = pad * prev_phi + (1.0 - pad) * next_phi

        return (next_phi, next_emit), (next_phi, next_emit)
コード例 #22
0
def binary_crossentropy(
    y_true: jnp.ndarray, 
    y_pred: jnp.ndarray, 
    from_logits: bool = False
    ) -> jnp.ndarray:

    if from_logits:
        return -jnp.mean(y_true * y_pred - jnp.logaddexp(0.0, y_pred), axis=-1)

    y_pred = jnp.clip(y_pred, utils.EPSILON, 1.0 - utils.EPSILON)
    return -jnp.mean(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred), axis=-1)
コード例 #23
0
ファイル: mv_copula_density_t.py プロジェクト: edfong/MP
def update_copula_single(logcdf_conditionals,logpdf_joints,u,v,alpha,rho): 
    d = jnp.shape(logpdf_joints)[0]

    logcop_distribution,logcop_dens = t1_copula_logdistribution_logdensity(u,v,rho)

    #Calculate product copulas
    logcop_dens_prod = jnp.cumsum(logcop_dens)

    #staggered 1 step to calculate conditional cdfs
    logcop_dens_prod_staggered = jnp.concatenate((jnp.zeros(1),logcop_dens_prod[0:d-1]))

    logalpha = jnp.log(alpha)
    log1alpha = jnp.log(1-alpha)

    logcdf_conditionals = jnp.logaddexp((log1alpha + logcdf_conditionals),(logalpha + logcop_dens_prod_staggered + logcop_distribution))\
                           -jnp.logaddexp(log1alpha,(logalpha+logcop_dens_prod_staggered))

    logpdf_joints = jnp.logaddexp(log1alpha, (logalpha+logcop_dens_prod))+logpdf_joints     

    return logcdf_conditionals,logpdf_joints
コード例 #24
0
def softplus(x: Array) -> Array:
    r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)

  Args:
    x : input array
  """
    return jnp.logaddexp(x, 0)
コード例 #25
0
ファイル: binary_crossentropy.py プロジェクト: chjort/elegy
def binary_crossentropy(y_true: jnp.ndarray,
                        y_pred: jnp.ndarray,
                        from_logits: bool = False) -> jnp.ndarray:
    assert abs(y_pred.ndim - y_true.ndim) <= 1

    y_true, y_pred = utils.maybe_expand_dims(y_true, y_pred)

    if from_logits:
        return -jnp.mean(y_true * y_pred - jnp.logaddexp(0.0, y_pred), axis=-1)

    y_pred = jnp.clip(y_pred, utils.EPSILON, 1.0 - utils.EPSILON)
    return -jnp.mean(
        y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred), axis=-1)
コード例 #26
0
def smooth_leaky_relu(x, alpha=1.0):
    """Calculate smooth leaky ReLU on an input.

    Source: https://stats.stackexchange.com/questions/329776/approximating-leaky-relu-with-a-differentiable-function

    Args:
        x (float): input value.
        alpha (float): controls level of nonlinearity via slope.

    Returns:
        Value transformed by the smooth leaky ReLU.
    """
    return alpha * x + (1 - alpha) * jnp.logaddexp(x, 0)
コード例 #27
0
def progressive_uniform_sampling(rng_key, proposal, new_proposal):
    p_accept = jax.scipy.special.expit(new_proposal.weight - proposal.weight)
    do_accept = jax.random.bernoulli(rng_key, p_accept)

    updated_proposal = Proposal(
        new_proposal.state,
        new_proposal.energy,
        jnp.logaddexp(proposal.weight, new_proposal.weight),
    )

    return jax.lax.cond(
        do_accept, lambda _: updated_proposal, lambda _: proposal, operand=None
    )
コード例 #28
0
ファイル: cubes_utils.py プロジェクト: fehiepsi/jaxns
    def body(state):
        (key, _done, log_t_cum, log_M) = state
        key, choose_key, sample_key = random.split(key, 3)
        i = random.categorical(choose_key, logits=log_Vp_i - log_Vp)
        # sample query
        x = random.uniform(sample_key,
                           shape=(D, ),
                           minval=points_lower[i, :],
                           maxval=points_upper[i, :])

        def inner_body(inner_state):
            (
                key,
                _done,
                _completed,
                log_t_M,
            ) = inner_state
            completed = jnp.logaddexp(log_t_cum, log_t_M) >= log_T
            # if not completed then increment t_M and test if point in cube
            log_t_M = jnp.where(completed, log_t_M, jnp.logaddexp(log_t_M, 0.))
            key, inner_choose_key = random.split(key, 2)
            j = random.randint(inner_choose_key,
                               shape=(),
                               minval=0,
                               maxval=N + 1)
            # point query
            in_j = points_in_box(x, points_lower[j, :], points_upper[j, :])
            done = in_j | completed
            return key, done, completed, log_t_M

        (key, _done, completed, log_t_M) = while_loop(
            lambda inner_state: ~inner_state[1], inner_body,
            (key, log_t_cum >= log_T, log_t_cum >= log_T, -jnp.inf))
        done = completed
        log_t_cum = jnp.logaddexp(log_t_cum, log_t_M)
        return (key, done, log_t_cum, jnp.logaddexp(log_M, 0.))
コード例 #29
0
ファイル: ctc_objectives.py プロジェクト: tensorflow/lingvo
  def loop_body(prev, x):
    prev_phi, prev_emit = prev
    # emit-to-phi epsilon transition, except if the next label is repetition
    prev_phi_orig = prev_phi
    prev_phi = prev_phi.at[:, 1:].set(
        jnp.logaddexp(prev_phi[:, 1:], prev_emit + _LOGEPSILON * repeat))

    logprob_emit, logprob_phi, pad = x

    # phi-to-emit transition
    next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit,
                              prev_emit + logprob_emit)
    # self-loop transition
    next_phi = prev_phi + logprob_phi
    # emit-to-phi blank transition only when the next label is repetition
    next_phi = next_phi.at[:, 1:].set(
        jnp.logaddexp(next_phi[:, 1:],
                      prev_emit + logprob_phi + _LOGEPSILON * (1.0 - repeat)))

    pad = pad.reshape((batchsize, 1))
    next_emit = pad * prev_emit + (1.0 - pad) * next_emit
    next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi

    return (next_phi, next_emit), (next_phi, next_emit)
コード例 #30
0
ファイル: activations.py プロジェクト: PKU-NIP-Lab/BrainPy
def softplus(x):
    r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)

  Parameters
  ----------
  x: JaxArray, jnp.ndarray
    The input array.
  """
    x = x.value if isinstance(x, JaxArray) else x
    return JaxArray(jnp.logaddexp(x, 0))