Esempio n. 1
0
    def _test_gather(self, output_device):
        inputs = (
            torch.randn(2,
                        4,
                        device='cuda:0',
                        requires_grad=True,
                        dtype=torch.double),
            torch.randn(2,
                        4,
                        device='cuda:1',
                        requires_grad=True,
                        dtype=torch.double),
        )
        result = dp.gather(inputs, output_device)
        self.assertEqual(result.size(), torch.Size([4, 4]))
        self.assertEqual(result[:2], inputs[0])
        self.assertEqual(result[2:], inputs[1])
        if output_device != -1:
            self.assertEqual(result.get_device(), output_device)
        else:
            self.assertFalse(result.is_cuda)
        grad = torch.randn((4, 4), dtype=torch.double)
        if output_device != -1:
            grad = grad.cuda(output_device)
        result.backward(grad)
        self.assertEqual(inputs[0].grad.data, grad[:2])
        self.assertEqual(inputs[1].grad.data, grad[2:])
        _assertGradAndGradgradChecks(
            self, lambda x, y: dp.gather((x, y), output_device), inputs)

        # test scalar inputs, should stack into a vector in this case
        inputs = (
            torch.randn((),
                        device='cuda:0',
                        requires_grad=True,
                        dtype=torch.double),
            torch.randn((),
                        device='cuda:1',
                        requires_grad=True,
                        dtype=torch.double),
        )
        result = dp.gather(inputs, output_device)
        self.assertEqual(result.size(), torch.Size([2]))
        self.assertEqual(result[0], inputs[0])
        self.assertEqual(result[1], inputs[1])
        if output_device != -1:
            self.assertEqual(result.get_device(), output_device)
        else:
            self.assertFalse(result.is_cuda)
        grad = torch.randn(2, dtype=torch.double)
        if output_device != -1:
            grad = grad.cuda(output_device)
        result.backward(grad)
        self.assertEqual(inputs[0].grad, grad[0])
        self.assertEqual(inputs[1].grad, grad[1])
        _assertGradAndGradgradChecks(
            self, lambda x, y: dp.gather((x, y), output_device), inputs)
Esempio n. 2
0
 def _test_scatter(self, tensor):
     x = tensor.detach().requires_grad_()
     result = dp.scatter(x, (0, 1))
     self.assertEqual(len(result), 2)
     self.assertEqual(result[0], x[:2])
     self.assertEqual(result[0].get_device(), 0)
     self.assertEqual(result[1], x[2:])
     self.assertEqual(result[1].get_device(), 1)
     grad = result[0].detach().clone().fill_(2)
     result[0].backward(grad)
     self.assertEqual(x.grad[:2], grad)
     self.assertEqual(x.grad[2:], grad.clone().zero_())
     _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))