Exemplo n.º 1
0
    def __init__(self, evaluation, div=1.0, iterative_refinement=False,
                 refinement_at_all_levels=False, refinement_at_adaptive_reso=True,
                 batch_norm=True, pyramid_type='VGG', md=4, upfeat_channels=2, dense_connection=True,
                 consensus_network=False, cyclic_consistency=True, decoder_inputs='corr_flow_feat'):
        """
        input: md --- maximum displacement (for correlation. default: 4), after warpping

        """
        super(GLUNet_model, self).__init__()
        self.div=div
        self.pyramid_type = pyramid_type
        self.leakyRELU = nn.LeakyReLU(0.1)
        self.l2norm = FeatureL2Norm()
        self.iterative_refinement = iterative_refinement #only during evaluation

        # where to put the refinement networks
        self.refinement_at_all_levels = refinement_at_all_levels
        self.refinement_at_adaptive_reso = refinement_at_adaptive_reso

        # definition of the inputs to the decoders
        self.decoder_inputs = decoder_inputs
        self.dense_connection = dense_connection
        self.upfeat_channels = upfeat_channels

        # improvement of the global correlation
        self.cyclic_consistency=cyclic_consistency
        self.consensus_network = consensus_network
        if self.cyclic_consistency:
            self.corr = FeatureCorrelation(shape='4D', normalization=False)
        elif consensus_network:
            ncons_kernel_sizes = [3, 3, 3]
            ncons_channels = [10, 10, 1]
            self.corr = FeatureCorrelation(shape='4D', normalization=False)
            # normalisation is applied in code here
            self.NeighConsensus = NeighConsensus(use_cuda=True,
                                                 kernel_sizes=ncons_kernel_sizes,
                                                 channels=ncons_channels)
        else:
            self.corr = CorrelationVolume()


        dd = np.cumsum([128,128,96,64,32])
        # 16x16
        nd = 16*16 # global correlation
        od = nd + 2
        self.decoder4 = CMDTop(in_channels=od, bn=batch_norm)
        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1)

        # 32x32
        nd = (2*md+1)**2 # constrained correlation, 4 pixels on each side
        if self.decoder_inputs == 'corr_flow_feat':
            od = nd + 2
        elif self.decoder_inputs == 'corr':
            od = nd
        elif self.decoder_inputs == 'corr_flow':
            od = nd + 2
        if dense_connection:
            self.decoder3 = OpticalFlowEstimator(in_channels=od, batch_norm=batch_norm)
            input_to_refinement = od + dd[4]
        else:
            self.decoder3 = OpticalFlowEstimatorNoDenseConnection(in_channels=od, batch_norm=batch_norm)
            input_to_refinement = 32

        # weights for refinement module
        if self.refinement_at_all_levels or self.refinement_at_adaptive_reso:
            self.dc_conv1 = conv(input_to_refinement, 128, kernel_size=3, stride=1, padding=1,  dilation=1, batch_norm=batch_norm)
            self.dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2, batch_norm=batch_norm)
            self.dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4, batch_norm=batch_norm)
            self.dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8, batch_norm=batch_norm)
            self.dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16, batch_norm=batch_norm)
            self.dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1, batch_norm=batch_norm)
            self.dc_conv7 = predict_flow(32)

        # 1/8 of original resolution
        nd = (2*md+1)**2 # constrained correlation, 4 pixels on each side
        if self.decoder_inputs == 'corr_flow_feat':
            od = nd + 2
        elif self.decoder_inputs == 'corr':
            od = nd
        elif self.decoder_inputs == 'corr_flow':
            od = nd + 2
        if dense_connection:
            self.decoder2 = OpticalFlowEstimator(in_channels=od, batch_norm=batch_norm)
            input_to_refinement = od + dd[4]
        else:
            self.decoder2 = OpticalFlowEstimatorNoDenseConnection(in_channels=od, batch_norm=batch_norm)
            input_to_refinement = 32
        if self.decoder_inputs == 'corr_flow_feat':
            self.upfeat2 = deconv(input_to_refinement, self.upfeat_channels, kernel_size=4, stride=2, padding=1)

        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1)

        if refinement_at_all_levels:
            # weights for refinement module
            self.dc_conv1_level2 = conv(input_to_refinement, 128, kernel_size=3, stride=1, padding=1,  dilation=1, batch_norm=batch_norm)
            self.dc_conv2_level2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2, batch_norm=batch_norm)
            self.dc_conv3_level2 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4, batch_norm=batch_norm)
            self.dc_conv4_level2 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8, batch_norm=batch_norm)
            self.dc_conv5_level2 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16, batch_norm=batch_norm)
            self.dc_conv6_level2 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1, batch_norm=batch_norm)
            self.dc_conv7_level2 = predict_flow(32)

        # 1/4 of original resolution
        nd = (2*md+1)**2 # constrained correlation, 4 pixels on each side
        if self.decoder_inputs == 'corr_flow_feat':
            od = nd + self.upfeat_channels + 2
        elif self.decoder_inputs == 'corr':
            od = nd
        elif self.decoder_inputs == 'corr_flow':
            od = nd + 2
        if dense_connection:
            self.decoder1 = OpticalFlowEstimator(in_channels=od, batch_norm=batch_norm)
            input_to_refinement = od+dd[4]
        else:
            self.decoder1 = OpticalFlowEstimatorNoDenseConnection(in_channels=od, batch_norm=batch_norm)
            input_to_refinement = 32

        self.l_dc_conv1 = conv(input_to_refinement, 128, kernel_size=3, stride=1, padding=1,  dilation=1, batch_norm=batch_norm)
        self.l_dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2, batch_norm=batch_norm)
        self.l_dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4, batch_norm=batch_norm)
        self.l_dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8, batch_norm=batch_norm)
        self.l_dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16, batch_norm=batch_norm)
        self.l_dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1, batch_norm=batch_norm)
        self.l_dc_conv7 = predict_flow(32)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
                if m.bias is not None:
                    m.bias.data.zero_()

        if pyramid_type == 'ResNet':
            self.pyramid = ResNetPyramid()
        else:
            self.pyramid = VGGPyramid()

        self.evaluation=evaluation
Exemplo n.º 2
0
    def __init__(self,
                 evaluation,
                 div=1.0,
                 batch_norm=True,
                 pyramid_type='VGG',
                 md=4,
                 cyclic_consistency=False,
                 consensus_network=True,
                 iterative_refinement=False):
        """
        input: md --- maximum displacement (for correlation. default: 4), after warpping

        """
        super(SemanticGLUNet_model, self).__init__()
        self.div = div
        self.pyramid_type = pyramid_type
        self.leakyRELU = nn.LeakyReLU(0.1)
        self.iterative_refinement = iterative_refinement
        self.cyclic_consistency = cyclic_consistency
        self.consensus_network = consensus_network
        if self.cyclic_consistency:
            self.corr = FeatureCorrelation(shape='4D', normalization=False)
        elif consensus_network:
            ncons_kernel_sizes = [3, 3, 3]
            ncons_channels = [10, 10, 1]
            self.corr = FeatureCorrelation(shape='4D', normalization=False)
            # normalisation is applied in code here

            self.NeighConsensus = NeighConsensus(
                use_cuda=True,
                kernel_sizes=ncons_kernel_sizes,
                channels=ncons_channels)
        else:
            self.corr = CorrelationVolume()
        # L2 feature normalisation
        self.l2norm = FeatureL2Norm()

        dd = np.cumsum([128, 128, 96, 64, 32])
        # weights for decoder at different levels
        nd = 16 * 16  # global correlation
        od = nd + 2
        self.decoder4 = CMDTop(in_channels=od, bn=batch_norm)
        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1)

        nd = (2 * md + 1)**2  # constrained correlation, 4 pixels on each side
        od = nd + 2
        self.decoder3 = OpticalFlowEstimator(in_channels=od,
                                             batch_norm=batch_norm)

        # weights for refinement module
        self.dc_conv1 = conv(od + dd[4],
                             128,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             dilation=1,
                             batch_norm=batch_norm)
        self.dc_conv2 = conv(128,
                             128,
                             kernel_size=3,
                             stride=1,
                             padding=2,
                             dilation=2,
                             batch_norm=batch_norm)
        self.dc_conv3 = conv(128,
                             128,
                             kernel_size=3,
                             stride=1,
                             padding=4,
                             dilation=4,
                             batch_norm=batch_norm)
        self.dc_conv4 = conv(128,
                             96,
                             kernel_size=3,
                             stride=1,
                             padding=8,
                             dilation=8,
                             batch_norm=batch_norm)
        self.dc_conv5 = conv(96,
                             64,
                             kernel_size=3,
                             stride=1,
                             padding=16,
                             dilation=16,
                             batch_norm=batch_norm)
        self.dc_conv6 = conv(64,
                             32,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             dilation=1,
                             batch_norm=batch_norm)
        self.dc_conv7 = predict_flow(32)

        # 1/8 of original resolution
        nd = (2 * md + 1)**2  # constrained correlation, 4 pixels on each side
        od = nd + 2  # only gets the upsampled flow
        self.decoder2 = OpticalFlowEstimator(in_channels=od,
                                             batch_norm=batch_norm)
        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
        self.upfeat2 = deconv(od + dd[4],
                              2,
                              kernel_size=4,
                              stride=2,
                              padding=1)

        # 1/4 of original resolution
        nd = (2 * md + 1)**2  # constrained correlation, 4 pixels on each side
        od = nd + 4
        self.decoder1 = OpticalFlowEstimator(in_channels=od,
                                             batch_norm=batch_norm)

        self.l_dc_conv1 = conv(od + dd[4],
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               dilation=1,
                               batch_norm=batch_norm)
        self.l_dc_conv2 = conv(128,
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=2,
                               dilation=2,
                               batch_norm=batch_norm)
        self.l_dc_conv3 = conv(128,
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=4,
                               dilation=4,
                               batch_norm=batch_norm)
        self.l_dc_conv4 = conv(128,
                               96,
                               kernel_size=3,
                               stride=1,
                               padding=8,
                               dilation=8,
                               batch_norm=batch_norm)
        self.l_dc_conv5 = conv(96,
                               64,
                               kernel_size=3,
                               stride=1,
                               padding=16,
                               dilation=16,
                               batch_norm=batch_norm)
        self.l_dc_conv6 = conv(64,
                               32,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               dilation=1,
                               batch_norm=batch_norm)
        self.l_dc_conv7 = predict_flow(32)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
                if m.bias is not None:
                    m.bias.data.zero_()

        if pyramid_type == 'ResNet':
            self.pyramid = ResNetPyramid()
        else:
            self.pyramid = VGGPyramid()

        self.evaluation = evaluation