示例#1
0
def test_correlation():
    #    model = correlation(1, 1, 1, 1, 1, 1)
    #    A = Variable(torch.randn(1,1,3,3))
    #    A_ = A.cuda()
    #    B = Variable(torch.randn(1,1,3,3))
    #    B_ = B.cuda()
    #
    #   #import pdb; pdb.set_trace()
    #    #model = correlation1d(3, 1, 20, 1, 1, 1)
    #    y = model(A_, B_)
    #    print(y.size())
    #
    #    print(y)
    #    return

    A = Variable(torch.randn(2, 3, 100, 100), requires_grad=True)
    A_ = A.cuda()
    B = Variable(torch.randn(2, 3, 100, 100), requires_grad=True)
    B_ = B.cuda()

    model = correlation(3, 3, 20, 1, 2, 1)
    y = model(A_, B_)
    print(y.size())

    print('Functional interface test passed')

    z = torch.mean(y)
    z.backward()
    print(A.grad.size())
    print(B.grad.size())

    if A.grad is not None and B.grad is not None:
        print('Backward pass test passed')

    A = Variable(torch.randn(2, 3, 100, 100), requires_grad=True)
    A_ = A.cuda()
    B = Variable(torch.randn(2, 3, 100, 100), requires_grad=True)
    B_ = B.cuda()

    y = Correlation(3, 3, 20, 1, 2, 1)(A_, B_)
    print(y.size())

    print('Module interface test passed')

    z = torch.mean(y)
    z.backward()
    print(A.grad.size())
    print(B.grad.size())

    if A.grad is not None and B.grad is not None:
        print('Backward pass test passed')
    def __init__(self, md=4):
        """
        input: md --- maximum displacement (for correlation. default: 4), after warpping

        """
        super(PWCDCNet, self).__init__()

        self.conv1a = conv(3, 16, kernel_size=3, stride=2)
        self.conv1aa = conv(16, 16, kernel_size=3, stride=1)
        self.conv1b = conv(16, 16, kernel_size=3, stride=1)
        self.conv2a = conv(16, 32, kernel_size=3, stride=2)
        self.conv2aa = conv(32, 32, kernel_size=3, stride=1)
        self.conv2b = conv(32, 32, kernel_size=3, stride=1)
        self.conv3a = conv(32, 64, kernel_size=3, stride=2)
        self.conv3aa = conv(64, 64, kernel_size=3, stride=1)
        self.conv3b = conv(64, 64, kernel_size=3, stride=1)
        self.conv4a = conv(64, 96, kernel_size=3, stride=2)
        self.conv4aa = conv(96, 96, kernel_size=3, stride=1)
        self.conv4b = conv(96, 96, kernel_size=3, stride=1)
        self.conv5a = conv(96, 128, kernel_size=3, stride=2)
        self.conv5aa = conv(128, 128, kernel_size=3, stride=1)
        self.conv5b = conv(128, 128, kernel_size=3, stride=1)
        self.conv6aa = conv(128, 196, kernel_size=3, stride=2)
        self.conv6a = conv(196, 196, kernel_size=3, stride=1)
        self.conv6b = conv(196, 196, kernel_size=3, stride=1)

        self.corr = Correlation(pad_size=md,
                                kernel_size=1,
                                max_displacement=md,
                                stride1=1,
                                stride2=1,
                                corr_multiply=1)
        self.leakyRELU = nn.LeakyReLU(0.1)

        nd = (2 * md + 1)**2
        dd = np.cumsum([128, 128, 96, 64, 32])

        od = nd
        self.conv6_0 = conv(od, 128, kernel_size=3, stride=1)
        self.conv6_1 = conv(od + dd[0], 128, kernel_size=3, stride=1)
        self.conv6_2 = conv(od + dd[1], 96, kernel_size=3, stride=1)
        self.conv6_3 = conv(od + dd[2], 64, kernel_size=3, stride=1)
        self.conv6_4 = conv(od + dd[3], 32, kernel_size=3, stride=1)
        self.predict_flow6 = predict_flow(od + dd[4])
        self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
        self.upfeat6 = deconv(od + dd[4],
                              2,
                              kernel_size=4,
                              stride=2,
                              padding=1)

        od = nd + 128 + 4
        self.conv5_0 = conv(od, 128, kernel_size=3, stride=1)
        self.conv5_1 = conv(od + dd[0], 128, kernel_size=3, stride=1)
        self.conv5_2 = conv(od + dd[1], 96, kernel_size=3, stride=1)
        self.conv5_3 = conv(od + dd[2], 64, kernel_size=3, stride=1)
        self.conv5_4 = conv(od + dd[3], 32, kernel_size=3, stride=1)
        self.predict_flow5 = predict_flow(od + dd[4])
        self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
        self.upfeat5 = deconv(od + dd[4],
                              2,
                              kernel_size=4,
                              stride=2,
                              padding=1)

        od = nd + 96 + 4
        self.conv4_0 = conv(od, 128, kernel_size=3, stride=1)
        self.conv4_1 = conv(od + dd[0], 128, kernel_size=3, stride=1)
        self.conv4_2 = conv(od + dd[1], 96, kernel_size=3, stride=1)
        self.conv4_3 = conv(od + dd[2], 64, kernel_size=3, stride=1)
        self.conv4_4 = conv(od + dd[3], 32, kernel_size=3, stride=1)
        self.predict_flow4 = predict_flow(od + dd[4])
        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
        self.upfeat4 = deconv(od + dd[4],
                              2,
                              kernel_size=4,
                              stride=2,
                              padding=1)

        od = nd + 64 + 4
        self.conv3_0 = conv(od, 128, kernel_size=3, stride=1)
        self.conv3_1 = conv(od + dd[0], 128, kernel_size=3, stride=1)
        self.conv3_2 = conv(od + dd[1], 96, kernel_size=3, stride=1)
        self.conv3_3 = conv(od + dd[2], 64, kernel_size=3, stride=1)
        self.conv3_4 = conv(od + dd[3], 32, kernel_size=3, stride=1)
        self.predict_flow3 = predict_flow(od + dd[4])
        self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
        self.upfeat3 = deconv(od + dd[4],
                              2,
                              kernel_size=4,
                              stride=2,
                              padding=1)

        od = nd + 32 + 4
        self.conv2_0 = conv(od, 128, kernel_size=3, stride=1)
        self.conv2_1 = conv(od + dd[0], 128, kernel_size=3, stride=1)
        self.conv2_2 = conv(od + dd[1], 96, kernel_size=3, stride=1)
        self.conv2_3 = conv(od + dd[2], 64, kernel_size=3, stride=1)
        self.conv2_4 = conv(od + dd[3], 32, kernel_size=3, stride=1)
        self.predict_flow2 = predict_flow(od + dd[4])
        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1)

        self.dc_conv1 = conv(od + dd[4],
                             128,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             dilation=1)
        self.dc_conv2 = conv(128,
                             128,
                             kernel_size=3,
                             stride=1,
                             padding=2,
                             dilation=2)
        self.dc_conv3 = conv(128,
                             128,
                             kernel_size=3,
                             stride=1,
                             padding=4,
                             dilation=4)
        self.dc_conv4 = conv(128,
                             96,
                             kernel_size=3,
                             stride=1,
                             padding=8,
                             dilation=8)
        self.dc_conv5 = conv(96,
                             64,
                             kernel_size=3,
                             stride=1,
                             padding=16,
                             dilation=16)
        self.dc_conv6 = conv(64,
                             32,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             dilation=1)
        self.dc_conv7 = predict_flow(32)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal(m.weight.data, mode='fan_in')
                if m.bias is not None:
                    m.bias.data.zero_()