class roBerta(nn.Module): def __init__(self, config, num=0): super(roBerta, self).__init__() model_config = RobertaConfig() model_config.vocab_size = config.vocab_size model_config.hidden_size = config.hidden_size[0] model_config.num_attention_heads = 16 # 计算loss的方法 self.loss_method = config.loss_method self.multi_drop = config.multi_drop self.roberta = RobertaModel(model_config) if config.requires_grad: for param in self.roberta.parameters(): param.requires_grad = True self.dropout = nn.Dropout(config.hidden_dropout_prob) self.hidden_size = config.hidden_size[num] if self.loss_method in ['binary', 'focal_loss', 'ghmc']: self.classifier = nn.Linear(self.hidden_size, 1) else: self.classifier = nn.Linear(self.hidden_size, self.num_labels) self.text_linear = nn.Linear(config.embeding_size, config.hidden_size[0]) self.vocab_layer = nn.Linear(config.hidden_size[0], config.vocab_size) self.classifier.apply(self._init_weights) self.roberta.apply(self._init_weights) self.text_linear.apply(self._init_weights) self.vocab_layer.apply(self._init_weights) def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02) def forward(self, inputs=None, attention_mask=None, output_id=None, labels=None): inputs = torch.relu(self.text_linear(inputs)) bert_outputs = self.roberta(inputs_embeds=inputs, attention_mask=attention_mask) #calculate mlm loss last_hidden_state = bert_outputs[0] output_id_tmp = output_id[output_id.ne(-100)] output_id_emb = last_hidden_state[output_id.ne(-100)] pre_score = self.vocab_layer(output_id_emb) loss_cro = CrossEntropyLoss() mlm_loss = loss_cro(torch.sigmoid(pre_score), output_id_tmp) labels_bool = labels.ne(-1) if labels_bool.sum().item() == 0: return mlm_loss, torch.tensor([]) #calculate label loss pooled_output = bert_outputs[1] out = self.classifier(pooled_output) out = out[labels_bool] labels_tmp = labels[labels_bool] label_loss = compute_loss(out, labels_tmp) out = torch.sigmoid(out).flatten() return mlm_loss + label_loss, out return out, loss
def main(): if not os.path.exists("./checkpoints"): os.mkdir("checkpoints") parser = argparse.ArgumentParser() def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Unsupported value encountered.') parser.add_argument( "--eval_mode", default=False, type=str2bool, required=False, help="Test or train the model", ) parser.add_argument( "--baseline", default=False, type=str2bool, required=False, help="use the baseline or the transformers model", ) parser.add_argument( "--load_weights", default=True, type=str2bool, required=False, help="Load the pretrained weights or randomly initialize the model", ) parser.add_argument( "--iter_per", default=4, type=int, required=False, help="cumulative gradient iteration cycle", ) args = parser.parse_args() directory_identifier = args.__str__().replace(" ", "") \ .replace("iter_per=1,", "") \ .replace("iter_per=2,", "") \ .replace("iter_per=4,", "") \ .replace("iter_per=8,", "") \ .replace("iter_per=16,", "") \ .replace("iter_per=32,", "") \ .replace("iter_per=64,", "") \ .replace("iter_per=128,", "") \ .replace("eval_mode=True", "eval_mode=False") tokenizer = RobertaTokenizer.from_pretrained("roberta-base") suffix = "roberta" if args.baseline: suffix = "naive" tokenizer = NaiveTokenizer() try: dataset, test_dataset, tokenizer = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb")) dev_dataset, _, _ = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb")) except: dataset = DisasterTweetsClassificationDataset(tokenizer, "data/train.csv", "train") test_dataset = DisasterTweetsClassificationDataset(tokenizer, "data/test.csv", "test") torch.save((dataset, test_dataset, tokenizer), open("checkpoints/dataset-%s.pyc" % suffix, "wb")) dev_dataset, _, _ = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb")) dev_dataset.eval() if args.baseline: encoder = nn.LSTM(256, 256, 1, batch_first=True, ) model = NaiveLSTMBaselineClassifier() else: if args.load_weights: encoder = RobertaModel.from_pretrained("roberta-base") model = RoBERTaClassifierHead(encoder.config) else: config = RobertaConfig.from_pretrained("roberta-base") encoder = RobertaModel(config=config) model = RoBERTaClassifierHead(config) encoder.cuda() model.cuda() dataloader = datautils.DataLoader( dataset, batch_size=64 // args.iter_per, shuffle=True, num_workers=16, drop_last=False, pin_memory=True ) dev_dataloader = datautils.DataLoader( dev_dataset, batch_size=64 // args.iter_per, shuffle=True, num_workers=16, drop_last=False, pin_memory=True ) test_dataloader = datautils.DataLoader( test_dataset, batch_size=64 // args.iter_per, shuffle=False, num_workers=16, drop_last=False, pin_memory=True ) if args.eval_mode: correct = 0 all = 0 encoder_, model_ = torch.load("checkpoints/%s" % directory_identifier) encoder.load_state_dict(encoder_) model.load_state_dict(model_) encoder.eval() with torch.no_grad(): for ids, mask, label in dev_dataloader: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() prediction = model(encoder, ids, mask).argmax(dim=-1) correct += (prediction == label).to(torch.long).sum().item() all += mask.shape[0] print("dev acc:", correct / all) opt = AdamW(lr=1e-6, weight_decay=0.05, params=list(encoder.parameters()) + list(model.parameters())) encoder.train() iter_num = 0 LOSS = [] for _ in range(5): iterator = tqdm.tqdm(dev_dataloader) for ids, mask, label in iterator: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() log_prediction = model(encoder, ids, mask) loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean() if iter_num % args.iter_per == 0: opt.zero_grad() (loss / args.iter_per).backward() LOSS.append(loss.item()) if len(LOSS) > 10: iterator.write("loss=%f" % np.mean(LOSS)) LOSS = [] if iter_num % args.iter_per == args.iter_per - 1: opt.step() iter_num += 1 encoder.eval() with torch.no_grad(): for ids, mask, label in dev_dataloader: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() prediction = model(encoder, ids, mask).argmax(dim=-1) correct += (prediction == label).to(torch.long).sum().item() all += mask.shape[0] print("dev acc rectified:", correct / all) with torch.no_grad(): with open("submission.csv", "w") as fout: print("id,target", file=fout) for id, ids, mask in test_dataloader: ids, mask = ids.cuda(), mask.cuda() prediction = model(encoder, ids, mask).argmax(dim=-1) for i in range(id.size(0)): print("%d,%d" % (id[i], prediction[i]), file=fout) exit() if args.baseline: lr = 5e-4 elif args.load_weights: lr = 1e-6 else: lr = 5e-6 opt = AdamW(lr=lr, weight_decay=0.10 if args.baseline else 0.05, params=list(encoder.parameters())+list(model.parameters())) flog = open("checkpoints/log-%s.txt" % directory_identifier, "w") flog.close() flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "w") flogeval.close() iter_num = 0 for epoch_idx in range(5 if args.baseline else 10): flog = open("checkpoints/log-%s.txt" % directory_identifier, "a") flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "a") LOSS = [] encoder.train() iterator = tqdm.tqdm(dataloader) for ids, mask, label in iterator: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() log_prediction = model(encoder, ids, mask) loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean() if iter_num % args.iter_per == 0: opt.zero_grad() (loss / args.iter_per).backward() LOSS.append(loss.item()) if len(LOSS) > 10: iterator.write("loss=%f" % np.mean(LOSS)) print("%f" % np.mean(LOSS), file=flog) LOSS = [] if iter_num % args.iter_per == args.iter_per - 1: opt.step() iter_num += 1 EVALLOSS = [] encoder.eval() iterator = tqdm.tqdm(dev_dataloader) with torch.no_grad(): for ids, mask, label in iterator: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() log_prediction = model(encoder, ids, mask) loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean() EVALLOSS.append(loss.item()) iterator.write("evalloss-%d=%f" % (epoch_idx, np.mean(EVALLOSS))) print("%f" % np.mean(EVALLOSS), file=flogeval) flog.close() flogeval.close() torch.save((encoder.state_dict(), model.state_dict()), "checkpoints/%s" % directory_identifier)