示例#1
0
def load_qa_from_pretrained(
    model: Optional[tf.keras.Model] = None,
    name: Optional[str] = None,
    path: Optional[str] = None,  # path to checkpoint from TF...ForPreTraining
    config: Optional[PretrainedConfig] = None,
) -> tf.keras.Model:
    """
    Load a TF...QuestionAnswering model by taking the main layer of a pretrained model.
    Preserves the model.config attribute.
    """
    assert (bool(name) ^ bool(model) ^ (bool(path) and bool(config))
            ), "Pass either name, model, or (path and config)"

    if name is not None:
        return TFAutoModelForQuestionAnswering.from_pretrained(name)
    elif model is not None:
        pretrained_model = model
    elif path is not None:
        pretrained_model = TFAutoModelForPreTraining.from_config(config)
        pretrained_model.load_weights(path)

    qa_model = TFAutoModelForQuestionAnswering.from_config(
        pretrained_model.config)
    pretrained_main_layer = getattr(pretrained_model,
                                    qa_model.base_model_prefix)
    assert (
        pretrained_main_layer is not None
    ), f"{pretrained_model} has no attribute '{model.base_model_prefix}'"
    # Generalized way of saying `model.albert = pretrained_model.albert`
    setattr(qa_model, qa_model.base_model_prefix, pretrained_main_layer)
    return qa_model
示例#2
0
def run_squad_and_get_results(
    run_name: str,
    fsx_prefix: str,
    pre_layer_norm: bool,
    model_size: str,
    load_from: Union[str, tf.keras.Model],
    load_step: int,
    batch_size: int,
    checkpoint_frequency: Optional[int],
    validate_frequency: Optional[int],
    learning_rate: float,
    warmup_steps: int,
    total_steps: int,
    dataset: str,
    dummy_eval: bool = False,
    config: Optional[PretrainedConfig] = None,
) -> Dict:
    checkpoint_frequency = checkpoint_frequency or 1000000
    validate_frequency = validate_frequency or 1000000

    if isinstance(load_from, tf.keras.Model):
        config = load_from.config
    assert config is not None, "config may not be None"

    # Instantiate QuestionAnswering model
    if isinstance(load_from, TFPreTrainedModel):
        model = load_qa_from_pretrained(model=load_from)
    elif load_from == "scratch":
        model = TFAutoModelForQuestionAnswering.from_config(config)
    elif load_from == "huggingface":
        model = load_qa_from_pretrained(name=f"albert-{model_size}-v2")
    else:
        raise ValueError(
            f"'load_from' is '{load_from}'; must be in ['scratch', 'huggingface', 'amazon']"
        )

    tokenizer = get_tokenizer()

    schedule = LinearWarmupPolyDecaySchedule(
        max_learning_rate=learning_rate,
        end_learning_rate=0,
        warmup_steps=warmup_steps,
        total_steps=total_steps,
    )
    optimizer = tfa.optimizers.AdamW(weight_decay=0.0, learning_rate=schedule)
    optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        optimizer, loss_scale="dynamic"
    )  # AMP

    model.call = wrap_tf_function_idempotent(model.call)

    if dataset == "squadv1":
        train_filename = "train-v1.1.json"
        val_filename = "dev-v1.1.json"
        processor = SquadV1Processor()
    elif dataset == "squadv2":
        train_filename = "train-v2.0.json"
        val_filename = "dev-v2.0.json"
        processor = SquadV2Processor()
    elif dataset == "debug":
        train_filename = "dev-v2.0.json"
        val_filename = "dev-v2.0.json"
        processor = SquadV2Processor()
    else:
        assert False, "--dataset must be one of ['squadv1', 'squadv2', 'debug']"

    data_dir = f"{fsx_prefix}/squad_data"

    train_dataset = get_dataset(
        tokenizer=tokenizer,
        processor=processor,
        data_dir=data_dir,
        filename=train_filename,
        batch_size=batch_size,
        shard=True,
        shuffle=True,
        repeat=True,
        drop_remainder=True,
    )

    if hvd.rank() == 0:
        print("Starting finetuning")
        pbar = tqdm.tqdm(total_steps)
        summary_writer = None  # Only create a writer if we make it through a successful step
        val_dataset = get_dataset(
            tokenizer=tokenizer,
            processor=processor,
            data_dir=data_dir,
            filename=val_filename,
            batch_size=batch_size,
            shard=False,
            shuffle=True,
            drop_remainder=False,
        )

    # Need to re-wrap every time this function is called
    # Wrapping train_step gives an error with optimizer initialization on the second pass
    # of run_squad_and_get_results(). Bug report at https://github.com/tensorflow/tensorflow/issues/38875
    # Discussion at https://github.com/tensorflow/tensorflow/issues/27120
    wrapped_train_step = tf.function(train_step)
    for step, batch in enumerate(train_dataset):
        learning_rate = schedule(step=tf.constant(step, dtype=tf.float32))
        loss, acc, exact_match, f1, precision, recall = wrapped_train_step(
            model=model, optimizer=optimizer, batch=batch
        )

        # Broadcast model after the first step so parameters and optimizer are initialized
        if step == 0:
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)

        is_final_step = step >= total_steps - 1
        if hvd.rank() == 0:
            do_checkpoint = (step % checkpoint_frequency == 0) or is_final_step
            do_validate = (step % validate_frequency == 0) or is_final_step

            pbar.update(1)
            description = f"Loss: {loss:.3f}, Acc: {acc:.3f}, EM: {exact_match:.3f}, F1: {f1:.3f}"
            pbar.set_description(description)

            if do_validate:
                print("Running validation")
                (
                    val_loss,
                    val_acc,
                    val_exact_match,
                    val_f1,
                    val_precision,
                    val_recall,
                ) = run_validation(model=model, val_dataset=val_dataset)
                description = (
                    f"Step {step} validation - Loss: {val_loss:.3f}, Acc: {val_acc:.3f}, "
                    f"EM: {val_exact_match:.3f}, F1: {val_f1:.3f}"
                )
                print(description)
                print("Running evaluation")
                if dummy_eval:
                    results = {
                        "exact": 0.8169797018445212,
                        "f1": 4.4469722448269335,
                        "total": 11873,
                        "HasAns_exact": 0.15182186234817813,
                        "HasAns_f1": 7.422216845956518,
                        "HasAns_total": 5928,
                        "NoAns_exact": 1.4802354920100924,
                        "NoAns_f1": 1.4802354920100924,
                        "NoAns_total": 5945,
                        "best_exact": 50.07159100480081,
                        "best_exact_thresh": 0.0,
                        "best_f1": 50.0772059855695,
                        "best_f1_thresh": 0.0,
                    }
                else:
                    results: Dict = get_evaluation_metrics(
                        model=model, data_dir=data_dir, filename=val_filename, batch_size=32,
                    )
                print_eval_metrics(results=results, step=step)

            if do_checkpoint:
                checkpoint_path = (
                    f"{fsx_prefix}/checkpoints/albert-squad/{run_name}-step{step}.ckpt"
                )
                print(f"Saving checkpoint at {checkpoint_path}")
                model.save_weights(checkpoint_path)

            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    f"{fsx_prefix}/logs/albert-squad/{run_name}"
                )
            with summary_writer.as_default():
                tf.summary.scalar("learning_rate", learning_rate, step=step)
                tf.summary.scalar("train_loss", loss, step=step)
                tf.summary.scalar("train_acc", acc, step=step)
                tf.summary.scalar("train_exact", exact_match, step=step)
                tf.summary.scalar("train_f1", f1, step=step)
                tf.summary.scalar("train_precision", precision, step=step)
                tf.summary.scalar("train_recall", recall, step=step)
                if do_validate:
                    tf.summary.scalar("val_loss", val_loss, step=step)
                    tf.summary.scalar("val_acc", val_acc, step=step)
                    tf.summary.scalar("val_exact", val_exact_match, step=step)
                    tf.summary.scalar("val_f1", val_f1, step=step)
                    tf.summary.scalar("val_precision", val_precision, step=step)
                    tf.summary.scalar("val_recall", val_recall, step=step)
                    # And the eval metrics
                    tensorboard_eval_metrics(
                        summary_writer=summary_writer, results=results, step=step
                    )

        if is_final_step:
            break

    # Can we return a value only on a single rank?
    if hvd.rank() == 0:
        pbar.close()
        print(f"Finished finetuning, job name {run_name}")
        return results
    pbar.close()

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--pre_layer_norm", type=str, choices=["true"])
    args = parser.parse_args()

    # Load finetuned model from checkpoint
    config = AutoConfig.from_pretrained("albert-base-v2")
    config.pre_layer_norm = args.pre_layer_norm == "true"
    model = TFAutoModelForQuestionAnswering.from_config(config)

    # XLA, AMP, tf.function
    tf.config.optimizer.set_jit(True)
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
    model.call = tf.function(model.call)

    # Get validation dataset
    data_dir = "/fsx/squad_data"
    train_filename = "train-v2.0.json"
    val_filename = "dev-v2.0.json"

    results = get_evaluation_metrics(
        model=model, data_dir=data_dir, filename=val_filename, batch_size=args.batch_size
    )
    print(dict(results))