Ejemplo n.º 1
0
def _logsumexp(x, dim):
    return logsumexp(x, axis=dim)
Ejemplo n.º 2
0
    def _finalise_results(self, state: NestedSamplerState, collect_samples: bool, stoachastic_uncertainty: bool,
                          max_samples: int):
        collect = ['logZ',
                   'logZerr',
                   'ESS',
                   'ESS_err',
                   'H',
                   'H_err',
                   'num_likelihood_evaluations',
                   'efficiency',
                   'marginalised',
                   'marginalised_uncert',
                   'log_L_samples',
                   'n_per_sample',
                   'log_p',
                   'log_X',
                   'sampler_efficiency',
                   'num_samples'
                   ]

        if collect_samples:
            collect.append('samples')


        NestedSamplerResults = namedtuple('NestedSamplerResults', collect)
        tracked_expectations = TrackedExpectation(self.marginalised, self.marginalised_shapes,
                                                  state=state.tracked_expectations_state)
        live_update_results = tracked_expectations.update_from_live_points(state.live_points, state.log_L_live)
        if self.marginalised is not None:
            marginalised = tracked_expectations.marg_mean()
            marginalised_uncert = None  # tracked_expectations.marg_variance()
        else:
            marginalised = None
            marginalised_uncert = None

        num_live_points = state.log_L_live.shape[0]
        n_per_sample = jnp.where(jnp.arange(max_samples) < state.num_dead, num_live_points, jnp.inf)
        n_per_sample = dynamic_update_slice(n_per_sample,
                                            num_live_points - jnp.arange(num_live_points, dtype=n_per_sample.dtype),
                                            [state.num_dead])
        sampler_efficiency = dynamic_update_slice(state.sampler_efficiency,
                                                  jnp.ones(num_live_points),
                                                  [state.num_dead])
        log_w = dynamic_update_slice(state.log_w,
                                     live_update_results[3],
                                     [state.num_dead])
        log_p = log_w - logsumexp(log_w)
        log_X = dynamic_update_slice(state.log_X,
                                     live_update_results[2],
                                     [state.num_dead])
        log_L_samples = dynamic_update_slice(state.log_L_dead,
                                          live_update_results[1],
                                          [state.num_dead])
        num_samples = state.num_dead + num_live_points



        data = dict(
            logZ=tracked_expectations.evidence_mean(),
            logZerr=jnp.sqrt(tracked_expectations.evidence_variance()),
            ESS=tracked_expectations.effective_sample_size(),
            ESS_err=None,
            H=tracked_expectations.information_gain_mean(),
            H_err=jnp.sqrt(tracked_expectations.information_gain_variance()),
            num_likelihood_evaluations=state.num_likelihood_evaluations,
            efficiency=(state.num_dead + state.log_L_live.shape[0]) / state.num_likelihood_evaluations,
            marginalised=marginalised,
            marginalised_uncert=marginalised_uncert,
            n_per_sample=n_per_sample,
            log_p=log_p,
            log_X=log_X,
            log_L_samples=log_L_samples,
            num_samples=num_samples,
            sampler_efficiency=sampler_efficiency
        )

        if collect_samples:

            # log_t = jnp.where(jnp.isinf(n_per_sample), 0., jnp.log(n_per_sample) - jnp.log(n_per_sample + 1.))
            # log_X = jnp.cumsum(log_t)
            ar = jnp.argsort(state.log_L_live)
            samples = dict_multimap(lambda dead_points, live_points:
                                    dynamic_update_slice(dead_points,
                                                         live_points.astype(dead_points.dtype)[ar, ...],
                                                         [state.num_dead] + [0] * (len(dead_points.shape) - 1)),
                                    state.dead_points, state.live_points)
            # log_L_samples = dynamic_update_slice(state.log_L_dead, state.log_L_live[ar], [state.num_dead])

            # sampler_efficiency = dynamic_update_slice(state.sampler_efficiency,
            #                                           jnp.ones(num_live_points),
            #                                           [state.num_dead])

            # num_samples = state.num_dead + num_live_points
            data['samples'] = samples
            # data['log_L_samples'] = log_L_samples
            # data['n_per_sample'] = n_per_sample
            # data['log_X'] = log_X
            #
            # data['sampler_efficiency'] = sampler_efficiency
            # data['num_samples'] = num_samples

            if stoachastic_uncertainty:
                S = 200
                logZ, m, cov, ESS, H = vmap(lambda key: stochastic_result_computation(n_per_sample,
                                                                                      key, samples, log_L_samples))(
                    random.split(state.key, S))
                data['logZ'] = jnp.mean(logZ, axis=0)
                data['logZerr'] = jnp.std(logZ, axis=0)
                data['H'] = jnp.mean(H, axis=0)
                data['H_err'] = jnp.std(H, axis=0)
                data['ESS'] = jnp.mean(ESS, axis=0)
                data['ESS_err'] = jnp.std(ESS, axis=0)

            # build mean weights
            # log_L_samples = jnp.concatenate([jnp.array([-jnp.inf]), log_L_samples])
            # log_X = jnp.concatenate([jnp.array([0.]), log_X])
            # log(dX_i) = log(X[i-1] - X[i]) = log((1-t_i)*X[i-1]) = log(1-t_i) + log(X[i-1])
            # log_dX = - jnp.log(n_per_sample + 1.) + log_X[:-1]
            # log_dX = jnp.log(-jnp.diff(jnp.exp(log_X)))
            # log_avg_L = jnp.logaddexp(log_L_samples[:-1], log_L_samples[1:]) - jnp.log(2.)
            # w_i = dX_i avg_L_i
            # log_w = log_dX + log_avg_L
            # log_p = log_w - logsumexp(log_w)
            # data['log_p'] = log_p

            # if self.marginalise is not None:
            #     def single_marginalise(marginalise):
            #         return jnp.sum(vmap(lambda p, sample: p * marginalise(**sample))(jnp.exp(log_p), samples), axis=0)
            #
            #     data['marginalised'] = dict_multimap(single_marginalise, self.marginalise)

        return NestedSamplerResults(**data)
def lsh_attention_single_head(query,
                              value,
                              n_buckets,
                              n_hashes,
                              causal_mask=True,
                              length_norm=False):
    """Applies LSH attention on a single head and a single batch.

  Args:
    query: query tensor of shape [qlength, dims].
    value: value tensor of shape [vlength, dims].
    n_buckets: integer, number of buckets.
    n_hashes: integer, number of hashes.
    causal_mask: boolean, to use causal mask or not.
    length_norm: boolean, to normalize k or not.
  Returns:
    output tensor of shape [qlength, dims]
  """

    qdim, vdim = query.shape[-1], value.shape[-1]
    chunk_size = n_hashes * n_buckets

    seqlen = query.shape[0]

    with nn.stochastic(jax.random.PRNGKey(0)):
        rng = nn.make_rng()

    buckets = hash_vectors(query,
                           rng,
                           num_buckets=n_buckets,
                           num_hashes=n_hashes)
    # buckets should be (seq_len)
    assert buckets.shape[-1] == n_hashes * seqlen

    total_hashes = n_hashes

    # create sort and unsort
    ticker = jax.lax.tie_in(query, jnp.arange(n_hashes * seqlen))
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = jax.lax.stop_gradient(buckets_and_t)
    # ticker = jnp.tile(jnp.reshape(ticker, [1, -1]), [batch_size, 1])
    sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                   ticker,
                                                   dimension=-1)
    _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
    sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
    sticker = jax.lax.stop_gradient(sticker)
    undo_sort = jax.lax.stop_gradient(undo_sort)

    st = (sticker % seqlen)

    sqk = jnp.take(query, st, axis=0)
    sv = jnp.take(value, st, axis=0)

    bkv_t = jnp.reshape(st, (chunk_size, -1))
    bqk = jnp.reshape(sqk, (chunk_size, -1, qdim))
    bv = jnp.reshape(sv, (chunk_size, -1, vdim))
    bq = bqk
    bk = bqk

    if length_norm:
        bk = length_normalized(bk)

    # get previous chunks
    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)

    # compute dot product attention
    dots = jnp.einsum('hie,hje->hij', bq, bk) * (qdim**0.5)

    if causal_mask:
        # apply causal mask
        # TODO(yitay): This is not working yet
        # We don't need causal reformer for any task YET.
        pass

    dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
    slogits = jnp.reshape(dots_logsumexp, [-1])
    dots = jnp.exp(dots - dots_logsumexp)

    x = jnp.matmul(dots, bv)
    x = jnp.reshape(x, [-1, qdim])

    # Unsort
    o = permute_via_gather(x, undo_sort, sticker, axis=0)
    logits = permute_via_sort(slogits, sticker, undo_sort, axis=0)
    logits = jnp.reshape(logits, [total_hashes, seqlen, 1])
    probs = jnp.exp(logits - logsumexp(logits, axis=0, keepdims=True))
    o = jnp.reshape(o, [n_hashes, seqlen, qdim])
    out = jnp.sum(o * probs, axis=0)
    out = jnp.reshape(out, [seqlen, qdim])

    return out
 def predict(theta, x_val):
     return -logsumexp(jnp.array([0., -jnp.dot(theta, x_val)]))