Example #1
0
def run_hmc_steps(theta, eps, Lmax, key, log_posterior,
                  log_posterior_grad_theta, diagonal_mass_matrix):
    # Diagonal mass matrix: diagonal entries of M (a vector)

    inverse_diag_mass = 1. / diagonal_mass_matrix

    key, subkey = random.split(key)

    # Location-scale transform to get the right variance
    # TODO: Check!
    phi = random.normal(
        subkey, shape=(theta.shape[0], )) * np.sqrt(diagonal_mass_matrix)

    start_theta = theta
    start_phi = phi

    cur_grad = log_posterior_grad_theta(theta)

    key, subkey = random.split(key)

    L = np_classic.random.randint(1, Lmax)

    for cur_l in range(L):
        phi = phi + 0.5 * eps * cur_grad
        theta = theta + eps * inverse_diag_mass * phi
        cur_grad = log_posterior_grad_theta(theta)
        phi = phi + 0.5 * eps * cur_grad

    # Compute (log) acceptance probability
    proposed_log_post = log_posterior(theta)
    previous_log_post = log_posterior(start_theta)

    proposed_log_phi = np.sum(
        norm.logpdf(phi, scale=np.sqrt(diagonal_mass_matrix)))
    previous_log_phi = np.sum(
        norm.logpdf(start_phi, scale=np.sqrt(diagonal_mass_matrix)))

    print(f'Proposed log posterior is: {proposed_log_post}.'
          f'Previous was {previous_log_post}.')

    if (np.isinf(proposed_log_post) or np.isnan(proposed_log_post)
            or np.isneginf(proposed_log_post)):
        # Reject
        was_accepted = False
        new_theta = start_theta
        # FIXME: What number to put here?
        log_r = -10
        return was_accepted, log_r, new_theta

    log_r = (proposed_log_post + proposed_log_phi - previous_log_post -
             previous_log_phi)

    was_accepted, new_theta = acceptance_step(log_r, theta, start_theta, key)

    return was_accepted, log_r, new_theta
    def test_forced_bos_token_logits_processor(self):
        vocab_size = 20
        batch_size = 4
        bos_token_id = 0

        logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)

        # check that all scores are -inf except the bos_token_id score
        input_ids = ids_tensor((batch_size, 1), vocab_size=20)
        cur_len = 1
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
        self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0])  # score for bos_token_id shold be zero

        # check that bos_token_id is not forced if current length is greater than 1
        cur_len = 3
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertFalse(jnp.isinf(scores).any())
Example #3
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out
    def test_forced_eos_token_logits_processor(self):
        vocab_size = 20
        batch_size = 4
        eos_token_id = 0
        max_length = 5

        logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)

        # check that all scores are -inf except the eos_token_id when max_length is reached
        input_ids = ids_tensor((batch_size, 4), vocab_size=20)
        cur_len = 4
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
        self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0])  # score for eos_token_id should be zero

        # check that eos_token_id is not forced if max_length is not reached
        cur_len = 3
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores, cur_len=cur_len)
        self.assertFalse(jnp.isinf(scores).any())