class TrainingModule(pl.LightningModule): def __init__(self, tagger: LstmTagger): super().__init__() self.tagger = tagger self.train_metrics = MetricCollection([Precision(), Recall(), TopkAccuracy(1)]) self.val_metrics = MetricCollection([Precision(), Recall(), TopkAccuracy(1)]) self.softmax = torch.nn.Softmax(dim=-1) self.celoss = SequenceCrossEntropyLoss(reduction="batch-mean", pad_idx=2) def training_step(self, batch, batch_idx): reports, target, masks = batch mask = torch.cat(masks, dim=1) if self.tagger.with_crf: emissions = torch.cat([self.tagger.calc_emissions(report, mask) for report, mask in zip(reports, masks)], dim=1) loss = -self.tagger.crf(emissions, target, mask) else: scores = self.tagger.forward(reports, masks) loss = self.celoss(scores, target) with torch.no_grad(): scores = self.tagger.forward(reports, masks) preds = scores.argmax(dim=-1) scores = self.softmax(scores) self.train_metrics.update(preds, target, mask, scores=scores) self.log("train_loss", loss) return loss def validation_step(self, batch, *args): reports, target, masks = batch mask = torch.cat(masks, dim=1) if self.tagger.with_crf: emissions = torch.cat([self.tagger.calc_emissions(report, mask) for report, mask in zip(reports, masks)], dim=1) loss = -self.tagger.crf(emissions, target, mask) else: scores = self.tagger.forward(reports, masks) loss = self.celoss(scores, target) with torch.no_grad(): scores = self.tagger.forward(reports, masks) preds = scores.argmax(dim=-1) scores = self.softmax(scores) self.val_metrics.update(preds, target, mask, scores=scores) return loss def validation_epoch_end(self, outputs: List[Any]) -> None: super().validation_epoch_end(outputs) self.log("val_metrics", self.val_metrics.compute()) print(self.val_metrics.compute()) self.val_metrics.reset() def training_epoch_end(self, outputs: List[Any]) -> None: super().training_epoch_end(outputs) self.log("train_metrics", self.train_metrics.compute()) self.train_metrics.reset() def configure_optimizers(self): return Adam(self.parameters(), lr=1e-4, weight_decay=1e-5)
class DeepAnalyze(pl.LightningModule): def __init__(self, feature_size, lstm_hidden_size, lstm_num_layers, n_tags, max_len): super().__init__() self.padding = 0 self.bi_listm = nn.LSTM(feature_size, lstm_hidden_size, num_layers=lstm_num_layers, bidirectional=True) self.attention = DeepAnalyzeAttention(lstm_hidden_size * 2, n_tags, max_len) self.crf = CRF(n_tags) self.lstm_dropout = nn.Dropout(0.25) self.train_metrics = MetricCollection( [Precision(), Recall(), TopkAccuracy(3)]) self.val_metrics = MetricCollection( [Precision(), Recall(), TopkAccuracy(3)]) def forward(self, inputs, mask): seq_len, batch_size = mask.shape x, _ = self.bi_listm(inputs) x = self.lstm_dropout(x) x = self.attention(x, mask) preds = self.crf.decode(x, mask) preds = [pred + [0] * (seq_len - len(pred)) for pred in preds] preds = torch.tensor(preds).transpose(0, 1).to(inputs.device) return preds def training_step(self, batch, batch_idx): inputs, labels, mask = batch x, _ = self.bi_listm(inputs) x = self.lstm_dropout(x) emissions = self.attention(x, mask) loss = -self.crf(emissions, labels, mask) with torch.no_grad(): preds = self.forward(inputs, mask) self.train_metrics.update(preds, labels, mask, scores=get_label_scores( self.crf, emissions, preds, mask)) self.log("train_loss", loss) return loss def validation_step(self, batch, *args): inputs, labels, mask = batch x, _ = self.bi_listm(inputs) x = self.lstm_dropout(x) emissions = self.attention(x, mask) loss = -self.crf(emissions, labels, mask) with torch.no_grad(): preds = self.forward(inputs, mask) self.val_metrics.update(preds, labels, mask, scores=get_label_scores( self.crf, emissions, preds, mask)) return loss def validation_epoch_end(self, outputs: List[Any]) -> None: super().validation_epoch_end(outputs) self.log("val_metrics", self.val_metrics.compute()) self.val_metrics.reset() def training_epoch_end(self, outputs: List[Any]) -> None: super().training_epoch_end(outputs) self.log("train_metrics", self.train_metrics.compute()) self.train_metrics.reset() def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3)