def _create_statistic_with_values(tp: int, fp: int,
                                   fn: int) -> PredictionStatistic:
     statistic = PredictionStatistic(False)
     statistic._true_positive = tp
     statistic._false_positive = fp
     statistic._false_negative = fn
     return statistic
    def test_calculate_statistic_with_masking_long_sequence(self):
        gt_subtokens = torch.tensor([1, 2, 3, 6, 7, 8, 0, 0, 0]).view(-1, 1)
        pred_subtokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(-1, 1)

        statistic = PredictionStatistic(True, 0, [0])
        statistic.update_statistic(gt_subtokens, pred_subtokens)

        self.assertEqual(statistic._true_positive, 6)
        self.assertEqual(statistic._false_positive, 2)
        self.assertEqual(statistic._false_negative, 0)
    def test_calculate_statistic_equal_tensors(self):
        gt_subtokens = torch.tensor([1, 2, 3, 4, 5, 0, -1]).view(-1, 1)
        pred_subtokens = torch.tensor([1, 2, 3, 4, 5, 0, -1]).view(-1, 1)
        skip = [-1, 0]

        statistic = PredictionStatistic(False, skip_tokens=skip)
        statistic.update_statistic(gt_subtokens, pred_subtokens)

        self.assertEqual(statistic._true_positive, 5)
        self.assertEqual(statistic._false_positive, 0)
        self.assertEqual(statistic._false_negative, 0)
Exemple #4
0
    def validation_step(self, batch: PathContextBatch,
                        batch_idx: int) -> Dict:  # type: ignore
        # [seq length; batch size; vocab size]
        logits = self(batch.contexts, batch.contexts_per_label,
                      batch.labels.shape[0])
        loss = self._calculate_loss(logits, batch.labels)
        prediction = logits.argmax(-1)

        statistic = PredictionStatistic(True, self._label_pad_id,
                                        self._metric_skip_tokens)
        statistic.update_statistic(batch.labels, prediction)

        return {"loss": loss, "statistic": statistic}
    def test_calculate_statistic(self):
        gt_subtokens = torch.tensor([[1, 1, 1, 0], [2, 2, 0, -1],
                                     [3, 3, -1, -1], [-1, -1, -1, -1]])
        pred_subtokens = torch.tensor([[2, 4, 1, 0], [4, 5, 2, 0],
                                       [1, 6, 3, 0], [5, -1, -1, -1]])
        skip = [-1, 0]

        statistic = PredictionStatistic(False, skip_tokens=skip)
        statistic.update_statistic(gt_subtokens, pred_subtokens)

        self.assertEqual(statistic._true_positive, 3)
        self.assertEqual(statistic._false_positive, 7)
        self.assertEqual(statistic._false_negative, 4)
    def test_update(self):
        stat1 = self._create_statistic_with_values(1, 2, 3)
        stat2 = self._create_statistic_with_values(4, 5, 6)
        union = PredictionStatistic.create_from_list([stat1, stat2])

        self.assertEqual(union._true_positive, 5)
        self.assertEqual(union._false_positive, 7)
        self.assertEqual(union._false_negative, 9)
Exemple #7
0
    def training_step(self, batch: PathContextBatch,
                      batch_idx: int) -> Dict:  # type: ignore
        # [seq length; batch size; vocab size]
        logits = self(batch.contexts, batch.contexts_per_label,
                      batch.labels.shape[0], batch.labels)
        loss = self._calculate_loss(logits, batch.labels)
        prediction = logits.argmax(-1)

        statistic = PredictionStatistic(True, self._label_pad_id,
                                        self._metric_skip_tokens)
        batch_metric = statistic.update_statistic(batch.labels, prediction)

        log: Dict[str, Union[float, torch.Tensor]] = {"train/loss": loss}
        for key, value in batch_metric.items():
            log[f"train/{key}"] = value
        self.log_dict(log)
        self.log("f1", batch_metric["f1"], prog_bar=True, logger=False)

        return {"loss": loss, "statistic": statistic}
Exemple #8
0
 def _shared_epoch_end(self, outputs: List[Dict], group: str):
     with torch.no_grad():
         mean_loss = torch.stack([out["loss"]
                                  for out in outputs]).mean().item()
         statistic = PredictionStatistic.create_from_list(
             [out["statistic"] for out in outputs])
         epoch_metrics = statistic.get_metric()
         log: Dict[str, Union[float, torch.Tensor]] = {
             f"{group}/loss": mean_loss
         }
         for key, value in epoch_metrics.items():
             log[f"{group}/{key}"] = value
         self.log_dict(log)
         self.log(f"{group}_loss", mean_loss)