def eval_apply_fn(params, x, y, mask):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            if early_collect:
                bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy)
            else:
                bf16_params = to_bf16(params)

            def eval_loss(x, y):
                loss, correct = Projection(config).loss(x, y)
                return {
                    "loss": loss.mean(axis=-1),
                    "last_loss": loss[:, -1],
                    "all_loss": loss,
                    "correct": correct
                }

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(eval_loss)).apply

            x = embed_apply_fn(bf16_params["embed"], x)

            def apply_scan_fn(layer_in, layer_state):
                x, mask = layer_in
                return (to_bf16(transformer_apply_fn(layer_state, x,
                                                     mask)), mask), None

            x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask),
                             xs=bf16_params["transformer"])[0][0]

            return projection_apply_fn(bf16_params["proj"], x, y)
            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
                (loss, last_loss), grad = val_grad_fn(to_bf16(state["params"]),
                                                      ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)
        def eval(state, ctx, tgt, ctx_length):
            def eval_loss(x, y, mask):
                transformer = CausalTransformerShard(config)
                return transformer.loss(x, y, mask=mask)

            eval_loss_fn = hk.without_apply_rng(hk.transform(eval_loss)).apply

            mask = (jnp.arange(0, len(ctx)) > ctx_length) * -1e10

            return eval_loss_fn(to_bf16(state["params"]), ctx, tgt, mask)
        def train(state, ctx, tgt):
            if early_collect:
                bf16_params = maybe_shard(to_bf16(state["params"]),
                                          mp_shard_strategy)
            else:
                bf16_params = to_bf16(state["params"])

            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)

            if ctx.shape[0] == 1:
                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx[0],
                                                      tgt[0])
            else:
                grad, (loss, last_loss) = jax.lax.scan(
                    microbatch,
                    jax.tree_map(
                        lambda x: jnp.zeros_like(x).astype(jnp.bfloat16),
                        bf16_params), (ctx, tgt))

            updates, new_opt_state = optimizer.update(grad, state["opt_state"],
                                                      state["params"])

            return to_f32(loss), to_f32(last_loss), {
                "params": optax.apply_updates(state["params"],
                                              to_f32(updates)),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }
        def train_apply_fn(params, x, y):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            def train_loss(x, y):
                loss, _ = Projection(config).loss(x, y, z_loss=1.0)
                return loss.mean(), loss[:, -1].mean()

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(train_loss)).apply

            x = embed_apply_fn(params["embed"], x)
            x = to_bf16(x)

            def apply_scan_fn(x, layer_state):
                return to_bf16(transformer_apply_fn(layer_state, x, 0)), None

            x = jax.lax.scan(apply_scan_fn, x, xs=params["transformer"])[0]

            return projection_apply_fn(params["proj"], x, y)
        def train(state, ctx, tgt):
            def train_loss(x, y):
                transformer = CausalTransformerShard(config)
                out = transformer.loss(x, y, z_loss=True)

                return out["loss"], out["last_loss"]

            train_loss_fn = hk.without_apply_rng(
                hk.transform(train_loss)).apply

            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
                (loss, last_loss), grad = val_grad_fn(to_bf16(state["params"]),
                                                      ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)

            if ctx.shape[0] == 1:
                val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
                (loss, last_loss), grad = val_grad_fn(to_bf16(state["params"]),
                                                      ctx[0], tgt[0])
            else:
                grad, (loss, last_loss) = jax.lax.scan(
                    microbatch,
                    jax.tree_map(
                        lambda x: jnp.zeros_like(x).astype(jnp.bfloat16),
                        state["params"]), (ctx, tgt))

            grad = jax.lax.pmean(grad, "batch")
            updates, new_opt_state = optimizer.update(grad, state["opt_state"],
                                                      state["params"])

            return to_f32(loss), to_f32(last_loss), {
                "params": optax.apply_updates(state["params"],
                                              to_f32(updates)),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }
Пример #7
0
    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
        meta = json.load(f)

    ckpt_step = meta["checkpoints"][-1]
    print(f"using checkpoint {ckpt_step}")

    with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
        network = CausalTransformer(params)

        start = time.time()
        network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
        print(f"network loaded in {time.time() - start:.06}s")

        start = time.time()
        del network.state["opt_state"]

        network.state["params"] = to_bf16(network.state["params"])
        print(f"network converted in {time.time() - start:.06}s")

        for i in range(cores_per_replica):
            write_ckpt(network.state, f"gs://{bucket}/{model_dir}_slim/step_{ckpt_step}/", i)
            print(f"written shard {i}")
 def apply_scan_fn(layer_in, layer_state):
     x, mask = layer_in
     return (to_bf16(transformer_apply_fn(layer_state, x,
                                          mask)), mask), None
 def apply_scan_fn(x, layer_state):
     return to_bf16(transformer_apply_fn(layer_state, x, 0)), None
    def __init__(self, config):
        self.config = config
        optimizer = config["optimizer"]

        bf16_optimizer = config.get("bf16_optimizer", False)
        early_cast = config.get("early_cast", False)
        early_collect = config.get("early_collect", True)

        def embedding(x):
            x = maybe_shard(x, P("dp", None))
            return EmbeddingShardV2(config)(x)

        def residual(x, mask):
            out = x + TransformerLayerShardV2(
                config, init_scale=2. / config["layers"])(x, mask)
            return maybe_shard(out, P("dp", None, "mp"))

        def transformer(x, mask):
            return hk.remat(residual)(x, mask)

        def projection(x):
            return Projection(config)(x)

        def init_fns():
            embed_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(embedding)).init
            transformer_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(transformer)).init
            projection_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(projection)).init

            return embed_init_fn, transformer_init_fn, projection_init_fn

        def shard_strategy(shape_dtype, parallel):
            if shape_dtype.ndim <= 1:
                return P()
            # embedding/projection layers
            elif shape_dtype.shape == (config["n_vocab"], config["d_model"]):
                return P(parallel, None)
            elif shape_dtype.shape == (config["d_model"], config["n_vocab"]):
                return P(None, parallel)

            # a transformer layer
            elif shape_dtype.shape[0] == config["layers"]:
                if shape_dtype.ndim == 2:
                    # a channel wise variable (e.g. layernorm parameters)
                    # replicate it for speed
                    return P(None)
                elif shape_dtype.ndim == 3:
                    # a weight matrix
                    matrix_size = shape_dtype.shape[1:]

                    assert matrix_size[0] != matrix_size[
                        1]  # this case is ambiguous

                    if matrix_size[0] == config["d_model"]:
                        # shard along the axis which is _not_ the model dimension
                        return P(None, None, parallel)
                    elif matrix_size[1] == config["d_model"]:
                        return P(None, parallel, None)
                else:
                    raise NotImplementedError("borked")

            else:
                raise NotImplementedError("borked")

        def init(key, x):
            embed_init_fn, transformer_init_fn, projection_init_fn = init_fns()

            def init_scan_fn(key, x):
                new_key, key = jax.random.split(key)

                return new_key, transformer_init_fn(key, x, 0)

            e_key, t_key, p_key = jax.random.split(key, 3)

            input_shape = (config["layers"], ) + x.shape + (
                config["d_model"], )

            params = {
                "embed":
                embed_init_fn(e_key, x),
                "transformer":
                jax.lax.scan(init_scan_fn,
                             t_key,
                             xs=jax.random.uniform(t_key,
                                                   input_shape,
                                                   dtype=jnp.float32))[1],
                "proj":
                projection_init_fn(
                    p_key,
                    jax.random.uniform(t_key,
                                       input_shape[1:],
                                       dtype=jnp.float32)),
            }

            return {
                "params": (to_bf16 if early_cast else to_f32)(params),
                "step":
                np.array(0),
                "opt_state":
                optimizer.init((to_bf16 if bf16_optimizer else to_f32)(params))
            }

        assert thread_resources.env.shape['mp'] == config["cores_per_replica"]

        dp = thread_resources.env.shape['dp']
        mp = thread_resources.env.shape['mp']

        key = hk.PRNGSequence(42)
        x = jax.random.uniform(next(key), (mp * dp, 16), minval=0,
                               maxval=1).astype(jnp.uint32)  # batch, seq

        head_print("starting shape evaluation")

        param_shapes = jax.eval_shape(init, jax.random.PRNGKey(42), x)

        state_shard = {
            "step":
            P(),

            # zero level 1: shard optimizer states over both MP and DP
            "opt_state":
            jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]),
                         param_shapes["opt_state"]),

            # fp32 params are also sharded (so this is like a weird mix between zero-1 and zero-3...)
            "params":
            jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]),
                         param_shapes["params"]),
        }

        head_print("sharding strategy:")
        jax.tree_multimap(head_print, state_shard, param_shapes)

        self.init_pjit = pjit(init,
                              in_axis_resources=(None, P("dp")),
                              out_axis_resources=state_shard)

        def apply_fns():
            embed_apply_fn = hk.without_apply_rng(
                hk.transform(embedding)).apply
            transformer_apply_fn = hk.without_apply_rng(
                hk.transform(transformer)).apply

            return embed_apply_fn, transformer_apply_fn

        def train_apply_fn(params, x, y):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            def train_loss(x, y):
                loss, _ = Projection(config).loss(x, y, z_loss=1.0)
                return loss.mean(), loss[:, -1].mean()

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(train_loss)).apply

            x = embed_apply_fn(params["embed"], x)
            x = to_bf16(x)

            def apply_scan_fn(x, layer_state):
                return to_bf16(transformer_apply_fn(layer_state, x, 0)), None

            x = jax.lax.scan(apply_scan_fn, x, xs=params["transformer"])[0]

            return projection_apply_fn(params["proj"], x, y)

        mp_shard_strategy = jax.tree_map(
            partial(shard_strategy, parallel=["mp"]), param_shapes["params"])

        def train(state, ctx, tgt):
            if early_collect:
                bf16_params = maybe_shard(to_bf16(state["params"]),
                                          mp_shard_strategy)
            else:
                bf16_params = to_bf16(state["params"])

            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)

            if ctx.shape[0] == 1:
                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx[0],
                                                      tgt[0])
            else:
                grad, (loss, last_loss) = jax.lax.scan(
                    microbatch,
                    jax.tree_map(
                        lambda x: jnp.zeros_like(x).astype(jnp.bfloat16),
                        bf16_params), (ctx, tgt))

            updates, new_opt_state = optimizer.update(grad, state["opt_state"],
                                                      state["params"])

            return to_f32(loss), to_f32(last_loss), {
                "params": optax.apply_updates(state["params"],
                                              to_f32(updates)),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }

        self.train_pjit = pjit(train,
                               in_axis_resources=(state_shard, P(None, "dp"),
                                                  P(None, "dp")),
                               out_axis_resources=(None, None, state_shard),
                               donate_argnums=(0, ))

        def eval_apply_fn(params, x, y, mask):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            if early_collect:
                bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy)
            else:
                bf16_params = to_bf16(params)

            def eval_loss(x, y):
                loss, correct = Projection(config).loss(x, y)
                return {
                    "loss": loss.mean(axis=-1),
                    "last_loss": loss[:, -1],
                    "all_loss": loss,
                    "correct": correct
                }

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(eval_loss)).apply

            x = embed_apply_fn(bf16_params["embed"], x)

            def apply_scan_fn(layer_in, layer_state):
                x, mask = layer_in
                return (to_bf16(transformer_apply_fn(layer_state, x,
                                                     mask)), mask), None

            x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask),
                             xs=bf16_params["transformer"])[0][0]

            return projection_apply_fn(bf16_params["proj"], x, y)

        def eval(params, ctx, tgt, ctx_length):
            mask = (jnp.arange(0, ctx.shape[1])[None, :] >
                    ctx_length[:, None]) * -1e10

            # head_print("mask.shape", mask.shape)
            # head_print("ctx.shape", ctx.shape)
            # head_print("ctx_length.shape", ctx_length.shape)

            return eval_apply_fn(params, ctx, tgt, mask[:, None, None, :])

        self.eval_pjit = pjit(
            eval,
            in_axis_resources=(mp_shard_strategy
                               if early_collect else state_shard["params"],
                               P("dp"), P("dp"), P("dp")),
            out_axis_resources=P("dp"))

        self.move_weights_pjit = pjit(
            lambda x: to_bf16(x),
            in_axis_resources=(state_shard["params"], ),
            out_axis_resources=mp_shard_strategy
            if early_collect else state_shard["params"])

        seq = config["seq"]
        vocab = config["n_vocab"]

        example_shape = (
            max(dp // jax.host_count(), 1),
            seq,
        )
        x = jax.random.uniform(next(key),
                               example_shape,
                               minval=0,
                               maxval=vocab).astype(jnp.uint32)  # batch, len

        head_print("in shape", x.shape)

        head_print("dp", dp)
        head_print("mp", mp)

        self.state = self.init_pjit(next(key), x)
        self.state_shard = state_shard
        self.eval_weights = None

        param_count = hk.data_structures.tree_size(self.state['params'])
        head_print(f"Total parameters: {param_count * dp}")
    def __init__(self, config):
        self.config = config
        optimizer = config["optimizer"]

        def eval(state, ctx, tgt, ctx_length):
            def eval_loss(x, y, mask):
                transformer = CausalTransformerShard(config)
                return transformer.loss(x, y, mask=mask)

            eval_loss_fn = hk.without_apply_rng(hk.transform(eval_loss)).apply

            mask = (jnp.arange(0, len(ctx)) > ctx_length) * -1e10

            return eval_loss_fn(to_bf16(state["params"]), ctx, tgt, mask)

        def train(state, ctx, tgt):
            def train_loss(x, y):
                transformer = CausalTransformerShard(config)
                out = transformer.loss(x, y, z_loss=True)

                return out["loss"], out["last_loss"]

            train_loss_fn = hk.without_apply_rng(
                hk.transform(train_loss)).apply

            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
                (loss, last_loss), grad = val_grad_fn(to_bf16(state["params"]),
                                                      ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)

            if ctx.shape[0] == 1:
                val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
                (loss, last_loss), grad = val_grad_fn(to_bf16(state["params"]),
                                                      ctx[0], tgt[0])
            else:
                grad, (loss, last_loss) = jax.lax.scan(
                    microbatch,
                    jax.tree_map(
                        lambda x: jnp.zeros_like(x).astype(jnp.bfloat16),
                        state["params"]), (ctx, tgt))

            grad = jax.lax.pmean(grad, "batch")
            updates, new_opt_state = optimizer.update(grad, state["opt_state"],
                                                      state["params"])

            return to_f32(loss), to_f32(last_loss), {
                "params": optax.apply_updates(state["params"],
                                              to_f32(updates)),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }

        def init(key, x):
            def train_loss(x, y):
                transformer = CausalTransformerShard(config)
                return transformer.loss(x, y)

            param_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(train_loss)).init

            params = param_init_fn(key, x, x)

            return {
                "params": ("early_cast" in config and to_bf16
                           or to_f32)(params),
                "step": np.array(0),
                "opt_state": optimizer.init(params)
            }

        def generate(state, key, ctx, ctx_length, aux, sampler_options):
            sampler = config["sampler"]
            gen_length = self.gen_length

            def generate_sample(context, ctx_length, aux):
                transformer = CausalTransformerShard(config)
                _, initial_state = transformer.generate_initial(
                    context, ctx_length)

                def generate_scan_fn(carry, sampler_input):
                    next_token, decode_state, sample_key = carry
                    sample_key, new_key = jax.random.split(sample_key)

                    logits, new_state = transformer.generate_once(
                        next_token, decode_state)
                    next_token, sample_info = sampler(sample_key, logits,
                                                      sampler_input,
                                                      **sampler_options)

                    if self.return_logits:
                        output = (next_token, sample_info, logits)
                    else:
                        output = (next_token, sample_info)
                    new_carry = (next_token, new_state, new_key)
                    return new_carry, output

                final_state, outputs = jax.lax.scan(generate_scan_fn,
                                                    initial_state,
                                                    xs=aux,
                                                    length=gen_length)
                return final_state, outputs

            generate_fn = hk.transform(generate_sample).apply
            return generate_fn(state["params"], key, ctx, ctx_length, aux)

        self.init_xmap = jax.experimental.maps.xmap(fun=init,
                                                    in_axes=(["shard", ...],
                                                             ["batch", ...]),
                                                    out_axes=["shard", ...],
                                                    axis_resources={
                                                        'shard': 'mp',
                                                        'batch': 'dp'
                                                    })

        self.eval_xmap = jax.experimental.maps.xmap(
            fun=eval,
            in_axes=(["shard", ...], ["batch", ...], ["batch",
                                                      ...], ["batch", ...]),
            out_axes=["batch", ...],
            axis_resources={
                'shard': 'mp',
                'batch': 'dp'
            })

        self.train_xmap = jax.experimental.maps.xmap(
            fun=train,
            in_axes=(["shard", ...], ["batch", ...], ["batch", ...]),
            out_axes=(["batch", ...], ["batch", ...], ["shard", ...]),
            donate_argnums=(0, ),
            axis_resources={
                'shard': 'mp',
                'batch': 'dp'
            })

        self.generate_xmap = jax.experimental.maps.xmap(
            fun=generate,
            in_axes=(["shard", ...], ["batch", ...], ["batch", ...],
                     ["batch", ...], ["batch", ...], ["batch", ...]),
            out_axes=["batch", ...],
            axis_resources={
                'shard': 'mp',
                'batch': 'dp'
            })

        self.move_xmap = jax.experimental.maps.xmap(
            fun=lambda x, _: to_bf16(x),
            in_axes=(["shard", ...], ["batch", ...]),
            out_axes=["shard", ...],
            axis_resources={
                'shard': 'mp',
                'batch': 'dp'
            })

        key = hk.PRNGSequence(42)

        assert thread_resources.env.shape['mp'] == config["cores_per_replica"]

        dp = thread_resources.env.shape['dp']
        mp = thread_resources.env.shape['mp']

        mp_per_host = min(mp, 8)

        seq = config["seq"]
        vocab = config["n_vocab"]

        example_shape = (
            max(dp // jax.host_count(), 1),
            seq,
        )
        x = jax.random.uniform(next(key),
                               example_shape,
                               minval=0,
                               maxval=vocab).astype(jnp.uint32)  # batch, len

        head_print("key shape", jnp.array(key.take(mp_per_host)).shape)
        head_print("in shape", x.shape)

        head_print("dp", dp)
        head_print("mp", mp)

        self.gen_length = 1
        self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)

        param_count = hk.data_structures.tree_size(self.state['params'])
        head_print(f"Total parameters: {param_count}")