class RobertaQA(BertPreTrainedModel):
    def __init__(self, config,model_name_or_path=None,pretrained_weights=None):
        super(RobertaQA, self).__init__(config)
        if model_name_or_path:
            self.roberta = RobertaModel.from_pretrained(model_name_or_path, config=config)
        else:
            self.roberta = RobertaModel(config=config)
        if pretrained_weights:
            self.roberta.load_state_dict(pretrained_weights)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)
        torch.nn.init.normal_(self.qa_outputs.weight, mean=0.0,std=self.config.initializer_range)
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        start_positions=None,
        end_positions=None,
    ):

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,

        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (start_logits, end_logits,)
        if start_positions is not None and end_positions is not None:
  
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

        return outputs  # (loss), start_logits, end_logits
示例#2
0
def main():
    if not os.path.exists("./checkpoints"):
        os.mkdir("checkpoints")

    parser = argparse.ArgumentParser()
    def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Unsupported value encountered.')
    parser.add_argument(
        "--eval_mode",
        default=False,
        type=str2bool,
        required=False,
        help="Test or train the model",
    )
    parser.add_argument(
        "--baseline",
        default=False,
        type=str2bool,
        required=False,
        help="use the baseline or the transformers model",
    )
    parser.add_argument(
        "--load_weights",
        default=True,
        type=str2bool,
        required=False,
        help="Load the pretrained weights or randomly initialize the model",
    )
    parser.add_argument(
        "--iter_per",
        default=4,
        type=int,
        required=False,
        help="cumulative gradient iteration cycle",
    )

    args = parser.parse_args()
    directory_identifier = args.__str__().replace(" ", "") \
        .replace("iter_per=1,", "") \
        .replace("iter_per=2,", "") \
        .replace("iter_per=4,", "") \
        .replace("iter_per=8,", "") \
        .replace("iter_per=16,", "") \
        .replace("iter_per=32,", "") \
        .replace("iter_per=64,", "") \
        .replace("iter_per=128,", "") \
        .replace("eval_mode=True", "eval_mode=False")


    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    suffix = "roberta"
    if args.baseline:
        suffix = "naive"
        tokenizer = NaiveTokenizer()
    try:
        dataset, test_dataset, tokenizer = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb"))
        dev_dataset, _, _ = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb"))
    except:
        dataset = DisasterTweetsClassificationDataset(tokenizer, "data/train.csv", "train")
        test_dataset = DisasterTweetsClassificationDataset(tokenizer, "data/test.csv", "test")
        torch.save((dataset, test_dataset, tokenizer), open("checkpoints/dataset-%s.pyc" % suffix, "wb"))
        dev_dataset, _, _  = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb"))
    dev_dataset.eval()

    if args.baseline:
        encoder = nn.LSTM(256, 256, 1, batch_first=True, )
        model = NaiveLSTMBaselineClassifier()
    else:
        if args.load_weights:
            encoder = RobertaModel.from_pretrained("roberta-base")
            model = RoBERTaClassifierHead(encoder.config)
        else:
            config = RobertaConfig.from_pretrained("roberta-base")
            encoder = RobertaModel(config=config)
            model = RoBERTaClassifierHead(config)
    encoder.cuda()
    model.cuda()
    dataloader = datautils.DataLoader(
        dataset, batch_size=64 // args.iter_per, shuffle=True,
        num_workers=16, drop_last=False, pin_memory=True
    )
    dev_dataloader = datautils.DataLoader(
        dev_dataset, batch_size=64 // args.iter_per, shuffle=True,
        num_workers=16, drop_last=False, pin_memory=True
    )
    test_dataloader = datautils.DataLoader(
        test_dataset, batch_size=64 // args.iter_per, shuffle=False,
        num_workers=16, drop_last=False, pin_memory=True
    )

    if args.eval_mode:
        correct = 0
        all = 0
        encoder_, model_ = torch.load("checkpoints/%s" % directory_identifier)
        encoder.load_state_dict(encoder_)
        model.load_state_dict(model_)
        encoder.eval()
        with torch.no_grad():
            for ids, mask, label in dev_dataloader:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                prediction = model(encoder, ids, mask).argmax(dim=-1)
                correct += (prediction == label).to(torch.long).sum().item()
                all += mask.shape[0]
        print("dev acc:", correct / all)
        opt = AdamW(lr=1e-6, weight_decay=0.05, params=list(encoder.parameters()) + list(model.parameters()))
        encoder.train()
        iter_num = 0
        LOSS = []
        for _ in range(5):
            iterator = tqdm.tqdm(dev_dataloader)
            for ids, mask, label in iterator:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                log_prediction = model(encoder, ids, mask)
                loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean()
                if iter_num % args.iter_per == 0:
                    opt.zero_grad()
                (loss / args.iter_per).backward()
                LOSS.append(loss.item())
                if len(LOSS) > 10:
                    iterator.write("loss=%f" % np.mean(LOSS))
                    LOSS = []
                if iter_num % args.iter_per == args.iter_per - 1:
                    opt.step()
                iter_num += 1
        encoder.eval()
        with torch.no_grad():
            for ids, mask, label in dev_dataloader:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                prediction = model(encoder, ids, mask).argmax(dim=-1)
                correct += (prediction == label).to(torch.long).sum().item()
                all += mask.shape[0]
        print("dev acc rectified:", correct / all)

        with torch.no_grad():
            with open("submission.csv", "w") as fout:
                print("id,target", file=fout)
                for id, ids, mask in test_dataloader:
                    ids, mask = ids.cuda(), mask.cuda()
                    prediction = model(encoder, ids, mask).argmax(dim=-1)
                    for i in range(id.size(0)):
                        print("%d,%d" % (id[i], prediction[i]), file=fout)
        exit()
    if args.baseline:
        lr = 5e-4
    elif args.load_weights:
        lr = 1e-6
    else:
        lr = 5e-6
    opt = AdamW(lr=lr, weight_decay=0.10 if args.baseline else 0.05, params=list(encoder.parameters())+list(model.parameters()))
    flog = open("checkpoints/log-%s.txt" % directory_identifier, "w")
    flog.close()
    flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "w")
    flogeval.close()
    iter_num = 0
    for epoch_idx in range(5 if args.baseline else 10):
        flog = open("checkpoints/log-%s.txt" % directory_identifier, "a")
        flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "a")
        LOSS = []
        encoder.train()
        iterator = tqdm.tqdm(dataloader)
        for ids, mask, label in iterator:
            ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
            log_prediction = model(encoder, ids, mask)
            loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean()
            if iter_num % args.iter_per == 0:
                opt.zero_grad()
            (loss / args.iter_per).backward()
            LOSS.append(loss.item())
            if len(LOSS) > 10:
                iterator.write("loss=%f" % np.mean(LOSS))
                print("%f" % np.mean(LOSS), file=flog)
                LOSS = []
            if iter_num % args.iter_per == args.iter_per - 1:
                opt.step()
            iter_num += 1
        EVALLOSS = []
        encoder.eval()
        iterator = tqdm.tqdm(dev_dataloader)
        with torch.no_grad():
            for ids, mask, label in iterator:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                log_prediction = model(encoder, ids, mask)
                loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean()
                EVALLOSS.append(loss.item())
        iterator.write("evalloss-%d=%f" % (epoch_idx, np.mean(EVALLOSS)))
        print("%f" % np.mean(EVALLOSS), file=flogeval)

        flog.close()
        flogeval.close()
        torch.save((encoder.state_dict(), model.state_dict()), "checkpoints/%s" % directory_identifier)