class RobertaQA(BertPreTrainedModel): def __init__(self, config,model_name_or_path=None,pretrained_weights=None): super(RobertaQA, self).__init__(config) if model_name_or_path: self.roberta = RobertaModel.from_pretrained(model_name_or_path, config=config) else: self.roberta = RobertaModel(config=config) if pretrained_weights: self.roberta.load_state_dict(pretrained_weights) self.qa_outputs = nn.Linear(config.hidden_size, 2) torch.nn.init.normal_(self.qa_outputs.weight, mean=0.0,std=self.config.initializer_range) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, start_positions=None, end_positions=None, ): outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) outputs = (start_logits, end_logits,) if start_positions is not None and end_positions is not None: # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 outputs = (total_loss,) + outputs return outputs # (loss), start_logits, end_logits
def main(): if not os.path.exists("./checkpoints"): os.mkdir("checkpoints") parser = argparse.ArgumentParser() def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Unsupported value encountered.') parser.add_argument( "--eval_mode", default=False, type=str2bool, required=False, help="Test or train the model", ) parser.add_argument( "--baseline", default=False, type=str2bool, required=False, help="use the baseline or the transformers model", ) parser.add_argument( "--load_weights", default=True, type=str2bool, required=False, help="Load the pretrained weights or randomly initialize the model", ) parser.add_argument( "--iter_per", default=4, type=int, required=False, help="cumulative gradient iteration cycle", ) args = parser.parse_args() directory_identifier = args.__str__().replace(" ", "") \ .replace("iter_per=1,", "") \ .replace("iter_per=2,", "") \ .replace("iter_per=4,", "") \ .replace("iter_per=8,", "") \ .replace("iter_per=16,", "") \ .replace("iter_per=32,", "") \ .replace("iter_per=64,", "") \ .replace("iter_per=128,", "") \ .replace("eval_mode=True", "eval_mode=False") tokenizer = RobertaTokenizer.from_pretrained("roberta-base") suffix = "roberta" if args.baseline: suffix = "naive" tokenizer = NaiveTokenizer() try: dataset, test_dataset, tokenizer = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb")) dev_dataset, _, _ = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb")) except: dataset = DisasterTweetsClassificationDataset(tokenizer, "data/train.csv", "train") test_dataset = DisasterTweetsClassificationDataset(tokenizer, "data/test.csv", "test") torch.save((dataset, test_dataset, tokenizer), open("checkpoints/dataset-%s.pyc" % suffix, "wb")) dev_dataset, _, _ = torch.load(open("checkpoints/dataset-%s.pyc" % suffix, "rb")) dev_dataset.eval() if args.baseline: encoder = nn.LSTM(256, 256, 1, batch_first=True, ) model = NaiveLSTMBaselineClassifier() else: if args.load_weights: encoder = RobertaModel.from_pretrained("roberta-base") model = RoBERTaClassifierHead(encoder.config) else: config = RobertaConfig.from_pretrained("roberta-base") encoder = RobertaModel(config=config) model = RoBERTaClassifierHead(config) encoder.cuda() model.cuda() dataloader = datautils.DataLoader( dataset, batch_size=64 // args.iter_per, shuffle=True, num_workers=16, drop_last=False, pin_memory=True ) dev_dataloader = datautils.DataLoader( dev_dataset, batch_size=64 // args.iter_per, shuffle=True, num_workers=16, drop_last=False, pin_memory=True ) test_dataloader = datautils.DataLoader( test_dataset, batch_size=64 // args.iter_per, shuffle=False, num_workers=16, drop_last=False, pin_memory=True ) if args.eval_mode: correct = 0 all = 0 encoder_, model_ = torch.load("checkpoints/%s" % directory_identifier) encoder.load_state_dict(encoder_) model.load_state_dict(model_) encoder.eval() with torch.no_grad(): for ids, mask, label in dev_dataloader: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() prediction = model(encoder, ids, mask).argmax(dim=-1) correct += (prediction == label).to(torch.long).sum().item() all += mask.shape[0] print("dev acc:", correct / all) opt = AdamW(lr=1e-6, weight_decay=0.05, params=list(encoder.parameters()) + list(model.parameters())) encoder.train() iter_num = 0 LOSS = [] for _ in range(5): iterator = tqdm.tqdm(dev_dataloader) for ids, mask, label in iterator: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() log_prediction = model(encoder, ids, mask) loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean() if iter_num % args.iter_per == 0: opt.zero_grad() (loss / args.iter_per).backward() LOSS.append(loss.item()) if len(LOSS) > 10: iterator.write("loss=%f" % np.mean(LOSS)) LOSS = [] if iter_num % args.iter_per == args.iter_per - 1: opt.step() iter_num += 1 encoder.eval() with torch.no_grad(): for ids, mask, label in dev_dataloader: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() prediction = model(encoder, ids, mask).argmax(dim=-1) correct += (prediction == label).to(torch.long).sum().item() all += mask.shape[0] print("dev acc rectified:", correct / all) with torch.no_grad(): with open("submission.csv", "w") as fout: print("id,target", file=fout) for id, ids, mask in test_dataloader: ids, mask = ids.cuda(), mask.cuda() prediction = model(encoder, ids, mask).argmax(dim=-1) for i in range(id.size(0)): print("%d,%d" % (id[i], prediction[i]), file=fout) exit() if args.baseline: lr = 5e-4 elif args.load_weights: lr = 1e-6 else: lr = 5e-6 opt = AdamW(lr=lr, weight_decay=0.10 if args.baseline else 0.05, params=list(encoder.parameters())+list(model.parameters())) flog = open("checkpoints/log-%s.txt" % directory_identifier, "w") flog.close() flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "w") flogeval.close() iter_num = 0 for epoch_idx in range(5 if args.baseline else 10): flog = open("checkpoints/log-%s.txt" % directory_identifier, "a") flogeval = open("checkpoints/evallog-%s.txt" % directory_identifier, "a") LOSS = [] encoder.train() iterator = tqdm.tqdm(dataloader) for ids, mask, label in iterator: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() log_prediction = model(encoder, ids, mask) loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean() if iter_num % args.iter_per == 0: opt.zero_grad() (loss / args.iter_per).backward() LOSS.append(loss.item()) if len(LOSS) > 10: iterator.write("loss=%f" % np.mean(LOSS)) print("%f" % np.mean(LOSS), file=flog) LOSS = [] if iter_num % args.iter_per == args.iter_per - 1: opt.step() iter_num += 1 EVALLOSS = [] encoder.eval() iterator = tqdm.tqdm(dev_dataloader) with torch.no_grad(): for ids, mask, label in iterator: ids, mask, label = ids.cuda(), mask.cuda(), label.cuda() log_prediction = model(encoder, ids, mask) loss = -log_prediction[torch.arange(ids.size(0)).cuda(), label].mean() EVALLOSS.append(loss.item()) iterator.write("evalloss-%d=%f" % (epoch_idx, np.mean(EVALLOSS))) print("%f" % np.mean(EVALLOSS), file=flogeval) flog.close() flogeval.close() torch.save((encoder.state_dict(), model.state_dict()), "checkpoints/%s" % directory_identifier)