def create_and_check_albert_for_pretraining(self, config, input_ids,
                                             token_type_ids, input_mask,
                                             sequence_labels,
                                             token_labels,
                                             choice_labels):
     model = AlbertForPreTraining(config=config)
     model.to(torch_device)
     model.eval()
     loss, prediction_scores, sop_scores = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         labels=token_labels,
         sentence_order_label=sequence_labels,
     )
     result = {
         "loss": loss,
         "prediction_scores": prediction_scores,
         "sop_scores": sop_scores,
     }
     self.parent.assertListEqual(
         list(result["prediction_scores"].size()),
         [self.batch_size, self.seq_length, self.vocab_size])
     self.parent.assertListEqual(list(result["sop_scores"].size()),
                                 [self.batch_size, config.num_labels])
     self.check_loss_output(result)
 def create_and_check_for_pretraining(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = AlbertForPreTraining(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         labels=token_labels,
         sentence_order_label=sequence_labels,
     )
     self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
     self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, config.num_labels))
Example #3
0
def main():
    # my dice shows 777 only. period.
    random.seed(EXPCONF.seed)
    np.random.seed(EXPCONF.seed)
    torch.manual_seed(EXPCONF.seed)
    torch.cuda.manual_seed_all(EXPCONF.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tempconf = EXPCONF.copy()
    tempconf.datamode = 'test'

    testloader, ___, _____ = get_loader(tempconf)
    trainloader, __, _trainds = get_loader(EXPCONF, getdev=False)
    devloader, _, _devds = get_loader(EXPCONF, getdev=True)

    assert len(trainloader) > 0, f"trainloader is empty!"
    assert len(devloader) > 0, f"devloader is empty!"

    # this is disgraceful.... but just specify things below
    model_weight, vocab, trained_condition = loadmodel_info(EXPCONF)

    albertconf = retrieve_conf(trained_condition, vocab)
    albert = AlbertForPreTraining(albertconf)
    albert.load_state_dict(model_weight)
    albert = albert.to(device)

    global_step = 0
    L = len(trainloader)
    bsz = len(trainloader[0])

    if not EXPCONF.infer_now:
        albert = albert.albert
        albert.eval()  # freeze

        cls = MLP(EXPCONF, albertconf.hidden_size, 2).to(device)
        cls.train()
        for p in cls.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        # huggingface example is doing this for language modeling...
        # https://github.com/huggingface/transformers/blob/v2.6.0/examples/run_language_modeling.py
        optimizer = AdamW(cls.parameters(),
                          lr=EXPCONF.cls_lr)  # otherwise, use default
        getsch = get_cosine_schedule_with_warmup if EXPCONF.cls_sch == 'cosine' else get_linear_schedule_with_warmup
        scheduler = getsch(optimizer, EXPCONF.cls_warmups,
                           EXPCONF.cls_numsteps)

        ## train cls only!
        while global_step < EXPCONF.cls_numsteps:
            lossep_pp = 0
            accep_pp = 0
            cls.train()
            for i, (b, l, datasetids) in enumerate(
                    tqdm(trainloader, desc="iterations progress"), 1):
                outputs = albert(**b, return_dict=True)
                global_step += 1

                logits = cls(outputs.pooler_output)
                losspp = F.cross_entropy(logits, l)

                lossppval = losspp.item()
                acc = accuracy(logits.clone().detach(), l)

                wandb.log({
                    'step':
                    global_step,
                    'cls.train_step/learning_rate':
                    get_lr_from_optim(optimizer),
                    'cls.train_step/pp_loss':
                    lossppval,
                    'cls.train_step/pp_acc':
                    acc,
                })

                optimizer.step()
                scheduler.step()
                cls.zero_grad()

                lossep_pp += lossppval
                accep_pp += acc
                if global_step % EXPCONF.logevery == 0:
                    lossep_pp /= L
                    accep_pp /= L

                    wandb.log({
                        'cls.train_ep/pp_loss': lossep_pp,
                        'cls.train_ep/pp_acc': accep_pp,
                    })
                    devpp_loss, devpp_acc = evaldev(EXPCONF, albert, cls,
                                                    devloader, global_step)
                    if devpp_acc > EXPCONF.savethld:
                        savemodel(EXPCONF,
                                  albert,
                                  cls,
                                  vocab,
                                  global_step,
                                  acc=devpp_acc)
                        write_sub(EXPCONF,
                                  albert,
                                  cls,
                                  global_step,
                                  acc=devpp_acc,
                                  testloader=testloader)

    else:  # infer now
        cls = None
        devpp_loss, devpp_acc = evaldev(EXPCONF,
                                        albert,
                                        cls,
                                        devloader,
                                        global_step,
                                        infernow=EXPCONF.infer_now)
        write_sub(EXPCONF,
                  albert,
                  cls,
                  global_step,
                  acc=devpp_acc,
                  testloader=testloader,
                  infernow=EXPCONF.infer_now)

    return None