def test_conv_values(): """Compare to pytorch convolution Check that the convolution agrees with a pytorch convolution on the output area that is not reflected. """ dtype = torch.double intensive = False if intensive: test_params = [(b, c_in, c_out, dil, size, implementation) for b in [1, 3, 5] for c_in in [1, 3, 4] for c_out in [1, 2, 3] for dil in range(1, 10) for size in [dil * 2 + 1, 50, 1023] for implementation in range(2, 6)] else: test_params = [(b, c_in, c_out, dil, size, implementation) for b in [2] for c_in in [3] for c_out in [5] for dil in [1, 3, 10] for size in [dil * 2 + 1, 29, 50] for implementation in range(2, 6)] for (B, C_in, C_out, dilation, size, impl) in test_params: shape = (size, 2 * size) # Execute my own implementation x = torch.randn(B, C_in, *shape, dtype=dtype).cuda() k = torch.randn(C_out, C_in, 3, 3, dtype=dtype).cuda() bias = torch.randn(C_out, dtype=dtype).cuda() y = torch.zeros(B, C_out, *shape, dtype=dtype).cuda() conv_cuda.conv_forward(x, k, bias, y, dilation, impl) # Execute pytorch convolution: conv_torch = torch.nn.Conv2d(1, 1, 3, padding=dilation, dilation=dilation).cuda() conv_torch.weight.data = k conv_torch.bias.data = bias y1 = conv_torch(x) assert y1.shape == y.shape # check shapes # Check center of output, where the output should be equal. d = dilation y_ = y[:, :, d:-d, d:-d] y1_ = y1[:, :, d:-d, d:-d] assert torch_equal(y1_, y_), (f"for shape {shape} and dilation {dilation} " f"and bias {bias}" f"and implementation {impl}\n" f"Your implementation:\n{y}" f"\nPyTorch:\n{y1}")
def test_dtype_check(): """Test if dtype checks are performed correctly """ d0 = torch.float d1 = torch.double dilation = 1 x = torch.zeros(1, 1, 5, 5, dtype=d0).cuda() y = torch.zeros(1, 1, 5, 5, dtype=d0).cuda() bias = torch.zeros(1, dtype=d1).cuda() k = torch.zeros(1, 1, 3, 3, dtype=d1).cuda() with pytest.raises(RuntimeError) as e_info: conv_cuda.conv_forward(x, k, bias, y, dilation)
def test_conv(): """Test convolution Check if an all-zero convolution runs without runtime errors. """ dtype = torch.float # or t.double dilation = 1 for implementation in range(2, 6): x = torch.zeros(1, 1, 5, 5, dtype=dtype).cuda() y = torch.ones(1, 1, 5, 5, dtype=dtype).cuda() bias = torch.zeros(1, dtype=dtype).cuda() k = torch.zeros(1, 1, 3, 3, dtype=dtype).cuda() conv_cuda.conv_forward(x, k, bias, y, dilation, implementation) assert y.sum().item() == approx(0.0)
def forward(ctx, input, weight, bias, output, stride, dilation): ctx.save_for_backward(input, weight, bias) ctx.dilation = dilation conv_cuda.conv_forward(input, weight, bias, output.data, dilation) return output
def A(k, x): y = torch.zeros(B, k.size(0), *x.shape[2:], dtype=dtype).cuda() bias = torch.zeros(k.size(0), dtype=dtype).cuda() assert x.size(0) == y.size(0), f"{x.shape} , {y.shape}" conv_cuda.conv_forward(x, k, bias, y, dilation) return y
def A(x, k): y = torch.zeros(B, k.size(0), *x.shape[2:], dtype=dtype).cuda() bias = torch.zeros(k.size(0), dtype=dtype).cuda() conv_cuda.conv_forward(x, k, bias, y, dilation) return y