def test_cdgmm_forward(self, data, backend, device, inplace): if device == 'cpu' and backend.NAME == 'skcuda': pytest.skip("skcuda backend can only run on gpu") x, filt, y = data # move to device x, filt, y = x.to(device), filt.to(device), y.to(device) # call cdgmm if inplace: x = x.clone() z = backend.cdgmm(x, filt, inplace=inplace) if inplace: z = x # compare assert (y - z).abs().max() < 1e-6
def test_Cublas(): for device in devices: if device == 'gpu': for backend in backends: x = torch.rand(100, 128, 128, 2).cuda() filter = torch.rand(128, 128, 2).cuda() filter[..., 1] = 0 y = torch.ones(100, 128, 128, 2).cuda() z = torch.Tensor(100, 128, 128, 2).cuda() for i in range(100): y[i, :, :, 0] = x[i, :, :, 0] * filter[:, :, 0] - x[ i, :, :, 1] * filter[:, :, 1] y[i, :, :, 1] = x[i, :, :, 1] * filter[:, :, 0] + x[ i, :, :, 0] * filter[:, :, 1] z = backend.cdgmm(x, filter) assert (y - z).abs().max() < 1e-6 elif device == 'cpu': for backend in backends: if backend.NAME == 'skcuda': continue x = torch.rand(100, 128, 128, 2) filter = torch.rand(128, 128, 2) filter[..., 1] = 0 y = torch.ones(100, 128, 128, 2) z = torch.Tensor(100, 128, 128, 2) for i in range(100): y[i, :, :, 0] = x[i, :, :, 0] * filter[:, :, 0] - x[ i, :, :, 1] * filter[:, :, 1] y[i, :, :, 1] = x[i, :, :, 1] * filter[:, :, 0] + x[ i, :, :, 0] * filter[:, :, 1] z = backend.cdgmm(x, filter) assert (y - z).abs().max() < 1e-6
def test_cdgmm_exceptions(self, backend): with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 3, 2)) assert "not compatible" in exc.value.args[0] with pytest.raises(TypeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 1), torch.empty(4, 5, 1)) assert "input must be complex" in exc.value.args[0] with pytest.raises(TypeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 3)) assert "filter must be complex or real" in exc.value.args[0] with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(3, 4, 5, 2)) assert "filter must be a 3-tensor" in exc.value.args[0] with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 1).double()) assert "must be of the same dtype" in exc.value.args[0] if 'gpu' in devices: with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 1).cuda()) assert "must be on the same device" in exc.value.args[0]
class TestCDGMM: @pytest.fixture(params=(False, True)) def data(self, request): real_filter = request.param x = torch.rand(100, 128, 128, 2) filt = torch.rand(128, 128, 2) y = torch.ones(100, 128, 128, 2) if real_filter: filt[..., 1] = 0 y[..., 0] = x[..., 0] * filt[..., 0] - x[..., 1] * filt[..., 1] y[..., 1] = x[..., 1] * filt[..., 0] + x[..., 0] * filt[..., 1] if real_filter: filt = filt[..., :1] return x, filt, y if 'gpu' in devices: x, filt, y = data x, filt = x.to('cpu'), filt.to('gpu') with pytest.raises(RuntimeError) as exc: backend.cdgmm(x, filt) assert ('device' in exc.value.args[0]) @pytest.mark.parametrize("backend", backends) @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("inplace", (False, True)) def test_cdgmm_forward(self, data, backend, device, inplace): if device == 'cpu' and backend.NAME == 'skcuda': pytest.skip("skcuda backend can only run on gpu") x, filt, y = data # move to device device = 'cuda' if device == 'gpu' else device x, filt, y = x.to(device), filt.to(device), y.to(device) # call cdgmm if inplace: x = x.clone() z = backend.cdgmm(x, filt, inplace=inplace) if inplace: z = x # compare assert (y - z).abs().max() < 1e-6 @pytest.mark.parametrize("backend", backends) def test_cdgmm_exceptions(self, backend): with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 3, 2)) assert "not compatible" in exc.value.args[0] with pytest.raises(TypeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 1), torch.empty(4, 5, 1)) assert "input must be complex" in exc.value.args[0] with pytest.raises(TypeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 3)) assert "filter must be complex or real" in exc.value.args[0] with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(3, 4, 5, 2)) assert "filter must be a 3-tensor" in exc.value.args[0] with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 1).double()) assert "must be of the same dtype" in exc.value.args[0] if 'gpu' in devices: with pytest.raises(RuntimeError) as exc: backend.cdgmm(torch.empty(3, 4, 5, 2), torch.empty(4, 5, 1).cuda()) assert "must be on the same device" in exc.value.args[0]