def _logsumexp(x, dim): return logsumexp(x, axis=dim)
def _finalise_results(self, state: NestedSamplerState, collect_samples: bool, stoachastic_uncertainty: bool, max_samples: int): collect = ['logZ', 'logZerr', 'ESS', 'ESS_err', 'H', 'H_err', 'num_likelihood_evaluations', 'efficiency', 'marginalised', 'marginalised_uncert', 'log_L_samples', 'n_per_sample', 'log_p', 'log_X', 'sampler_efficiency', 'num_samples' ] if collect_samples: collect.append('samples') NestedSamplerResults = namedtuple('NestedSamplerResults', collect) tracked_expectations = TrackedExpectation(self.marginalised, self.marginalised_shapes, state=state.tracked_expectations_state) live_update_results = tracked_expectations.update_from_live_points(state.live_points, state.log_L_live) if self.marginalised is not None: marginalised = tracked_expectations.marg_mean() marginalised_uncert = None # tracked_expectations.marg_variance() else: marginalised = None marginalised_uncert = None num_live_points = state.log_L_live.shape[0] n_per_sample = jnp.where(jnp.arange(max_samples) < state.num_dead, num_live_points, jnp.inf) n_per_sample = dynamic_update_slice(n_per_sample, num_live_points - jnp.arange(num_live_points, dtype=n_per_sample.dtype), [state.num_dead]) sampler_efficiency = dynamic_update_slice(state.sampler_efficiency, jnp.ones(num_live_points), [state.num_dead]) log_w = dynamic_update_slice(state.log_w, live_update_results[3], [state.num_dead]) log_p = log_w - logsumexp(log_w) log_X = dynamic_update_slice(state.log_X, live_update_results[2], [state.num_dead]) log_L_samples = dynamic_update_slice(state.log_L_dead, live_update_results[1], [state.num_dead]) num_samples = state.num_dead + num_live_points data = dict( logZ=tracked_expectations.evidence_mean(), logZerr=jnp.sqrt(tracked_expectations.evidence_variance()), ESS=tracked_expectations.effective_sample_size(), ESS_err=None, H=tracked_expectations.information_gain_mean(), H_err=jnp.sqrt(tracked_expectations.information_gain_variance()), num_likelihood_evaluations=state.num_likelihood_evaluations, efficiency=(state.num_dead + state.log_L_live.shape[0]) / state.num_likelihood_evaluations, marginalised=marginalised, marginalised_uncert=marginalised_uncert, n_per_sample=n_per_sample, log_p=log_p, log_X=log_X, log_L_samples=log_L_samples, num_samples=num_samples, sampler_efficiency=sampler_efficiency ) if collect_samples: # log_t = jnp.where(jnp.isinf(n_per_sample), 0., jnp.log(n_per_sample) - jnp.log(n_per_sample + 1.)) # log_X = jnp.cumsum(log_t) ar = jnp.argsort(state.log_L_live) samples = dict_multimap(lambda dead_points, live_points: dynamic_update_slice(dead_points, live_points.astype(dead_points.dtype)[ar, ...], [state.num_dead] + [0] * (len(dead_points.shape) - 1)), state.dead_points, state.live_points) # log_L_samples = dynamic_update_slice(state.log_L_dead, state.log_L_live[ar], [state.num_dead]) # sampler_efficiency = dynamic_update_slice(state.sampler_efficiency, # jnp.ones(num_live_points), # [state.num_dead]) # num_samples = state.num_dead + num_live_points data['samples'] = samples # data['log_L_samples'] = log_L_samples # data['n_per_sample'] = n_per_sample # data['log_X'] = log_X # # data['sampler_efficiency'] = sampler_efficiency # data['num_samples'] = num_samples if stoachastic_uncertainty: S = 200 logZ, m, cov, ESS, H = vmap(lambda key: stochastic_result_computation(n_per_sample, key, samples, log_L_samples))( random.split(state.key, S)) data['logZ'] = jnp.mean(logZ, axis=0) data['logZerr'] = jnp.std(logZ, axis=0) data['H'] = jnp.mean(H, axis=0) data['H_err'] = jnp.std(H, axis=0) data['ESS'] = jnp.mean(ESS, axis=0) data['ESS_err'] = jnp.std(ESS, axis=0) # build mean weights # log_L_samples = jnp.concatenate([jnp.array([-jnp.inf]), log_L_samples]) # log_X = jnp.concatenate([jnp.array([0.]), log_X]) # log(dX_i) = log(X[i-1] - X[i]) = log((1-t_i)*X[i-1]) = log(1-t_i) + log(X[i-1]) # log_dX = - jnp.log(n_per_sample + 1.) + log_X[:-1] # log_dX = jnp.log(-jnp.diff(jnp.exp(log_X))) # log_avg_L = jnp.logaddexp(log_L_samples[:-1], log_L_samples[1:]) - jnp.log(2.) # w_i = dX_i avg_L_i # log_w = log_dX + log_avg_L # log_p = log_w - logsumexp(log_w) # data['log_p'] = log_p # if self.marginalise is not None: # def single_marginalise(marginalise): # return jnp.sum(vmap(lambda p, sample: p * marginalise(**sample))(jnp.exp(log_p), samples), axis=0) # # data['marginalised'] = dict_multimap(single_marginalise, self.marginalise) return NestedSamplerResults(**data)
def lsh_attention_single_head(query, value, n_buckets, n_hashes, causal_mask=True, length_norm=False): """Applies LSH attention on a single head and a single batch. Args: query: query tensor of shape [qlength, dims]. value: value tensor of shape [vlength, dims]. n_buckets: integer, number of buckets. n_hashes: integer, number of hashes. causal_mask: boolean, to use causal mask or not. length_norm: boolean, to normalize k or not. Returns: output tensor of shape [qlength, dims] """ qdim, vdim = query.shape[-1], value.shape[-1] chunk_size = n_hashes * n_buckets seqlen = query.shape[0] with nn.stochastic(jax.random.PRNGKey(0)): rng = nn.make_rng() buckets = hash_vectors(query, rng, num_buckets=n_buckets, num_hashes=n_hashes) # buckets should be (seq_len) assert buckets.shape[-1] == n_hashes * seqlen total_hashes = n_hashes # create sort and unsort ticker = jax.lax.tie_in(query, jnp.arange(n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # ticker = jnp.tile(jnp.reshape(ticker, [1, -1]), [batch_size, 1]) sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = jnp.take(query, st, axis=0) sv = jnp.take(value, st, axis=0) bkv_t = jnp.reshape(st, (chunk_size, -1)) bqk = jnp.reshape(sqk, (chunk_size, -1, qdim)) bv = jnp.reshape(sv, (chunk_size, -1, vdim)) bq = bqk bk = bqk if length_norm: bk = length_normalized(bk) # get previous chunks bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) # compute dot product attention dots = jnp.einsum('hie,hje->hij', bq, bk) * (qdim**0.5) if causal_mask: # apply causal mask # TODO(yitay): This is not working yet # We don't need causal reformer for any task YET. pass dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True) slogits = jnp.reshape(dots_logsumexp, [-1]) dots = jnp.exp(dots - dots_logsumexp) x = jnp.matmul(dots, bv) x = jnp.reshape(x, [-1, qdim]) # Unsort o = permute_via_gather(x, undo_sort, sticker, axis=0) logits = permute_via_sort(slogits, sticker, undo_sort, axis=0) logits = jnp.reshape(logits, [total_hashes, seqlen, 1]) probs = jnp.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = jnp.reshape(o, [n_hashes, seqlen, qdim]) out = jnp.sum(o * probs, axis=0) out = jnp.reshape(out, [seqlen, qdim]) return out
def predict(theta, x_val): return -logsumexp(jnp.array([0., -jnp.dot(theta, x_val)]))