def test_update(self): st_stat = SubtokenStatistic(1, 2, 3) st_stat_other = SubtokenStatistic(4, 5, 6) st_stat.update(st_stat_other) self.assertEqual(st_stat.true_positive, 5) self.assertEqual(st_stat.false_positive, 7) self.assertEqual(st_stat.false_negative, 9)
def validation_step(self, batch: PathContextBatch, batch_idx: int) -> Dict: # [seq length; batch size; vocab size] logits = self(batch.context, batch.contexts_per_label, batch.labels.shape[0]) loss = self._calculate_loss(logits, batch.labels) with torch.no_grad(): statistic = SubtokenStatistic().calculate_statistic(batch.labels, logits.argmax(-1)) return {"val_loss": loss, "statistic": statistic}
def test_calculate_metrics_with_group(self): st_stat = SubtokenStatistic(3, 7, 2) metrics = st_stat.calculate_metrics(group="train") true_metrics = { "train/precision": 0.3, "train/recall": 0.6, "train/f1": 0.4 } self.assertDictEqual(metrics, true_metrics)
def training_step(self, batch: PathContextBatch, batch_idx: int) -> Dict: # [seq length; batch size; vocab size] logits = self(batch.context, batch.contexts_per_label, batch.labels.shape[0], batch.labels) loss = self._calculate_loss(logits, batch.labels) log = {"train/loss": loss} with torch.no_grad(): statistic = SubtokenStatistic().calculate_statistic(batch.labels, logits.argmax(-1)) log.update(statistic.calculate_metrics(group="train")) progress_bar = {"train/f1": log["train/f1"]} return {"loss": loss, "log": log, "progress_bar": progress_bar, "statistic": statistic}
def _calculate_metric(self, logits: torch.Tensor, labels: torch.Tensor) -> SubtokenStatistic: with torch.no_grad(): # [seq length; batch size] prediction = logits.argmax(-1) mask_max_value, mask_max_indices = torch.max( prediction == self.vocab.label_to_id[PAD], dim=0) mask_max_indices[~mask_max_value] = prediction.shape[0] mask = torch.arange(prediction.shape[0], device=self.device).view( -1, 1) >= mask_max_indices prediction[mask] = self.vocab.label_to_id[PAD] statistic = SubtokenStatistic().calculate_statistic( labels, prediction, [self.vocab.label_to_id[t] for t in [PAD, UNK]], ) return statistic
def test_calculate_zero_metrics(self): st_stat = SubtokenStatistic(0, 0, 0) metrics = st_stat.calculate_metrics() true_metrics = {"precision": 0, "recall": 0, "f1": 0} self.assertDictEqual(metrics, true_metrics)
def test_calculate_metrics(self): st_stat = SubtokenStatistic(3, 7, 2) metrics = st_stat.calculate_metrics() true_metrics = {"precision": 0.3, "recall": 0.6, "f1": 0.4} self.assertDictEqual(metrics, true_metrics)