Ejemplo n.º 1
0
class roBerta(nn.Module):
    def __init__(self, config, num=0):
        super(roBerta, self).__init__()
        model_config = RobertaConfig()
        model_config.vocab_size = config.vocab_size
        model_config.hidden_size = config.hidden_size[0]
        model_config.num_attention_heads = 16
        # 计算loss的方法
        self.loss_method = config.loss_method
        self.multi_drop = config.multi_drop

        self.roberta = RobertaModel(model_config)
        if config.requires_grad:
            for param in self.roberta.parameters():
                param.requires_grad = True

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.hidden_size = config.hidden_size[num]
        if self.loss_method in ['binary', 'focal_loss', 'ghmc']:
            self.classifier = nn.Linear(self.hidden_size, 1)
        else:
            self.classifier = nn.Linear(self.hidden_size, self.num_labels)
        self.text_linear = nn.Linear(config.embeding_size,
                                     config.hidden_size[0])
        self.vocab_layer = nn.Linear(config.hidden_size[0], config.vocab_size)

        self.classifier.apply(self._init_weights)
        self.roberta.apply(self._init_weights)
        self.text_linear.apply(self._init_weights)
        self.vocab_layer.apply(self._init_weights)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)

    def forward(self,
                inputs=None,
                attention_mask=None,
                output_id=None,
                labels=None):
        inputs = torch.relu(self.text_linear(inputs))
        bert_outputs = self.roberta(inputs_embeds=inputs,
                                    attention_mask=attention_mask)

        #calculate mlm loss
        last_hidden_state = bert_outputs[0]
        output_id_tmp = output_id[output_id.ne(-100)]
        output_id_emb = last_hidden_state[output_id.ne(-100)]
        pre_score = self.vocab_layer(output_id_emb)
        loss_cro = CrossEntropyLoss()
        mlm_loss = loss_cro(torch.sigmoid(pre_score), output_id_tmp)

        labels_bool = labels.ne(-1)
        if labels_bool.sum().item() == 0:
            return mlm_loss, torch.tensor([])

        #calculate label loss
        pooled_output = bert_outputs[1]
        out = self.classifier(pooled_output)
        out = out[labels_bool]
        labels_tmp = labels[labels_bool]
        label_loss = compute_loss(out, labels_tmp)
        out = torch.sigmoid(out).flatten()
        return mlm_loss + label_loss, out

        return out, loss
Ejemplo n.º 2
0
def main():
    if not os.path.exists("./checkpoints"):
        os.mkdir("checkpoints")

    parser = argparse.ArgumentParser()
    def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Unsupported value encountered.')
    parser.add_argument(
        "--eval_mode",
        default=False,
        type=str2bool,
        required=False,
        help="Test or train the model",
    )
    parser.add_argument(
        "--baseline",
        default=False,
        type=str2bool,
        required=False,
        help="use the baseline or the transformers model",
    )
    parser.add_argument(
        "--load_weights",
        default=True,
        type=str2bool,
        required=False,
        help="Load the pretrained weights or randomly initialize the model",
    )
    parser.add_argument(
        "--iter_per",
        default=4,
        type=int,
        required=False,
        help="cumulative gradient iteration cycle",
    )

    args = parser.parse_args()
    directory_identifier = args.__str__().replace(" ", "") \
        .replace("iter_per=1,", "") \
        .replace("iter_per=2,", "") \
        .replace("iter_per=4,", "") \
        .replace("iter_per=8,", "") \
        .replace("iter_per=16,", "") \
        .replace("iter_per=32,", "") \
        .replace("iter_per=64,", "") \
        .replace("iter_per=128,", "") \
        .replace("eval_mode=True", "eval_mode=False")


    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    suffix = "roberta"
    if args.baseline:
        suffix = "naive"
        tokenizer = NaiveTokenizer()
    try:
        dataset, test_dataset, tokenizer = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb"))
        dev_dataset, _, _ = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb"))
    except:
        dataset = DisasterTweetsClassificationDataset(tokenizer, "data/train.csv", "train")
        test_dataset = DisasterTweetsClassificationDataset(tokenizer, "data/test.csv", "test")
        torch.save((dataset, test_dataset, tokenizer), open("checkpoints/dataset-%s.pyc" % suffix, "wb"))
        dev_dataset, _, _  = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb"))
    dev_dataset.eval()

    if args.baseline:
        encoder = nn.LSTM(256, 256, 1, batch_first=True, )
        model = NaiveLSTMBaselineClassifier()
    else:
        if args.load_weights:
            encoder = RobertaModel.from_pretrained("roberta-base")
            model = RoBERTaClassifierHead(encoder.config)
        else:
            config = RobertaConfig.from_pretrained("roberta-base")
            encoder = RobertaModel(config=config)
            model = RoBERTaClassifierHead(config)
    encoder.cuda()
    model.cuda()
    dataloader = datautils.DataLoader(
        dataset, batch_size=64 // args.iter_per, shuffle=True,
        num_workers=16, drop_last=False, pin_memory=True
    )
    dev_dataloader = datautils.DataLoader(
        dev_dataset, batch_size=64 // args.iter_per, shuffle=True,
        num_workers=16, drop_last=False, pin_memory=True
    )
    test_dataloader = datautils.DataLoader(
        test_dataset, batch_size=64 // args.iter_per, shuffle=False,
        num_workers=16, drop_last=False, pin_memory=True
    )

    if args.eval_mode:
        correct = 0
        all = 0
        encoder_, model_ = torch.load("checkpoints/%s" % directory_identifier)
        encoder.load_state_dict(encoder_)
        model.load_state_dict(model_)
        encoder.eval()
        with torch.no_grad():
            for ids, mask, label in dev_dataloader:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                prediction = model(encoder, ids, mask).argmax(dim=-1)
                correct += (prediction == label).to(torch.long).sum().item()
                all += mask.shape[0]
        print("dev acc:", correct / all)
        opt = AdamW(lr=1e-6, weight_decay=0.05, params=list(encoder.parameters()) + list(model.parameters()))
        encoder.train()
        iter_num = 0
        LOSS = []
        for _ in range(5):
            iterator = tqdm.tqdm(dev_dataloader)
            for ids, mask, label in iterator:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                log_prediction = model(encoder, ids, mask)
                loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean()
                if iter_num % args.iter_per == 0:
                    opt.zero_grad()
                (loss / args.iter_per).backward()
                LOSS.append(loss.item())
                if len(LOSS) > 10:
                    iterator.write("loss=%f" % np.mean(LOSS))
                    LOSS = []
                if iter_num % args.iter_per == args.iter_per - 1:
                    opt.step()
                iter_num += 1
        encoder.eval()
        with torch.no_grad():
            for ids, mask, label in dev_dataloader:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                prediction = model(encoder, ids, mask).argmax(dim=-1)
                correct += (prediction == label).to(torch.long).sum().item()
                all += mask.shape[0]
        print("dev acc rectified:", correct / all)

        with torch.no_grad():
            with open("submission.csv", "w") as fout:
                print("id,target", file=fout)
                for id, ids, mask in test_dataloader:
                    ids, mask = ids.cuda(), mask.cuda()
                    prediction = model(encoder, ids, mask).argmax(dim=-1)
                    for i in range(id.size(0)):
                        print("%d,%d" % (id[i], prediction[i]), file=fout)
        exit()
    if args.baseline:
        lr = 5e-4
    elif args.load_weights:
        lr = 1e-6
    else:
        lr = 5e-6
    opt = AdamW(lr=lr, weight_decay=0.10 if args.baseline else 0.05, params=list(encoder.parameters())+list(model.parameters()))
    flog = open("checkpoints/log-%s.txt" % directory_identifier, "w")
    flog.close()
    flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "w")
    flogeval.close()
    iter_num = 0
    for epoch_idx in range(5 if args.baseline else 10):
        flog = open("checkpoints/log-%s.txt" % directory_identifier, "a")
        flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "a")
        LOSS = []
        encoder.train()
        iterator = tqdm.tqdm(dataloader)
        for ids, mask, label in iterator:
            ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
            log_prediction = model(encoder, ids, mask)
            loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean()
            if iter_num % args.iter_per == 0:
                opt.zero_grad()
            (loss / args.iter_per).backward()
            LOSS.append(loss.item())
            if len(LOSS) > 10:
                iterator.write("loss=%f" % np.mean(LOSS))
                print("%f" % np.mean(LOSS), file=flog)
                LOSS = []
            if iter_num % args.iter_per == args.iter_per - 1:
                opt.step()
            iter_num += 1
        EVALLOSS = []
        encoder.eval()
        iterator = tqdm.tqdm(dev_dataloader)
        with torch.no_grad():
            for ids, mask, label in iterator:
                ids, mask, label = ids.cuda(), mask.cuda(), label.cuda()
                log_prediction = model(encoder, ids, mask)
                loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean()
                EVALLOSS.append(loss.item())
        iterator.write("evalloss-%d=%f" % (epoch_idx, np.mean(EVALLOSS)))
        print("%f" % np.mean(EVALLOSS), file=flogeval)

        flog.close()
        flogeval.close()
        torch.save((encoder.state_dict(), model.state_dict()), "checkpoints/%s" % directory_identifier)