def test_function(batch_size, height, width, in_channels, out_channels, kernel_size_h, kernel_size_w, in_length, out_length, stride, padding, dilation, routing_type, kwargs, num_iterations): x_cpu = torch.randn(batch_size, in_channels, height, width, dtype=torch.double, requires_grad=True) w_cpu = torch.randn(out_channels // out_length, out_length, in_length, kernel_size_h, kernel_size_w, dtype=torch.double, requires_grad=True) x_gpu = x_cpu.detach().to('cuda').requires_grad_() w_gpu = w_cpu.detach().to('cuda').requires_grad_() y_fast = CL.capsule_cov2d(x_gpu, w_gpu, stride, padding, dilation, routing_type, num_iterations, **kwargs) y_ref = CL.capsule_cov2d(x_cpu, w_cpu, stride, padding, dilation, routing_type, num_iterations, **kwargs) assert y_fast.view(-1).tolist() == approx(y_ref.view(-1).tolist()) go_cpu = torch.randn(y_ref.size(), dtype=torch.double) go_gpu = go_cpu.detach().to('cuda') y_fast.backward(go_gpu) gx_fast = x_gpu.grad.clone() gw_fast = w_gpu.grad.clone() assert gradcheck( partial(CL.capsule_cov2d, stride=stride, padding=padding, dilation=dilation, routing_type=routing_type, num_iterations=num_iterations, **kwargs), (x_gpu, w_gpu)) y_ref.backward(go_cpu) gx_ref = x_cpu.grad.clone() gw_ref = w_cpu.grad.clone() assert gradcheck( partial(CL.capsule_cov2d, stride=stride, padding=padding, dilation=dilation, routing_type=routing_type, num_iterations=num_iterations, **kwargs), (x_cpu, w_cpu)) assert gx_fast.view(-1).tolist() == approx(gx_ref.view(-1).tolist()) assert gw_fast.view(-1).tolist() == approx(gw_ref.view(-1).tolist())
def test_multigpu(batch_size, height, width, in_channels, out_channels, kernel_size_h, kernel_size_w, in_length, out_length, stride, padding, dilation, routing_type, kwargs, num_iterations): a0 = torch.randn(batch_size, in_channels, height, width, device='cuda:0', requires_grad=True) a1 = torch.randn(batch_size, in_channels, height, width, device='cuda:1', requires_grad=True) w0 = torch.randn(out_channels // out_length, out_length, in_length, kernel_size_h, kernel_size_w, device='cuda:0', requires_grad=True) w1 = torch.randn(out_channels // out_length, out_length, in_length, kernel_size_h, kernel_size_w, device='cuda:1', requires_grad=True) y0 = CL.capsule_cov2d(a0, w0, stride, padding, dilation, routing_type=routing_type, num_iterations=num_iterations, **kwargs) go = torch.randn(y0.size(), device='cuda:0') y0.backward(go) y1 = CL.capsule_cov2d(a1, w1, stride, padding, dilation, routing_type=routing_type, num_iterations=num_iterations, **kwargs) y1.backward(go.detach().to('cuda:1'))
def forward(self, input): return CL.capsule_cov2d(input, self.weight, self.stride, self.padding, self.dilation, self.share_weight, self.routing_type, self.num_iterations, self.bias, **self.kwargs)
def test_function(batch_size, height, width, in_channels, out_channels, kernel_size_h, kernel_size_w, in_length, out_length, stride, padding, dilation, routing_type, num_iterations, squash): x_cpu = torch.randn(batch_size, in_channels, height, width, dtype=torch.double, requires_grad=True) w_cpu = torch.randn(out_channels // out_length, out_length, in_length, kernel_size_h, kernel_size_w, dtype=torch.double, requires_grad=True) x_gpu = x_cpu.detach().to('cuda').requires_grad_() w_gpu = w_cpu.detach().to('cuda').requires_grad_() y_fast, prob_fast = CL.capsule_cov2d(x_gpu, w_gpu, stride, padding, dilation, routing_type=routing_type, num_iterations=num_iterations, squash=squash) y_ref, prob_ref = CL.capsule_cov2d(x_cpu, w_cpu, stride, padding, dilation, routing_type=routing_type, num_iterations=num_iterations, squash=squash) assert torch.allclose(y_fast.cpu(), y_ref) assert torch.allclose(prob_fast.cpu(), prob_ref) go_cpu = torch.randn(y_ref.size(), dtype=torch.double) go_gpu = go_cpu.detach().to('cuda') y_fast.backward(go_gpu) gx_fast = x_gpu.grad.clone() gw_fast = w_gpu.grad.clone() assert gradcheck( partial(CL.capsule_cov2d, stride=stride, padding=padding, dilation=dilation, routing_type=routing_type, num_iterations=num_iterations, squash=squash), (x_gpu, w_gpu)) y_ref.backward(go_cpu) gx_ref = x_cpu.grad.clone() gw_ref = w_cpu.grad.clone() assert gradcheck( partial(CL.capsule_cov2d, stride=stride, padding=padding, dilation=dilation, routing_type=routing_type, num_iterations=num_iterations, squash=squash), (x_cpu, w_cpu)) assert torch.allclose(gx_fast.cpu(), gx_ref) assert torch.allclose(gw_fast.cpu(), gw_ref)