Example #1
0
def test_function(batch_size, in_capsules, out_capsules, in_length, out_length, routing_type, share_weight,
                  num_iterations, squash):
    x_cpu = torch.randn(batch_size, in_capsules, in_length, dtype=torch.double, requires_grad=True)
    if share_weight:
        w_cpu = torch.randn(out_capsules, out_length, in_length, dtype=torch.double, requires_grad=True)
    else:
        w_cpu = torch.randn(out_capsules, in_capsules, out_length, in_length, dtype=torch.double, requires_grad=True)
    x_gpu = x_cpu.detach().to('cuda').requires_grad_()
    w_gpu = w_cpu.detach().to('cuda').requires_grad_()
    y_fast, prob_fast = CL.capsule_linear(x_gpu, w_gpu, share_weight, routing_type, num_iterations, squash)
    y_ref, prob_ref = CL.capsule_linear(x_cpu, w_cpu, share_weight, routing_type, num_iterations, squash)
    assert torch.allclose(y_fast.cpu(), y_ref)
    assert torch.allclose(prob_fast.cpu(), prob_ref)

    go_cpu = torch.randn(y_ref.size(), dtype=torch.double)
    go_gpu = go_cpu.detach().to('cuda')
    y_fast.backward(go_gpu)
    gx_fast = x_gpu.grad.clone()
    gw_fast = w_gpu.grad.clone()
    assert gradcheck(
        partial(CL.capsule_linear, share_weight=share_weight, routing_type=routing_type, num_iterations=num_iterations,
                squash=squash), (x_gpu, w_gpu))

    y_ref.backward(go_cpu)
    gx_ref = x_cpu.grad.clone()
    gw_ref = w_cpu.grad.clone()
    assert gradcheck(
        partial(CL.capsule_linear, share_weight=share_weight, routing_type=routing_type, num_iterations=num_iterations,
                squash=squash), (x_cpu, w_cpu))

    assert torch.allclose(gx_fast.cpu(), gx_ref)
    assert torch.allclose(gw_fast.cpu(), gw_ref)
Example #2
0
def test_function(batch_size, in_capsules, out_capsules, in_length, out_length, routing_type, kwargs, share_weight,
                  num_iterations):
    x_cpu = torch.randn(batch_size, in_capsules, in_length, dtype=torch.double, requires_grad=True)
    if share_weight:
        w_cpu = torch.randn(out_capsules, out_length, in_length, dtype=torch.double, requires_grad=True)
    else:
        w_cpu = torch.randn(out_capsules, in_capsules, out_length, in_length, dtype=torch.double, requires_grad=True)
    x_gpu = x_cpu.detach().to('cuda').requires_grad_()
    w_gpu = w_cpu.detach().to('cuda').requires_grad_()
    y_fast = CL.capsule_linear(x_gpu, w_gpu, share_weight, routing_type, num_iterations, **kwargs)
    y_ref = CL.capsule_linear(x_cpu, w_cpu, share_weight, routing_type, num_iterations, **kwargs)
    assert y_fast.view(-1).tolist() == approx(y_ref.view(-1).tolist())

    go_cpu = torch.randn(y_ref.size(), dtype=torch.double)
    go_gpu = go_cpu.detach().to('cuda')
    y_fast.backward(go_gpu)
    gx_fast = x_gpu.grad.clone()
    gw_fast = w_gpu.grad.clone()
    assert gradcheck(
        partial(CL.capsule_linear, share_weight=share_weight, routing_type=routing_type, num_iterations=num_iterations,
                **kwargs), (x_gpu, w_gpu))

    y_ref.backward(go_cpu)
    gx_ref = x_cpu.grad.clone()
    gw_ref = w_cpu.grad.clone()
    assert gradcheck(
        partial(CL.capsule_linear, share_weight=share_weight, routing_type=routing_type, num_iterations=num_iterations,
                **kwargs), (x_cpu, w_cpu))

    assert gx_fast.view(-1).tolist() == approx(gx_ref.view(-1).tolist())
    assert gw_fast.view(-1).tolist() == approx(gw_ref.view(-1).tolist())
Example #3
0
def test_multigpu(batch_size, in_capsules, out_capsules, in_length, out_length, routing_type, kwargs, share_weight,
                  num_iterations):
    a0 = torch.randn(batch_size, in_capsules, in_length, device='cuda:0', requires_grad=True)
    a1 = torch.randn(batch_size, in_capsules, in_length, device='cuda:1', requires_grad=True)
    if share_weight:
        w0 = torch.randn(out_capsules, out_length, in_length, device='cuda:0', requires_grad=True)
        w1 = torch.randn(out_capsules, out_length, in_length, device='cuda:1', requires_grad=True)
    else:
        w0 = torch.randn(out_capsules, in_capsules, out_length, in_length, device='cuda:0', requires_grad=True)
        w1 = torch.randn(out_capsules, in_capsules, out_length, in_length, device='cuda:1', requires_grad=True)
    y0 = CL.capsule_linear(a0, w0, share_weight, routing_type, num_iterations, **kwargs)
    go = torch.randn(y0.size(), device='cuda:0')
    y0.backward(go)
    y1 = CL.capsule_linear(a1, w1, share_weight, routing_type, num_iterations, **kwargs)
    y1.backward(go.detach().to('cuda:1'))
Example #4
0
 def forward(self, input):
     return CL.capsule_linear(input, self.weight, self.share_weight,
                              self.routing_type, self.num_iterations,
                              self.bias, **self.kwargs)