def _compute_batch(): trainer = Trainer(fast_dev_run=True, accelerator='horovod', logger=False) assert isinstance(trainer.accelerator, CPUAccelerator) # TODO: test that we selected the correct training_type_plugin based on horovod flags metric = Accuracy(compute_on_step=True, dist_sync_on_step=True, dist_sync_fn=trainer.training_type_plugin.all_gather, threshold=threshold) for i in range(hvd.rank(), num_batches, hvd.size()): batch_result = metric(preds[i], target[i]) if hvd.rank() == 0: dist_preds = torch.stack( [preds[i + r] for r in range(hvd.size())]) dist_target = torch.stack( [target[i + r] for r in range(hvd.size())]) sk_batch_result = sk_metric(dist_preds, dist_target) assert np.allclose(batch_result.numpy(), sk_batch_result) # check on all batches on all ranks result = metric.compute() assert isinstance(result, torch.Tensor) total_preds = torch.stack([preds[i] for i in range(num_batches)]) total_target = torch.stack([target[i] for i in range(num_batches)]) sk_result = sk_metric(total_preds, total_target) assert np.allclose(result.numpy(), sk_result)
def _compute_batch(): trainer = Trainer( fast_dev_run=True, distributed_backend='horovod', ) accelerator_backend = trainer.accelerator_connector.select_accelerator( ) assert isinstance(accelerator_backend, HorovodAccelerator) metric = Accuracy(compute_on_step=True, dist_sync_on_step=True, dist_sync_fn=accelerator_backend.gather_all_tensors, threshold=threshold) for i in range(hvd.rank(), num_batches, hvd.size()): batch_result = metric(preds[i], target[i]) if hvd.rank() == 0: dist_preds = torch.stack( [preds[i + r] for r in range(hvd.size())]) dist_target = torch.stack( [target[i + r] for r in range(hvd.size())]) sk_batch_result = sk_metric(dist_preds, dist_target) assert np.allclose(batch_result.numpy(), sk_batch_result) # check on all batches on all ranks result = metric.compute() assert isinstance(result, torch.Tensor) total_preds = torch.stack([preds[i] for i in range(num_batches)]) total_target = torch.stack([target[i] for i in range(num_batches)]) sk_result = sk_metric(total_preds, total_target) assert np.allclose(result.numpy(), sk_result)