def loss(self, x, targets, z_loss=1): x = f_psum(x) x = self.norm(x) logits = self.proj(x) shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard global_max = jax.lax.pmax( jax.lax.stop_gradient(logits.max(-1, keepdims=True)), "shard") logits -= jax.lax.stop_gradient(global_max) gt_onehot = jax.nn.one_hot(targets - shard_start_index, self.dim_per_shard) predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1) predicted_logits = g_psum(predicted_logits) exp_logits = jnp.exp(logits) sum_exp_logits = exp_logits.sum(axis=-1) sum_exp_logits = g_psum(sum_exp_logits) loss = jnp.log(sum_exp_logits) - predicted_logits loss += (1e-4 * jnp.square(jnp.log(sum_exp_logits)) * z_loss).mean() correct = (0.0 == predicted_logits) return loss, correct
def decode_once(self, decode_state, x, attn_bias): x = f_psum(x) x = self.norm(x) assert x.shape[0] == 1 q, v, k = self.qvk_proj(x) # add new kv to end v = jnp.concatenate((decode_state["v"], v), axis=0)[1:] k = jnp.concatenate((decode_state["k"], k), axis=0)[1:] tokens_decoded = decode_state["tokens_decoded"] + 1 length = v.shape[0] masked_tokens = length - tokens_decoded attention_mask = jnp.arange(0, length) < masked_tokens bias = (-1e10 * attention_mask) bias += attn_bias attn_out = self.self_attn(q, v, k, bias) dense_out = self.ff(x) return g_psum(attn_out + dense_out), { "tokens_decoded": tokens_decoded, "k": k, "v": v }
def get_init_decode_state(self, x, given_length, attn_bias): x = f_psum(x) x = self.norm(x) q, v, k = self.qvk_proj(x) full_length = x.shape[0] masked_tokens = full_length - given_length seq_len = x.shape[0] causal_mask = np.tril(np.ones((seq_len, seq_len))) bias = -1e10 * (1. - causal_mask) # regular AR masking bias -= 1e10 * (jnp.arange(0, full_length) < masked_tokens ) # mask out zero tokens before context starts bias += attn_bias # finally add attn bias for rpe attn_out = self.self_attn(q, v, k, bias) dense_out = self.ff(x) return g_psum(attn_out + dense_out), { "k": k, "v": v, "tokens_decoded": given_length.astype(jnp.uint32) }
def __call__(self, x, attn_bias): x = f_psum(x) x = self.norm(x) q, v, k = self.qvk_proj(x) seq_len = x.shape[0] causal_mask = np.tril(np.ones((seq_len, seq_len))) bias = -1e10 * (1. - causal_mask) bias += attn_bias attn_out = self.self_attn(q, v, k, bias) dense_out = self.ff(x) return g_psum(attn_out + dense_out)
def get_init_decode_state(self, x, given_length, attn_bias): x = f_psum(x) x = self.norm(x) q, v, k, ff = self.input(x) full_length = x.shape[1] masked_tokens = full_length - given_length causal_mask = np.tril(np.ones((full_length, full_length))) bias = -1e10 * (1. - causal_mask) # regular AR masking bias -= 1e10 * (jnp.arange(0, full_length) < masked_tokens ) # mask out zero tokens before context starts bias += attn_bias # finally add attn bias for rpe attn_out = self.self_attn(q, v, k, bias) ff_out = jax.nn.gelu(ff) return self.output(attn_out, ff_out),\ {"k": k, "v": v, "tokens_decoded": given_length.astype(jnp.uint32)}