class BertCRFClassifier(pl.LightningModule): def __init__(self, num_classes: int, bert_weights: str): super(BertCRFClassifier, self).__init__() self.bert = BertForTokenClassification.from_pretrained( bert_weights, num_labels=4, output_hidden_states=True) self.crf = CRF(num_tags=num_classes) def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor = None): bert_out, _ = self.bert(input_ids, attention_mask=attention_mask) crf_out = self.crf(bert_out) # pooled_logits = torch.mean(torch.stack(bert_logits), dim=0) return crf_out, bert_out def crf_loss(self, pred_logits: torch.tensor, labels: torch.tensor) -> torch.tensor: return self.crf.loss(pred_logits, labels) def training_step(self, batch, batch_idx): input_ids, mask, labels = batch preds, logits = self.forward(input_ids, attention_mask=mask) loss = self.crf_loss(logits, labels) tensorboard_logs = {'train_loss': loss} return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): input_ids, mask, labels = batch preds, logits = self.forward(input_ids, attention_mask=mask) loss = self.crf_loss(logits, labels) labels = labels.detach().cpu().numpy().flatten() preds = preds.detach().cpu().numpy().flatten() recall = recall_score(labels, preds, average="macro") recall = torch.tensor(recall) precision = precision_score(labels, preds, average="macro") precision = torch.tensor(precision) return {'val_loss': loss, "recall": recall, "precision": precision} def validation_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_recall = torch.stack([x["recall"] for x in outputs]).mean() avg_precision = torch.stack([x["precision"] for x in outputs]).mean() tensorboard_logs = { "val_loss": avg_loss, 'avg_val_recall': avg_recall, 'avg_val_precision': avg_precision } return { 'avg_val_loss': avg_loss, 'avg_val_recall': avg_recall, 'avg_val_precision': avg_precision, 'progress_bar': tensorboard_logs } def configure_optimizers(self): param_optimizer = list(self.parameters()) no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [{ "params": [ p for n, p in self.bert.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }] optimizer = Adam(optimizer_grouped_parameters, lr=2e-5) # optimizer = Adam(optimizer_grouped_parameters, lr=5e-5) return optimizer @pl.data_loader def train_dataloader(self): return train_dataloader_ @pl.data_loader def val_dataloader(self): return val_dataloader_
class BertCRFClassifier(pl.LightningModule): def __init__(self, num_classes: int, bert_weights: str, dropout: float=.10): super(BertCRFClassifier, self).__init__() self.bert = BertModel.from_pretrained(bert_weights) for param in list(self.bert.parameters())[:-5]: param.requires_grad = False hidden_size = self.bert.config.hidden_size self.span_clf_head = nn.Linear(hidden_size, num_classes) self.binary_clf_head = nn.Linear(hidden_size, 2) self.attention = SelfAttention(hidden_size, batch_first=True) self.dropout = nn.Dropout(p=dropout) self.crf = CRF(num_tags=num_classes) def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, sent_lens: torch.tensor): bert_last, bert_hidden = self.bert(input_ids, attention_mask=attention_mask) span_attention = self.attention(bert_last, sent_lens) bin_attention = self.attention(bert_hidden, sent_lens) span_clf = self.dropout(self.span_clf_head(span_attention)) bin_clf = self.dropout(self.binary_clf_head(bin_attention)) crf_out = self.crf(span_clf) return crf_out, span_clf, bin_clf def crf_loss(self, pred_logits: torch.tensor, labels: torch.tensor) -> torch.tensor: return self.crf.loss(pred_logits, labels) def training_step(self, batch, batch_idx): input_ids, mask, labels = batch sent_lengths = torch.sum(labels, dim=1).long().to("cuda:0") bin_labels = (torch.sum(labels, dim=1) > 0).long() bin_labels = bin_labels.to("cuda:0") preds, span_logits, bin_logits = self.forward(input_ids, attention_mask=mask, sent_lens=sent_lengths) span_loss = self.crf_loss(span_logits, labels) bin_loss = F.cross_entropy(bin_logits, bin_labels) combined_loss = span_loss + bin_loss tensorboard_logs = {'train_loss': combined_loss} return {'loss': combined_loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): input_ids, mask, labels = batch sent_lengths = torch.sum(labels, dim=1).long().to("cuda:0") bin_labels = (torch.sum(labels, dim=1) > 0).long() bin_labels = bin_labels.to("cuda:0") preds, span_logits, bin_logits = self.forward(input_ids, attention_mask=mask, sent_lens=sent_lengths) span_loss = self.crf_loss(span_logits, labels) bin_loss = F.cross_entropy(bin_logits, bin_labels) combined_loss = span_loss + bin_loss labels = labels.detach().cpu().numpy().flatten() bin_labels = bin_labels.detach().cpu().numpy().flatten() span_preds = preds.detach().cpu().numpy().flatten() bin_preds = torch.argmax(bin_logits,dim=1).detach().cpu().numpy().flatten() span_recall = torch.tensor(recall_score(labels, span_preds, average="macro")) bin_recall = torch.tensor(recall_score(bin_labels, bin_preds, average="macro")) span_precision = torch.tensor(precision_score(labels, span_preds, average="macro")) bin_precision = torch.tensor(precision_score(bin_labels, bin_preds, average="macro")) return {'val_loss': combined_loss, "span_recall": span_recall, "bin_recall": bin_recall, "span_precision": span_precision, "bin_precision": bin_precision} def validation_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_span_recall = torch.stack([x["span_recall"] for x in outputs]).mean() avg_bin_recall = torch.stack([x["bin_recall"] for x in outputs]).mean() avg_span_precision = torch.stack([x["span_precision"] for x in outputs]).mean() avg_bin_precision = torch.stack([x["bin_precision"] for x in outputs]).mean() tensorboard_logs = {"val_loss": avg_loss} return {'avg_val_loss': avg_loss, 'avg_val_span_recall': avg_span_recall, 'avg_val_bin_recall': avg_bin_recall, 'avg_val_span_precision': avg_span_precision, 'avg_val_bin_precision': avg_bin_precision, 'progress_bar': tensorboard_logs} def configure_optimizers(self): param_optimizer = list(self.named_parameters()) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.bert.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 0.01, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8) return optimizer @pl.data_loader def train_dataloader(self): return train_dataloader_ @pl.data_loader def val_dataloader(self): return val_dataloader_