def progressive_biased_sampling(rng_key, proposal, new_proposal): """Baised proposal sampling. Unlike uniform sampling, biased sampling favors new proposals. It thus biases the transition away from the trajectory's initial state. """ p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) new_weight = jnp.logaddexp(proposal.weight, new_proposal.weight) new_sum_log_p_accept = jnp.logaddexp(proposal.sum_log_p_accept, new_proposal.sum_log_p_accept) return jax.lax.cond( do_accept, lambda _: Proposal( new_proposal.state, new_proposal.energy, new_weight, new_sum_log_p_accept, ), lambda _: Proposal( proposal.state, proposal.energy, new_weight, new_sum_log_p_accept, ), operand=None, )
def test_softplus_ibp(self): def softplus_model(inp): return jax.nn.softplus(inp) z = jnp.array([[-2., 3.]]) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( softplus_model, input_bounds) self.assertArrayAlmostEqual(jnp.logaddexp(z - 1., 0), output_bounds.lower) self.assertArrayAlmostEqual(jnp.logaddexp(z + 1., 0), output_bounds.upper)
def update_pn_forward(carry, i): logpmf_ytest, logpmf_yn, y_samp, pdiff, y, x, x_test, rho, rho_x, ind_new, vT, logpmf_init = carry #Sample new x x_new = x[ind_new[i]] #Sample new y based on unif rv y_new = jnp.where((jnp.log(vT[i]) <= logpmf_yn[ind_new[i]]), x=1, y=0) log_vn = logpmf_yn[ind_new[i]] y_samp = index_update(y_samp, index[i], y_new) #Update pmf_yn #compute x rhos/alphas logalpha = jnp.log(2. - (1 / (i + 1))) - jnp.log(i + 2) logk_xx = mvcc.calc_logkxx(x, x_new, rho_x) logalphak_xx = logalpha + logk_xx log1alpha = jnp.log1p(-jnp.exp(logalpha)) logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx) ) #alpha*k_xx /(1-alpha + alpha*k_xx) #clip for numerical stability to prevent NaNs eps = 1e-4 #1e-6 causes optimization to fail logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps)) logpmf_yn = mvcc.update_copula(logpmf_yn, log_vn, y_new, logalpha_x, rho) #Compute pdiff pdiff = index_update( pdiff, index[i, :], jnp.abs(jnp.exp(logpmf_yn[:, 0]) - jnp.exp(logpmf_init[:, 0]))) #Update pmf_ytest #compute x rhos/alphas logalpha = jnp.log(2. - (1 / (i + 1))) - jnp.log(i + 2) logk_xx = mvcc.calc_logkxx(x_test, x_new, rho_x) logalphak_xx = logalpha + logk_xx log1alpha = jnp.log1p(-jnp.exp(logalpha)) logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx) ) #alpha*k_xx /(1-alpha + alpha*k_xx) #clip for numerical stability to prevent NaNs eps = 1e-4 #1e-6 causes optimization to fail logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps)) #Update pmf_ytest logpmf_ytest = mvcc.update_copula(logpmf_ytest, log_vn, y_new, logalpha_x, rho) carry = logpmf_ytest, logpmf_yn, y_samp, pdiff, y, x, x_test, rho, rho_x, ind_new, vT, logpmf_init return carry, i
def update_pn(carry,i): vn,logcdf_conditionals_yn,logpdf_joints_yn,preq_loglik,x,rho,rho_x = carry #Compute new x x_new = x[i] logalpha = jnp.log(2.- (1/(i+1)))-jnp.log(i+2) #compute x rhos/alphas logk_xx = calc_logkxx(x,x_new,rho_x) logalphak_xx = logalpha + logk_xx log1alpha = jnp.log1p(-jnp.exp(logalpha)) logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha,logalphak_xx)) #alpha*k_xx /(1-alpha + alpha*k_xx) #clip for numerical stability to prevent NaNs eps = 1e-5 #1e-6 causes optimization to fail logalpha_x = jnp.clip(logalpha_x,jnp.log(eps),jnp.log(1-eps)) u = jnp.exp(logcdf_conditionals_yn) v = jnp.exp(logcdf_conditionals_yn[i]) vn = index_update(vn,i,v) #remember history of vn preq_loglik = index_update(preq_loglik,i,logpdf_joints_yn[i,-1]) logcdf_conditionals_yn,logpdf_joints_yn= update_copula(logcdf_conditionals_yn,logpdf_joints_yn,u,v,logalpha_x,rho) carry = vn,logcdf_conditionals_yn,logpdf_joints_yn,preq_loglik,x,rho,rho_x return carry,i
def test_cluster_split(): import pylab as plt from jax import disable_jit points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.zeros(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0], jnp.bool_)) radii, rotation = ellipsoid_params(C) y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split = \ cluster_split(random.PRNGKey(0), points, mask, log_VS, log_ellipsoid_volume(radii), kmeans_init=True) print(jnp.logaddexp(log_ellipsoid_volume(radii1), log_ellipsoid_volume(radii2)), log_ellipsoid_volume(radii)) print(log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split) print(cluster_id) y = mu1[:, None] + rotation1 @ jnp.diag(radii1) @ x plt.plot(y[0, :], y[1, :]) y = mu2[:, None] + rotation2 @ jnp.diag(radii2) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == 0 plt.scatter(points[mask, 0], points[mask, 1]) mask = cluster_id == 1 plt.scatter(points[mask, 0], points[mask, 1]) plt.show()
def gibbs_loop(i, rng_particles_log_posteriors): rng, particles, log_posteriors = rng_particles_log_posteriors i = i % num_patients # flip values at index i particles_flipped = jax.ops.index_update( particles, jax.ops.index[:, i], np.logical_not(particles[:, i])) # compute log_posterior of flipped particles log_posteriors_flipped_at_i = bayes.tempered_logpos_logbase( particles_flipped, log_posterior_params, log_base_measure_params, rho) # compute acceptance probability, depending on whether we use Liu mod. if liu_modification: log_proposal_ratio = log_posteriors_flipped_at_i - log_posteriors else: log_proposal_ratio = log_posteriors_flipped_at_i - np.logaddexp( log_posteriors_flipped_at_i, log_posteriors) # here the MH thresholding is implicitly done. rng, rng_unif = jax.random.split(rng, 2) random_values = jax.random.uniform(rng_unif, particles.shape[:1]) flipped_at_i = np.log(random_values) < log_proposal_ratio selected_at_i = np.logical_xor(flipped_at_i, particles[:, i]) particles = jax.ops.index_update(particles, jax.ops.index[:, i], selected_at_i) log_posteriors = np.where(flipped_at_i, log_posteriors_flipped_at_i, log_posteriors) return [rng, particles, log_posteriors]
def discretized_logistic_logpmf(images, means, inv_scales): """Compute log-probabilities for each mixture component, pixel and channel.""" # Compute the difference between the logistic cdf half a level above and half # a level below the image value. centered = images - means # Where images == 1 we use log(1 - cdf(images - 1 / 255)) top = -jnp.logaddexp(0, (centered - 1 / 255) * inv_scales) # Where images == -1 we use log(cdf(images + 1 / 255)) bottom = -jnp.logaddexp(0, -(centered + 1 / 255) * inv_scales) # Elsewhere we use log(cdf(images + 1 / 255) - cdf(images - 1 / 255)) mid = log1mexp(inv_scales / 127.5) + top + bottom return jnp.where(images == 1, top, jnp.where(images == -1, bottom, mid))
def image_sample(rng, params, nrow, ncol): """Sample images from the generative model.""" _, dec_params = params code_rng, img_rng = random.split(rng) logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10))) sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits)) return image_grid(nrow, ncol, sampled_images, (28, 28))
def logp_unordered_subset(theta, zs): # last dimension of the zs indicates the selected elements # sparse index representation # # Wouter et al use the Gumbel representation to compute p(Sk) in # exponential time rather than factorial. # We do it in factorial time. Sz, Sx, N, K = zs.shape # Is there a better syntax for gather logp_z = theta[ np.arange(Sz)[:,None,None,None], np.arange(Sx)[:,None,None], np.arange(N)[:,None], zs, ] # get denominator orderings perms = all_perms(K) logp = logp_z[..., perms] # cumlogsumexp would be more stable? but there are only two elements here... # sum_i p(b_i) #a = logp.max(-1, keepdims=True) #p = np.exp(logp - a) #sbi0 = a + np.log(p.cumsum(-1) - p) # slow implementation, the above seems wrong sbis = [np.log(np.zeros(logp[..., 0].shape))] for i in range(K-1): sbis.append(np.logaddexp(sbis[-1], logp[..., i])) sbi = np.stack(sbis, -1) logp_bs = logp.sum(-1) - log1mexp(sbi).sum(-1) logp_b = lse(logp_bs, -1) return logp_b
def update_pn(carry, i): log_vn, logpmf1_yn, preq_loglik, y, x, rho, rho_x = carry #Compute new x y_new = y[i] x_new = x[i] logalpha = jnp.log(2. - (1 / (i + 1))) - jnp.log(i + 2) #compute x rhos/alphas eps = 5e-5 logk_xx = calc_logkxx(x, x_new, rho_x) logalphak_xx = logalpha + logk_xx log1alpha = jnp.log1p(jnp.clip(-jnp.exp(logalpha), -1 + eps, jnp.inf)) logalpha_x = (logalphak_xx) - (jnp.logaddexp(log1alpha, logalphak_xx) ) #alpha*k_xx /(1-alpha + alpha*k_xx) #clip for numerical stability to prevent NaNs logalpha_x = jnp.clip(logalpha_x, jnp.log(eps), jnp.log(1 - eps)) #add p1 or (1-p1) depending on what y_new is temp = y_new * logpmf1_yn[i, -1] + (1 - y_new) * jnp.log1p( jnp.clip(-jnp.exp(logpmf1_yn[i, -1]), -1 + eps, jnp.inf)) preq_loglik = index_update(preq_loglik, i, temp) # log_v = logpmf1_yn[i] log_vn = index_update(log_vn, i, log_v) logpmf1_yn = update_copula(logpmf1_yn, log_v, y_new, logalpha_x, rho) carry = log_vn, logpmf1_yn, preq_loglik, y, x, rho, rho_x return carry, i
def to_exp(self) -> MultinomialEP: max_q = jnp.maximum(0.0, jnp.amax(self.log_odds, axis=-1)) q_minus_max_q = self.log_odds - max_q[..., np.newaxis] log_scaled_A = jnp.logaddexp(-max_q, jss.logsumexp(q_minus_max_q, axis=-1)) return MultinomialEP( jnp.exp(q_minus_max_q - log_scaled_A[..., np.newaxis]))
def body(state): (i, done, old_cluster_id, _, _, _, _, _, _, _, _, min_loss, delay) = state mask1 = mask & (old_cluster_id == 0) mask2 = mask & (old_cluster_id == 1) # estimate volumes of current clustering n1 = jnp.sum(mask1) n2 = jnp.sum(mask2) log_VS1 = log_VS + jnp.log(n1) - jnp.log(n_S) log_VS2 = log_VS + jnp.log(n2) - jnp.log(n_S) # construct E_1, E_2 and compute volumes mu1, C1 = bounding_ellipsoid(points, mask1) radii1, rotation1 = ellipsoid_params(C1) log_VE1 = log_ellipsoid_volume(radii1) mu2, C2 = bounding_ellipsoid(points, mask2) radii2, rotation2 = ellipsoid_params(C2) log_VE2 = log_ellipsoid_volume(radii2) # enlarge to at least cover V(S1) and V(S2) log_scale1 = log_coverage_scale(log_VE1, log_VS1, D) log_scale2 = log_coverage_scale(log_VE2, log_VS2, D) C1 = C1 / jnp.exp(log_scale1) radii1 = jnp.exp(jnp.log(radii1) + log_scale1) C2 = C2 / jnp.exp(log_scale2) radii2 = jnp.exp(jnp.log(radii2) + log_scale2) log_VE1 = log_VE1 + log_scale1 * D log_VE2 = log_VE2 + log_scale2 * D # compute reassignment metrics maha1 = vmap(lambda point: (point - mu1) @ C1 @ (point - mu1))(points) maha2 = vmap(lambda point: (point - mu2) @ C2 @ (point - mu2))(points) log_h1 = log_VE1 - log_VS1 + jnp.log(maha1) log_h2 = log_VE2 - log_VS2 + jnp.log(maha2) # reassign delta_F = jnp.exp(log_h1) - jnp.exp(log_h2) reassign_idx = jnp.argmax(jnp.abs(delta_F)) new_cluster_id = dynamic_update_slice( cluster_id, (delta_F[reassign_idx, None] > 0).astype(jnp.int_), reassign_idx[None]) # new_cluster_id = jnp.where(log_h1 < log_h2, 0, 1) log_V_sum = jnp.logaddexp(log_VE1, log_VE2) new_loss = jnp.exp(log_V_sum - log_VS) loss_decreased = new_loss < min_loss delay = jnp.where(loss_decreased, 0, delay + 1) min_loss = jnp.where(loss_decreased, new_loss, min_loss) ### # i / delay / loss_decreased / new_loss / min_loss # 0 / 0 / True / a / a # 1 / 1 / False / b / a # 2 / 2 / False / a / a # 3 / 3 / False / b / a # 4 / 4 / False / a / a done = jnp.all(new_cluster_id == old_cluster_id) \ | (delay >= 10) \ | (n1 < D + 1) \ | (n2 < D + 1) \ | jnp.isnan(log_V_sum) # print(i, "reassignments", jnp.sum(new_cluster_id != old_cluster_id), 'F', log_V_sum) # print(i, done, jnp.abs(delta_F).max()) return (i + 1, done, new_cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, min_loss, delay)
def log_likelihood(theta, **kwargs): def log_circ(theta, c, r, w): return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2)) w1=w2=jnp.array(0.1) r1=r2=jnp.array(2.) c1 = jnp.array([0., -4.]) c2 = jnp.array([0., 4.]) return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2))
def nat_to_probability(self) -> RealArray: max_q = jnp.maximum(0.0, jnp.amax(self.log_odds, axis=-1)) q_minus_max_q = self.log_odds - max_q[..., np.newaxis] log_scaled_A = jnp.logaddexp(-max_q, jss.logsumexp(q_minus_max_q, axis=-1)) p = jnp.exp(q_minus_max_q - log_scaled_A[..., np.newaxis]) final_p = 1.0 - jnp.sum(p, axis=-1, keepdims=True) return jnp.append(p, final_p, axis=-1)
def update_sum_log_p_accept(inputs): _, proposal, new_proposal = inputs return Proposal( proposal.state, proposal.energy, proposal.weight, jnp.logaddexp(proposal.sum_log_p_accept, new_proposal.sum_log_p_accept), )
def softplus(x): r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) """ return jnp.logaddexp(x, 0)
def logaddexp(x1, x2): if is_complex(x1) or is_complex(x2): select1 = x1.real > x2.real amax = jnp.where(select1, x1, x2) delta = jnp.where(select1, x2-x1, x1-x2) return jnp.where(jnp.isnan(delta), x1+x2, # NaNs or infinities of the same sign. amax + jnp.log1p(jnp.exp(delta))) else: return jnp.logaddexp(x1, x2)
def inner_body(inner_state): ( key, _done, _completed, log_t_M, ) = inner_state completed = jnp.logaddexp(log_t_cum, log_t_M) >= log_T # if not completed then increment t_M and test if point in cube log_t_M = jnp.where(completed, log_t_M, jnp.logaddexp(log_t_M, 0.)) key, inner_choose_key = random.split(key, 2) j = random.randint(inner_choose_key, shape=(), minval=0, maxval=N + 1) # point query in_j = points_in_box(x, points_lower[j, :], points_upper[j, :]) done = in_j | completed return key, done, completed, log_t_M
def sigmoid_forward(params, x): """Maps x in (-infinity, infinity)^d to (0, 1)^d""" beta, log_scale, shift = params x = (x + shift) * np.exp( log_scale) # To get closer to unit square when mapped out sigmoid_log_prob = np.sum(log_scale + np.log(beta) - beta * x - 2 * np.logaddexp(0, -beta * x), axis=1) x = 1.0 / (1 + np.exp(-beta * x)) return x, sigmoid_log_prob
def bernoulli_log_density(b, unnormalized_logprob): """ Args: Unnormalized_logprob: log(mu / (1 - mu)) <- "logit" b: 0 or 1 (i.e. binarized digit/image) Return: log Ber(b | mu) """ s = b * 2 - 1 return -np.logaddexp(0., -s * unnormalized_logprob)
def loop_body(prev, x): prev_phi, prev_emit = prev # emit-to-phi epsilon transition prev_phi = prev_phi.at[:, 1:].set( jnp.logaddexp(prev_phi[:, 1:], prev_emit[:, :-1])) logprob_emit, logprob_phi, pad = x # phi-to-emit transition next_emit = jnp.logaddexp(prev_phi + logprob_emit, prev_emit + logprob_emit) # self-loop transition next_phi = prev_phi + logprob_phi.reshape((batchsize, 1)) pad = pad.reshape((batchsize, 1)) next_emit = pad * prev_emit + (1.0 - pad) * next_emit next_phi = pad * prev_phi + (1.0 - pad) * next_phi return (next_phi, next_emit), (next_phi, next_emit)
def binary_crossentropy( y_true: jnp.ndarray, y_pred: jnp.ndarray, from_logits: bool = False ) -> jnp.ndarray: if from_logits: return -jnp.mean(y_true * y_pred - jnp.logaddexp(0.0, y_pred), axis=-1) y_pred = jnp.clip(y_pred, utils.EPSILON, 1.0 - utils.EPSILON) return -jnp.mean(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred), axis=-1)
def update_copula_single(logcdf_conditionals,logpdf_joints,u,v,alpha,rho): d = jnp.shape(logpdf_joints)[0] logcop_distribution,logcop_dens = t1_copula_logdistribution_logdensity(u,v,rho) #Calculate product copulas logcop_dens_prod = jnp.cumsum(logcop_dens) #staggered 1 step to calculate conditional cdfs logcop_dens_prod_staggered = jnp.concatenate((jnp.zeros(1),logcop_dens_prod[0:d-1])) logalpha = jnp.log(alpha) log1alpha = jnp.log(1-alpha) logcdf_conditionals = jnp.logaddexp((log1alpha + logcdf_conditionals),(logalpha + logcop_dens_prod_staggered + logcop_distribution))\ -jnp.logaddexp(log1alpha,(logalpha+logcop_dens_prod_staggered)) logpdf_joints = jnp.logaddexp(log1alpha, (logalpha+logcop_dens_prod))+logpdf_joints return logcdf_conditionals,logpdf_joints
def softplus(x: Array) -> Array: r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) Args: x : input array """ return jnp.logaddexp(x, 0)
def binary_crossentropy(y_true: jnp.ndarray, y_pred: jnp.ndarray, from_logits: bool = False) -> jnp.ndarray: assert abs(y_pred.ndim - y_true.ndim) <= 1 y_true, y_pred = utils.maybe_expand_dims(y_true, y_pred) if from_logits: return -jnp.mean(y_true * y_pred - jnp.logaddexp(0.0, y_pred), axis=-1) y_pred = jnp.clip(y_pred, utils.EPSILON, 1.0 - utils.EPSILON) return -jnp.mean( y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred), axis=-1)
def smooth_leaky_relu(x, alpha=1.0): """Calculate smooth leaky ReLU on an input. Source: https://stats.stackexchange.com/questions/329776/approximating-leaky-relu-with-a-differentiable-function Args: x (float): input value. alpha (float): controls level of nonlinearity via slope. Returns: Value transformed by the smooth leaky ReLU. """ return alpha * x + (1 - alpha) * jnp.logaddexp(x, 0)
def progressive_uniform_sampling(rng_key, proposal, new_proposal): p_accept = jax.scipy.special.expit(new_proposal.weight - proposal.weight) do_accept = jax.random.bernoulli(rng_key, p_accept) updated_proposal = Proposal( new_proposal.state, new_proposal.energy, jnp.logaddexp(proposal.weight, new_proposal.weight), ) return jax.lax.cond( do_accept, lambda _: updated_proposal, lambda _: proposal, operand=None )
def body(state): (key, _done, log_t_cum, log_M) = state key, choose_key, sample_key = random.split(key, 3) i = random.categorical(choose_key, logits=log_Vp_i - log_Vp) # sample query x = random.uniform(sample_key, shape=(D, ), minval=points_lower[i, :], maxval=points_upper[i, :]) def inner_body(inner_state): ( key, _done, _completed, log_t_M, ) = inner_state completed = jnp.logaddexp(log_t_cum, log_t_M) >= log_T # if not completed then increment t_M and test if point in cube log_t_M = jnp.where(completed, log_t_M, jnp.logaddexp(log_t_M, 0.)) key, inner_choose_key = random.split(key, 2) j = random.randint(inner_choose_key, shape=(), minval=0, maxval=N + 1) # point query in_j = points_in_box(x, points_lower[j, :], points_upper[j, :]) done = in_j | completed return key, done, completed, log_t_M (key, _done, completed, log_t_M) = while_loop( lambda inner_state: ~inner_state[1], inner_body, (key, log_t_cum >= log_T, log_t_cum >= log_T, -jnp.inf)) done = completed log_t_cum = jnp.logaddexp(log_t_cum, log_t_M) return (key, done, log_t_cum, jnp.logaddexp(log_M, 0.))
def loop_body(prev, x): prev_phi, prev_emit = prev # emit-to-phi epsilon transition, except if the next label is repetition prev_phi_orig = prev_phi prev_phi = prev_phi.at[:, 1:].set( jnp.logaddexp(prev_phi[:, 1:], prev_emit + _LOGEPSILON * repeat)) logprob_emit, logprob_phi, pad = x # phi-to-emit transition next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) # self-loop transition next_phi = prev_phi + logprob_phi # emit-to-phi blank transition only when the next label is repetition next_phi = next_phi.at[:, 1:].set( jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + _LOGEPSILON * (1.0 - repeat))) pad = pad.reshape((batchsize, 1)) next_emit = pad * prev_emit + (1.0 - pad) * next_emit next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi return (next_phi, next_emit), (next_phi, next_emit)
def softplus(x): r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) Parameters ---------- x: JaxArray, jnp.ndarray The input array. """ x = x.value if isinstance(x, JaxArray) else x return JaxArray(jnp.logaddexp(x, 0))