Пример #1
0
    def _test():
        acc = Accuracy()

        y_pred = torch.randint(0, 2, size=(10,)).long()
        y = torch.randint(0, 2, size=(10,)).long()
        acc.update((y_pred, y))
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert acc._type == "binary"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # Batched Updates
        acc.reset()
        y_pred = torch.randint(0, 2, size=(100,)).long()
        y = torch.randint(0, 2, size=(100,)).long()

        n_iters = 16
        batch_size = y.shape[0] // n_iters + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert acc._type == "binary"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #2
0
def test_binary_input_NL():
    # Binary accuracy on input of shape (N, L)
    acc = Accuracy()

    # TODO: y_pred should be binary after 0.1.2 release
    # y_pred = torch.randint(0, 2, size=(10, 5)).type(torch.LongTensor)
    y_pred = torch.rand(10, 5)
    y = torch.randint(0, 2, size=(10, 5)).type(torch.LongTensor)
    acc.update((y_pred, y))
    np_y = y.numpy().ravel()
    # np_y_pred = y_pred.numpy().ravel()
    np_y_pred = (y_pred.numpy().ravel() > 0.5).astype('int')
    assert acc._type == 'binary'
    assert isinstance(acc.compute(), float)
    assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

    acc.reset()
    # TODO: y_pred should be binary after 0.1.2 release
    # y_pred = torch.randint(0, 2, size=(10, 1, 5)).type(torch.LongTensor)
    y_pred = torch.rand(10, 1, 5)
    y = torch.randint(0, 2, size=(10, 1, 5)).type(torch.LongTensor)
    acc.update((y_pred, y))
    np_y = y.numpy().ravel()
    # np_y_pred = y_pred.numpy().ravel()
    np_y_pred = (y_pred.numpy().ravel() > 0.5).astype('int')
    assert acc._type == 'binary'
    assert isinstance(acc.compute(), float)
    assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #3
0
def test_compute_batch_images():
    acc = Accuracy()

    y_pred = torch.sigmoid(torch.rand(1, 2, 2))
    y = torch.ones(1, 2, 2).type(torch.LongTensor)
    y_pred = y_pred.unsqueeze(1)
    indices = torch.max(torch.cat([1.0 - y_pred, y_pred], dim=1), dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.view(-1).data.numpy(), indices.view(-1).data.numpy()) == pytest.approx(acc.compute())

    acc.reset()
    y_pred = torch.sigmoid(torch.rand(2, 1, 2, 2))
    y = torch.ones(2, 2, 2).type(torch.LongTensor)
    indices = torch.max(torch.cat([1.0 - y_pred, y_pred], dim=1), dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.view(-1).data.numpy(), indices.view(-1).data.numpy()) == pytest.approx(acc.compute())

    acc.reset()
    y_pred = torch.sigmoid(torch.rand(2, 1, 2, 2))
    y = torch.ones(2, 1, 2, 2).type(torch.LongTensor)
    indices = torch.max(torch.cat([1.0 - y_pred, y_pred], dim=1), dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.view(-1).data.numpy(), indices.view(-1).data.numpy()) == pytest.approx(acc.compute())
Пример #4
0
def test_multiclass_input_N():
    # Multiclass input data of shape (N, ) and (N, C)

    acc = Accuracy()

    y_pred = torch.rand(10, 4)
    y = torch.randint(0, 4, size=(10,)).type(torch.LongTensor)
    acc.update((y_pred, y))
    np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
    np_y = y.numpy().ravel()
    assert acc._type == 'multiclass'
    assert isinstance(acc.compute(), float)
    assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

    acc.reset()
    y_pred = torch.rand(4, 10)
    y = torch.randint(0, 10, size=(4, 1)).type(torch.LongTensor)
    acc.update((y_pred, y))
    np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
    np_y = y.numpy().ravel()
    assert acc._type == 'multiclass'
    assert isinstance(acc.compute(), float)
    assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

    # 2-classes
    acc.reset()
    y_pred = torch.rand(4, 2)
    y = torch.randint(0, 2, size=(4, 1)).type(torch.LongTensor)
    acc.update((y_pred, y))
    np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
    np_y = y.numpy().ravel()
    assert acc._type == 'multiclass'
    assert isinstance(acc.compute(), float)
    assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #5
0
def test_categorical_compute():
    acc = Accuracy()

    y_pred = torch.softmax(torch.rand(4, 4), dim=1)
    y = torch.ones(4).type(torch.LongTensor)
    indices = torch.max(y_pred, dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.view(-1).data.numpy(), indices.view(-1).data.numpy()) == pytest.approx(acc.compute())

    acc.reset()
    y_pred = torch.softmax(torch.rand(2, 2), dim=1)
    y = torch.ones(2).type(torch.LongTensor)
    indices = torch.max(y_pred, dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.view(-1).data.numpy(), indices.view(-1).data.numpy()) == pytest.approx(acc.compute())
Пример #6
0
    def _test():
        acc = Accuracy(is_multilabel=True)

        y_pred = torch.randint(0, 2, size=(4, 5, 12, 10))
        y = torch.randint(0, 2, size=(4, 5, 12, 10)).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(
            y_pred)  # (N, C, H, W, ...) -> (N * H * W ..., C)
        np_y = to_numpy_multilabel(
            y)  # (N, C, H, W, ...) -> (N * H * W ..., C)
        assert acc._type == 'multilabel'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.randint(0, 2,
                               size=(4, 10, 12, 8)).type(torch.LongTensor)
        y = torch.randint(0, 2, size=(4, 10, 12, 8)).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(
            y_pred)  # (N, C, H, W, ...) -> (N * H * W ..., C)
        np_y = to_numpy_multilabel(
            y)  # (N, C, H, W, ...) -> (N * H * W ..., C)
        assert acc._type == 'multilabel'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # Batched Updates
        acc.reset()
        y_pred = torch.randint(0, 2, size=(100, 5, 12, 10))
        y = torch.randint(0, 2, size=(100, 5, 12, 10)).type(torch.LongTensor)

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))

        np_y_pred = to_numpy_multilabel(
            y_pred)  # (N, C, L, ...) -> (N * L * ..., C)
        np_y = to_numpy_multilabel(y)  # (N, C, L, ...) -> (N * L ..., C)
        assert acc._type == 'multilabel'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #7
0
def test_binary_compute():
    acc = Accuracy()

    y_pred = torch.sigmoid(torch.rand(4, 1))
    y = torch.ones(4).type(torch.LongTensor)
    indices = torch.max(torch.cat([1.0 - y_pred, y_pred], dim=1), dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.data.numpy(), indices.data.numpy()) == pytest.approx(acc.compute())

    acc.reset()
    y_pred = torch.sigmoid(torch.rand(4))
    y = torch.ones(4).type(torch.LongTensor)
    y_pred = y_pred.unsqueeze(1)
    indices = torch.max(torch.cat([1.0 - y_pred, y_pred], dim=1), dim=1)[1]
    acc.update((y_pred, y))
    assert isinstance(acc.compute(), float)
    assert accuracy_score(y.data.numpy(), indices.data.numpy()) == pytest.approx(acc.compute())
Пример #8
0
def test_binary_input(n_times, test_data_binary):
    acc = Accuracy()

    y_pred, y, batch_size = test_data_binary
    acc.reset()
    if batch_size > 1:
        n_iters = y.shape[0] // batch_size + 1
        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))
    else:
        acc.update((y_pred, y))

    np_y = y.numpy().ravel()
    np_y_pred = y_pred.numpy().ravel()

    assert acc._type == "binary"
    assert isinstance(acc.compute(), float)
    assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #9
0
    def _test():
        acc = Accuracy()

        y_pred = torch.randint(0, 2, size=(4, 1)).type(torch.LongTensor)
        y = torch.randint(0, 2, size=(4, )).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert acc._type == 'binary'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.randint(0, 2, size=(4, 1, 12)).type(torch.LongTensor)
        y = torch.randint(0, 2, size=(4, 12)).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert acc._type == 'binary'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # Batched Updates
        acc.reset()
        y_pred = torch.randint(0, 2,
                               size=(100, 1, 8, 8)).type(torch.LongTensor)
        y = torch.randint(0, 2, size=(100, 8, 8)).type(torch.LongTensor)

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))

        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().ravel()
        assert acc._type == 'binary'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #10
0
    def _test():
        acc = Accuracy()

        y_pred = torch.rand(4, 5, 12, 10)
        y = torch.randint(0, 5, size=(4, 12, 10)).long()
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == "multiclass"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.rand(4, 5, 10, 12, 8)
        y = torch.randint(0, 5, size=(4, 10, 12, 8)).long()
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == "multiclass"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # Batched Updates
        acc.reset()
        y_pred = torch.rand(100, 3, 8, 8)
        y = torch.randint(0, 3, size=(100, 8, 8)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        assert acc._type == "multiclass"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #11
0
def test(epoch, model, test_loader, writer, embeddings=None):
    model.eval()
    test_loss = 0
    correct = 0

    acc = Accuracy()
    acc.reset()

    all_targets = []
    all_results = []
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)

            # perform prediction
            output = model(data)

            test_loss += criterion(output, targets).item()

            # Since during training sigmoid is applied in BCEWithLogitsLoss
            # we also need to apply it here
            output = torch.sigmoid(output)

            # Make a hard decision threshold at 0.5
            output[output > 0.5] = 1
            output[output <= 0.5] = 0

            acc.update((output, targets))

    acc_value = acc.compute()
    test_loss /= len(test_loader.sampler)
    writer.add_scalar("Test Loss", test_loss, int((epoch + 1)))
    writer.add_scalar("Test Acc", acc_value, int((epoch + 1)))

    print(
        "Test set: Average loss: {:.4f}, Accuracy: ({:.0f}%)\n".format(
            test_loss, acc_value * 100
        )
    )
Пример #12
0
    def _test():
        acc = Accuracy(is_multilabel=True)
        y_pred = torch.randint(0, 2, size=(10, 4))
        y = torch.randint(0, 2, size=(10, 4)).long()
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy()
        np_y = y.numpy()
        assert acc._type == "multilabel"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.randint(0, 2, size=(50, 7)).long()
        y = torch.randint(0, 2, size=(50, 7)).long()
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy()
        np_y = y.numpy()
        assert acc._type == "multilabel"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # Batched Updates
        acc.reset()
        y_pred = torch.randint(0, 2, size=(100, 4))
        y = torch.randint(0, 2, size=(100, 4)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = y.numpy()
        np_y_pred = y_pred.numpy()
        assert acc._type == "multilabel"
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
Пример #13
0
        return_attention_mask=False,
    ),
    batched=True,
    batch_size=8192,
)
dataset.set_format(type="torch", columns=["input_ids", "label"])
dataloader = DataLoader(dataset,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=12)

accuracy = Accuracy()
best_acc = 0.0

for iter in range(args.num_iter):
    accuracy.reset()

    for batch in tqdm(dataloader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["label"].to(device)
        loss, logits, hidden_states, attentions = model(
            input_ids, labels=labels).values()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        accuracy.update((logits, labels))

    acc = accuracy.compute()
Пример #14
0
    def _test(metric_device):
        metric_device = torch.device(metric_device)
        acc = Accuracy(is_multilabel=True, device=metric_device)

        torch.manual_seed(10 + rank)
        y_pred = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long()
        y = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long()
        acc.update((y_pred, y))

        assert (
            acc._num_correct.device == metric_device
        ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = to_numpy_multilabel(y_pred.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)
        np_y = to_numpy_multilabel(y.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)
        assert acc._type == "multilabel"
        n = acc._num_examples
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)

        acc.reset()
        torch.manual_seed(10 + rank)
        y_pred = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long()
        y = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long()
        acc.update((y_pred, y))

        assert (
            acc._num_correct.device == metric_device
        ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = to_numpy_multilabel(y_pred.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)
        np_y = to_numpy_multilabel(y.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)

        assert acc._type == "multilabel"
        n = acc._num_examples
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
        # check that result is not changed
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)

        # Batched Updates
        acc.reset()
        torch.manual_seed(10 + rank)
        y_pred = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long()
        y = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        assert (
            acc._num_correct.device == metric_device
        ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = to_numpy_multilabel(y_pred.cpu())  # (N, C, L, ...) -> (N * L * ..., C)
        np_y = to_numpy_multilabel(y.cpu())  # (N, C, L, ...) -> (N * L ..., C)

        assert acc._type == "multilabel"
        n = acc._num_examples
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
Пример #15
0
class ClassificationTask(pl.LightningModule, TFLogger):
    """Standard interface for the trainer to interact with the model."""
    def __init__(self, params):
        super().__init__()
        self.save_hyperparameters(params)
        self.model = get_model(params)
        self.loss = get_loss_fn(params)
        self.val_acc = Accuracy()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        """
        Returns:
            A dictionary of loss and metrics, with:
                loss(required): loss used to calculate the gradient
                log: metrics to be logged to the TensorBoard and metrics.csv
                progress_bar: metrics to be logged to the progress bar
                              and metrics.csv
        """
        x, y = batch
        logits = self.forward(x)
        loss = self.loss(logits.view(-1), y)
        return {'loss': loss, 'log': {'train_loss': loss}}

    def validation_step(self, batch, batch_nb):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss(logits.view(-1), y)
        y_hat = (logits > 0).float()
        self.val_acc.update((y_hat, y))
        return loss

    def validation_epoch_end(self, outputs):
        """
        Aggregate and return the validation metrics

        Args:
        outputs: A list of dictionaries of metrics from `validation_step()'
        Returns: None
        Returns:
            A dictionary of loss and metrics, with:
                val_loss (required): validation_loss
                log: metrics to be logged to the TensorBoard and metrics.csv
                progress_bar: metrics to be logged to the progress bar
                              and metrics.csv
        """

        avg_loss = torch.stack(outputs).mean()
        avg_acc = self.val_acc.compute()
        self.val_acc.reset()
        self.log("val_loss", avg_loss)

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat.view(-1), y)
        return {'test_loss': loss, 'log': {'test_loss': loss}}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        return {'avg_test_loss': avg_loss}

    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=0.02)]

    def train_dataloader(self):
        dataset = ImageClassificationDemoDataset()
        return DataLoader(dataset, shuffle=True, batch_size=2, num_workers=8)

    def val_dataloader(self):
        dataset = ImageClassificationDemoDataset()
        return DataLoader(dataset, shuffle=False, batch_size=1, num_workers=8)

    def test_dataloader(self):
        dataset = ImageClassificationDemoDataset()
        return DataLoader(dataset, shuffle=False, batch_size=1, num_workers=8)
Пример #16
0
    def _test():
        acc = Accuracy()

        y_pred = torch.rand(10, 4)
        y = torch.randint(0, 4, size=(10, )).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == 'multiclass'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.rand(10, 10, 1)
        y = torch.randint(0, 18, size=(10, 1)).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == 'multiclass'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.rand(10, 18)
        y = torch.randint(0, 18, size=(10, )).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == 'multiclass'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        acc.reset()
        y_pred = torch.rand(4, 10)
        y = torch.randint(0, 10, size=(4, )).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == 'multiclass'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # 2-classes
        acc.reset()
        y_pred = torch.rand(4, 2)
        y = torch.randint(0, 2, size=(4, )).type(torch.LongTensor)
        acc.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y.numpy().ravel()
        assert acc._type == 'multiclass'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

        # Batched Updates
        acc.reset()
        y_pred = torch.rand(100, 5)
        y = torch.randint(0, 5, size=(100, )).type(torch.LongTensor)

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))

        np_y = y.numpy().ravel()
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        assert acc._type == 'multiclass'
        assert isinstance(acc.compute(), float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())