Ejemplo n.º 1
0
    def test_regression_model(self):
        n_classes = 1  # Only one label for regression

        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.dataset_name = "glue/stsb"  # regression task dataset
        frozen_config = ml_collections.FrozenConfigDict(config)

        model = models.SequenceClassificationModel(config=frozen_config,
                                                   n_classes=n_classes)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, model, init_batch)
        self.assertEqual(params.keys(), {"encoder", "classification"})

        # Logits for eval/prediction (no labels supplied).
        eval_inputs = dummy_inputs(rng, config)
        eval_inputs["deterministic"] = True
        logits = model.apply({"params": params}, **eval_inputs)
        expected_logits_shape = (config.train_batch_size, n_classes)
        self.assertEqual(jnp.shape(logits), expected_logits_shape)

        # Metrics for training (labels supplied).
        train_inputs = dummy_inputs(rng, config)
        _, label_key = jax.random.split(rng)
        train_inputs["labels"] = jax.random.uniform(
            label_key, (config.train_batch_size, ), minval=0., maxval=1.)

        metrics = model.apply({"params": params},
                              rngs={"dropout": rng},
                              **train_inputs)
        self.assertEqual(metrics.keys(), {"loss", "num_labels"})
Ejemplo n.º 2
0
    def test_classification_model(self):
        n_classes = 2

        config = dummy_config(model_arch=ModelArchitecture.BERT)
        with config.unlocked():
            config.dataset_name = "dummy/classification_dataset"
        frozen_config = ml_collections.FrozenConfigDict(config)

        model = models.SequenceClassificationModel(config=frozen_config,
                                                   n_classes=n_classes)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, model, init_batch)
        self.assertEqual(params.keys(), {"encoder", "classification"})

        # Logits for eval/prediction (no labels supplied).
        eval_inputs = dummy_inputs(rng, config)
        eval_inputs["deterministic"] = True
        logits = model.apply({"params": params}, **eval_inputs)
        expected_logits_shape = (config.train_batch_size, n_classes)
        self.assertEqual(jnp.shape(logits), expected_logits_shape)

        # Metrics for training (labels supplied).
        train_inputs = dummy_inputs(rng, config)
        train_inputs["labels"] = jnp.ones(config.train_batch_size, jnp.int32)
        metrics = model.apply({"params": params},
                              rngs={"dropout": rng},
                              **train_inputs)
        self.assertEqual(metrics.keys(),
                         {"loss", "correct_predictions", "num_labels"})
Ejemplo n.º 3
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and Tensorboard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of processes
      and devices, or config is underspecified.
  """
    n_processes = jax.process_count()  # Number of processes
    n_devices = jax.local_device_count()  # Number of local devices per process

    if config.train_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Training batch size must be divisible by the total number of devices, "
            "but training batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.train_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    if config.eval_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Eval batch size must be divisible by the total number of devices, "
            "but eval batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.eval_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    per_process_train_batch_size = config.train_batch_size // n_processes
    per_process_eval_batch_size = config.eval_batch_size // n_processes

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    is_regression_task = config.dataset_name == "glue/stsb"

    num_classes = (1 if is_regression_task else
                   ds_info.features["label"].num_classes)

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()

    frozen_config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config=frozen_config,
                                               n_classes=num_classes)

    params = _init_params(model, init_rng, config)

    optimizer = _create_adam_optimizer(config.learning_rate, params)

    # In case current job restarts, ensure that we continue from where we left
    # off.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    start_step = int(optimizer.state.step)

    # Otherwise, try to restore optimizer and model state from config checkpoint.
    if (start_step == 0 and "init_checkpoint_dir" in config
            and config.init_checkpoint_dir):
        optimizer = _restore_pretrained_model(optimizer, params, config)

    # We access model state only from optimizer via optimizer.target.
    del params

    optimizer = jax_utils.replicate(optimizer)

    if is_regression_task:
        compute_stats = functools.partial(_compute_regression_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())
    else:
        compute_stats = functools.partial(_compute_classification_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    glue_inputs = functools.partial(input_pipeline.glue_inputs,
                                    dataset_name=config.dataset_name,
                                    max_seq_length=config.max_seq_length,
                                    tokenizer=tokenizer)
    train_ds = glue_inputs(split=tfds.Split.TRAIN,
                           batch_size=per_process_train_batch_size,
                           training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, n_devices)

    loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics,
                                            model=model,
                                            pad_id=tokenizer.pad_id())
    p_train_step = jax.pmap(functools.partial(
        train_utils.train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        learning_rate_fn=learning_rate_fn),
                            axis_name="batch")
    p_eval_step = jax.pmap(functools.partial(train_utils.eval_step,
                                             metric_fn=compute_stats),
                           axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name,
                                              is_regression_task)

    train_metrics = []

    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            optimizer, train_step_metrics, rngs = p_train_step(optimizer,
                                                               train_batch,
                                                               rng=rngs)
            train_metrics.append(train_step_metrics)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1) and jax.process_index() == 0:
            # Save un-replicated optimizer and model state.
            checkpoints.save_checkpoint(workdir,
                                        jax_utils.unreplicate(optimizer),
                                        step,
                                        keep=2)

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)

        train_metrics = common_utils.get_metrics(train_metrics)
        train_summary = {
            "loss":
            jnp.sum(train_metrics["loss"]) /
            jnp.sum(train_metrics["num_labels"]),
            "learning_rate":
            learning_rate_fn(step)
        }
        if not is_regression_task:
            train_summary["accuracy"] = jnp.sum(
                train_metrics["correct_predictions"]) / jnp.sum(
                    train_metrics["num_labels"])

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        train_metrics = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix,
                                  batch_size=per_process_eval_batch_size,
                                  training=False)

            all_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                all_stats.append(
                    _evaluate(p_eval_step, optimizer.target, eval_batch,
                              n_devices))
            flat_stats = {}
            for k in all_stats[
                    0]:  # All batches of output stats are the same size
                flat_stats[k] = np.concatenate([stat[k] for stat in all_stats],
                                               axis=0)
            eval_summary = eval_metrics_fn(flat_stats)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()