def test_training_load_best_model_at_end_adapter(self):
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        data_args = GlueDataTrainingArguments(
            task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
        )
        train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
        eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

        model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
        model.add_adapter("adapter")
        model.train_adapter("adapter")

        training_args = TrainingArguments(
            output_dir="./examples",
            do_train=True,
            learning_rate=0.001,
            max_steps=1,
            save_steps=1,
            remove_unused_columns=False,
            load_best_model_at_end=True,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            num_train_epochs=2,
        )
        trainer = AdapterTrainer(
            model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset
        )
        with self.assertLogs(logger) as cm:
            trainer.train()
            self.assertTrue(any("Loading best adapter(s) from" in line for line in cm.output))
        self.assertEqual(Stack("adapter"), trainer.model.active_adapters)
示例#2
0
def train_transformer(config, checkpoint_dir=None):
    data_args = DataTrainingArguments(task_name=config["task_name"],
                                      data_dir=config["data_dir"])
    tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
    train_dataset = GlueDataset(data_args,
                                tokenizer=tokenizer,
                                mode="train",
                                cache_dir=config["data_dir"])
    eval_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="dev",
                               cache_dir=config["data_dir"])
    eval_dataset = eval_dataset[:len(eval_dataset) // 2]
    training_args = TrainingArguments(
        output_dir=tune.get_trial_dir(),
        learning_rate=config["learning_rate"],
        do_train=True,
        do_eval=True,
        evaluate_during_training=True,
        eval_steps=(len(train_dataset) // config["per_gpu_train_batch_size"]) +
        1,
        # We explicitly set save to 0, and do saving in evaluate instead
        save_steps=0,
        num_train_epochs=config["num_epochs"],
        max_steps=config["max_steps"],
        per_device_train_batch_size=config["per_gpu_train_batch_size"],
        per_device_eval_batch_size=config["per_gpu_val_batch_size"],
        warmup_steps=0,
        weight_decay=config["weight_decay"],
        logging_dir="./logs",
    )

    # Arguments for W&B.
    name = tune.get_trial_name()
    wandb_args = {
        "project_name": "transformers_pbt",
        "watch": "false",  # Either set to gradient, false, or all
        "run_name": name,
    }

    tune_trainer = get_trainer(recover_checkpoint(checkpoint_dir,
                                                  config["model_name"]),
                               train_dataset,
                               eval_dataset,
                               config["task_name"],
                               training_args,
                               wandb_args=wandb_args)
    tune_trainer.train(recover_checkpoint(checkpoint_dir,
                                          config["model_name"]))
示例#3
0
def test_best_model(analysis, model_name, task_name, data_dir):
    data_args = DataTrainingArguments(task_name=task_name, data_dir=data_dir)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    best_config = analysis.get_best_config(metric="eval_acc", mode="max")
    print(best_config)
    best_checkpoint = recover_checkpoint(
        analysis.get_best_trial(metric="eval_acc",
                                mode="max").checkpoint.value)
    print(best_checkpoint)
    best_model = AutoModelForSequenceClassification.from_pretrained(
        best_checkpoint).to("cuda")

    test_args = TrainingArguments(output_dir="./best_model_results", )
    test_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="dev",
                               cache_dir=data_dir)
    test_dataset = test_dataset[len(test_dataset) // 2:]

    test_trainer = Trainer(best_model,
                           test_args,
                           compute_metrics=build_compute_metrics_fn(task_name))

    metrics = test_trainer.evaluate(test_dataset)
    print(metrics)
    def test_train_single_adapter(self):
        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name,
                                                  use_fast=False)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model = AutoModelWithHeads.from_config(self.config())

        # add two adapters: one will be trained and the other should be frozen
        model.add_adapter("mrpc")
        model.add_adapter("dummy")
        model.add_classification_head("mrpc")

        self.assertIn("mrpc", model.config.adapters.adapters)
        self.assertIn("dummy", model.config.adapters.adapters)

        # train the mrpc adapter -> should be activated & unfreezed
        model.train_adapter("mrpc")
        self.assertEqual(set(["mrpc"]), model.active_adapters.flatten())

        # all weights of the adapter should be activated
        for k, v in filter_parameters(model, "adapters.mrpc.").items():
            self.assertTrue(v.requires_grad, k)
        # all weights of the adapter not used for training should be freezed
        for k, v in filter_parameters(model, "adapters.dummy.").items():
            self.assertFalse(v.requires_grad, k)
        # weights of the model should be freezed (check on some examples)
        for k, v in filter_parameters(model,
                                      "encoder.layer.0.attention").items():
            self.assertFalse(v.requires_grad, k)

        state_dict_pre = copy.deepcopy(model.state_dict())

        # setup dataset
        data_args = GlueDataTrainingArguments(
            task_name="mrpc",
            data_dir="./tests/fixtures/tests_samples/MRPC",
            overwrite_cache=True)
        train_dataset = GlueDataset(data_args,
                                    tokenizer=tokenizer,
                                    mode="train")
        training_args = TrainingArguments(output_dir="./examples",
                                          do_train=True,
                                          learning_rate=0.1,
                                          max_steps=7,
                                          no_cuda=True)

        # evaluate
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
        )
        trainer.train()

        for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(),
                                        model.state_dict().items()):
            if "mrpc" in k1:
                self.assertFalse(torch.equal(v1, v2))
            else:
                self.assertTrue(torch.equal(v1, v2))
示例#5
0
    def test_resume_training(self):

        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        data_args = GlueDataTrainingArguments(
            task_name="mrpc",
            data_dir="./tests/fixtures/tests_samples/MRPC",
            overwrite_cache=True)
        train_dataset = GlueDataset(data_args,
                                    tokenizer=tokenizer,
                                    mode="train")

        model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased")
        model.add_adapter("adapter")
        model.add_adapter("additional_adapter")
        model.set_active_adapters("adapter")

        training_args = TrainingArguments(
            output_dir="./examples",
            do_train=True,
            learning_rate=0.1,
            logging_steps=1,
            max_steps=1,
            save_steps=1,
            remove_unused_columns=False,
        )
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            do_save_adapters=True,
            do_save_full_model=False,
        )

        trainer.train()
        # create second model that should resume the training of the first
        model_resume = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased")
        model_resume.add_adapter("adapter")
        model_resume.add_adapter("additional_adapter")
        model_resume.set_active_adapters("adapter")
        trainer_resume = Trainer(
            model=model_resume,
            args=TrainingArguments(do_train=True,
                                   max_steps=1,
                                   output_dir="./examples"),
            train_dataset=train_dataset,
        )
        trainer_resume.train(resume_from_checkpoint=True)

        self.assertEqual(model.config.adapters.adapters,
                         model_resume.config.adapters.adapters)

        for ((k1, v1), (k2,
                        v2)) in zip(trainer.model.state_dict().items(),
                                    trainer_resume.model.state_dict().items()):
            self.assertEqual(k1, k2)
            if "adapter" in k1:
                self.assertTrue(torch.equal(v1, v2), k1)
示例#6
0
 def test_default_classification(self):
     MODEL_ID = "bert-base-cased-finetuned-mrpc"
     tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
     data_args = GlueDataTrainingArguments(
         task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
     )
     dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
     data_collator = DefaultDataCollator()
     batch = data_collator.collate_batch(dataset.features)
     self.assertEqual(batch["labels"].dtype, torch.long)
示例#7
0
 def test_default_regression(self):
     MODEL_ID = "distilroberta-base"
     tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
     data_args = GlueDataTrainingArguments(
         task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
     )
     dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
     data_collator = DefaultDataCollator()
     batch = data_collator.collate_batch(dataset.features)
     self.assertEqual(batch["labels"].dtype, torch.float)
示例#8
0
 def setUpClass(self):
     self.MODEL_ID = "albert-base-v2"
     self.data_args = DataTrainingArguments(
         task_name="mrpc",
         data_dir="./tests/fixtures/tests_samples/MRPC",
         overwrite_cache=True,
     )
     self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_ID)
     self.dataset = GlueDataset(self.data_args, self.tokenizer, mode="dev")
     self.config = AutoConfig.from_pretrained(
         self.MODEL_ID, num_labels=3, finetuning_task="mrpc")
     self.dataloader = DataLoader(self.dataset, batch_size=2, collate_fn=default_data_collator)
示例#9
0
def load_datasets(data_args, model_args):
    # set tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=None,
    )
    # Get datasets
    train_dataset = GlueDataset(data_args,
                                tokenizer=tokenizer,
                                cache_dir=model_args.cache_dir)
    eval_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="dev",
                               cache_dir=model_args.cache_dir)
    test_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="test",
                               cache_dir=model_args.cache_dir)

    return train_dataset, eval_dataset, test_dataset
    def test_load_task_adapter_from_hub(self):
        """This test checks if an adapter is loaded from the Hub correctly by evaluating it on some MRPC samples
        and comparing with the expected result.
        """
        for config in ["pfeiffer", "houlsby"]:
            with self.subTest(config=config):
                tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
                model = BertForSequenceClassification.from_pretrained(
                    "bert-base-uncased")

                loading_info = {}
                adapter_name = model.load_adapter("sts/mrpc@ukp",
                                                  config=config,
                                                  version="1",
                                                  loading_info=loading_info)
                model.train_adapter(adapter_name)

                self.assertEqual(0, len(loading_info["missing_keys"]))
                self.assertEqual(0, len(loading_info["unexpected_keys"]))

                self.assertIn(adapter_name, model.config.adapters.adapters)
                self.assertNotIn(adapter_name,
                                 model.base_model.invertible_adapters)

                # check if config is valid
                expected_hash = get_adapter_config_hash(
                    AdapterConfig.load(config))
                real_hash = get_adapter_config_hash(
                    model.config.adapters.get(adapter_name))
                self.assertEqual(expected_hash, real_hash)

                # setup dataset
                data_args = GlueDataTrainingArguments(
                    task_name="mrpc",
                    data_dir="./tests/fixtures/tests_samples/MRPC",
                    overwrite_cache=True)
                eval_dataset = GlueDataset(data_args,
                                           tokenizer=tokenizer,
                                           mode="dev")
                training_args = TrainingArguments(output_dir="./examples",
                                                  no_cuda=True)

                # evaluate
                trainer = Trainer(
                    model=model,
                    args=training_args,
                    eval_dataset=eval_dataset,
                    compute_metrics=self._compute_glue_metrics("mrpc"),
                    adapter_names=["mrpc"],
                )
                result = trainer.evaluate()
                self.assertGreater(result["eval_acc"], 0.9)
    def test_train_adapter_fusion(self):
        for model_name in self.model_names:
            with self.subTest(model_name=model_name):
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForSequenceClassification.from_pretrained(model_name)

                # load the adapters to be fused
                model.load_adapter("sts/mrpc@ukp", with_head=False)
                model.load_adapter("sts/qqp@ukp", with_head=False)
                model.load_adapter("sts/sts-b@ukp", with_head=False)

                self.assertIn("mrpc", model.config.adapters.adapters)
                self.assertIn("qqp", model.config.adapters.adapters)
                self.assertIn("sts-b", model.config.adapters.adapters)

                # setup fusion
                adapter_setup = [["mrpc", "qqp", "sts-b"]]
                model.add_fusion(adapter_setup[0])
                model.train_fusion(adapter_setup[0])
                model.set_active_adapters(adapter_setup)
                self.assertEqual(adapter_setup, model.active_adapters)

                # all weights of the adapters should be frozen (test for one)
                for k, v in filter_parameters(model, "text_task_adapters.mrpc").items():
                    self.assertFalse(v.requires_grad, k)
                # all weights of the fusion layer should be activated
                for k, v in filter_parameters(model, "adapter_fusion_layer").items():
                    self.assertTrue(v.requires_grad, k)
                # weights of the model should be freezed (check on some examples)
                for k, v in filter_parameters(model, "encoder.layer.0.attention").items():
                    self.assertFalse(v.requires_grad, k)

                state_dict_pre = copy.deepcopy(model.state_dict())

                # setup dataset
                data_args = GlueDataTrainingArguments(
                    task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
                )
                train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
                training_args = TrainingArguments(
                    output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=5, no_cuda=True
                )

                # evaluate
                trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset,)
                trainer.train()

                for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()):
                    if "adapter_fusion_layer" in k1 or "classifier" in k1:
                        self.assertFalse(torch.equal(v1, v2), k1)
                    else:
                        self.assertTrue(torch.equal(v1, v2), k1)
示例#12
0
    def test_trainer_eval_mrpc(self):
        MODEL_ID = "bert-base-cased-finetuned-mrpc"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
        data_args = GlueDataTrainingArguments(
            task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
        )
        eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

        training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
        trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
        result = trainer.evaluate()
        self.assertLess(result["eval_loss"], 0.2)
示例#13
0
 def test_meta_dataset(self):
     data_args = DataTrainingArguments(
         task_name="mrpc",
         data_dir="./tests/fixtures/tests_samples/MRPC",
         overwrite_cache=True,
     )
     train_dataset = GlueDataset(data_args, tokenizer=self.tokenizer)
     meta_dataset = MetaDataset(train_dataset)
     self.assertEqual(len(meta_dataset[1000]), 2)
     self.assertEqual(meta_dataset[1000][0]["input_ids"].shape, torch.Size([128]))
     self.assertEqual(
         meta_dataset[1000][0]["attention_mask"].shape, torch.Size([128])
     )
     self.assertEqual(meta_dataset[1000][0]["labels"].item(), 0)
     self.assertEqual(meta_dataset[1000][1]["labels"].item(), 1)
    def test_general(self):
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        data_args = GlueDataTrainingArguments(
            task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
        )
        train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

        model = AutoModelWithHeads.from_pretrained("bert-base-uncased")

        model.add_classification_head("task", num_labels=3)

        # add the adapters to be fused
        model.add_adapter("task")
        model.add_adapter("additional_adapter")

        model.train_adapter("task")
        self.assertEqual("task", model.active_head)
        self.assertEqual(Stack("task"), model.active_adapters)
        with TemporaryDirectory() as tempdir:
            training_args = TrainingArguments(
                output_dir=tempdir,
                do_train=True,
                learning_rate=0.1,
                logging_steps=1,
                max_steps=1,
                save_steps=1,
                remove_unused_columns=False,
            )
            trainer = AdapterTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
            )

            trainer.train()

            # Check that adapters are actually saved but the full model is not
            files_dir_checkpoint = [file_or_dir for file_or_dir in os.listdir(os.path.join(tempdir, "checkpoint-1"))]
            self.assertTrue("task" in files_dir_checkpoint)
            self.assertTrue("additional_adapter" in files_dir_checkpoint)
            # Check that full model weights are not stored
            self.assertFalse("pytorch_model.bin" in files_dir_checkpoint)

            # this should always be false in the adapter trainer
            self.assertFalse(trainer.args.remove_unused_columns)
            self.assertEqual("task", model.active_head)
            self.assertEqual(Stack("task"), model.active_adapters)
def train(X_train, y_train, y_column_name, model_name=None):
    eval_dataset = y_train[y_column_name]

    model_args = ModelArguments(model_name_or_path="distilbert-base-cased", )
    global data_args
    data_args = DataTrainingArguments(task_name="mnli",
                                      data_dir="../../datasets/Newswire")
    num_labels = glue_tasks_num_labels[data_args.task_name]
    training_args = TrainingArguments(
        output_dir=model_name,
        overwrite_output_dir=True,
        do_train=True,
        do_eval=True,
        per_gpu_train_batch_size=32,
        per_gpu_eval_batch_size=128,
        num_train_epochs=1,
        logging_steps=500,
        logging_first_step=True,
        save_steps=1000,
        evaluate_during_training=True,
    )

    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
    )

    train_dataset = GlueDataset(data_args,
                                tokenizer=tokenizer,
                                limit_length=100_000)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )
    trainer.train()
示例#16
0
    def test_cluster_indices(self):
        clustering_args = Clustering_Arguments(
            batch_size=32,
            num_clusters_elements=32,
            embedding_path=self.embedding_path,
            num_clusters=8,
            cluster_output_path=self.cluster_output_path,
        )
        cluster_indices = self.clustering_proc.get_cluster_indices_by_num(
            clustering_args.num_clusters_elements
        )
        self.assertTrue(len(cluster_indices) > 10000)

        # Testing with Pytorch Dataset
        data_args = DataTrainingArguments(
            task_name="MRPC", data_dir=self.data_dir, overwrite_cache=True
        )
        tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
        train_dataset = GlueDataset(data_args, tokenizer)
        train_dataset = torch.utils.data.Subset(train_dataset, cluster_indices)
        self.assertEqual(len(train_dataset[0].input_ids), 128)
示例#17
0
def train(EXP: str, MODEL_NAME: str, TASK_NAME: str, N_LABELS: int, DELTA: float, WEIGHT_DECAY: float, DEVICE: str) -> float:
    EPOCHS         = 5
    BATCH_SIZE     = 8
    SAMPLES        = 10
    FREEZE         = True
    LOGS           = "logs"
    MAX_SEQ_LENGTH = 128
    LOADER_OPTIONS = { "num_workers": 6, "pin_memory": True }
    LR             = 2e-5
    ADAM_EPSILON   = 1e-8
    N_WARMUP_STEPS = 0
    MAX_GRAD_NORM  = 1
    DATA_DIR       = os.path.join("./dataset/glue/data", TASK_NAME)

    os.makedirs(LOGS, exist_ok=True)
    writer_path = os.path.join(LOGS, f"bayeformers_bert_glue.{EXP}")
    writer_suff = f".DELTA_{DELTA}.WEIGHT_DECAY_{WEIGHT_DECAY}"
    writer      = SummaryWriter(writer_path + writer_suff)
    
    o_model, tokenizer = setup_model(MODEL_NAME, TASK_NAME, N_LABELS)
    o_model            = o_model.to(DEVICE)

    glue          = GlueDataTrainingArguments(TASK_NAME, data_dir=DATA_DIR, max_seq_length=MAX_SEQ_LENGTH)
    train_dataset = GlueDataset(glue, tokenizer=tokenizer)
    test_dataset  = GlueDataset(glue, tokenizer=tokenizer, mode="dev")
    train_loader  = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate, **LOADER_OPTIONS)
    test_loader   = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate, **LOADER_OPTIONS)
    
    decay           = [param for name, param in o_model.named_parameters() if name     in ["bias", "LayerNorm.weight"]]
    no_decay        = [param for name, param in o_model.named_parameters() if name not in ["bias", "LayerNorm.weight"]]
    params_decay    = { "params": decay,    "weight_decay": WEIGHT_DECAY }
    params_no_decay = { "params": no_decay, "weight_decay": 0.0 }
    parameters      = [params_decay, params_no_decay]

    criterion = nn.CrossEntropyLoss().to(DEVICE)
    optim     = AdamW(parameters, lr=LR, eps=ADAM_EPSILON)
    scheduler = get_linear_schedule_with_warmup(optim, N_WARMUP_STEPS, EPOCHS)

    report = Report()
    for epoch in tqdm(range(EPOCHS), desc="Epoch"):

        # ============================ TRAIN ======================================
        o_model.train()
        report.reset()
        
        pbar = tqdm(train_loader, desc="Train")
        for inputs in pbar:
            inputs = dic2cuda(inputs, DEVICE)
            labels = inputs["labels"]

            optim.zero_grad()
            logits = o_model(**inputs)[1]
            loss   = criterion(logits.view(-1, N_LABELS), labels.view(-1))
            acc    = (torch.argmax(logits, dim=1) == labels).float().sum()

            loss.backward()
            nn.utils.clip_grad_norm_(o_model.parameters(), MAX_GRAD_NORM)
            optim.step()

            report.total += loss.item()      / len(train_loader)
            report.acc   += acc.item() * 100 / len(train_dataset)

            pbar.set_postfix(total=report.total, acc=report.acc)

        scheduler.step()
        writer.add_scalar("train_nll", report.total, epoch)
        writer.add_scalar("train_acc", report.acc,   epoch)

        # ============================ TEST =======================================
        o_model.eval()
        report.reset()
        
        with torch.no_grad():
            pbar = tqdm(test_loader, desc="Test")
            for inputs in pbar:
                inputs = dic2cuda(inputs, DEVICE)
                labels = inputs["labels"]

                logits = o_model(**inputs)[1]
                loss   = criterion(logits.view(-1, N_LABELS), labels.view(-1))
                acc    = (torch.argmax(logits, dim=1) == labels).float().sum()

                report.total += loss.item()       / len(test_loader)
                report.acc   += acc.item() * 100  / len(test_dataset)

                pbar.set_postfix(total=report.total, acc=report.acc)

        writer.add_scalar("test_nll", report.total, epoch)
        writer.add_scalar("test_acc", report.acc,   epoch)

    # ============================ EVALUTATION ====================================
    b_model                  = to_bayesian(o_model, delta=DELTA, freeze=FREEZE)
    b_model                  = b_model.to(DEVICE)

    b_model.eval()
    report.reset()

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Bayesian Eval")
        for inputs in pbar:
            inputs = dic2cuda(inputs, DEVICE)
            labels = inputs["labels"]
            B      = inputs["input_ids"].size(0)

            samples = sample_bayesian(b_model, inputs, SAMPLES, B, N_LABELS, DEVICE)
            raw_logits, logits, log_prior, log_variational_posterior = samples

            nll     = criterion(logits, labels.view(-1))            
            loss    = (log_variational_posterior - log_prior) / len(test_loader) + nll
            acc     = (torch.argmax(logits, dim=1) == labels).float().sum()
            acc_std = np.std([(torch.argmax(logits, dim=1) == labels).float().sum().item() for logits in raw_logits])

            report.total                     += loss.item()                      / len(test_loader)
            report.nll                       += nll.item()                       / len(test_loader)
            report.log_prior                 += log_prior.item()                 / len(test_loader)
            report.log_variational_posterior += log_variational_posterior.item() / len(test_loader)
            report.acc                       += acc.item() * 100                 / len(test_dataset)
            report.acc_std                   += acc_std                          / len(test_loader)

            pbar.set_postfix(
                total=report.total,
                nll=report.nll,
                log_prior=report.log_prior,
                log_variational_posterior=report.log_variational_posterior,
                acc=report.acc,
                acc_std=report.acc_std,
            )

    writer.add_scalar("bayesian_eval_nll",     report.nll,     epoch)
    writer.add_scalar("bayesian_eval_acc",     report.acc,     epoch)
    writer.add_scalar("bayesian_eval_acc_std", report.acc_std, epoch)

    decay           = [param for name, param in b_model.named_parameters() if name     in ["bias", "LayerNorm.weight"]]
    no_decay        = [param for name, param in b_model.named_parameters() if name not in ["bias", "LayerNorm.weight"]]
    params_decay    = { "params": decay,    "weight_decay": WEIGHT_DECAY }
    params_no_decay = { "params": no_decay, "weight_decay": 0.0 }
    parameters      = [params_decay, params_no_decay]

    criterion = nn.CrossEntropyLoss().to(DEVICE)
    optim     = AdamW(parameters, lr=LR, eps=ADAM_EPSILON)
    scheduler = get_linear_schedule_with_warmup(optim, N_WARMUP_STEPS, EPOCHS)

    for epoch in tqdm(range(EPOCHS), desc="Bayesian Epoch"):

        # ============================ TRAIN ======================================
        b_model.train()
        report.reset()
        
        pbar = tqdm(train_loader, desc="Bayesian Train")
        for inputs in pbar:
            inputs = dic2cuda(inputs, DEVICE)
            labels = inputs["labels"]
            B      = inputs["input_ids"].size(0)

            optim.zero_grad()
            samples = sample_bayesian(b_model, inputs, SAMPLES, B, N_LABELS, DEVICE)
            raw_logits, logits, log_prior, log_variational_posterior = samples

            nll     = criterion(logits, labels.view(-1))            
            loss    = (log_variational_posterior - log_prior) / len(train_loader) + nll
            acc     = (torch.argmax(logits, dim=1) == labels).float().sum()
            acc_std = np.std([(torch.argmax(logits, dim=1) == labels).float().sum().item() for logits in raw_logits])

            loss.backward()
            nn.utils.clip_grad_norm_(b_model.parameters(), MAX_GRAD_NORM)
            optim.step()

            report.total                     += loss.item()                      / len(train_loader)
            report.nll                       += nll.item()                       / len(train_loader)
            report.log_prior                 += log_prior.item()                 / len(train_loader)
            report.log_variational_posterior += log_variational_posterior.item() / len(train_loader)
            report.acc                       += acc.item() * 100                 / len(train_dataset)
            report.acc_std                   += acc_std                          / len(train_loader)

            pbar.set_postfix(
                total=report.total,
                nll=report.nll,
                log_prior=report.log_prior,
                log_variational_posterior=report.log_variational_posterior,
                acc=report.acc,
                acc_std=acc_std,
            )

        scheduler.step()
        writer.add_scalar("bayesian_train_nll",     report.nll,     epoch)
        writer.add_scalar("bayesian_train_acc",     report.acc,     epoch)
        writer.add_scalar("bayesian_train_acc_std", report.acc_std, epoch)

        # ============================ TEST =======================================
        b_model.eval()
        report.reset()
        
        with torch.no_grad():
            pbar = tqdm(test_loader, desc="Bayesian Test")
            for inputs in pbar:
                inputs = dic2cuda(inputs, DEVICE)
                labels = inputs["labels"]
                B      = inputs["input_ids"].size(0)

                samples = sample_bayesian(b_model, inputs, SAMPLES, B, N_LABELS, DEVICE)
                raw_logits, logits, log_prior, log_variational_posterior = samples

                nll     = criterion(logits, labels.view(-1))
                loss    = (log_variational_posterior - log_prior) / len(test_loader) + nll
                acc     = (torch.argmax(logits, dim=1) == labels).float().sum()
                acc_std = np.std([(torch.argmax(logits, dim=1) == labels).float().sum().item() for logits in raw_logits])

                report.total                     += loss.item()                      / len(test_loader)
                report.nll                       += nll.item()                       / len(test_loader)
                report.log_prior                 += log_prior.item()                 / len(test_loader)
                report.log_variational_posterior += log_variational_posterior.item() / len(test_loader)
                report.acc                       += acc.item() * 100                 / len(test_dataset)
                report.acc_std                   += acc_std                          / len(test_loader)

                pbar.set_postfix(
                    total=report.total,
                    nll=report.nll,
                    log_prior=report.log_prior,
                    log_variational_posterior=report.log_variational_posterior,
                    acc=report.acc,
                    acc_std=report.acc_std,
                )

        writer.add_scalar("bayesian_test_nll",     report.nll,     epoch)
        writer.add_scalar("bayesian_test_acc",     report.acc,     epoch)
        writer.add_scalar("bayesian_test_acc_std", report.acc_std, epoch)

    torch.save({
        "weight_decay": WEIGHT_DECAY,
        "delta"       : DELTA,
        "acc"         : report.acc,
        "acc_std"     : report.acc_std,
        "model"       : b_model.state_dict()
    }, f"{writer_path + writer_suff}.pth")

    return report.acc
示例#18
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    def convert_gate_to_mask(gates, num_of_heads=None):
        if num_of_heads is not None:
            head_mask = torch.zeros_like(gates)
            current_heads_to_keep = gates.view(-1).sort(descending = True)[1]
            current_heads_to_keep = current_heads_to_keep[:num_of_heads]
            head_mask = head_mask.view(-1)
            head_mask[current_heads_to_keep] = 1.0
            head_mask = head_mask.view_as(gates)
        else:
            head_mask = (gates > 0.5).float()
        return head_mask

    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
    )
    if data_args.task_name == "mnli":
        data_args.task_name="mnli-mm"
        eval_dataset = (
            GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
            if training_args.do_eval
            else None
        )
        data_args.task_name = "mnli"
    else:
        eval_dataset = (
            GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
            if training_args.do_eval
            else None
        )


    if data_args.task_name == "mnli":
        metric = "eval_mnli/acc"
    else:
        metric = "eval_acc"

    torch.manual_seed(42)
    model = BertForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
    )

    # Initialize our Trainer
    training_args.max_steps = -1
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn(data_args.task_name),
    )

    start = time.time()
    trainer.evaluate(eval_dataset=eval_dataset)
    end = time.time()
    time_original = end - start
    total_original = 0
    for parameter in model.parameters():
        total_original += parameter.numel()
    total_original
    print("Before pruning: time: {}, num: {}".format(time_original, total_original))

    for num_of_heads in [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
        total_pruned = []
        time_pruned = []
        for i in range(10):
            torch.manual_seed(i)
            config = AutoConfig.from_pretrained(
                model_args.config_name if model_args.config_name else model_args.model_name_or_path,
                num_labels=num_labels,
                finetuning_task=data_args.task_name,
                cache_dir=model_args.cache_dir,
            )
            model = BertForSequenceClassification.from_pretrained(
                model_args.model_name_or_path,
                config=config,
            )

            # Initialize our Trainer
            training_args.max_steps = -1
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                compute_metrics=build_compute_metrics_fn(data_args.task_name),
            )

            head_mask = torch.zeros(12,12)
            indices = indices = torch.randint(144, (num_of_heads,))
            head_mask.view(-1)[indices] = 1

            heads_to_prune = {}
            for i, hm in enumerate(head_mask):
                heads_to_prune[i] = list((hm == 0).nonzero().view(-1).numpy())
            
            model.prune_heads(heads_to_prune)

            start = time.time()
            trainer.evaluate(eval_dataset=eval_dataset)
            end = time.time()
            time_pruned.append(end-start)
            total = 0
            for parameter in model.parameters():
                total += parameter.numel()
            total
            total_pruned.append(total)
        speedup = (time_original - np.mean(time_pruned)) / time_original * 100
        shrinkage = (total_original - np.mean(total_pruned)) / total_original * 100
        print("After pruning (num of heads: {}): speedup: {}, shrinkage: {}".format(num_of_heads, speedup, shrinkage))
示例#19
0
def tune_transformer(num_samples=8, gpus_per_trial=0, smoke_test=False):
    data_dir_name = "./data" if not smoke_test else "./test_data"
    data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))
    if not os.path.exists(data_dir):
        os.mkdir(data_dir, 0o755)

    # Change these as needed.
    model_name = "bert-base-uncased" if not smoke_test \
        else "sshleifer/tiny-distilroberta-base"
    task_name = "rte"

    task_data_dir = os.path.join(data_dir, task_name.upper())

    num_labels = glue_tasks_num_labels[task_name]

    config = AutoConfig.from_pretrained(model_name,
                                        num_labels=num_labels,
                                        finetuning_task=task_name)

    # Download and cache tokenizer, model, and features
    print("Downloading and caching Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Triggers tokenizer download to cache
    print("Downloading and caching pre-trained model")
    AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config,
    )

    def get_model():
        return AutoModelForSequenceClassification.from_pretrained(
            model_name,
            config=config,
        )

    # Download data.
    download_data(task_name, data_dir)

    data_args = GlueDataTrainingArguments(task_name=task_name,
                                          data_dir=task_data_dir)

    train_dataset = GlueDataset(data_args,
                                tokenizer=tokenizer,
                                mode="train",
                                cache_dir=task_data_dir)
    eval_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="dev",
                               cache_dir=task_data_dir)

    training_args = TrainingArguments(
        output_dir=".",
        learning_rate=1e-5,  # config
        do_train=True,
        do_eval=True,
        no_cuda=gpus_per_trial <= 0,
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        num_train_epochs=2,  # config
        max_steps=-1,
        per_device_train_batch_size=16,  # config
        per_device_eval_batch_size=16,  # config
        warmup_steps=0,
        weight_decay=0.1,  # config
        logging_dir="./logs",
        skip_memory_metrics=True,
        report_to="none")

    trainer = Trainer(model_init=get_model,
                      args=training_args,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      compute_metrics=build_compute_metrics_fn(task_name))

    tune_config = {
        "per_device_train_batch_size": 32,
        "per_device_eval_batch_size": 32,
        "num_train_epochs": tune.choice([2, 3, 4, 5]),
        "max_steps": 1 if smoke_test else -1,  # Used for smoke test.
    }

    scheduler = PopulationBasedTraining(time_attr="training_iteration",
                                        metric="eval_acc",
                                        mode="max",
                                        perturbation_interval=1,
                                        hyperparam_mutations={
                                            "weight_decay":
                                            tune.uniform(0.0, 0.3),
                                            "learning_rate":
                                            tune.uniform(1e-5, 5e-5),
                                            "per_device_train_batch_size":
                                            [16, 32, 64],
                                        })

    reporter = CLIReporter(parameter_columns={
        "weight_decay": "w_decay",
        "learning_rate": "lr",
        "per_device_train_batch_size": "train_bs/gpu",
        "num_train_epochs": "num_epochs"
    },
                           metric_columns=[
                               "eval_acc", "eval_loss", "epoch",
                               "training_iteration"
                           ])

    trainer.hyperparameter_search(
        hp_space=lambda _: tune_config,
        backend="ray",
        n_trials=num_samples,
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        scheduler=scheduler,
        keep_checkpoints_num=1,
        checkpoint_score_attr="training_iteration",
        stop={"training_iteration": 1} if smoke_test else None,
        progress_reporter=reporter,
        local_dir="~/ray_results/",
        name="tune_transformer_pbt",
        log_to_file=True)
示例#20
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if training_args.do_train and (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir):
        raise ValueError(f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.")
    if not os.path.exists(training_args.output_dir):
        os.makedirs(training_args.output_dir)

    # Setup logging
    logging.basicConfig(
        format="%(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
        # filename=f'{training_args.output_dir}/log',
        # filemode='w',
    )
    # logger.addHandler(logging.StreamHandler(sys.stdout))
    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(filename=f'{training_args.output_dir}/log', mode='w' if training_args.do_train else 'a'))
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    mnli_config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        num_labels=3,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )

    mnli_model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=mnli_config,
        cache_dir=model_args.cache_dir,
    )

    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )

    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    pretrained_state_dict = {k: v for k, v in mnli_model.state_dict().items() if k != 'classifier.out_proj.bias' and k != 'classifier.out_proj.weight'}
    model.load_state_dict(pretrained_state_dict, strict=False)

    # Get datasets
    train_dataset_class = GlueDataset
    eval_dataset_class = GlueDataset
    if training_args.do_aug and training_args.aug_type:
        if training_args.aug_type in {'back_trans', 'cbert'}:
            if data_args.train_aug_file:
                train_dataset_class = GlueAugDataset
                data_args.aug_type = training_args.aug_type
            if data_args.dev_aug_file:
                eval_dataset_class = GlueAugDataset
                data_args.aug_type = training_args.aug_type
    train_dataset = train_dataset_class(data_args, tokenizer=tokenizer) if training_args.do_train else None
    eval_dataset = eval_dataset_class(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval or \
                                                                                    training_args.do_eval_all else None

    def compute_metrics(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

    if training_args.do_debug:
        eval_dataset = eval_dataset[:100]

    # training_args.do_aug = model_args.do_aug
    # training_args.aug_type = data_args.aug_type
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None,
        )
        # if not training_args.evaluate_during_training:
            # trainer.save_model()
            # For convenience, we also re-save the tokenizer to the same directory,
            # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    elif training_args.do_eval:
        # Evaluation
        results = {}
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
            eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))

        for eval_dataset in eval_datasets:
            eval_output = trainer.evaluate(eval_dataset=eval_dataset)
            results[f'{eval_dataset.args.task_name}_acc'] = eval_output['eval_acc']
            results[f'{eval_dataset.args.task_name}_loss'] = eval_output['eval_loss']

        return results

    elif training_args.do_eval_all:
        results = []
        logger.info('*** Evaluate all checkpoints ***')

        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
            eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))

        all_checkpoints = glob.glob(f'{training_args.output_dir}/checkpoint-*')
        all_checkpoints = sorted(all_checkpoints, key=lambda x: int(x.split('-')[-1]))

        for checkpoint in all_checkpoints:
            step = int(checkpoint.split('-')[-1])
            model.load_pretrained(checkpoint)
            model.to(training_args.device)
            step_result = [step]
            for eval_dataset in eval_datasets:
                trainer.global_step = step
                result = trainer.evaluate(eval_dataset=eval_dataset)
                # result['step'] = step
                step_result += [result['eval_acc'], result['eval_loss']]
            results.append(step_result)

        header = ['step']
        for eval_dataset in eval_datasets:
            header += [f'{eval_dataset.args.task_name}_acc', f'{eval_dataset.args.task_name}_loss']

        logger.info("***** Eval results *****")
        report_results(header, results, axis=1)
    def test_reloading_prediction_head(self):
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        data_args = GlueDataTrainingArguments(
            task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
        )
        train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

        model = AutoModelWithHeads.from_pretrained("bert-base-uncased")

        model.add_classification_head("adapter", num_labels=3)
        model.add_classification_head("dummy", num_labels=2)

        # add the adapters to be fused
        model.add_adapter("adapter")
        model.add_adapter("additional_adapter")

        # setup fusion
        adapter_setup = Fuse("adapter", "additional_adapter")
        model.add_adapter_fusion(adapter_setup)
        model.train_adapter_fusion(adapter_setup)
        model.set_active_adapters(adapter_setup)
        self.assertEqual(adapter_setup, model.active_adapters)
        self.assertEqual("dummy", model.active_head)
        with TemporaryDirectory() as tempdir:
            training_args = TrainingArguments(
                output_dir=tempdir,
                do_train=True,
                learning_rate=0.1,
                logging_steps=1,
                max_steps=1,
                save_steps=1,
                remove_unused_columns=False,
            )
            trainer = AdapterTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
            )

            trainer.train()
            # create second model that should resume the training of the first
            model_resume = AutoModelWithHeads.from_pretrained("bert-base-uncased")

            model_resume.add_classification_head("adapter", num_labels=3)
            model_resume.add_classification_head("dummy", num_labels=2)
            model_resume.add_adapter("adapter")
            model_resume.add_adapter("additional_adapter")
            # setup fusion
            adapter_setup = Fuse("adapter", "additional_adapter")
            model_resume.add_adapter_fusion(adapter_setup)
            model_resume.train_adapter_fusion(adapter_setup)
            model_resume.set_active_adapters(adapter_setup)
            trainer_resume = AdapterTrainer(
                model=model_resume,
                args=TrainingArguments(do_train=True, max_steps=1, output_dir=tempdir),
                train_dataset=train_dataset,
            )
            trainer_resume.train(resume_from_checkpoint=True)

            self.assertEqual("dummy", model.active_head)
            self.assertEqual(model.config.adapters.adapters, model_resume.config.adapters.adapters)

            for ((k1, v1), (k2, v2)) in zip(
                trainer.model.state_dict().items(), trainer_resume.model.state_dict().items()
            ):
                self.assertEqual(k1, k2)
                if "adapter" in k1 or "dummy" in k1:
                    self.assertTrue(torch.equal(v1, v2), k1)
示例#22
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = BertConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = BertTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = BertForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        dropoutnd_minheaddim=model_args.dropoutnd_minheaddim,
        dropoutnd_maxrate=model_args.dropoutnd_maxrate,
    )

    if model_args.vanilla is True:
        print("Using vanilla weights.")
        model.apply(model._init_weights)

    # Get datasets
    train_dataset = GlueDataset(
        data_args, tokenizer=tokenizer) if training_args.do_train else None
    eval_dataset = GlueDataset(
        data_args, tokenizer=tokenizer,
        evaluate=True) if training_args.do_eval else None

    def compute_metrics(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

    # Initialize our Trainer
    trainer_args = SomeArgs()
    for k in training_args.__dict__:
        setattr(trainer_args, k, getattr(training_args, k))
    for k in model_args.__dict__:
        if hasattr(trainer_args, k):
            raise Exception("args already have this arguments")
        setattr(trainer_args, k, getattr(model_args, k))
    for k in data_args.__dict__:
        if hasattr(trainer_args, k):
            raise Exception("args already have this arguments")
        setattr(trainer_args, k, getattr(data_args, k))

    trainer = Trainer(
        model=model,
        wandb_name="glue_nd",
        wandb_args=trainer_args,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # TODO: evaluate multiple times, for different budgets
    mindim = config.hidden_size / config.num_attention_heads * model_args.dropoutnd_minheaddim
    steps = 10
    budgets = np.arange(steps) / (steps - 1) * (config.hidden_size -
                                                mindim) + mindim
    for budget in budgets:
        dropoutnd = budget / config.hidden_size
        model.test_dropout_nd_rate = dropoutnd
        print("dropout nd rate of bert: ", model.bert.get_dropout_nd_rate())
        # Evaluation
        results = {}
        if training_args.do_eval and training_args.local_rank in [-1, 0]:
            logger.info(f"*** Evaluate with dropout-ND {dropoutnd}***")

            # Loop to handle MNLI double evaluation (matched, mis-matched)
            eval_datasets = [eval_dataset]
            if data_args.task_name == "mnli":
                mnli_mm_data_args = dataclasses.replace(data_args,
                                                        task_name="mnli-mm")
                eval_datasets.append(
                    GlueDataset(mnli_mm_data_args,
                                tokenizer=tokenizer,
                                evaluate=True))

            for eval_dataset in eval_datasets:
                result = trainer.evaluate(eval_dataset=eval_dataset)

                output_eval_file = os.path.join(
                    training_args.output_dir,
                    f"eval_results_{eval_dataset.args.task_name}.txt")
                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results {} *****".format(
                        eval_dataset.args.task_name))
                    for key, value in result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))

                results.update(result)

    return results
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    def convert_gate_to_mask(gates, num_of_heads=None):
        if num_of_heads is not None:
            head_mask = torch.zeros_like(gates)
            current_heads_to_keep = gates.view(-1).sort(descending=True)[1]
            current_heads_to_keep = current_heads_to_keep[:num_of_heads]
            head_mask = head_mask.view(-1)
            head_mask[current_heads_to_keep] = 1.0
            head_mask = head_mask.view_as(gates)
        else:
            head_mask = (gates > 0.5).float()
        return head_mask

    def build_compute_metrics_fn(
            task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = (GlueDataset(
        data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir)
                     if training_args.do_train else None)
    # train_dataset = Subset(train_dataset, list(range(int(0.1 * len(train_dataset)))))
    if data_args.task_name == "mnli":
        data_args.task_name = "mnli-mm"
        eval_dataset = (GlueDataset(data_args,
                                    tokenizer=tokenizer,
                                    mode="dev",
                                    cache_dir=model_args.cache_dir)
                        if training_args.do_eval else None)
        data_args.task_name = "mnli"
    else:
        eval_dataset = (GlueDataset(data_args,
                                    tokenizer=tokenizer,
                                    mode="dev",
                                    cache_dir=model_args.cache_dir)
                        if training_args.do_eval else None)

    if data_args.task_name == "mnli":
        metric = "eval_mnli/acc"
    else:
        metric = "eval_acc"

    annealing = True
    reducing_heads = False
    for temperature in [1e-8]:
        for num_of_heads in [8]:
            for cooldown_steps in [25000]:
                for starting_temperature in [1000]:
                    for starting_num_of_heads in [144]:
                        for lr in [0.5]:
                            logger.info(
                                "cooldown_steps: {}, starting_temperature: {}, starting_num_of_heads: {}, learning_rate: {}," \
                                " temperature: {}".format(
                                    cooldown_steps if annealing or reducing_heads else "N.A.",
                                    starting_temperature if annealing else "N.A.",
                                    starting_num_of_heads if reducing_heads else "N.A.",
                                    lr,
                                    temperature,
                            ))
                            torch.manual_seed(42)
                            model = BertForSequenceClassificationConcrete.from_pretrained(
                                model_args.model_name_or_path,
                                config=config,
                            )

                            # for n, p in model.named_parameters():
                            #     if n != "w":
                            #         p.requires_grad = False

                            optimizer_grouped_parameters = [
                                {
                                    "params": [
                                        p for n, p in model.named_parameters()
                                        if n != "w"
                                    ],
                                    "lr":
                                    training_args.learning_rate,
                                },
                                {
                                    "params": [
                                        p for n, p in model.named_parameters()
                                        if n == "w"
                                    ],
                                    "lr":
                                    lr,
                                },
                            ]
                            optimizer = AdamW(
                                optimizer_grouped_parameters,
                                betas=(0.9, 0.999),
                                eps=1e-8,
                            )

                            # Initialize our Trainer
                            training_args.max_steps = -1
                            trainer = DropoutTrainer(
                                model=model,
                                args=training_args,
                                train_dataset=train_dataset,
                                eval_dataset=eval_dataset,
                                compute_metrics=build_compute_metrics_fn(
                                    data_args.task_name),
                                num_of_heads=num_of_heads,
                                reducing_heads=reducing_heads,
                                temperature=temperature,
                                cooldown_steps=cooldown_steps,
                                annealing=annealing,
                                starting_temperature=starting_temperature,
                                starting_num_of_heads=starting_num_of_heads,
                                optimizers=(optimizer, None),
                                intermediate_masks=True,
                                # ste=True,
                            )

                            # Training
                            trainer.train()
                            trainer.save_model()
                            # score = trainer.evaluate(eval_dataset=eval_dataset)[metric]
                            # print_2d_tensor(model.get_w())
                            # logger.info("temperature: {}, num of heads: {}, accuracy: {}".format(temperature, num_of_heads, score * 100))

                            model._apply_dropout = False
                            head_mask = convert_gate_to_mask(
                                model.get_w(), num_of_heads)
                            # torch.save(head_mask, os.path.join(training_args.output_dir, "mask" + str(num_of_heads) + ".pt"))
                            # print_2d_tensor(head_mask)
                            model.apply_masks(head_mask)
                            score = trainer.evaluate(
                                eval_dataset=eval_dataset)[metric]
                            sparsity = 100 - head_mask.sum() / head_mask.numel(
                            ) * 100
                            logger.info(
                                "Masking: current score: %f, remaining heads %d (%.1f percents)",
                                score,
                                head_mask.sum(),
                                100 - sparsity,
                            )
示例#24
0
def main():
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help=
        "Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " +
        ", ".join(glue_processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written.",
    )

    # Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help=
        "Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help=
        "Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--data_subset",
        type=int,
        default=-1,
        help="If > 0: limit the data to a subset of data_subset instances.")
    parser.add_argument("--overwrite_output_dir",
                        action="store_true",
                        help="Whether to overwrite data in output directory")
    parser.add_argument(
        "--overwrite_cache",
        action="store_true",
        help="Overwrite the cached training and evaluation sets")

    parser.add_argument("--exact_pruning",
                        action="store_true",
                        help="Compute head importance for each step")
    parser.add_argument("--dont_normalize_importance_by_layer",
                        action="store_true",
                        help="Don't normalize importance score by layers")
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )
    parser.add_argument("--dont_use_abs",
                        action="store_true",
                        help="Don't apply abs on first order derivative")
    parser.add_argument("--use_squared",
                        action="store_true",
                        help="Use squared derivative as quality")
    parser.add_argument("--use_second",
                        action="store_true",
                        help="Use second order derivative as quality")
    parser.add_argument(
        "--use_contexts",
        action="store_true",
        help="Use context vectors instead of attentions weights")

    parser.add_argument(
        "--masking_amount",
        default=0.1,
        type=float,
        help="Amount to heads to masking at each masking step.")
    parser.add_argument("--metric_name",
                        default="acc",
                        type=str,
                        help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, sequences shorter padded.",
    )
    parser.add_argument("--batch_size",
                        default=32,
                        type=int,
                        help="Batch size.")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip",
                        type=str,
                        default="",
                        help="Can be used for distant debugging.")
    parser.add_argument("--server_port",
                        type=str,
                        default="",
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup devices and distributed training
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available()
                                   and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
        torch.distributed.init_process_group(
            backend="nccl")  # Initializes the distributed backend

    # Setup logging
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("device: {} n_gpu: {}, distributed: {}".format(
        args.device, args.n_gpu, bool(args.local_rank != -1)))

    # Set seeds
    set_seed(args.seed)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    if args.task_name not in glue_processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = glue_processors[args.task_name]()
    args.output_mode = glue_output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
        cache_dir=args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        cache_dir=args.cache_dir,
    )
    model = BertForSequenceClassificationConcrete.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir,
    )
    model.to(args.device)
    model.eval()
    # Print/save training arguments
    os.makedirs(args.output_dir, exist_ok=True)
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
    logger.info("Training/evaluation parameters %s", args)

    # Prepare dataset for the GLUE task
    train_dataset = GlueDataset(args, tokenizer=tokenizer)
    train_sampler = SequentialSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=default_data_collator)

    if args.task_name == "mnli":
        args.task_name = "mnli-mm"
        eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
        args.task_name = "mnli"
        val_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
        args.metric_name = "mnli/acc"
    else:
        eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
    if args.data_subset > 0:
        eval_dataset = Subset(
            eval_dataset, list(range(min(args.data_subset,
                                         len(eval_dataset)))))
    eval_sampler = SequentialSampler(
        eval_dataset) if args.local_rank == -1 else DistributedSampler(
            eval_dataset)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size,
                                 collate_fn=default_data_collator)
    val_sampler = SequentialSampler(
        val_dataset) if args.local_rank == -1 else DistributedSampler(
            val_dataset)
    val_dataloader = DataLoader(val_dataset,
                                sampler=val_sampler,
                                batch_size=args.batch_size,
                                collate_fn=default_data_collator)

    # p_value = test(args, model, train_dataloader, eval_dataset)
    # logger.info("p_value is: %f", p_value)

    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
    # head_importance = compute_heads_importance(args, model, train_dataloader)
    # head_importance = torch.Tensor(np.load(os.path.join(args.output_dir, "head_importance.npy"))).to(args.device)
    # args.exact_pruning = True
    # args.dont_normalize_importance_by_layer = True
    # args.use_second = True
    # scores, sparsities, all_head_masks = mask_heads(
    #     args, model, train_dataloader, eval_dataloader
    # )
    # logger.info("Area under curve: %.2f", auc(sparsities, scores))

    # scores, sparsities, all_head_masks = unmask_heads(
    #     args, model, train_dataloader, eval_dataloader
    # )
    # logger.info("Area under curve: %.2f", auc(sparsities, scores))

    for k in [2]:
        score, sparisity, head_mask = gibbs_sampling(
            args,
            model,
            train_dataloader,
            eval_dataloader,
            val_dataloader=val_dataloader,
            early_stop_step=12,
            K=k,
            n_groups=1)
示例#25
0
    def test_resume_training_with_fusion(self):
        def encode_batch(batch):
            """Encodes a batch of input data using the model tokenizer."""
            return tokenizer(batch["sentence1"],
                             batch["sentence2"],
                             max_length=80,
                             truncation=True,
                             padding="max_length")

        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        data_args = GlueDataTrainingArguments(
            task_name="mrpc",
            data_dir="./tests/fixtures/tests_samples/MRPC",
            overwrite_cache=True)
        train_dataset = GlueDataset(data_args,
                                    tokenizer=tokenizer,
                                    mode="train")

        model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased")
        model.add_adapter("adapter")
        model.add_adapter("additional_adapter")
        model.add_fusion(Fuse("adapter", "additional_adapter"))
        model.set_active_adapters(Fuse("adapter", "additional_adapter"))

        training_args = TrainingArguments(
            output_dir="./examples",
            do_train=True,
            learning_rate=0.1,
            logging_steps=1,
            max_steps=1,
            save_steps=1,
            remove_unused_columns=False,
        )
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            do_save_adapters=True,
            do_save_full_model=False,
            do_save_adapter_fusion=True,
        )

        trainer.train()
        model_resume = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased")
        model_resume.add_adapter("adapter")
        model_resume.add_adapter("additional_adapter")
        model_resume.add_fusion(Fuse("adapter", "additional_adapter"))
        model_resume.set_active_adapters(Fuse("adapter", "additional_adapter"))
        trainer_resume = Trainer(
            model=model_resume,
            args=TrainingArguments(do_train=True,
                                   max_steps=1,
                                   output_dir="./examples"),
            train_dataset=train_dataset,
        )
        trainer_resume.train(resume_from_checkpoint=True)

        self.assertEqual(model.config.adapters.adapters,
                         model_resume.config.adapters.adapters)

        for ((k1, v1), (k2,
                        v2)) in zip(trainer.model.state_dict().items(),
                                    trainer_resume.model.state_dict().items()):
            self.assertEqual(k1, k2)
            if "adapter" in k1:
                self.assertTrue(torch.equal(v1, v2), k1)
示例#26
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = (GlueDataset(
        data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir)
                     if training_args.do_train else None)
    eval_dataset = (GlueDataset(data_args,
                                tokenizer=tokenizer,
                                mode="dev",
                                cache_dir=model_args.cache_dir)
                    if training_args.do_eval else None)
    test_dataset = (GlueDataset(data_args,
                                tokenizer=tokenizer,
                                mode="test",
                                cache_dir=model_args.cache_dir)
                    if training_args.do_predict else None)

    def build_compute_metrics_fn(
            task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn(data_args.task_name),
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    eval_results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args,
                                                    task_name="mnli-mm")
            eval_datasets.append(
                GlueDataset(mnli_mm_data_args,
                            tokenizer=tokenizer,
                            mode="dev",
                            cache_dir=model_args.cache_dir))

        for eval_dataset in eval_datasets:
            trainer.compute_metrics = build_compute_metrics_fn(
                eval_dataset.args.task_name)
            eval_result = trainer.evaluate(eval_dataset=eval_dataset)

            output_eval_file = os.path.join(
                training_args.output_dir,
                f"eval_results_{eval_dataset.args.task_name}.txt")
            if trainer.is_world_master():
                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results {} *****".format(
                        eval_dataset.args.task_name))
                    for key, value in eval_result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))

            eval_results.update(eval_result)

    if training_args.do_predict:
        logging.info("*** Test ***")
        test_datasets = [test_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args,
                                                    task_name="mnli-mm")
            test_datasets.append(
                GlueDataset(mnli_mm_data_args,
                            tokenizer=tokenizer,
                            mode="test",
                            cache_dir=model_args.cache_dir))

        for test_dataset in test_datasets:
            predictions = trainer.predict(
                test_dataset=test_dataset).predictions
            if output_mode == "classification":
                predictions = np.argmax(predictions, axis=1)

            output_test_file = os.path.join(
                training_args.output_dir,
                f"test_results_{test_dataset.args.task_name}.txt")
            if trainer.is_world_master():
                with open(output_test_file, "w") as writer:
                    logger.info("***** Test results {} *****".format(
                        test_dataset.args.task_name))
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
                        if output_mode == "regression":
                            writer.write("%d\t%3.3f\n" % (index, item))
                        else:
                            item = test_dataset.get_labels()[item]
                            writer.write("%d\t%s\n" % (index, item))
    return eval_results
示例#27
0
device_ids = list(range(torch.cuda.device_count()))
print(f"GPU list: {device_ids}")

print(json.dumps([model_config, pretraining_config], indent=4))

########################### Loading Datasets ###########################

tokenizer = utils.get_tokenizer(model_config["max_seq_len"])
model_config["vocab_size"] = len(tokenizer.get_vocab())

data_args = GlueDataTrainingArguments(
    task_name=args.task,
    data_dir=os.path.join(glue_dataset_folder, args.task),
    max_seq_length=model_config["max_seq_len"],
    overwrite_cache=True)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer)
data_loader = DataLoader(train_dataset,
                         batch_size=args.batch_size,
                         shuffle=True,
                         collate_fn=default_data_collator)
num_steps_per_epoch = len(data_loader)
print(f"num_steps_per_epoch: {num_steps_per_epoch}", flush=True)

dev_datasets = {"dev": GlueDataset(data_args, tokenizer=tokenizer, mode="dev")}
if args.task.lower() == "mnli":
    data_args = GlueDataTrainingArguments(
        task_name="mnli-mm",
        data_dir=os.path.join(glue_dataset_folder, args.task),
        max_seq_length=model_config["max_seq_len"],
        overwrite_cache=True)
    dev_datasets["dev-mm"] = GlueDataset(data_args,
示例#28
0
def main():
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    # Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
    )
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
    )
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )

    parser.add_argument(
        "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
    )
    parser.add_argument(
        "--masking_threshold",
        default=0.9,
        type=float,
        help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
    )
    parser.add_argument(
        "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
    )
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, sequences shorter padded.",
    )
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup devices and distributed training
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")  # Initializes the distributed backend

    # Setup logging
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))

    # Set seeds
    set_seed(args.seed)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    if args.task_name not in glue_processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = glue_processors[args.task_name]()
    args.output_mode = glue_output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
        cache_dir=args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir,
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir,
    )

    # Distributed and parallel training
    model.to(args.device)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Print/save training arguments
    os.makedirs(args.output_dir, exist_ok=True)
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
    logger.info("Training/evaluation parameters %s", args)

    # Prepare dataset for the GLUE task
    eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
    if args.data_subset > 0:
        eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch
    )

    # Compute head entropy and importance score
    compute_heads_importance(args, model, eval_dataloader)

    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
        head_mask = mask_heads(args, model, eval_dataloader)
        prune_heads(args, model, eval_dataloader, head_mask)
    def run_glue(self, model_name, task_name, fp16):
        model_args = ModelArguments(model_name_or_path=model_name,
                                    cache_dir=self.cache_dir)
        data_args = GlueDataTrainingArguments(
            task_name=task_name,
            data_dir=self.data_dir + "/" + task_name,
            max_seq_length=self.max_seq_length)

        training_args = TrainingArguments(
            output_dir=self.output_dir + "/" + task_name,
            do_train=True,
            do_eval=True,
            per_gpu_train_batch_size=self.train_batch_size,
            learning_rate=self.learning_rate,
            num_train_epochs=self.num_train_epochs,
            local_rank=self.local_rank,
            overwrite_output_dir=self.overwrite_output_dir,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            fp16=fp16,
            logging_steps=self.logging_steps)

        # Setup logging
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=logging.INFO
            if training_args.local_rank in [-1, 0] else logging.WARN,
        )
        logger.warning(
            "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
            training_args.local_rank,
            training_args.device,
            training_args.n_gpu,
            bool(training_args.local_rank != -1),
            training_args.fp16,
        )
        logger.info("Training/evaluation parameters %s", training_args)

        set_seed(training_args.seed)
        onnxruntime.set_seed(training_args.seed)

        try:
            num_labels = glue_tasks_num_labels[data_args.task_name]
            output_mode = glue_output_modes[data_args.task_name]
        except KeyError:
            raise ValueError("Task not found: %s" % (data_args.task_name))

        config = AutoConfig.from_pretrained(
            model_args.config_name
            if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name
            if model_args.tokenizer_name else model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )

        model = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
        )

        train_dataset = (GlueDataset(data_args, tokenizer=tokenizer)
                         if training_args.do_train else None)

        eval_dataset = (GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
                        if training_args.do_eval else None)

        def compute_metrics(p: EvalPrediction) -> Dict:
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(data_args.task_name, preds,
                                        p.label_ids)

        model_desc = self.model_to_desc(model_name, model)
        # Initialize the ORTTrainer within ORTTransformerTrainer
        trainer = ORTTransformerTrainer(
            model=model,
            model_desc=model_desc,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
        )

        # Training
        if training_args.do_train:
            trainer.train()
            trainer.save_model()

        # Evaluation
        results = {}
        if training_args.do_eval and training_args.local_rank in [-1, 0]:
            logger.info("*** Evaluate ***")

            result = trainer.evaluate()

            logger.info("***** Eval results {} *****".format(
                data_args.task_name))
            for key, value in result.items():
                logger.info("  %s = %s", key, value)

            results.update(result)

        return results
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Comment the wandb lines if you don't have wandb login.
    wandb.init(project="mixing-heads-finetuning")
    wandb.config.update(model_args)
    wandb.config.update(data_args)
    wandb.config.update(training_args)

    # HERE MODIFY THE CONFIG PATHS ...
    def extract_last(path):
        return [s for s in path.split("/") if s][-1]

    restricted_prefix = ""
    if model_args.restricted_attention:
        if model_args.context_attention_only == 1:
            restricted_prefix = "context_only-"
        elif model_args.context_attention_only == 0:
            restricted_prefix = "content_only-"
        else:
            raise ValueError("Should set context_attention_only to 0 or 1")

    output_model_name = (
        (model_args.model_output_prefix or "") +
        ("finetuned-" if training_args.do_train else "") +
        ("mix{}-".format(model_args.mix_size) if model_args.mix_size else "") +
        restricted_prefix + extract_last(model_args.model_name_or_path))

    training_args.output_dir = os.path.join(
        training_args.output_dir,
        output_model_name,
        data_args.task_name,
        str(model_args.repeat_id),
    )

    data_args.data_dir = os.path.join(data_args.data_dir, data_args.task_name)

    if training_args.num_train_epochs == 3.0 and data_args.task_name.lower(
    ) in [
            "sst-2",
            "rte",
    ]:
        training_args.num_train_epochs = 10.0
        print("OVERIDE NUMBER OF EPOCH FOR TASK {} TO {}".format(
            data_args.task_name, training_args.num_train_epochs))

    if os.path.exists(model_args.model_name_or_path):
        model_args.model_name_or_path = os.path.join(
            model_args.model_name_or_path,
            data_args.task_name,
            str(model_args.repeat_id),
        )

    # DONE MODIFYING THE CONFIG

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(model_args.repeat_id if model_args.
             repeat_id is not None else training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = GlueDataset(
        data_args, tokenizer=tokenizer) if training_args.do_train else None
    eval_dataset = (GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
                    if training_args.do_eval else None)
    test_dataset = (GlueDataset(data_args, tokenizer=tokenizer, mode="test")
                    if training_args.do_predict else None)

    def compute_metrics(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

    if model_args.restricted_attention and model_args.mix_heads:
        raise ValueError(
            "Context/content attention and mix heads not implemented together correctly"
        )

    if model_args.restricted_attention:
        if hasattr(model, "bert"):
            layers = model.bert.encoder.layer
        elif hasattr(model, "electra"):
            layers = model.electra.encoder.layer
        else:
            raise Exception(
                'Does not support transforming model "{}" to mixed self-attention.'
                .format(type(model)))

        print("Make {}-only self-attention layers...".format(
            "context" if model_args.context_attention_only ==
            1 else "content"))
        for i in tqdm.trange(len(layers)):
            # set b_K = 0
            layers[i].attention.self.key.bias.requires_grad = False
            layers[i].attention.self.key.bias.zero_()

            if model_args.context_attention_only == 1:
                # set b_Q = 0
                layers[i].attention.self.query.bias.requires_grad = False
                layers[i].attention.self.query.bias.zero_()
            else:  # content attention only
                # set W_Q = 0
                layers[i].attention.self.query.weight.requires_grad = False
                layers[i].attention.self.query.weight.zero_()

    if torch.cuda.is_available():
        model = model.to("cuda:0")

    if model_args.mix_heads:
        start = time.time()

        adapter = BERTCollaborativeAdapter
        if "albert" in model_args.model_name_or_path.lower():
            adapter = ALBERTCollaborativeAdapter
        if "distilbert" in model_args.model_name_or_path.lower():
            adapter = DistilBERTCollaborativeAdapter

        swap_to_collaborative(
            model,
            adapter,
            dim_shared_query_key=model_args.mix_size,
            tol=model_args.mix_decomposition_tol,
        )

        elapsed = time.time() - start
        wandb.run.summary["decomposition_time"] = elapsed

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    eval_results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args,
                                                    task_name="mnli-mm")
            eval_datasets.append(
                GlueDataset(mnli_mm_data_args, tokenizer=tokenizer,
                            mode="dev"))

        for eval_dataset in eval_datasets:
            eval_result = trainer.evaluate(eval_dataset=eval_dataset)

            output_eval_file = os.path.join(
                training_args.output_dir,
                f"eval_results_{eval_dataset.args.task_name}.txt")
            if trainer.is_world_master():
                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results {} *****".format(
                        eval_dataset.args.task_name))
                    for key, value in eval_result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))

            eval_results.update(eval_result)

    if training_args.do_predict:
        logging.info("*** Test ***")
        test_datasets = [test_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args,
                                                    task_name="mnli-mm")
            test_datasets.append(
                GlueDataset(mnli_mm_data_args,
                            tokenizer=tokenizer,
                            mode="test"))

        for test_dataset in test_datasets:
            predictions = trainer.predict(
                test_dataset=test_dataset).predictions
            if output_mode == "classification":
                predictions = np.argmax(predictions, axis=1)

            output_test_file = os.path.join(
                training_args.output_dir,
                f"test_results_{test_dataset.args.task_name}.txt")
            if trainer.is_world_master():
                with open(output_test_file, "w") as writer:
                    logger.info("***** Test results {} *****".format(
                        test_dataset.args.task_name))
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
                        if output_mode == "regression":
                            writer.write("%d\t%3.3f\n" % (index, item))
                        else:
                            item = test_dataset.get_labels()[item]
                            writer.write("%d\t%s\n" % (index, item))
                            wandb.run.summary[key] = value
    return eval_results