示例#1
0
    def test_update(self):
        st_stat = SubtokenStatistic(1, 2, 3)
        st_stat_other = SubtokenStatistic(4, 5, 6)
        st_stat.update(st_stat_other)

        self.assertEqual(st_stat.true_positive, 5)
        self.assertEqual(st_stat.false_positive, 7)
        self.assertEqual(st_stat.false_negative, 9)
示例#2
0
    def test_calculate_metrics_with_group(self):
        st_stat = SubtokenStatistic(3, 7, 2)
        metrics = st_stat.calculate_metrics(group="train")
        true_metrics = {
            "train/precision": 0.3,
            "train/recall": 0.6,
            "train/f1": 0.4
        }

        self.assertDictEqual(metrics, true_metrics)
示例#3
0
    def training_step(self, batch: PathContextBatch, batch_idx: int) -> Dict:
        # [seq length; batch size; vocab size]
        logits = self(batch.context, batch.contexts_per_label, batch.labels.shape[0], batch.labels)
        loss = self._calculate_loss(logits, batch.labels)
        log = {"train/loss": loss}
        with torch.no_grad():
            statistic = SubtokenStatistic().calculate_statistic(batch.labels, logits.argmax(-1))

        log.update(statistic.calculate_metrics(group="train"))
        progress_bar = {"train/f1": log["train/f1"]}

        return {"loss": loss, "log": log, "progress_bar": progress_bar, "statistic": statistic}
示例#4
0
    def validation_step(self, batch: PathContextBatch, batch_idx: int) -> Dict:
        # [seq length; batch size; vocab size]
        logits = self(batch.context, batch.contexts_per_label, batch.labels.shape[0])
        loss = self._calculate_loss(logits, batch.labels)
        with torch.no_grad():
            statistic = SubtokenStatistic().calculate_statistic(batch.labels, logits.argmax(-1))

        return {"val_loss": loss, "statistic": statistic}
示例#5
0
 def _general_epoch_end(self, outputs: List[Dict], loss_key: str, group: str) -> Dict:
     with torch.no_grad():
         logs = {f"{group}/loss": torch.stack([out[loss_key] for out in outputs]).mean()}
         logs.update(
             SubtokenStatistic.union_statistics([out["statistic"] for out in outputs]).calculate_metrics(group)
         )
     progress_bar = {k: v for k, v in logs.items() if k in [f"{group}/loss", f"{group}/f1"]}
     return {f"{group}_loss": logs[f"{group}/loss"], "log": logs, "progress_bar": progress_bar}
    def test_calculate_statistic_equal_tensors(self):
        gt_subtokens = torch.tensor([[1, 2, 3, 4, 5, 0, -1]])
        pred_subtokens = torch.tensor([[1, 2, 3, 4, 5, 0, -1]])
        skip = [-1, 0]

        st_stat = SubtokenStatistic.calculate_statistic(
            gt_subtokens, pred_subtokens, skip)

        self.assertEqual(st_stat.true_positive, 5)
        self.assertEqual(st_stat.false_positive, 0)
        self.assertEqual(st_stat.false_negative, 0)
示例#7
0
    def test_calculate_statistic(self):
        gt_subtokens = torch.tensor([[1, 1, 1, 0], [2, 2, 0, -1],
                                     [3, 3, -1, -1], [-1, -1, -1, -1]])
        pred_subtokens = torch.tensor([[2, 4, 1, 0], [4, 5, 2, 0],
                                       [1, 6, 3, 0], [5, -1, -1, -1]])
        skip = [-1, 0]

        st_stat = SubtokenStatistic.calculate_statistic(
            gt_subtokens, pred_subtokens, skip)

        self.assertEqual(st_stat.true_positive, 3)
        self.assertEqual(st_stat.false_positive, 7)
        self.assertEqual(st_stat.false_negative, 4)
 def _calculate_metric(self, logits: torch.Tensor,
                       labels: torch.Tensor) -> SubtokenStatistic:
     with torch.no_grad():
         # [seq length; batch size]
         prediction = logits.argmax(-1)
         mask_max_value, mask_max_indices = torch.max(
             prediction == self.vocab.label_to_id[PAD], dim=0)
         mask_max_indices[~mask_max_value] = prediction.shape[0]
         mask = torch.arange(prediction.shape[0], device=self.device).view(
             -1, 1) >= mask_max_indices
         prediction[mask] = self.vocab.label_to_id[PAD]
         statistic = SubtokenStatistic().calculate_statistic(
             labels,
             prediction,
             [self.vocab.label_to_id[t] for t in [PAD, UNK]],
         )
     return statistic
示例#9
0
    def test_calculate_zero_metrics(self):
        st_stat = SubtokenStatistic(0, 0, 0)
        metrics = st_stat.calculate_metrics()
        true_metrics = {"precision": 0, "recall": 0, "f1": 0}

        self.assertDictEqual(metrics, true_metrics)
示例#10
0
    def test_calculate_metrics(self):
        st_stat = SubtokenStatistic(3, 7, 2)
        metrics = st_stat.calculate_metrics()
        true_metrics = {"precision": 0.3, "recall": 0.6, "f1": 0.4}

        self.assertDictEqual(metrics, true_metrics)