Example #1
0
    def _handle_batch(self, batch):
        """
        Docs.
        """
        x, y = batch
        x_noise = (x + torch.rand_like(x)).clamp_(0, 1)
        y_hat, x_ = self.model(x_noise)

        loss_clf = F.cross_entropy(y_hat, y)
        iou = metrics.iou(x_, x)
        loss_iou = 1 - iou
        loss = loss_clf + loss_iou
        accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat,
                                                              y,
                                                              topk=(1, 3, 5))

        self.state.batch_metrics = {
            "loss_clf": loss_clf,
            "loss_iou": loss_iou,
            "loss": loss,
            "iou": iou,
            "accuracy01": accuracy01,
            "accuracy03": accuracy03,
            "accuracy05": accuracy05,
        }

        if self.state.is_train_loader:
            loss.backward()
            self.state.optimizer.step()
            self.state.optimizer.zero_grad()
Example #2
0
def test_iou():
    size = 4
    half_size = size // 2
    shape = (1, 1, size, size)

    # check 0: one empty
    empty = torch.zeros(shape)
    full = torch.ones(shape)
    assert metrics.iou(empty, full, activation="none").item() == 0

    # check 0: no overlap
    left = torch.ones(shape)
    left[:, :, :, half_size:] = 0
    right = torch.ones(shape)
    right[:, :, :, :half_size] = 0
    assert metrics.iou(left, right, activation="none").item() == 0

    # check 1: both empty, both full, complete overlap
    assert metrics.iou(empty, empty, activation="none") == 1
    assert metrics.iou(full, full, activation="none") == 1
    assert metrics.iou(left, left, activation="none") == 1

    # check 0.5: half overlap
    top_left = torch.zeros(shape)
    top_left[:, :, :half_size, :half_size] = 1
    assert metrics.iou(top_left, left, activation="none").item() == 0.5

    # check multiclass: 0, 0, 1, 1, 1, 0.5
    a = torch.cat([empty, left, empty, full, left, top_left], dim=1)
    b = torch.cat([full, right, empty, full, left, left], dim=1)
    ans = torch.Tensor([0, 0, 1, 1, 1, 0.5])
    assert torch.all(
        metrics.iou(a, b, classes=["dummy"], activation="none") == ans
    )