Example #1
0
    def output(self, *x):
        out = jnp.concatenate(x, axis=-1)
        out = maybe_shard(out, P("dp", None, "mp", None))

        out = out.reshape(x[0].shape[:-2] + (-1, ))
        out_shard = maybe_shard(out, P("dp", None, "mp"))

        return self.output_proj(out_shard)
Example #2
0
    def __call__(self, x, dtype=jnp.bfloat16):
        input_onehot = jax.nn.one_hot(x, self.in_dim)
        input_onehot = maybe_shard(input_onehot, P("dp", None, "mp"))

        proj_out = self.proj(input_onehot)

        return proj_out
Example #3
0
    def head_split(self, x):
        reshaped = x.reshape(x.shape[:-1] +
                             (self.n_head // self.mp_num, self.d_head))
        reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + x.shape[-1:])

        # return reshaped
        return maybe_shard(reshaped, P("dp", None, "mp", None))
        def eval_apply_fn(params, x, y, mask):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            if early_collect:
                bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy)
            else:
                bf16_params = to_bf16(params)

            def eval_loss(x, y):
                loss, correct = Projection(config).loss(x, y)
                return {
                    "loss": loss.mean(axis=-1),
                    "last_loss": loss[:, -1],
                    "all_loss": loss,
                    "correct": correct
                }

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(eval_loss)).apply

            x = embed_apply_fn(bf16_params["embed"], x)

            def apply_scan_fn(layer_in, layer_state):
                x, mask = layer_in
                return (to_bf16(transformer_apply_fn(layer_state, x,
                                                     mask)), mask), None

            x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask),
                             xs=bf16_params["transformer"])[0][0]

            return projection_apply_fn(bf16_params["proj"], x, y)
Example #5
0
    def self_attn(self, q, v, k, attn_bias):
        k_rot = k[:, :, :, :self.d_rotary]
        k_pass = k[:, :, :, self.d_rotary:]

        q_rot = q[:, :, :, :self.d_rotary]
        q_pass = q[:, :, :, self.d_rotary:]

        sincos = fixed_pos_embedding(k_rot, seq_dim=1)
        q_rot = apply_rotary_pos_emb_v2(q_rot, sincos)
        k_rot = apply_rotary_pos_emb_v2(k_rot, sincos)
        q_rot = maybe_shard(q_rot, P("dp", None, "mp", None))
        k_rot = maybe_shard(k_rot, P("dp", None, "mp", None))

        k = jnp.concatenate([k_rot, k_pass], axis=-1)
        q = jnp.concatenate([q_rot, q_pass], axis=-1)

        k = maybe_shard(k, P("dp", None, "mp", None))
        q = maybe_shard(q, P("dp", None, "mp", None))

        attention_logits = jnp.einsum("bthd,bThd->bhtT", q, k)

        attention_logits = maybe_shard(attention_logits,
                                       P("dp", "mp", None, None))

        sqrt_key_size = np.sqrt(self.d_head).astype(k.dtype)
        attention_logits = attention_logits / sqrt_key_size

        attention_logits += attn_bias
        attention_logits = maybe_shard(attention_logits,
                                       P("dp", "mp", None, None))

        attention_weights = jax.nn.softmax(attention_logits)
        attention_weights = maybe_shard(attention_weights,
                                        P("dp", "mp", None, None))

        attention_vec = jnp.einsum("bhtT,bThd->bthd", attention_weights, v)

        attention_vec = maybe_shard(attention_vec, P("dp", None, "mp", None))
        sharded_attn_vec = attention_vec.reshape(attention_vec.shape[:2] +
                                                 (self.mp_num, self.n_head //
                                                  self.mp_num, -1))
        sharded_attn_vec = maybe_shard(sharded_attn_vec,
                                       P("dp", None, "mp", None, None))

        attention_vec = attention_vec.reshape(sharded_attn_vec.shape[:2] +
                                              (self.mp_num, -1))
        return maybe_shard(attention_vec, P("dp", None, "mp", None))
Example #6
0
    def input(self, x):
        # [batch, seq, dim]
        projected = self.input_proj(x)

        # [batch, seq, mp, dim//mp]
        projected = maybe_shard(projected, P("dp", None, "mp"))
        mp_split = jnp.reshape(projected,
                               projected.shape[:-1] + (self.mp_num, -1))
        mp_split = maybe_shard(mp_split, P("dp", None, "mp", None))

        local_dim = self.d_head * self.n_head // self.mp_num

        q, v, k, ff = jnp.split(mp_split,
                                [local_dim, local_dim * 2, local_dim * 3],
                                axis=-1)

        q = self.head_split(q)
        v = self.head_split(v)
        k = self.head_split(k)

        return q, v, k, ff
        def train(state, ctx, tgt):
            if early_collect:
                bf16_params = maybe_shard(to_bf16(state["params"]),
                                          mp_shard_strategy)
            else:
                bf16_params = to_bf16(state["params"])

            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)

            if ctx.shape[0] == 1:
                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx[0],
                                                      tgt[0])
            else:
                grad, (loss, last_loss) = jax.lax.scan(
                    microbatch,
                    jax.tree_map(
                        lambda x: jnp.zeros_like(x).astype(jnp.bfloat16),
                        bf16_params), (ctx, tgt))

            updates, new_opt_state = optimizer.update(grad, state["opt_state"],
                                                      state["params"])

            return to_f32(loss), to_f32(last_loss), {
                "params": optax.apply_updates(state["params"],
                                              to_f32(updates)),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }
 def residual(x, mask):
     out = x + TransformerLayerShardV2(
         config, init_scale=2. / config["layers"])(x, mask)
     return maybe_shard(out, P("dp", None, "mp"))
 def embedding(x):
     x = maybe_shard(x, P("dp", None))
     return EmbeddingShardV2(config)(x)