debug_bert_squad_base = deepcopy(finetuning_bert_100k_glue_get_info)
debug_bert_squad_base.update(
    # Model Args
    model_name_or_path="bert-base-cased",
    finetuning=True,
    task_names=None,
    task_name="squad",
    dataset_name="squad",
    dataset_config_name="plain_text",
    trainer_class=QuestionAnsweringTrainer,
    max_seq_length=128,
    do_train=True,
    do_eval=True,
    do_predict=False,
    trainer_callbacks=[
        TrackEvalMetrics(),
    ],
    max_steps=100,
    eval_steps=20,
    rm_checkpoints=True,
    load_best_model_at_end=True,
    warmup_ratio=0.)

# Supposed to train in about 24 minutes, takes an hour though
# Expect f1 score of 88.52, exact_match of 81.22
bert_squad_replication = deepcopy(debug_bert_squad_base)
bert_squad_replication.update(
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    num_train_epochs=2,
    max_seq_length=384,
# ---------
# Tiny BERT
# ---------

# This combines KD + RigL + OneCycle LR on Tiny BERT.
tiny_bert_trifecta_300k = deepcopy(tiny_bert_sparse_100k)
tiny_bert_trifecta_300k.update(
    max_steps=300000,
    model_type="fully_static_sparse_bert",
    overwrite_output_dir=True,

    # Sparsity callback
    trainer_callbacks=[
        RezeroWeightsCallback(),
        TrackEvalMetrics(),
    ],
    fp16=True,

    trainer_class=KDRigLOneCycleLRTrainer,
    trainer_mixin_args=dict(

        # One cycle lr
        max_lr=0.0075,
        pct_start=0.3,
        anneal_strategy="linear",
        cycle_momentum=True,
        base_momentum=0.85,
        max_momentum=0.95,
        div_factor=25,
        final_div_factor=1e4,
Exemple #3
0
    do_train=True,
    do_eval=True,
    do_predict=True,
    eval_steps=15,
    evaluation_strategy="steps",
    load_best_model_at_end=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    max_steps=45,  # made very short for fast debugging
    metric_for_best_model="eval_accuracy",
    num_runs=3,
    rm_checkpoints=True,
    trainer_callbacks=[
        TrackEvalMetrics(),
    ],
)

debug_finetuning_mnli = deepcopy(debug_finetuning)
debug_finetuning_mnli.update(
    task_names=["mnli", "wnli", "rte"],
    trainer_callbacks=[TrackEvalMetrics()],
    num_runs=2,
    task_hyperparams=dict(
        mnli=dict(
            trainer_class=MultiEvalSetTrainer,
            # deliberately incorrect metric - could should fix for you.
            metric_for_best_model="mm_accuracy",
            trainer_mixin_args=dict(
                eval_sets=["validation_matched", "validation_mismatched"],
)

hp_search_finetuning_trifecta_90_100k_small_tasks = deepcopy(
    hp_search_finetuning_trifecta_85_100k_small_tasks)
hp_search_finetuning_trifecta_90_100k_small_tasks.update(
    model_name_or_path=
    "/mnt/efs/results/pretrained-models/transformers-local/bert_sparse_90%_trifecta_100k"  # noqa
)

hp_search_finetuning_bert_100k_small_tasks = deepcopy(
    hp_search_finetuning_trifecta_85_100k_small_tasks)
hp_search_finetuning_bert_100k_small_tasks.update(
    model_type="bert",
    model_name_or_path=
    "/mnt/efs/results/pretrained-models/transformers-local/bert_100k",  # noqa
    trainer_callbacks=[TrackEvalMetrics()],
)

hp_search_finetuning_trifecta_2x_small_tasks = deepcopy(
    hp_search_finetuning_trifecta_80_100k_small_tasks)
hp_search_finetuning_trifecta_2x_small_tasks.update(
    model_name_or_path=
    "/mnt/efs/results/pretrained-models/transformers-local/bert_sparse_2x_trifecta_100k"  # noqa
)

hp_search_finetuning_bert_100k_big_tasks = deepcopy(
    hp_search_finetuning_trifecta_85_100k_big_tasks)
hp_search_finetuning_bert_100k_big_tasks.update(
    model_type="bert",
    model_name_or_path=
    "/mnt/efs/results/pretrained-models/transformers-local/bert_100k",  # noqa