예제 #1
0
def test_ri_optim(milestones, addition):
    models = [
        nn.Conv2d(1, 3, 3),
        CapsuleLinear(10, 8, 16, num_iterations=1),
        CapsuleConv2d(8, 16, 3, 4, 8),
        nn.Sequential(nn.Conv2d(1, 20, 5), CapsuleLinear(10, 8, 16)),
        nn.ModuleList([
            nn.Sequential(nn.Conv2d(1, 5, 3),
                          CapsuleLinear(10, 8, 16, num_iterations=2)),
            nn.Sequential(CapsuleLinear(10, 8, 16),
                          CapsuleConv2d(8, 16, 3, 4, 8)),
            CapsuleLinear(10, 8, 16)
        ])
    ]
    for model in models:
        schedule = MultiStepRI(model, milestones, addition, verbose=True)
        for epoch in range(20):
            schedule.step()
예제 #2
0
def test_dropout_optim(milestones, addition):
    models = [
        nn.Conv2d(1, 3, 3),
        CapsuleLinear(10, 8, 16, dropout=0.1),
        CapsuleConv2d(8, 16, 3, 4, 8),
        nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(), CapsuleLinear(10, 8,
                                                                    16)),
        nn.ModuleList([
            nn.Sequential(nn.Conv2d(1, 5, 3), nn.ReLU(),
                          CapsuleLinear(10, 8, 16, dropout=0.2)),
            nn.Sequential(CapsuleLinear(10, 8, 16),
                          CapsuleConv2d(8, 16, 3, 4, 8)),
            CapsuleLinear(10, 8, 16)
        ])
    ]
    for model in models:
        schedule = MultiStepDropout(model, milestones, addition, verbose=True)
        for epoch in range(10):
            schedule.step()
예제 #3
0
def test_module(batch_size, height, width, in_channels, out_channels,
                kernel_size_h, kernel_size_w, in_length, out_length, stride,
                padding, dilation, routing_type, kwargs, num_iterations):
    module = CapsuleConv2d(in_channels, out_channels,
                           (kernel_size_h, kernel_size_w), in_length,
                           out_length, stride, padding, dilation, routing_type,
                           num_iterations, **kwargs)
    x = torch.randn(batch_size, in_channels, height, width)
    y_cpu = module(x)
    y_cuda = module.to('cuda')(x.to('cuda'))
    assert y_cuda.view(-1).tolist() == approx(y_cpu.view(-1).tolist(),
                                              abs=1e-5)
예제 #4
0
def test_module(batch_size, height, width, in_channels, out_channels,
                kernel_size_h, kernel_size_w, in_length, out_length, stride,
                padding, dilation, routing_type, num_iterations, squash):
    module = CapsuleConv2d(in_channels,
                           out_channels, (kernel_size_h, kernel_size_w),
                           in_length,
                           out_length,
                           stride,
                           padding,
                           dilation,
                           routing_type=routing_type,
                           num_iterations=num_iterations,
                           squash=squash)
    x = torch.randn(batch_size, in_channels, height, width)
    y_cpu, prob_cpu = module(x)
    y_cuda, prob_cuda = module.to('cuda')(x.to('cuda'))
    assert torch.allclose(y_cuda.cpu(), y_cpu)
    assert torch.allclose(prob_cuda.cpu(), prob_cpu)