Ejemplo n.º 1
0
    def _compute_batch():
        trainer = Trainer(fast_dev_run=True,
                          accelerator='horovod',
                          logger=False)

        assert isinstance(trainer.accelerator, CPUAccelerator)
        # TODO: test that we selected the correct training_type_plugin based on horovod flags

        metric = Accuracy(compute_on_step=True,
                          dist_sync_on_step=True,
                          dist_sync_fn=trainer.training_type_plugin.all_gather,
                          threshold=threshold)

        for i in range(hvd.rank(), num_batches, hvd.size()):
            batch_result = metric(preds[i], target[i])
            if hvd.rank() == 0:
                dist_preds = torch.stack(
                    [preds[i + r] for r in range(hvd.size())])
                dist_target = torch.stack(
                    [target[i + r] for r in range(hvd.size())])
                sk_batch_result = sk_metric(dist_preds, dist_target)
                assert np.allclose(batch_result.numpy(), sk_batch_result)

        # check on all batches on all ranks
        result = metric.compute()
        assert isinstance(result, torch.Tensor)

        total_preds = torch.stack([preds[i] for i in range(num_batches)])
        total_target = torch.stack([target[i] for i in range(num_batches)])
        sk_result = sk_metric(total_preds, total_target)

        assert np.allclose(result.numpy(), sk_result)
Ejemplo n.º 2
0
    def _compute_batch():
        trainer = Trainer(
            fast_dev_run=True,
            distributed_backend='horovod',
        )

        accelerator_backend = trainer.accelerator_connector.select_accelerator(
        )
        assert isinstance(accelerator_backend, HorovodAccelerator)

        metric = Accuracy(compute_on_step=True,
                          dist_sync_on_step=True,
                          dist_sync_fn=accelerator_backend.gather_all_tensors,
                          threshold=threshold)

        for i in range(hvd.rank(), num_batches, hvd.size()):
            batch_result = metric(preds[i], target[i])
            if hvd.rank() == 0:
                dist_preds = torch.stack(
                    [preds[i + r] for r in range(hvd.size())])
                dist_target = torch.stack(
                    [target[i + r] for r in range(hvd.size())])
                sk_batch_result = sk_metric(dist_preds, dist_target)
                assert np.allclose(batch_result.numpy(), sk_batch_result)

        # check on all batches on all ranks
        result = metric.compute()
        assert isinstance(result, torch.Tensor)

        total_preds = torch.stack([preds[i] for i in range(num_batches)])
        total_target = torch.stack([target[i] for i in range(num_batches)])
        sk_result = sk_metric(total_preds, total_target)

        assert np.allclose(result.numpy(), sk_result)
Ejemplo n.º 3
0
    def __init__(self, n_classes):
        super().__init__()
        self.save_hyperparameters()

        # gpu에서 오류남
        # self.example_input_array = {"input_ids": torch.LongTensor(1, 100),
        #                             "attention_mask": torch.LongTensor(1, 100)}
        self.acc = Accuracy()
        self.loss_fn = nn.CrossEntropyLoss()

        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
Ejemplo n.º 4
0
def attack(model, dataset, attack_alg='fgsm', eps=0.1, device='cuda'):
    accuracy = Accuracy().to(device)
    model = model.to(device)
    model.eval()
    for x, y in dataset:
        x, y = x.to(device), y.to(device)
        # if model.binary:
        #     x[x >= 0.5] = 1
        #     x[x < 0.5] = 0
        x_fgsm = fast_gradient_method(model,
                                      x,
                                      eps,
                                      np.inf,
                                      clip_min=0,
                                      clip_max=1)
        _, y_pred_fgsm = model(x_fgsm).max(1)
        accuracy(y_pred_fgsm, y)
    acc = accuracy.compute()
    print(acc)
Ejemplo n.º 5
0
def test_accuracy_invalid_shape():
    with pytest.raises(ValueError):
        acc = Accuracy()
        acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3))