Exemplo n.º 1
0
    def test_total_num_batches(self):
        batch_config = BatchConfig(micro_batch_size=2,
                                   num_replicas=2,
                                   gradient_accumulation_count=2,
                                   dataset_size=16,
                                   global_batches_per_log=1)
        assert batch_config.num_micro_batches_per_epoch == 8
        assert batch_config.steps_per_execution == 4
        assert batch_config.epochs == 28803072
        assert batch_config.total_num_micro_batches == 230424576

        batch_config = BatchConfig(micro_batch_size=2,
                                   num_replicas=2,
                                   gradient_accumulation_count=2,
                                   dataset_size=16,
                                   global_batches_per_log=2)
        assert batch_config.num_micro_batches_per_epoch == 8
        assert batch_config.steps_per_execution == 8
        assert batch_config.epochs == 28803072
        assert batch_config.total_num_micro_batches == 230424576

        batch_config = BatchConfig(micro_batch_size=2,
                                   num_replicas=2,
                                   gradient_accumulation_count=2,
                                   dataset_size=18,
                                   global_batches_per_log=2)
        assert batch_config.num_micro_batches_per_epoch == 8
        assert batch_config.steps_per_execution == 8
        assert batch_config.total_num_micro_batches == 204821840
Exemplo n.º 2
0
    def test_get_num_micro_batches_per_epoch(self):
        batch_config = BatchConfig(micro_batch_size=3,
                                   num_replicas=2,
                                   gradient_accumulation_count=3,
                                   dataset_size=18)
        assert batch_config.num_micro_batches_per_epoch == 6

        batch_config = BatchConfig(micro_batch_size=3,
                                   num_replicas=2,
                                   gradient_accumulation_count=3,
                                   dataset_size=20)
        assert batch_config.num_micro_batches_per_epoch == 6
Exemplo n.º 3
0
 def test_raise_error_for_no_phase(self):
     with pytest.raises(ValueError):
         BatchConfig(micro_batch_size=2,
                     num_replicas=2,
                     gradient_accumulation_count=2,
                     dataset_size=40,
                     task=Task.OTHER)
Exemplo n.º 4
0
 def test_custom_num_train_steps(self):
     batch_config = BatchConfig(micro_batch_size=2,
                                num_replicas=2,
                                gradient_accumulation_count=2,
                                dataset_size=8,
                                total_num_train_samples=80)
     assert batch_config.num_train_steps == 10
     assert batch_config.epochs == 10
     assert batch_config.total_num_micro_batches == 40
Exemplo n.º 5
0
 def test_num_train_steps(self):
     batch_config = BatchConfig(micro_batch_size=2,
                                num_replicas=2,
                                gradient_accumulation_count=2,
                                dataset_size=16)
     assert batch_config.num_train_steps == 57606144
     assert batch_config.epochs == 28803072
     assert batch_config.num_micro_batches_per_epoch == 8
     assert batch_config.total_num_micro_batches == 230424576
Exemplo n.º 6
0
 def test_micro_batch_size(self):
     batch_config = BatchConfig(micro_batch_size=3,
                                num_replicas=1,
                                gradient_accumulation_count=1,
                                dataset_size=18)
     assert batch_config.micro_batch_size == 3
Exemplo n.º 7
0
 def test_calc_global_batch_size(self):
     batch_config = BatchConfig(micro_batch_size=3,
                                num_replicas=4,
                                gradient_accumulation_count=2,
                                dataset_size=40)
     assert batch_config.global_batch_size == 24
Exemplo n.º 8
0
 def test_gradient_accumulation_count(self):
     batch_config = BatchConfig(micro_batch_size=1,
                                num_replicas=1,
                                gradient_accumulation_count=2,
                                dataset_size=40)
     assert batch_config.gradient_accumulation_count == 2
Exemplo n.º 9
0
 def test_num_micro_batches_per_weight_update(self):
     batch_config = BatchConfig(micro_batch_size=1,
                                num_replicas=4,
                                gradient_accumulation_count=4,
                                dataset_size=40)
     assert batch_config.num_micro_batches_per_weight_update == 4 * 4
Exemplo n.º 10
0
def pretrain(**config):
    # Get required options
    micro_batch_size = config["micro_batch_size"]
    replicas = config["replicas"]
    grad_acc_steps_per_replica = config["grad_acc_steps_per_replica"]
    optimizer_opts = config["optimizer_opts"]
    use_outlining = config["use_outlining"]
    replicated_tensor_sharding = config["replicated_tensor_sharding"]
    fp_exceptions = config["fp_exceptions"]
    bert_config = config["bert_config"]
    wandb_opts = config["wandb_opts"]
    pipeline_stages = config["pipeline_stages"]
    device_mapping = config["device_mapping"]

    # Get optional options
    total_num_train_samples = config.get("total_num_train_samples", None)
    save_ckpt_path = config.get("save_ckpt_path",
                                Path(__file__).parent.joinpath("checkpoints").absolute())
    pretrained_ckpt_path = config.get("pretrained_ckpt_path", None)
    ckpt_every_n_steps_per_execution = config.get("ckpt_every_n_steps_per_execution",
                                                  2000)

    universal_run_name = config.get(
        "name", f"{Path(config['config']).stem}-{wandb_opts['init']['name']}")
    universal_run_name += f"-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    print(f"Universal name for run: {universal_run_name}")

    set_random_seeds(config["seed"])
    num_pipeline_stages = max(device_mapping) + 1
    num_ipus = replicas * num_pipeline_stages

    bert_config = BertConfig(**bert_config, hidden_act=ipu.nn_ops.gelu)

    dataset, filenames = get_pretraining_dataset(micro_batch_size=micro_batch_size,
                                                 dataset_dir=config["dataset_dir"],
                                                 max_seq_length=bert_config.max_seq_length,
                                                 max_predictions_per_seq=bert_config.max_predictions_per_seq,
                                                 distributed_worker_count=1,
                                                 seed=config["seed"],
                                                 data_type=tf.float16)
    num_samples = get_dataset_files_count(filenames)
    if bert_config.max_seq_length == 128:
        task = Task.PRETRAIN_PHASE_ONE
    elif bert_config.max_seq_length == 384:
        task = Task.PRETRAIN_PHASE_TWO
    else:
        raise ValueError("Sequence length must be 128 or 384")
    batch_config = BatchConfig(micro_batch_size=micro_batch_size,
                               num_replicas=replicas,
                               gradient_accumulation_count=grad_acc_steps_per_replica,
                               dataset_size=num_samples,
                               global_batches_per_log=config["global_batches_per_log"],
                               total_num_train_samples=total_num_train_samples,
                               task=task)

    policy = tf.keras.mixed_precision.Policy("float16")
    tf.keras.mixed_precision.set_global_policy(policy)

    strategy = create_ipu_strategy(num_ipus,
                                   fp_exceptions=fp_exceptions,
                                   enable_recomputation=config["enable_recomputation"],
                                   min_remote_tensor_size=config["min_remote_tensor_size"])
    with strategy.scope():
        model = IpuTFBertForPreTraining(config=bert_config)

        # Convert subclass model to functional, expand main layers to enable pipelining, and replace some layers to
        # optimise performance.
        model = convert_tf_bert_model(
            model,
            dataset,
            post_process_bert_input_layer,
            replace_layers=config["replace_layers"],
            use_outlining=use_outlining,
            embedding_serialization_factor=config["embedding_serialization_factor"]
        )

        # Load from pretrained checkpoint if requested.
        if pretrained_ckpt_path:
            print("Attempting to load pretrained checkpoint from"
                  f" path {pretrained_ckpt_path}")
            load_checkpoint_into_model(model, pretrained_ckpt_path)
        else:
            if task == Task.PRETRAIN_PHASE_TWO:
                print("WARNING: Phase 2 pre-training should be done from a completed Phase 1 checkpoint. "
                      "Please specify the path to the Phase 1 checkpoint with 'pretrained_ckpt_path' in the config or "
                      "as a command line argument.")

        if num_pipeline_stages > 1:
            # Configure pipeline stages.
            pipeline_assigner = PipelineStagesAssigner(PIPELINE_ALLOCATE_PREVIOUS,
                                                       PIPELINE_NAMES)
            assignments = model.get_pipeline_stage_assignment()
            assignments = pipeline_assigner.assign_pipeline_stages(assignments,
                                                                   pipeline_stages)
            model.set_pipeline_stage_assignment(assignments)
            model.print_pipeline_stage_assignment_summary()
            poplar_options_per_pipeline_stage = get_poplar_options_per_pipeline_stage(
                num_pipeline_stages,
                device_mapping,
                config["matmul_available_memory_proportion_per_pipeline_stage"],
                config["matmul_partials_type"])
            model.set_pipelining_options(
                gradient_accumulation_steps_per_replica=batch_config.gradient_accumulation_count,
                pipeline_schedule=ipu.ops.pipelining_ops.PipelineSchedule.Grouped,
                device_mapping=device_mapping,
                offload_weight_update_variables=config["optimizer_state_offchip"],
                replicated_optimizer_state_sharding=replicated_tensor_sharding,
                forward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage,
                backward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage,
                recomputation_mode=ipu.pipelining_ops.RecomputationMode.RecomputeAndBackpropagateInterleaved,
            )

        # Prepare losses and wrap them in an out-feed queue.
        nsp_loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
        nsp_loss = wrap_loss_in_enqueuer(NSPLossFunction,
                                         nsp_loss_outfeed_queue,
                                         ["nsp_loss_average"])()
        mlm_loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
        mlm_loss = wrap_loss_in_enqueuer(MLMLossFunction,
                                         mlm_loss_outfeed_queue,
                                         ["mlm_loss_average"])()
        # Prepare learning rate and wrap it in an out-feed queue.
        config['learning_rate']['lr_schedule_params']['total_steps'] = batch_config.num_train_steps
        lr_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue(outfeed_mode=ipu.ipu_outfeed_queue.IPUOutfeedMode.LAST)
        lr_scheduler = get_lr_scheduler(scheduler_name=config["learning_rate"]["lr_schedule"],
                                        schedule_params=config["learning_rate"]["lr_schedule_params"],
                                        queue=lr_outfeed_queue)
        # Prepare optimizer.
        outline_optimizer_apply_gradients = False if replicated_tensor_sharding else use_outlining
        optimizer = get_optimizer(
            optimizer_opts["name"],
            grad_acc_steps_per_replica,
            replicas,
            lr_scheduler,
            outline_optimizer_apply_gradients,
            loss_scaling=config["loss_scaling"],
            weight_decay_rate=optimizer_opts["params"]["weight_decay_rate"],
        )
        # Compile the model.
        model.compile(
            optimizer=optimizer,
            loss={"nsp___cls": nsp_loss,
                  "mlm___cls": mlm_loss},
            steps_per_execution=batch_config.steps_per_execution,
        )
        # Set up callbacks
        callbacks = CallbackFactory.get_callbacks(
            universal_run_name=universal_run_name,
            batch_config=batch_config,
            model=model,
            checkpoint_path=save_ckpt_path,
            ckpt_every_n_steps_per_execution=ckpt_every_n_steps_per_execution,
            outfeed_queues=[lr_outfeed_queue,
                            nsp_loss_outfeed_queue,
                            mlm_loss_outfeed_queue],
            config=config,
        )
        # Print configs to be logged in wandb's terminal.
        print(config)
        print(f"Training batch config:\n{batch_config}")
        # Train the model
        # In order to achieve a specific number of steps, we set the number of
        # epochs to 1 and the steps per epoch to the number of steps we require.
        print("Forcing `model.fit` to run a particular number of steps by"
              " running a single 'epoch' with the number of steps we"
              " require. This allows running a fraction of actual"
              " epochs.")
        history = model.fit(dataset,
                            steps_per_epoch=batch_config.total_num_micro_batches,
                            epochs=1,
                            callbacks=callbacks)
        return history
Exemplo n.º 11
0
def fine_tune_squad(**config):
    # Get checkpoint from Hugging Face's models repo or from our own checkpoint.
    bert_model_name = config.get("bert_model_name", None)
    pretrained_ckpt_path = config.get("pretrained_ckpt_path", None)  # This can be passed from both cli and config file.
    assert bert_model_name is not None or pretrained_ckpt_path is not None, \
        "SQuAD requires a pretrained model, either `bert_model_name` (via config file) or `pretrained_ckpt_path` " \
        "(via config file or `--pretrained-ckpt-path` command line argument) but none provided."
    assert (bert_model_name is not None and pretrained_ckpt_path is None) \
        or (bert_model_name is None and pretrained_ckpt_path is not None), \
        f"Only one checkpoint is accepted, but two provided: `bert_model_name`={bert_model_name}, " \
        f"and `pretrained_ckpt_path`={pretrained_ckpt_path}."
    if pretrained_ckpt_path is not None:
        bert_config_params = config["bert_config"]

    # Get required options
    micro_batch_size = config["micro_batch_size"]
    num_epochs = config["num_epochs"]
    optimizer_opts = config["optimizer_opts"]
    learning_rate = config['learning_rate']
    replicas = config["replicas"]
    grad_acc_steps_per_replica = config["grad_acc_steps_per_replica"]
    wandb_opts = config["wandb_opts"]
    use_outlining = config["use_outlining"]
    replace_layers = config["replace_layers"]
    enable_recomputation = config["enable_recomputation"]
    embedding_serialization_factor = config["embedding_serialization_factor"]
    optimizer_state_offchip = config["optimizer_state_offchip"]
    matmul_available_memory_proportion_per_pipeline_stage = config[
        "matmul_available_memory_proportion_per_pipeline_stage"]
    matmul_partials_type = config["matmul_partials_type"]
    pipeline_stages = config["pipeline_stages"]
    device_mapping = config["device_mapping"]
    global_batches_per_log = config["global_batches_per_log"]
    seed = config["seed"]
    cache_dir = config["cache_dir"]
    output_dir = config["output_dir"]

    # Get optional options
    save_ckpt_path = config.get("save_ckpt_path", Path(__file__).parent.joinpath("checkpoints").absolute())
    ckpt_every_n_steps_per_execution = config.get("ckpt_every_n_steps_per_execution", 2000)

    universal_run_name = config.get("name", f"{Path(config['config']).stem}-{wandb_opts['init']['name']}")
    universal_run_name += f"-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    print(f"Universal name for run: {universal_run_name}")
    set_random_seeds(seed)
    num_pipeline_stages = len(device_mapping)
    num_ipus_per_replicas = max(device_mapping) + 1
    num_ipus = replicas * num_ipus_per_replicas

    # Load training and validation data
    # =================================
    train_dataset, eval_dataset, num_train_samples, num_eval_samples, raw_datasets = get_squad_data(
        micro_batch_size,
        cache_dir
    )
    train_batch_config = BatchConfig(micro_batch_size=micro_batch_size,
                                     total_num_train_samples=num_epochs * num_train_samples.numpy(),
                                     num_replicas=replicas,
                                     gradient_accumulation_count=grad_acc_steps_per_replica,
                                     dataset_size=num_train_samples.numpy(),
                                     global_batches_per_log=global_batches_per_log,
                                     task=Task.OTHER)

    # Create model
    # ============
    policy = tf.keras.mixed_precision.Policy("float16")
    tf.keras.mixed_precision.set_global_policy(policy)
    strategy = create_ipu_strategy(num_ipus, enable_recomputation=enable_recomputation)
    with strategy.scope():
        # Instantiate the pretrained model given in the config.
        if bert_model_name is not None:
            model = TFBertForQuestionAnswering.from_pretrained(bert_model_name)
        else:
            bert_config = BertConfig(**bert_config_params, hidden_act=ipu.nn_ops.gelu)
            model = TFBertForQuestionAnswering(config=bert_config)

        # Convert subclass model to functional, expand main layers to enable pipelining, and replace some layers to
        # optimise performance.
        model = convert_tf_bert_model(
            model,
            train_dataset,
            post_process_bert_input_layer,
            replace_layers=replace_layers,
            use_outlining=use_outlining,
            embedding_serialization_factor=embedding_serialization_factor,
            rename_outputs={'tf.compat.v1.squeeze': 'start_positions', 'tf.compat.v1.squeeze_1': 'end_positions'}
        )
        # Load from pretrained checkpoint if requested.
        if pretrained_ckpt_path is not None:
            print(f"Attempting to load pretrained checkpoint from path {pretrained_ckpt_path}. "
                  f"This will overwrite the current weights")
            load_checkpoint_into_model(model, pretrained_ckpt_path)

        # Configure pipeline stages
        # =========================
        if num_pipeline_stages > 1:
            pipeline_assigner = PipelineStagesAssigner(PIPELINE_ALLOCATE_PREVIOUS, PIPELINE_NAMES)
            assignments = model.get_pipeline_stage_assignment()
            assignments = pipeline_assigner.assign_pipeline_stages(assignments, pipeline_stages)
            model.set_pipeline_stage_assignment(assignments)
            model.print_pipeline_stage_assignment_summary()
            poplar_options_per_pipeline_stage = get_poplar_options_per_pipeline_stage(
                num_ipus_per_replicas,
                device_mapping,
                matmul_available_memory_proportion_per_pipeline_stage,
                matmul_partials_type
            )
            model.set_pipelining_options(
                gradient_accumulation_steps_per_replica=grad_acc_steps_per_replica,
                pipeline_schedule=ipu.ops.pipelining_ops.PipelineSchedule.Grouped,
                device_mapping=device_mapping,
                offload_weight_update_variables=optimizer_state_offchip,
                forward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage,
                backward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage,
                recomputation_mode=ipu.pipelining_ops.RecomputationMode.RecomputeAndBackpropagateInterleaved,
            )

        # Compile the model for training
        # ==============================
        # Wrap loss in an out-feed queue.
        loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
        qa_loss = wrap_loss_in_enqueuer(QuestionAnsweringLossFunction,
                                        loss_outfeed_queue,
                                        ["end_positions_loss", "start_positions_loss"])()
        # Define optimiser with polynomial decay learning rate.
        learning_rate['lr_schedule_params']['total_steps'] = train_batch_config.num_train_steps
        lr_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue(outfeed_mode=ipu.ipu_outfeed_queue.IPUOutfeedMode.LAST)
        lr_scheduler = get_lr_scheduler(scheduler_name=learning_rate["lr_schedule"],
                                        schedule_params=learning_rate["lr_schedule_params"],
                                        queue=lr_outfeed_queue)
        # Prepare optimizer.
        outline_optimizer_apply_gradients = use_outlining
        optimizer = get_optimizer(
            optimizer_opts["name"],
            grad_acc_steps_per_replica,
            replicas,
            lr_scheduler,
            outline_optimizer_apply_gradients,
            weight_decay_rate=optimizer_opts["params"]["weight_decay_rate"],
        )
        # Compile the model.
        model.compile(
            optimizer=optimizer,
            loss={"end_positions": qa_loss, "start_positions": qa_loss},
            metrics='accuracy',
            steps_per_execution=train_batch_config.steps_per_execution
        )

        # Train the model
        # ===============
        # Set up callbacks
        callbacks = CallbackFactory.get_callbacks(
            universal_run_name=universal_run_name,
            batch_config=train_batch_config,
            model=model,
            checkpoint_path=save_ckpt_path,
            ckpt_every_n_steps_per_execution=ckpt_every_n_steps_per_execution,
            outfeed_queues=[lr_outfeed_queue, loss_outfeed_queue],
            config=config,
        )
        # Print configs to be logged in wandb's terminal.
        print(config)
        print(f"Training batch config:\n{train_batch_config}")
        # Train the model
        history = model.fit(
            train_dataset,
            steps_per_epoch=train_batch_config.num_micro_batches_per_epoch,
            epochs=num_epochs,
            callbacks=callbacks
        )

    # Evaluate the model on the validation set
    # ========================================
    # Prepare the dataset to be evaluated in the IPU.
    eval_batch_config = BatchConfig(micro_batch_size=micro_batch_size,
                                    total_num_train_samples=num_eval_samples.numpy(),
                                    num_replicas=replicas,
                                    gradient_accumulation_count=grad_acc_steps_per_replica,
                                    dataset_size=num_eval_samples.numpy(),
                                    task=Task.OTHER)
    max_eval_samples = eval_batch_config.micro_batch_size * eval_batch_config.num_micro_batches_per_epoch
    eval_pred_dataset = get_prediction_dataset(eval_dataset, max_eval_samples)
    with strategy.scope():
        # Re-compile the model for prediction if needed.
        if train_batch_config.steps_per_execution != eval_batch_config.steps_per_execution:
            model.compile(steps_per_execution=eval_batch_config.steps_per_execution)
        # Get predictions for the validation data.
        print(f"Running inference:\nGenerating predictions on the validation data...")
        predictions = model.predict(
            eval_pred_dataset,
            batch_size=eval_batch_config.micro_batch_size
        )
    # The predictions for the end position goes first in the model outputs tuple (note the output of model.summary()).
    end_predictions, start_predictions = predictions
    # Match the predictions to answers in the original context.
    # This will also write out the predictions to a json file in the directory given by `output_dir`.
    final_predictions = postprocess_qa_predictions(
        list(raw_datasets["validation"].as_numpy_iterator()),
        list(eval_dataset.unbatch().as_numpy_iterator()),
        (start_predictions, end_predictions),
        output_dir=output_dir
    )
    # Format the predictions and the actual labels as expected by the metric.
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
    formatted_labels = format_raw_data_for_metric(raw_datasets["validation"])
    metric = load_metric("squad")
    metrics = metric.compute(predictions=formatted_predictions, references=formatted_labels)
    print("Evaluation metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value:.3f}")

    return history