def test_sharded_train_step(self):
        num_steps = 2

        rng = jax.random.PRNGKey(0)
        rngs = jax.random.split(rng, jax.device_count())

        config = frozen_config(sharded_params=True)
        sharded_match_fn = core_utils.match_fn(r".*expert.*")

        train_state = create_flax_train_state(rng, config, num_steps)
        p_train_state = jax_utils.replicate(train_state)

        p_train_step = jax.pmap(functools.partial(
            train_utils.pmap_train_step,
            loss_and_metrics_fn=dummy_loss_and_metrics,
            axis_name="batch",
            sharded_match_fn=sharded_match_fn),
                                axis_name="batch")

        batch = dummy_batch(rng, config.train_batch_size,
                            config.max_seq_length)
        batch = common_utils.shard(batch)

        expected_metrics = ClassificationStats(batch_loss=0.1,
                                               num_labels=2,
                                               correct_predictions=1,
                                               grad_l2_sum=0.)

        for _ in range(num_steps):
            p_train_state, metrics, rngs = p_train_step(
                train_state=p_train_state, batch=batch, rng=rngs)
            self.assertEqual(metrics, expected_metrics)
Ejemplo n.º 2
0
def _clear_pretrained_output_layer(state_cpu, ckpt_state):
    """Clear ("classification") output layer weights.

  We use a fresh output layer because the classification tasks differ from the
  MLM and NSP pre-training tasks.

  Args:
    state_cpu: CPU-initialized train state, containing shape initialized
      parameters.
    ckpt_state: Initialized model state (parameters) from restored checkpoint.

  Returns:
    Inputs parameters, but with output layer cleared.
  """
    ckpt_state["params"]["classification"] = state_cpu.params["classification"]
    ckpt_state["opt_state"] = core_utils.tree_map_with_names(
        jnp.zeros_like,
        ckpt_state["opt_state"],
        filter_fn=core_utils.match_fn(r".*classification.*"))
    return ckpt_state
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and TensorBoard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of hosts and
      devices, or config is underspecified.
  """
    # Update config before config validation.
    with config.unlocked():
        # Numeric floating point type to use for model computations.
        config.dtype = jnp.float32

    train_utils.validate_config(config)

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    tokenizer.SetEncodeExtraOptions("")
    # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer.

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
        config.pad_id = tokenizer.pad_id()

    config = ml_collections.FrozenConfigDict(config)

    model = models.PreTrainingModel(config=config)
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    params = _init_params(model, init_rng, config)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        decay_steps=config.num_train_steps - config.num_warmup_steps,
    )

    tx = optax.adamw(learning_rate_fn,
                     b1=0.9,
                     b2=0.999,
                     eps=1e-6,
                     weight_decay=0.01)
    if config.clipped_grad_norm:
        tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm),
                         tx)

    # jit state creation to ensure arrays are created on same device as input
    # (i.e. CPU).
    state_cpu = jax.jit(
        functools.partial(FlaxTrainState.create,
                          apply_fn=model.apply,
                          params=params,
                          tx=tx))()

    # We access model params only via state.params
    del params

    if config.num_experts > 1:
        sharded_match_fn = core_utils.match_fn(r".*expert.*")
        not_sharded_match_fn = lambda name: not sharded_match_fn(name)
    else:
        sharded_match_fn = None
        not_sharded_match_fn = lambda name: True

    state, start_step = _restore_state_from_checkpoint(workdir, state_cpu,
                                                       sharded_match_fn,
                                                       not_sharded_match_fn,
                                                       config)
    train_ds, eval_ds = _init_train_and_eval_ds(tokenizer, config)
    train_iter = iter(train_ds)

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, jax.local_device_count())

    loss_and_metrics_fn = functools.partial(
        _compute_loss_and_metrics,
        model=model,
        is_experts_model=config.num_experts > 1,
        auxiliary_loss_factor=config.auxiliary_loss_factor,
        router_z_loss_factor=config.router_z_loss_factor)
    train_step = functools.partial(
        train_utils.pmap_train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        axis_name="batch",
        sharded_match_fn=sharded_match_fn,
        gradient_accum_steps=config.gradient_accum_steps)
    p_train_step = jax.pmap(train_step, axis_name="batch")

    eval_step = functools.partial(_compute_eval_stats, model=model)
    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    seconds = 0.
    train_stats = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, config.num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            tick = time.time()
            state, train_step_stats, rngs = p_train_step(state,
                                                         train_batch,
                                                         rng=rngs)
            if config.measure_step_speed:
                jax.tree_map(lambda opt: opt.block_until_ready(), state)
                tock = time.time()
                seconds += tock - tick

            train_stats.append(train_step_stats)

        if (step > 0 and config.save_checkpoints_steps
                and step % config.save_checkpoints_steps == 0):
            # We allow all hosts to potentially save checkpoints because some model
            # parameters are sharded across devices. Parameters replicated across
            # devices (i.e. not sharded) will only be checkpointed by host 0.
            unreplicated_state = jax.tree_map(
                np.array,
                core_utils.tree_unreplicate_by_name(state,
                                                    not_sharded_match_fn))
            checkpoints.save_checkpoint(workdir,
                                        unreplicated_state,
                                        sharded_match_fn,
                                        step,
                                        keep=config.checkpoints_to_keep)
            del unreplicated_state  # Only used for checkpointing.

        # Periodic metric handling.
        if step % config.eval_frequency != 0 and step > 0:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = train_utils.collect_metrics(train_stats)
        train_summary = train_utils.compute_pretraining_metrics(train_metrics)
        train_summary["learning_rate"] = learning_rate_fn(step)
        if config.measure_step_speed:
            train_summary["steps_per_sec"] = (step - start_step + 1) / seconds

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_stats = []

        logging.info("Gathering evaluation metrics at step: %d", step)

        eval_stats = []
        for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds):
            eval_batch = common_utils.shard(eval_batch)
            eval_stats.append(p_eval_step(state.params, eval_batch))
        eval_metrics = train_utils.collect_metrics(eval_stats)
        eval_summary = train_utils.compute_pretraining_metrics(
            eval_metrics, record_grad_norm=False)

        if jax.process_index() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
Ejemplo n.º 4
0
 def test_empty_prefix(self):
     match_fn = core_utils.match_fn(None)
     self.assertFalse(match_fn("/test/something"))
     self.assertFalse(match_fn("to/test/or/not/"))
Ejemplo n.º 5
0
 def test_regex_prefix(self):
     match_fn = core_utils.match_fn(r".*test.*")
     self.assertTrue(match_fn("/test/something"))
     self.assertTrue(match_fn("to/test/or/not/"))
     self.assertFalse(match_fn("no/match"))
Ejemplo n.º 6
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and TensorBoard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of hosts and
      devices, or config is underspecified.
  """
    # Update config before config validation.
    with config.unlocked():
        # Numeric floating point type to use for model computations.
        config.dtype = jnp.float32

    train_utils.validate_config(config)

    per_host_train_batch_size = config.train_batch_size // jax.process_count()
    per_host_eval_batch_size = config.eval_batch_size // jax.process_count()

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    # STSB is a regression task. COPA and ReCoRD are treated as scalar/regression
    # tasks during training.
    is_regression_task = (config.dataset_name == "glue/stsb"
                          or config.dataset_name == "super_glue/copa"
                          or config.dataset_name == "super_glue/record")
    if is_regression_task:
        num_classes = 1
    else:
        num_classes = ds_info.features["label"].num_classes

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
        config.pad_id = tokenizer.pad_id()

    config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config, num_classes)
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    params = _init_params(model, init_rng, config)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    tx = optax.adamw(learning_rate_fn,
                     b1=0.9,
                     b2=0.999,
                     eps=1e-6,
                     weight_decay=0.01)
    if config.clipped_grad_norm:
        tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm),
                         tx)

    # jit state creation to ensure arrays are created on same device as input
    # (i.e. CPU).
    state_cpu = jax.jit(
        functools.partial(FlaxTrainState.create,
                          apply_fn=model.apply,
                          params=params,
                          tx=tx))()

    # We access model params only via state.params
    del params

    if config.num_experts > 1:
        sharded_match_fn = core_utils.match_fn(r".*expert.*")
        not_sharded_match_fn = lambda name: not sharded_match_fn(name)
    else:
        sharded_match_fn = None
        not_sharded_match_fn = lambda name: True

    state, start_step = _restore_state_from_checkpoint(workdir, state_cpu,
                                                       sharded_match_fn,
                                                       not_sharded_match_fn,
                                                       config)

    if is_regression_task:
        scoring_fn = lambda y: y[Ellipsis, 0]
    else:
        scoring_fn = lambda y: y.argmax(-1)
    compute_stats = functools.partial(_compute_stats,
                                      model=model,
                                      scoring_fn=scoring_fn)

    classification_inputs = functools.partial(
        input_pipeline.classification_inputs,
        dataset_name=config.dataset_name,
        max_seq_length=config.max_seq_length,
        tokenizer=tokenizer)
    train_ds = classification_inputs(split=tfds.Split.TRAIN,
                                     batch_size=per_host_train_batch_size,
                                     training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, jax.local_device_count())

    loss_and_metrics_fn = functools.partial(
        _compute_loss_and_metrics,
        model=model,
        is_experts_model=config.num_experts > 1,
        auxiliary_loss_factor=config.auxiliary_loss_factor,
        router_z_loss_factor=config.router_z_loss_factor)
    train_step = functools.partial(
        train_utils.pmap_train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        axis_name="batch",
        sharded_match_fn=sharded_match_fn,
        gradient_accum_steps=config.gradient_accum_steps)
    p_train_step = jax.pmap(train_step, axis_name="batch")
    p_eval_step = jax.pmap(compute_stats, axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name)

    train_stats = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            state, train_step_stats, rngs = p_train_step(state,
                                                         train_batch,
                                                         rng=rngs)

            train_stats.append(train_step_stats)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1):
            # We allow all hosts to potentially save checkpoints because some model
            # parameters are sharded across devices. Parameters replicated across
            # devices (i.e. not sharded) will only be checkpointed by host 0.
            unreplicated_train_state = jax.tree_map(
                np.array,
                core_utils.tree_unreplicate_by_name(state,
                                                    not_sharded_match_fn))
            checkpoints.save_checkpoint(workdir,
                                        unreplicated_train_state,
                                        sharded_match_fn,
                                        step,
                                        keep=config.checkpoints_to_keep)
            del unreplicated_train_state  # Only used for checkpointing.

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = train_utils.collect_metrics(train_stats)
        train_summary = train_utils.compute_classification_metrics(
            train_metrics, is_regression_task)
        train_summary["learning_rate"] = learning_rate_fn(step)

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_stats = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = classification_inputs(
                split=tfds.Split.VALIDATION + split_suffix,
                batch_size=per_host_eval_batch_size,
                training=False)

            eval_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                eval_stats.append(
                    _evaluate(p_eval_step, state.params, eval_batch))
            eval_metrics = {}
            for k in eval_stats[
                    0]:  # All batches of output stats are the same size
                eval_metrics[k] = np.concatenate(
                    [stat[k] for stat in eval_stats], axis=0)
            eval_summary = eval_metrics_fn(eval_metrics)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()