def _test_gather(self, output_device): inputs = ( torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double), torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double), ) result = dp.gather(inputs, output_device) self.assertEqual(result.size(), torch.Size([4, 4])) self.assertEqual(result[:2], inputs[0]) self.assertEqual(result[2:], inputs[1]) if output_device != -1: self.assertEqual(result.get_device(), output_device) else: self.assertFalse(result.is_cuda) grad = torch.randn((4, 4), dtype=torch.double) if output_device != -1: grad = grad.cuda(output_device) result.backward(grad) self.assertEqual(inputs[0].grad.data, grad[:2]) self.assertEqual(inputs[1].grad.data, grad[2:]) _assertGradAndGradgradChecks( self, lambda x, y: dp.gather((x, y), output_device), inputs) # test scalar inputs, should stack into a vector in this case inputs = ( torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double), torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double), ) result = dp.gather(inputs, output_device) self.assertEqual(result.size(), torch.Size([2])) self.assertEqual(result[0], inputs[0]) self.assertEqual(result[1], inputs[1]) if output_device != -1: self.assertEqual(result.get_device(), output_device) else: self.assertFalse(result.is_cuda) grad = torch.randn(2, dtype=torch.double) if output_device != -1: grad = grad.cuda(output_device) result.backward(grad) self.assertEqual(inputs[0].grad, grad[0]) self.assertEqual(inputs[1].grad, grad[1]) _assertGradAndGradgradChecks( self, lambda x, y: dp.gather((x, y), output_device), inputs)
def _test_scatter(self, tensor): x = tensor.detach().requires_grad_() result = dp.scatter(x, (0, 1)) self.assertEqual(len(result), 2) self.assertEqual(result[0], x[:2]) self.assertEqual(result[0].get_device(), 0) self.assertEqual(result[1], x[2:]) self.assertEqual(result[1].get_device(), 1) grad = result[0].detach().clone().fill_(2) result[0].backward(grad) self.assertEqual(x.grad[:2], grad) self.assertEqual(x.grad[2:], grad.clone().zero_()) _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))