Ejemplo n.º 1
0
    def test_step(self, batch_data, batch_index):
        x, y = batch_data
        logits = self(x)
        criterion = nn.CrossEntropyLoss()
        probs = torch.softmax(logits, dim=1)

        # validation metrics
        acc = accuracy(torch.argmax(probs, dim=1), torch.argmax(y, dim=1))
        loss = criterion(logits, torch.argmax(y, dim=1))
        f_score = f1(torch.argmax(probs, dim=1),
                     torch.argmax(y, dim=1),
                     average='weighted',
                     num_classes=4)
        self.log('test/f1', f_score, prog_bar=True)
        self.log('test/loss', loss, prog_bar=True)
        self.log('test/accuracy', acc, prog_bar=True)
        predictions = torch.argmax(probs, dim=1)
        targets = torch.argmax(y, dim=1)
        return {
            "test_loss": loss,
            "test_accuracy": acc,
            "f_score": f_score,
            "predictions": predictions,
            "targets": targets
        }
 def validation_epoch_end(self, outputs):
     val_loss = sum([x["val_loss"] for x in outputs])
     pred = torch.cat([x["pred"] for x in outputs])
     true = torch.cat([x["true"] for x in outputs])
     f_score = metrics.f1(pred, true, num_classes=2)
     accuracy = metrics.accuracy(pred, true)
     out = {"val_loss": val_loss, "val_f_score": f_score, "val_accuracy": accuracy}
     self.log_dict({"val_loss": val_loss, "val_f_score": f_score, "val_accuracy": accuracy})
     return {**out, "log": out}
Ejemplo n.º 3
0
 def test_step(self, batch: tdict, batch_idx: int) -> Tensor:
     step_out = self.step(batch, batch_idx)
     loss = step_out["loss"]
     test_f1 = f1(
         step_out["probs"],
         batch["labels"],
         num_classes=Label("REL").count,
     )
     self.log_dict({"test_loss": loss, "test_f1": test_f1})
     return loss
Ejemplo n.º 4
0
    def validation_step(self, batch_data, batch_index):
        x, y = batch_data
        logits = self(x)
        criterion = nn.CrossEntropyLoss()
        probs = torch.softmax(logits, dim=1)

        # validation metrics
        acc = accuracy(torch.argmax(probs, dim=1), torch.argmax(y, dim=1))
        loss = criterion(logits, torch.argmax(y, dim=1))
        f_score = f1(torch.argmax(probs, dim=1),
                     torch.argmax(y, dim=1),
                     average='weighted',
                     num_classes=4)
        return {"val_loss": loss, "val_accuracy": acc, "f_score": f_score}
Ejemplo n.º 5
0
 def validation_step(self, batch: tdict, batch_idx: int) -> Tensor:
     step_out = self.step(batch, batch_idx)
     loss = step_out["loss"]
     val_f1 = f1(
         step_out["probs"],
         batch["labels"],
         num_classes=Label("REL").count,
     )
     self.log_dict(
         {"val_loss": loss, "val_f1": val_f1},
         on_step=False,
         on_epoch=True,
         prog_bar=True,
     )
     return loss
Ejemplo n.º 6
0
 def training_step(self, batch: tdict, batch_idx: int) -> Tensor:
     step_out = self.step(batch, batch_idx)
     loss = step_out["loss"]
     train_f1 = f1(
         step_out["probs"],
         batch["labels"],
         num_classes=Label("REL").count,
     )
     self.log_dict(
         {"train_loss": loss, "train_f1": train_f1},
         on_step=True,
         on_epoch=True,
         prog_bar=True,
     )
     return loss
Ejemplo n.º 7
0
    def validation_step(self, batch, batch_idx):
        tokens = batch["tokens"]
        labels = batch["labels"]
        y_hat = self.forward(tokens)
        #breakpoint()
        flattened_labels = labels[labels != -1]
        flattened_y_hat = y_hat[labels != -1]
        loss = F.nll_loss(flattened_y_hat, flattened_labels)
        acc = accuracy(flattened_y_hat.exp(), flattened_labels)
        f1_value = f1(flattened_y_hat.exp(), flattened_labels, 
            average='micro',
            ignore_index=streaming_punctuator.data.PUNCTUATION2ID[""], 
            num_classes=len(streaming_punctuator.data.PUNCTUATIONS))

        self.log('val_loss', loss, prog_bar=True, logger=True)
        self.log('val_acc', acc, prog_bar=True, logger=True)
        self.log('val_f1', f1_value, prog_bar=True, logger=True)
Ejemplo n.º 8
0
    def batch_step(self, batch):
        """ Used in train and validation """
        data, target = batch
        if self.training and self.hparams.cutmix>0 and torch.rand(1) < self.hparams.cutmix_prob:
            lam = np.random.beta(self.hparams.cutmix, self.hparams.cutmix)
            rand_index = torch.randperm(data.size()[0]).to(data.device)
            target_a = target
            target_b = target[rand_index]
            # Now the bboxes for the input and mask
            _, _, w, h = data.size()
            cut_rat = np.sqrt(1.0 - lam)
            cut_w, cut_h = int(w*cut_rat), int(h*cut_rat)  # Box size
            cx, cy = np.random.randint(w), np.random.randint(h)  # Box center
            bbx1 = np.clip(cx - cut_w // 2, 0, w)
            bbx2 = np.clip(cx + cut_w // 2, 0, w)
            bby1 = np.clip(cy - cut_h // 2, 0, h)
            bby2 = np.clip(cy + cut_h // 2, 0, h)
            data[:, :, bbx1:bbx2, bby1:bby2] = data[rand_index, :, bbx1:bbx2, bby1:bby2]
            # Adjust the classification loss based on pixel area ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (w*h))
            logits = self.model(data)
            loss = self.criterion(logits, target_a)*lam + self.criterion(logits, target_b)*(1.0-lam)
        else:
            logits = self.model(data)
            loss = self.criterion(logits, target)

        pred = torch.argmax(logits, dim=1)
        acc = accuracy(pred, target)
        avg_precision, avg_recall = precision_recall(pred, target, num_classes=self.hparams.num_classes,
                                                        average="macro", mdmc_average="global")
        weighted_f1 = f1(pred, target, num_classes=self.hparams.num_classes,
                            threshold=0.5, average="weighted")
        metrics = {
            "loss": loss,  # attached to computation graph, not necessary in validation, but I'm to lazy to fix
            "accuracy": acc,
            "average_precision": avg_precision,
            "average_recall": avg_recall,
            "weighted_f1": weighted_f1,
        }
        return metrics
Ejemplo n.º 9
0
def test_f1_score(pred, target, exp_score):
    score = f1(tensor(pred), tensor(target), num_classes=1, average='none')
    assert torch.allclose(score, tensor(exp_score))