class LightningLongformerCLS(pl.LightningModule): def __init__(self, config): super().__init__() self.train_config = config self.roberta = RobertaForMaskedLM.from_pretrained('roberta-base') _ = self.roberta.eval() for param in self.roberta.parameters(): param.requires_grad = False self.pred_model = self.roberta.roberta self.enc_model = self.pred_model.embeddings.word_embeddings # self.proj_head = DVProjectionHead() # self.proj_head = DVProjectionHead_ActiFirst() self.proj_head = DVProjectionHead_EmbActi() self.tkz = RobertaTokenizer.from_pretrained("roberta-base") self.collator = TokenizerCollate(self.tkz) self.lossfunc = nn.BCEWithLogitsLoss() self.acc = Accuracy(threshold=0.0) self.f1 = F1(threshold=0.0) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.train_config["learning_rate"]) scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup( optimizer, num_warmup_steps=10, num_training_steps=5000, num_cycles=10) schedulers = [{ 'scheduler': scheduler, 'interval': 'step', 'frequency': 1 }] return [optimizer], schedulers def train_dataloader(self): # self.dataset_train = PANDataset('./data_pickle_cutcombo/pan_all_cls/train_kucombo_only.pickle') # self.dataset_train = PANDataset('./data_pickle_cutcombo/pan_14e_cls/train_essays.pickle') self.dataset_train = PANDataset( './data_pickle_cutcombo/pan_14n_cls/train_novels_kucombo_only.pickle' ) self.loader_train = DataLoader( self.dataset_train, batch_size=self.train_config["batch_size"], collate_fn=self.collator, num_workers=4, pin_memory=True, drop_last=False, shuffle=False) return self.loader_train def val_dataloader(self): # self.dataset_val = PANDataset('./data_pickle_cutcombo/pan_14e_cls/test02_essays_onecut.pickle') self.dataset_val = PANDataset( './data_pickle_cutcombo/pan_14n_cls/test02_novels_onecut.pickle') self.loader_val = DataLoader( self.dataset_val, batch_size=self.train_config["batch_size"], collate_fn=self.collator, num_workers=4, pin_memory=True, drop_last=False, shuffle=False) return self.loader_val def test_dataloader(self): # self.dataset_test = PANDataset('./data_pickle_cutcombo/pan_14e_cls/test02_essays_onecut.pickle') self.dataset_test = PANDataset( './data_pickle_cutcombo/pan_14n_cls/test02_novels_onecut.pickle') self.loader_test = DataLoader( self.dataset_test, batch_size=self.train_config["batch_size"], collate_fn=self.collator, num_workers=4, pin_memory=True, drop_last=False, shuffle=False) return self.loader_test @autocast() def forward(self, inputs, onedoc_enc=False): def one_doc_embed(input_ids, input_mask, mask_n=1): uniq_mask = [] uniq_input, inverse_indices = torch.unique(input_ids, return_inverse=True, dim=0) invi = inverse_indices.detach().cpu().numpy() for i in range(uniq_input.shape[0]): first_index = np.where(invi == i)[0][0] uniq_mask.append(input_mask[first_index, :]) input_ids = uniq_input input_mask = torch.stack(uniq_mask, dim=0) embed = self.enc_model(input_ids) result_embed = [] result_pred = [] # skip start and end symbol masked_ids = input_ids.clone() for i in range(1, input_ids.shape[1] - mask_n): masked_ids[:, i:(i + mask_n)] = self.tkz.mask_token_id output = self.pred_model(input_ids=masked_ids, attention_mask=input_mask, return_dict=False)[0] result_embed.append(embed[:, i:(i + mask_n), :]) result_pred.append(output[:, i:(i + mask_n), :]) masked_ids[:, i:(i + mask_n)] = input_ids[:, i:(i + mask_n)] # stack along doc_len result_embed = torch.cat(result_embed, dim=1) result_pred = torch.cat(result_pred, dim=1) rec_embed = [] rec_pred = [] for i in invi: rec_embed.append(result_embed[i, :, :]) rec_pred.append(result_pred[i, :, :]) rec_embed = torch.stack(rec_embed, dim=0) rec_pred = torch.stack(rec_pred, dim=0) return rec_embed, rec_pred if onedoc_enc: doc_ids, doc_mask = inputs doc_embed, doc_pred = one_doc_embed(input_ids=doc_ids, input_mask=doc_mask) doc_dv = doc_pred - doc_embed return doc_pred, doc_embed, doc_dv else: labels, kno_ids, kno_mask, unk_ids, unk_mask = inputs kno_embed, kno_pred = one_doc_embed(input_ids=kno_ids, input_mask=kno_mask) unk_embed, unk_pred = one_doc_embed(input_ids=unk_ids, input_mask=unk_mask) kno_dv = kno_pred - kno_embed unk_dv = unk_pred - unk_embed # logits = self.proj_head(kno_dv, kno_mask[:,1:-1], unk_dv, unk_mask[:,1:-1]) logits = self.proj_head(kno_embed, kno_dv, kno_mask[:, 1:-1], unk_embed, unk_dv, unk_mask[:, 1:-1]) logits = torch.squeeze(logits) labels = labels.float() loss = self.lossfunc(logits, labels) return (loss, logits, (kno_embed, kno_pred, unk_embed, unk_pred)) def training_step(self, batch, batch_idx): labels, kno_ids, kno_mask, unk_ids, unk_mask = batch loss, logits, outputs = self( (labels, kno_ids, kno_mask, unk_ids, unk_mask)) self.log("train_loss", loss) self.log("logits mean", logits.mean()) self.log("LR", self.trainer.optimizers[0].param_groups[0]['lr']) return loss def validation_step(self, batch, batch_idx): labels, kno_ids, kno_mask, unk_ids, unk_mask = batch loss, logits, outputs = self( (labels, kno_ids, kno_mask, unk_ids, unk_mask)) self.acc(logits, labels.float()) self.f1(logits, labels.float()) return {"val_loss": loss} def validation_epoch_end(self, validation_step_outputs): avg_loss = torch.stack( [x['val_loss'] for x in validation_step_outputs]).mean() self.log("val_loss", avg_loss) self.log('eval accuracy', self.acc.compute()) self.log('eval F1', self.f1.compute())
class LightningLongformerCLS(pl.LightningModule): def __init__(self, config): super().__init__() self.train_config = config self.roberta = RobertaForMaskedLM.from_pretrained('roberta-base') _ = self.roberta.eval() for param in self.roberta.parameters(): param.requires_grad = False self.pred_model = self.roberta.roberta self.enc_model = self.pred_model.embeddings.word_embeddings self.proj_head = DVProjectionHead_EmbActi() self.lossfunc = nn.BCEWithLogitsLoss() self.acc = Accuracy(threshold=0.0) self.f1 = F1(threshold=0.0) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.train_config["learning_rate"]) return optimizer # scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, # num_warmup_steps=20, # num_training_steps=10000, # num_cycles=10) # schedulers = [ # { # 'scheduler': scheduler, # 'interval': 'step', # 'frequency': 1 # }] # return [optimizer], schedulers def train_dataloader(self): self.dataset_train = PANDatasetKUEP( "data_pickle_cutcombo/pan_13_cls/train_KUEP_combo.pt") # self.dataset_train = PANDatasetKUEP("data_pickle_cutcombo/pan_14e_cls/train_KUEP_combo.pt") # self.dataset_train = PANDatasetKUEP("data_pickle_cutcombo/pan_14n_cls/train_KUEP_combo110k.pt") # self.dataset_train = PANDatasetKUEP("data_pickle_cutcombo/pan_15_cls/train_KUEP_combo.pt") self.loader_train = DataLoader( self.dataset_train, batch_size=self.train_config["batch_size"], collate_fn=TokenizerCollateKUEP(), num_workers=1, pin_memory=True, drop_last=False, shuffle=True) return self.loader_train def val_dataloader(self): self.dataset_val = PANDatasetKUEP( "data_pickle_cutcombo/pan_13_cls/test02_KUEP.pt", # self.dataset_val = PANDatasetKUEP("data_pickle_cutcombo/pan_14e_cls/test02_essays_KUEP.pt", # self.dataset_val = PANDatasetKUEP("data_pickle_cutcombo/pan_14n_cls/test02_KUEP.pt", # self.dataset_val = PANDatasetKUEP("data_pickle_cutcombo/pan_15_cls/test_KUEP.pt", test_1cut=True) self.loader_val = DataLoader( self.dataset_val, batch_size=self.train_config["batch_size"], collate_fn=TokenizerCollateKUEP(), num_workers=1, pin_memory=True, drop_last=False, shuffle=False) return self.loader_val @autocast() def forward(self, inputs): labels, kno_ids, kno_mask, kno_embed, kno_pred, \ unk_ids, unk_mask, unk_embed, unk_pred = inputs kno_dv = kno_pred - kno_embed unk_dv = unk_pred - unk_embed logits = self.proj_head(kno_embed, kno_dv, kno_mask[:, 1:-1], unk_embed, unk_dv, unk_mask[:, 1:-1]) logits = torch.squeeze(logits) labels = labels.float() loss = self.lossfunc(logits, labels) return (loss, logits, (kno_embed, kno_pred, unk_embed, unk_pred)) def training_step(self, batch, batch_idx): labels, kno_ids, kno_mask, kno_embed, kno_pred, \ unk_ids, unk_mask, unk_embed, unk_pred = batch loss, logits, outputs = self(batch) self.log("train_loss", loss) self.log("logits mean", logits.mean()) self.log("LR", self.trainer.optimizers[0].param_groups[0]['lr']) return loss def validation_step(self, batch, batch_idx): labels, kno_ids, kno_mask, kno_embed, kno_pred, \ unk_ids, unk_mask, unk_embed, unk_pred = batch loss, logits, outputs = self(batch) self.acc(logits, labels.float()) self.f1(logits, labels.float()) return {"val_loss": loss} def validation_epoch_end(self, validation_step_outputs): avg_loss = torch.stack( [x['val_loss'] for x in validation_step_outputs]).mean() self.log("val_loss", avg_loss) self.log('eval accuracy', self.acc.compute()) self.log('eval F1', self.f1.compute())