def run_hmc_steps(theta, eps, Lmax, key, log_posterior, log_posterior_grad_theta, diagonal_mass_matrix): # Diagonal mass matrix: diagonal entries of M (a vector) inverse_diag_mass = 1. / diagonal_mass_matrix key, subkey = random.split(key) # Location-scale transform to get the right variance # TODO: Check! phi = random.normal( subkey, shape=(theta.shape[0], )) * np.sqrt(diagonal_mass_matrix) start_theta = theta start_phi = phi cur_grad = log_posterior_grad_theta(theta) key, subkey = random.split(key) L = np_classic.random.randint(1, Lmax) for cur_l in range(L): phi = phi + 0.5 * eps * cur_grad theta = theta + eps * inverse_diag_mass * phi cur_grad = log_posterior_grad_theta(theta) phi = phi + 0.5 * eps * cur_grad # Compute (log) acceptance probability proposed_log_post = log_posterior(theta) previous_log_post = log_posterior(start_theta) proposed_log_phi = np.sum( norm.logpdf(phi, scale=np.sqrt(diagonal_mass_matrix))) previous_log_phi = np.sum( norm.logpdf(start_phi, scale=np.sqrt(diagonal_mass_matrix))) print(f'Proposed log posterior is: {proposed_log_post}.' f'Previous was {previous_log_post}.') if (np.isinf(proposed_log_post) or np.isnan(proposed_log_post) or np.isneginf(proposed_log_post)): # Reject was_accepted = False new_theta = start_theta # FIXME: What number to put here? log_r = -10 return was_accepted, log_r, new_theta log_r = (proposed_log_post + proposed_log_phi - previous_log_post - previous_log_phi) was_accepted, new_theta = acceptance_step(log_r, theta, start_theta, key) return was_accepted, log_r, new_theta
def test_forced_bos_token_logits_processor(self): vocab_size = 20 batch_size = 4 bos_token_id = 0 logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) # check that all scores are -inf except the bos_token_id score input_ids = ids_tensor((batch_size, 1), vocab_size=20) cur_len = 1 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all()) self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero # check that bos_token_id is not forced if current length is greater than 1 cur_len = 3 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertFalse(jnp.isinf(scores).any())
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def test_forced_eos_token_logits_processor(self): vocab_size = 20 batch_size = 4 eos_token_id = 0 max_length = 5 logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) # check that all scores are -inf except the eos_token_id when max_length is reached input_ids = ids_tensor((batch_size, 4), vocab_size=20) cur_len = 4 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all()) self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero # check that eos_token_id is not forced if max_length is not reached cur_len = 3 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertFalse(jnp.isinf(scores).any())