def test_update(self): st_stat = SubtokenStatistic(1, 2, 3) st_stat_other = SubtokenStatistic(4, 5, 6) st_stat.update(st_stat_other) self.assertEqual(st_stat.true_positive, 5) self.assertEqual(st_stat.false_positive, 7) self.assertEqual(st_stat.false_negative, 9)
def test_calculate_metrics_with_group(self): st_stat = SubtokenStatistic(3, 7, 2) metrics = st_stat.calculate_metrics(group="train") true_metrics = { "train/precision": 0.3, "train/recall": 0.6, "train/f1": 0.4 } self.assertDictEqual(metrics, true_metrics)
def training_step(self, batch: PathContextBatch, batch_idx: int) -> Dict: # [seq length; batch size; vocab size] logits = self(batch.context, batch.contexts_per_label, batch.labels.shape[0], batch.labels) loss = self._calculate_loss(logits, batch.labels) log = {"train/loss": loss} with torch.no_grad(): statistic = SubtokenStatistic().calculate_statistic(batch.labels, logits.argmax(-1)) log.update(statistic.calculate_metrics(group="train")) progress_bar = {"train/f1": log["train/f1"]} return {"loss": loss, "log": log, "progress_bar": progress_bar, "statistic": statistic}
def validation_step(self, batch: PathContextBatch, batch_idx: int) -> Dict: # [seq length; batch size; vocab size] logits = self(batch.context, batch.contexts_per_label, batch.labels.shape[0]) loss = self._calculate_loss(logits, batch.labels) with torch.no_grad(): statistic = SubtokenStatistic().calculate_statistic(batch.labels, logits.argmax(-1)) return {"val_loss": loss, "statistic": statistic}
def _general_epoch_end(self, outputs: List[Dict], loss_key: str, group: str) -> Dict: with torch.no_grad(): logs = {f"{group}/loss": torch.stack([out[loss_key] for out in outputs]).mean()} logs.update( SubtokenStatistic.union_statistics([out["statistic"] for out in outputs]).calculate_metrics(group) ) progress_bar = {k: v for k, v in logs.items() if k in [f"{group}/loss", f"{group}/f1"]} return {f"{group}_loss": logs[f"{group}/loss"], "log": logs, "progress_bar": progress_bar}
def test_calculate_statistic_equal_tensors(self): gt_subtokens = torch.tensor([[1, 2, 3, 4, 5, 0, -1]]) pred_subtokens = torch.tensor([[1, 2, 3, 4, 5, 0, -1]]) skip = [-1, 0] st_stat = SubtokenStatistic.calculate_statistic( gt_subtokens, pred_subtokens, skip) self.assertEqual(st_stat.true_positive, 5) self.assertEqual(st_stat.false_positive, 0) self.assertEqual(st_stat.false_negative, 0)
def test_calculate_statistic(self): gt_subtokens = torch.tensor([[1, 1, 1, 0], [2, 2, 0, -1], [3, 3, -1, -1], [-1, -1, -1, -1]]) pred_subtokens = torch.tensor([[2, 4, 1, 0], [4, 5, 2, 0], [1, 6, 3, 0], [5, -1, -1, -1]]) skip = [-1, 0] st_stat = SubtokenStatistic.calculate_statistic( gt_subtokens, pred_subtokens, skip) self.assertEqual(st_stat.true_positive, 3) self.assertEqual(st_stat.false_positive, 7) self.assertEqual(st_stat.false_negative, 4)
def _calculate_metric(self, logits: torch.Tensor, labels: torch.Tensor) -> SubtokenStatistic: with torch.no_grad(): # [seq length; batch size] prediction = logits.argmax(-1) mask_max_value, mask_max_indices = torch.max( prediction == self.vocab.label_to_id[PAD], dim=0) mask_max_indices[~mask_max_value] = prediction.shape[0] mask = torch.arange(prediction.shape[0], device=self.device).view( -1, 1) >= mask_max_indices prediction[mask] = self.vocab.label_to_id[PAD] statistic = SubtokenStatistic().calculate_statistic( labels, prediction, [self.vocab.label_to_id[t] for t in [PAD, UNK]], ) return statistic
def test_calculate_zero_metrics(self): st_stat = SubtokenStatistic(0, 0, 0) metrics = st_stat.calculate_metrics() true_metrics = {"precision": 0, "recall": 0, "f1": 0} self.assertDictEqual(metrics, true_metrics)
def test_calculate_metrics(self): st_stat = SubtokenStatistic(3, 7, 2) metrics = st_stat.calculate_metrics() true_metrics = {"precision": 0.3, "recall": 0.6, "f1": 0.4} self.assertDictEqual(metrics, true_metrics)