Ejemplo n.º 1
0
 def _general_epoch_end(self, outputs: List[Dict], group: str) -> Dict:
     with torch.no_grad():
         mean_loss = torch.stack([out["loss"]
                                  for out in outputs]).mean().item()
         logs = {f"{group}/loss": mean_loss}
         logs.update(
             Statistic.union_statistics([
                 out["statistic"] for out in outputs
             ]).calculate_metrics(group))
         self.log_dict(logs)
         self.log(f"{group}_loss", mean_loss)
Ejemplo n.º 2
0
    def training_step(self, batch: TokensBatch, batch_idx: int) -> Dict:
        # (batch size, output size)
        logits = self(batch.tokens, batch.tokens_per_label)
        loss = F.nll_loss(logits, batch.labels)
        log: Dict[str, Union[float, torch.Tensor]] = {"train/loss": loss}
        with torch.no_grad():
            _, preds = logits.max(dim=1)
            statistic = Statistic().calculate_statistic(
                batch.labels,
                preds,
                2,
            )
            batch_matric = statistic.calculate_metrics(group="train")
            log.update(batch_matric)
            self.log_dict(log)
            self.log("f1",
                     batch_matric["train/f1"],
                     prog_bar=True,
                     logger=False)

        return {"loss": loss, "statistic": statistic}
Ejemplo n.º 3
0
    def training_step(self, batch: PathContextBatch,
                      batch_idx: int) -> Dict:  # type: ignore
        # [batch size; num_classes]
        logits = self(batch.contexts, batch.contexts_per_label)
        loss = F.cross_entropy(logits, batch.labels)
        log = {"train/loss": loss}
        with torch.no_grad():
            _, preds = logits.max(dim=1)
            statistic = Statistic().calculate_statistic(
                batch.labels,
                preds,
                2,
            )
            batch_matric = statistic.calculate_metrics(group="train")
            log.update(batch_matric)
            self.log_dict(log)
            self.log("f1",
                     batch_matric["train/f1"],
                     prog_bar=True,
                     logger=False)

        return {"loss": loss, "statistic": statistic}
Ejemplo n.º 4
0
    def validation_step(self, batch: TokensBatch, batch_idx: int) -> Dict:
        # (batch size, output size)
        logits = self(batch.tokens, batch.tokens_per_label)
        loss = F.nll_loss(logits, batch.labels)
        with torch.no_grad():
            _, preds = logits.max(dim=1)
            statistic = Statistic().calculate_statistic(
                batch.labels,
                preds,
                2,
            )

        return {"loss": loss, "statistic": statistic}
Ejemplo n.º 5
0
    def validation_step(self, batch: PathContextBatch,
                        batch_idx: int) -> Dict:  # type: ignore
        # [batch size; num_classes]
        logits = self(batch.contexts, batch.contexts_per_label)
        loss = F.cross_entropy(logits, batch.labels)
        with torch.no_grad():
            _, preds = logits.max(dim=1)
            statistic = Statistic().calculate_statistic(
                batch.labels,
                preds,
                2,
            )

        return {"loss": loss, "statistic": statistic}