Ejemplo n.º 1
0
    def test_contiguity_exception(self, backend_device):
        backend, device = backend_device

        x = torch.empty(3, 4, 5, 3).to(device)[..., :2]
        y = torch.empty(4, 5, 3).to(device)[..., :2]

        with pytest.raises(RuntimeError) as exc:
            backend.cdgmm(x.contiguous(), y)
        assert 'be contiguous' in exc.value.args[0]

        with pytest.raises(RuntimeError) as exc:
            backend.cdgmm(x, y.contiguous())
        assert 'be contiguous' in exc.value.args[0]
Ejemplo n.º 2
0
    def test_device_mismatch(self, backend_device):
        backend, device = backend_device

        if device == 'cpu':
            return

        if torch.cuda.device_count() < 2:
            return

        x = torch.empty(3, 4, 5, 2).to('cuda:0')
        y = torch.empty(4, 5, 1).to('cuda:1')

        with pytest.raises(TypeError) as exc:
            backend.cdgmm(x, y)
        assert 'must be on the same GPU' in exc.value.args[0]
Ejemplo n.º 3
0
    def test_gpu_only(self, data, backend):
        x, filt, y = data
        if backend.name.endswith('_skcuda'):
            x = x.cpu()
            filt = filt.cpu()

            with pytest.raises(TypeError) as exc:
                z = backend.cdgmm(x, filt)
            assert 'must be CUDA' in exc.value.args[0]
Ejemplo n.º 4
0
    def test_cdgmm_forward(self, data, backend_device, inplace):
        backend, device = backend_device

        x, filt, y = data
        x, filt, y = x.to(device), filt.to(device), y.to(device)

        z = backend.cdgmm(x, filt, inplace=inplace)

        Warning('Tolerance has been slightly lowered here...')
        # There is a very small meaningless difference for skcuda+GPU
        assert torch.allclose(y, z, atol=1e-7, rtol=1e-6)
Ejemplo n.º 5
0
    def test_cdgmm_exceptions(self, backend):
        with pytest.raises(RuntimeError) as exc:
            backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 3, 2))
        assert 'not compatible' in exc.value.args[0]

        with pytest.raises(TypeError) as exc:
            backend.cdgmm(torch.empty(3, 4, 5, 1), torch.empty(4, 5, 1))
        assert 'input should be complex' in exc.value.args[0]

        with pytest.raises(TypeError) as exc:
            backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 3))
        assert 'should be complex' in exc.value.args[0]

        with pytest.raises(TypeError) as exc:
            backend.cdgmm(torch.empty(3, 4, 5, 2),
                          torch.empty(4, 5, 1).double())
        assert 'must be of the same dtype' in exc.value.args[0]

        if 'cuda' in devices:
            if backend.name.endswith('_skcuda'):
                with pytest.raises(TypeError) as exc:
                    backend.cdgmm(torch.empty(3, 4, 5, 2),
                                  torch.empty(4, 5, 1).cuda())
                assert 'must be cuda tensors' in exc.value.args[0].lower()
            elif not backend.name.endswith('_skcuda'):
                with pytest.raises(TypeError) as exc:
                    backend.cdgmm(torch.empty(3, 4, 5, 2),
                                  torch.empty(4, 5, 1).cuda())
                assert 'input must be on gpu' in exc.value.args[0].lower()

                with pytest.raises(TypeError) as exc:
                    backend.cdgmm(
                        torch.empty(3, 4, 5, 2).cuda(), torch.empty(4, 5, 1))
                assert 'input must be on cpu' in exc.value.args[0].lower()