def output(self, *x): out = jnp.concatenate(x, axis=-1) out = maybe_shard(out, P("dp", None, "mp", None)) out = out.reshape(x[0].shape[:-2] + (-1, )) out_shard = maybe_shard(out, P("dp", None, "mp")) return self.output_proj(out_shard)
def __call__(self, x, dtype=jnp.bfloat16): input_onehot = jax.nn.one_hot(x, self.in_dim) input_onehot = maybe_shard(input_onehot, P("dp", None, "mp")) proj_out = self.proj(input_onehot) return proj_out
def head_split(self, x): reshaped = x.reshape(x.shape[:-1] + (self.n_head // self.mp_num, self.d_head)) reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + x.shape[-1:]) # return reshaped return maybe_shard(reshaped, P("dp", None, "mp", None))
def eval_apply_fn(params, x, y, mask): embed_apply_fn, transformer_apply_fn = apply_fns() if early_collect: bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy) else: bf16_params = to_bf16(params) def eval_loss(x, y): loss, correct = Projection(config).loss(x, y) return { "loss": loss.mean(axis=-1), "last_loss": loss[:, -1], "all_loss": loss, "correct": correct } projection_apply_fn = hk.without_apply_rng( hk.transform(eval_loss)).apply x = embed_apply_fn(bf16_params["embed"], x) def apply_scan_fn(layer_in, layer_state): x, mask = layer_in return (to_bf16(transformer_apply_fn(layer_state, x, mask)), mask), None x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask), xs=bf16_params["transformer"])[0][0] return projection_apply_fn(bf16_params["proj"], x, y)
def self_attn(self, q, v, k, attn_bias): k_rot = k[:, :, :, :self.d_rotary] k_pass = k[:, :, :, self.d_rotary:] q_rot = q[:, :, :, :self.d_rotary] q_pass = q[:, :, :, self.d_rotary:] sincos = fixed_pos_embedding(k_rot, seq_dim=1) q_rot = apply_rotary_pos_emb_v2(q_rot, sincos) k_rot = apply_rotary_pos_emb_v2(k_rot, sincos) q_rot = maybe_shard(q_rot, P("dp", None, "mp", None)) k_rot = maybe_shard(k_rot, P("dp", None, "mp", None)) k = jnp.concatenate([k_rot, k_pass], axis=-1) q = jnp.concatenate([q_rot, q_pass], axis=-1) k = maybe_shard(k, P("dp", None, "mp", None)) q = maybe_shard(q, P("dp", None, "mp", None)) attention_logits = jnp.einsum("bthd,bThd->bhtT", q, k) attention_logits = maybe_shard(attention_logits, P("dp", "mp", None, None)) sqrt_key_size = np.sqrt(self.d_head).astype(k.dtype) attention_logits = attention_logits / sqrt_key_size attention_logits += attn_bias attention_logits = maybe_shard(attention_logits, P("dp", "mp", None, None)) attention_weights = jax.nn.softmax(attention_logits) attention_weights = maybe_shard(attention_weights, P("dp", "mp", None, None)) attention_vec = jnp.einsum("bhtT,bThd->bthd", attention_weights, v) attention_vec = maybe_shard(attention_vec, P("dp", None, "mp", None)) sharded_attn_vec = attention_vec.reshape(attention_vec.shape[:2] + (self.mp_num, self.n_head // self.mp_num, -1)) sharded_attn_vec = maybe_shard(sharded_attn_vec, P("dp", None, "mp", None, None)) attention_vec = attention_vec.reshape(sharded_attn_vec.shape[:2] + (self.mp_num, -1)) return maybe_shard(attention_vec, P("dp", None, "mp", None))
def input(self, x): # [batch, seq, dim] projected = self.input_proj(x) # [batch, seq, mp, dim//mp] projected = maybe_shard(projected, P("dp", None, "mp")) mp_split = jnp.reshape(projected, projected.shape[:-1] + (self.mp_num, -1)) mp_split = maybe_shard(mp_split, P("dp", None, "mp", None)) local_dim = self.d_head * self.n_head // self.mp_num q, v, k, ff = jnp.split(mp_split, [local_dim, local_dim * 2, local_dim * 3], axis=-1) q = self.head_split(q) v = self.head_split(v) k = self.head_split(k) return q, v, k, ff
def train(state, ctx, tgt): if early_collect: bf16_params = maybe_shard(to_bf16(state["params"]), mp_shard_strategy) else: bf16_params = to_bf16(state["params"]) def microbatch(old_grad, batch): ctx, tgt = batch val_grad_fn = jax.value_and_grad(train_apply_fn, has_aux=True, allow_int=True) (loss, last_loss), grad = val_grad_fn(bf16_params, ctx, tgt) new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad, grad) return new_grad, (loss, last_loss) if ctx.shape[0] == 1: val_grad_fn = jax.value_and_grad(train_apply_fn, has_aux=True, allow_int=True) (loss, last_loss), grad = val_grad_fn(bf16_params, ctx[0], tgt[0]) else: grad, (loss, last_loss) = jax.lax.scan( microbatch, jax.tree_map( lambda x: jnp.zeros_like(x).astype(jnp.bfloat16), bf16_params), (ctx, tgt)) updates, new_opt_state = optimizer.update(grad, state["opt_state"], state["params"]) return to_f32(loss), to_f32(last_loss), { "params": optax.apply_updates(state["params"], to_f32(updates)), "step": state["step"] + 1, "opt_state": new_opt_state, }
def residual(x, mask): out = x + TransformerLayerShardV2( config, init_scale=2. / config["layers"])(x, mask) return maybe_shard(out, P("dp", None, "mp"))
def embedding(x): x = maybe_shard(x, P("dp", None)) return EmbeddingShardV2(config)(x)