def _create_statistic_with_values(tp: int, fp: int, fn: int) -> PredictionStatistic: statistic = PredictionStatistic(False) statistic._true_positive = tp statistic._false_positive = fp statistic._false_negative = fn return statistic
def test_calculate_statistic_with_masking_long_sequence(self): gt_subtokens = torch.tensor([1, 2, 3, 6, 7, 8, 0, 0, 0]).view(-1, 1) pred_subtokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(-1, 1) statistic = PredictionStatistic(True, 0, [0]) statistic.update_statistic(gt_subtokens, pred_subtokens) self.assertEqual(statistic._true_positive, 6) self.assertEqual(statistic._false_positive, 2) self.assertEqual(statistic._false_negative, 0)
def test_calculate_statistic_equal_tensors(self): gt_subtokens = torch.tensor([1, 2, 3, 4, 5, 0, -1]).view(-1, 1) pred_subtokens = torch.tensor([1, 2, 3, 4, 5, 0, -1]).view(-1, 1) skip = [-1, 0] statistic = PredictionStatistic(False, skip_tokens=skip) statistic.update_statistic(gt_subtokens, pred_subtokens) self.assertEqual(statistic._true_positive, 5) self.assertEqual(statistic._false_positive, 0) self.assertEqual(statistic._false_negative, 0)
def validation_step(self, batch: PathContextBatch, batch_idx: int) -> Dict: # type: ignore # [seq length; batch size; vocab size] logits = self(batch.contexts, batch.contexts_per_label, batch.labels.shape[0]) loss = self._calculate_loss(logits, batch.labels) prediction = logits.argmax(-1) statistic = PredictionStatistic(True, self._label_pad_id, self._metric_skip_tokens) statistic.update_statistic(batch.labels, prediction) return {"loss": loss, "statistic": statistic}
def test_calculate_statistic(self): gt_subtokens = torch.tensor([[1, 1, 1, 0], [2, 2, 0, -1], [3, 3, -1, -1], [-1, -1, -1, -1]]) pred_subtokens = torch.tensor([[2, 4, 1, 0], [4, 5, 2, 0], [1, 6, 3, 0], [5, -1, -1, -1]]) skip = [-1, 0] statistic = PredictionStatistic(False, skip_tokens=skip) statistic.update_statistic(gt_subtokens, pred_subtokens) self.assertEqual(statistic._true_positive, 3) self.assertEqual(statistic._false_positive, 7) self.assertEqual(statistic._false_negative, 4)
def test_update(self): stat1 = self._create_statistic_with_values(1, 2, 3) stat2 = self._create_statistic_with_values(4, 5, 6) union = PredictionStatistic.create_from_list([stat1, stat2]) self.assertEqual(union._true_positive, 5) self.assertEqual(union._false_positive, 7) self.assertEqual(union._false_negative, 9)
def training_step(self, batch: PathContextBatch, batch_idx: int) -> Dict: # type: ignore # [seq length; batch size; vocab size] logits = self(batch.contexts, batch.contexts_per_label, batch.labels.shape[0], batch.labels) loss = self._calculate_loss(logits, batch.labels) prediction = logits.argmax(-1) statistic = PredictionStatistic(True, self._label_pad_id, self._metric_skip_tokens) batch_metric = statistic.update_statistic(batch.labels, prediction) log: Dict[str, Union[float, torch.Tensor]] = {"train/loss": loss} for key, value in batch_metric.items(): log[f"train/{key}"] = value self.log_dict(log) self.log("f1", batch_metric["f1"], prog_bar=True, logger=False) return {"loss": loss, "statistic": statistic}
def _shared_epoch_end(self, outputs: List[Dict], group: str): with torch.no_grad(): mean_loss = torch.stack([out["loss"] for out in outputs]).mean().item() statistic = PredictionStatistic.create_from_list( [out["statistic"] for out in outputs]) epoch_metrics = statistic.get_metric() log: Dict[str, Union[float, torch.Tensor]] = { f"{group}/loss": mean_loss } for key, value in epoch_metrics.items(): log[f"{group}/{key}"] = value self.log_dict(log) self.log(f"{group}_loss", mean_loss)