Esempio n. 1
0
    def test_learning_rate_scheduler(self):
        num_steps = 6
        warmup_steps = 2
        learning_rate_fn = train_utils.create_learning_rate_scheduler(
            factors="constant * linear_warmup * linear_decay",
            base_learning_rate=1,
            warmup_steps=warmup_steps,
            decay_steps=num_steps - warmup_steps,
        )

        _ = learning_rate_fn(0)

        for step, expected_rate in zip(range(num_steps),
                                       [0, 0.5, 1, 0.75, 0.5, 0.25]):
            self.assertAlmostEqual(learning_rate_fn(step),
                                   expected_rate,
                                   delta=1e-7)
Esempio n. 2
0
    def test_train_step(self):
        num_steps = 2
        learning_rate_fn = train_utils.create_learning_rate_scheduler(
            factors="constant * linear_decay",
            base_learning_rate=1,
            warmup_steps=0,
            decay_steps=num_steps - 1,
        )

        rng = jax.random.PRNGKey(0)
        rngs = jax.random.split(rng, jax.device_count())

        config = dummy_frozen_config()
        optimizer = create_optimizer(rng, config)
        p_optimizer = jax_utils.replicate(optimizer)

        p_train_step = jax.pmap(functools.partial(
            train_utils.train_step,
            loss_and_metrics_fn=dummy_loss_and_metrics,
            learning_rate_fn=learning_rate_fn,
            clipped_grad_norm=1.0),
                                axis_name="batch")

        batch = jax.random.randint(
            rng, (config.train_batch_size, config.max_seq_length),
            minval=0,
            maxval=10)
        batch = common_utils.shard(batch)

        for _ in range(num_steps):
            p_optimizer, metrics, rngs = p_train_step(optimizer=p_optimizer,
                                                      batch=batch,
                                                      rng=rngs)
            self.assertSetEqual(
                set(metrics.keys()), {
                    "very_helpful_metric", "clipped_grad_l2_sum",
                    "unclipped_grad_l2_sum"
                })
Esempio n. 3
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and Tensorboard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of processes
      and devices, or config is underspecified.
  """
    n_processes = jax.process_count()  # Number of processes
    n_devices = jax.local_device_count()  # Number of local devices per process

    if config.train_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Training batch size must be divisible by the total number of devices, "
            "but training batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.train_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    if config.eval_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Eval batch size must be divisible by the total number of devices, "
            "but eval batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.eval_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    per_process_train_batch_size = config.train_batch_size // n_processes
    per_process_eval_batch_size = config.eval_batch_size // n_processes

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    is_regression_task = config.dataset_name == "glue/stsb"

    num_classes = (1 if is_regression_task else
                   ds_info.features["label"].num_classes)

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()

    frozen_config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config=frozen_config,
                                               n_classes=num_classes)

    params = _init_params(model, init_rng, config)

    optimizer = _create_adam_optimizer(config.learning_rate, params)

    # In case current job restarts, ensure that we continue from where we left
    # off.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    start_step = int(optimizer.state.step)

    # Otherwise, try to restore optimizer and model state from config checkpoint.
    if (start_step == 0 and "init_checkpoint_dir" in config
            and config.init_checkpoint_dir):
        optimizer = _restore_pretrained_model(optimizer, params, config)

    # We access model state only from optimizer via optimizer.target.
    del params

    optimizer = jax_utils.replicate(optimizer)

    if is_regression_task:
        compute_stats = functools.partial(_compute_regression_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())
    else:
        compute_stats = functools.partial(_compute_classification_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    glue_inputs = functools.partial(input_pipeline.glue_inputs,
                                    dataset_name=config.dataset_name,
                                    max_seq_length=config.max_seq_length,
                                    tokenizer=tokenizer)
    train_ds = glue_inputs(split=tfds.Split.TRAIN,
                           batch_size=per_process_train_batch_size,
                           training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, n_devices)

    loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics,
                                            model=model,
                                            pad_id=tokenizer.pad_id())
    p_train_step = jax.pmap(functools.partial(
        train_utils.train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        learning_rate_fn=learning_rate_fn),
                            axis_name="batch")
    p_eval_step = jax.pmap(functools.partial(train_utils.eval_step,
                                             metric_fn=compute_stats),
                           axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name,
                                              is_regression_task)

    train_metrics = []

    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            optimizer, train_step_metrics, rngs = p_train_step(optimizer,
                                                               train_batch,
                                                               rng=rngs)
            train_metrics.append(train_step_metrics)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1) and jax.process_index() == 0:
            # Save un-replicated optimizer and model state.
            checkpoints.save_checkpoint(workdir,
                                        jax_utils.unreplicate(optimizer),
                                        step,
                                        keep=2)

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)

        train_metrics = common_utils.get_metrics(train_metrics)
        train_summary = {
            "loss":
            jnp.sum(train_metrics["loss"]) /
            jnp.sum(train_metrics["num_labels"]),
            "learning_rate":
            learning_rate_fn(step)
        }
        if not is_regression_task:
            train_summary["accuracy"] = jnp.sum(
                train_metrics["correct_predictions"]) / jnp.sum(
                    train_metrics["num_labels"])

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        train_metrics = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix,
                                  batch_size=per_process_eval_batch_size,
                                  training=False)

            all_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                all_stats.append(
                    _evaluate(p_eval_step, optimizer.target, eval_batch,
                              n_devices))
            flat_stats = {}
            for k in all_stats[
                    0]:  # All batches of output stats are the same size
                flat_stats[k] = np.concatenate([stat[k] for stat in all_stats],
                                               axis=0)
            eval_summary = eval_metrics_fn(flat_stats)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()
Esempio n. 4
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and Tensorboard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of processes
      and devices, or config is underspecified.
  """
    n_processes = jax.process_count()  # Number of processes
    n_devices = jax.local_device_count()  # Number of local devices per process

    if config.train_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Training batch size must be divisible by the total number of devices, "
            "but training batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.train_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    if config.eval_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Eval batch size must be divisible by the total number of devices, "
            "but eval batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.eval_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    per_process_train_batch_size = config.train_batch_size // n_processes
    per_process_eval_batch_size = config.eval_batch_size // n_processes

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    tokenizer.SetEncodeExtraOptions("")
    # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer.

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
    frozen_config = ml_collections.FrozenConfigDict(config)
    model = models.PreTrainingModel(config=frozen_config,
                                    random_seed=config.seed)

    params = _init_params(model, init_rng, frozen_config)

    optimizer = _create_adam_optimizer(config.learning_rate, params)
    # We access model state only from optimizer via optimizer.target.
    del params

    # In case current job restarts, ensure that we continue from where we left
    # off.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    start_step = int(optimizer.state.step)

    # Otherwise, try to restore optimizer and model state from config checkpoint.
    if start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir:
        optimizer = checkpoints.restore_checkpoint(config.init_checkpoint_dir,
                                                   optimizer)

    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        decay_steps=config.num_train_steps - config.num_warmup_steps,
    )

    c4_masked_lm_inputs = functools.partial(
        input_pipeline.c4_masked_lm_inputs,
        tokenizer=tokenizer,
        max_seq_length=config.max_seq_length,
        max_predictions_per_seq=config.max_predictions_per_seq,
        masking_rate=config.masking_rate,
        mask_token_proportion=config.mask_token_proportion,
        random_token_proportion=config.random_token_proportion)
    train_ds = c4_masked_lm_inputs(batch_size=per_process_train_batch_size)
    train_iter = iter(train_ds)
    eval_ds = c4_masked_lm_inputs(batch_size=per_process_eval_batch_size)

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, n_devices)

    loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics,
                                            model=model,
                                            pad_id=tokenizer.pad_id())
    p_train_step = jax.pmap(functools.partial(
        train_utils.train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        learning_rate_fn=learning_rate_fn,
        clipped_grad_norm=config.clipped_grad_norm),
                            axis_name="batch")

    metric_fn = functools.partial(_compute_eval_stats,
                                  model=model,
                                  pad_id=tokenizer.pad_id())
    p_eval_step = jax.pmap(functools.partial(train_utils.eval_step,
                                             metric_fn=metric_fn),
                           axis_name="batch")

    train_metrics = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, config.num_train_steps):
        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            optimizer, train_step_metrics, rngs = p_train_step(optimizer,
                                                               train_batch,
                                                               rng=rngs)
            train_metrics.append(train_step_metrics)

        if (step > 0 and config.save_checkpoints_steps
                and step % config.save_checkpoints_steps == 0
                and jax.process_index() == 0):
            # Save un-replicated optimizer + model state.
            checkpoints.save_checkpoint(workdir,
                                        jax_utils.unreplicate(optimizer),
                                        step,
                                        keep=2)

        # Periodic metric handling.
        if step % config.eval_frequency != 0 and step > 0:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = common_utils.get_metrics(train_metrics)
        train_summary = _compute_loss_and_accuracy_metrics(train_metrics)
        # Add training specific metrics.
        train_summary["unclipped_grad_l2_norm"] = jnp.sqrt(
            jnp.sum(train_metrics["unclipped_grad_l2_sum"]))
        train_summary["clipped_grad_l2_norm"] = jnp.sqrt(
            jnp.sum(train_metrics["clipped_grad_l2_sum"]))
        train_summary["learning_rate"] = learning_rate_fn(step)

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_metrics = []

        logging.info("Gathering evaluation metrics at step: %d", step)

        all_stats = []
        for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds):
            eval_batch = common_utils.shard(eval_batch)
            all_stats.append(p_eval_step(optimizer.target, eval_batch))
        flat_stats = {}
        for k in all_stats[0]:
            flat_stats[k] = np.concatenate([stats[k] for stats in all_stats],
                                           axis=0)
        eval_summary = _compute_loss_and_accuracy_metrics(flat_stats)

        if jax.process_index() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()