Ejemplo n.º 1
0
class TrainingModule(pl.LightningModule):
    def __init__(self, tagger: LstmTagger):
        super().__init__()
        self.tagger = tagger

        self.train_metrics = MetricCollection([Precision(), Recall(), TopkAccuracy(1)])
        self.val_metrics = MetricCollection([Precision(), Recall(), TopkAccuracy(1)])

        self.softmax = torch.nn.Softmax(dim=-1)
        self.celoss = SequenceCrossEntropyLoss(reduction="batch-mean", pad_idx=2)

    def training_step(self, batch, batch_idx):
        reports, target, masks = batch
        mask = torch.cat(masks, dim=1)

        if self.tagger.with_crf:
            emissions = torch.cat([self.tagger.calc_emissions(report, mask) for report, mask in zip(reports, masks)],
                                  dim=1)
            loss = -self.tagger.crf(emissions, target, mask)
        else:
            scores = self.tagger.forward(reports, masks)
            loss = self.celoss(scores, target)

        with torch.no_grad():
            scores = self.tagger.forward(reports, masks)
            preds = scores.argmax(dim=-1)

        scores = self.softmax(scores)
        self.train_metrics.update(preds, target, mask, scores=scores)

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, *args):
        reports, target, masks = batch
        mask = torch.cat(masks, dim=1)
        if self.tagger.with_crf:
            emissions = torch.cat([self.tagger.calc_emissions(report, mask) for report, mask in zip(reports, masks)],
                                  dim=1)
            loss = -self.tagger.crf(emissions, target, mask)
        else:
            scores = self.tagger.forward(reports, masks)
            loss = self.celoss(scores, target)

        with torch.no_grad():
            scores = self.tagger.forward(reports, masks)
            preds = scores.argmax(dim=-1)

        scores = self.softmax(scores)
        self.val_metrics.update(preds, target, mask, scores=scores)

        return loss

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        super().validation_epoch_end(outputs)
        self.log("val_metrics", self.val_metrics.compute())
        print(self.val_metrics.compute())
        self.val_metrics.reset()

    def training_epoch_end(self, outputs: List[Any]) -> None:
        super().training_epoch_end(outputs)
        self.log("train_metrics", self.train_metrics.compute())
        self.train_metrics.reset()

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-4, weight_decay=1e-5)
Ejemplo n.º 2
0
class DeepAnalyze(pl.LightningModule):
    def __init__(self, feature_size, lstm_hidden_size, lstm_num_layers, n_tags,
                 max_len):
        super().__init__()
        self.padding = 0
        self.bi_listm = nn.LSTM(feature_size,
                                lstm_hidden_size,
                                num_layers=lstm_num_layers,
                                bidirectional=True)
        self.attention = DeepAnalyzeAttention(lstm_hidden_size * 2, n_tags,
                                              max_len)
        self.crf = CRF(n_tags)
        self.lstm_dropout = nn.Dropout(0.25)

        self.train_metrics = MetricCollection(
            [Precision(), Recall(), TopkAccuracy(3)])
        self.val_metrics = MetricCollection(
            [Precision(), Recall(), TopkAccuracy(3)])

    def forward(self, inputs, mask):
        seq_len, batch_size = mask.shape

        x, _ = self.bi_listm(inputs)
        x = self.lstm_dropout(x)
        x = self.attention(x, mask)
        preds = self.crf.decode(x, mask)

        preds = [pred + [0] * (seq_len - len(pred)) for pred in preds]
        preds = torch.tensor(preds).transpose(0, 1).to(inputs.device)

        return preds

    def training_step(self, batch, batch_idx):
        inputs, labels, mask = batch
        x, _ = self.bi_listm(inputs)
        x = self.lstm_dropout(x)
        emissions = self.attention(x, mask)

        loss = -self.crf(emissions, labels, mask)

        with torch.no_grad():
            preds = self.forward(inputs, mask)

        self.train_metrics.update(preds,
                                  labels,
                                  mask,
                                  scores=get_label_scores(
                                      self.crf, emissions, preds, mask))

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, *args):
        inputs, labels, mask = batch
        x, _ = self.bi_listm(inputs)
        x = self.lstm_dropout(x)
        emissions = self.attention(x, mask)
        loss = -self.crf(emissions, labels, mask)
        with torch.no_grad():
            preds = self.forward(inputs, mask)

        self.val_metrics.update(preds,
                                labels,
                                mask,
                                scores=get_label_scores(
                                    self.crf, emissions, preds, mask))
        return loss

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        super().validation_epoch_end(outputs)
        self.log("val_metrics", self.val_metrics.compute())
        self.val_metrics.reset()

    def training_epoch_end(self, outputs: List[Any]) -> None:
        super().training_epoch_end(outputs)
        self.log("train_metrics", self.train_metrics.compute())
        self.train_metrics.reset()

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)