コード例 #1
0
ファイル: FADNet.py プロジェクト: liuguoyou/FADNet-1
    def __init__(self,
                 batchNorm=True,
                 lastRelu=False,
                 resBlock=True,
                 maxdisp=-1,
                 input_channel=3):
        super(FADNet, self).__init__()
        self.input_channel = input_channel
        self.batchNorm = batchNorm
        self.lastRelu = lastRelu
        self.maxdisp = maxdisp
        self.resBlock = resBlock

        # First Block (DispNetC)
        self.dispnetc = DispNetC(self.batchNorm,
                                 maxdisp=self.maxdisp,
                                 input_channel=input_channel)

        # warp layer and channelnorm layer
        self.channelnorm = ChannelNorm()
        self.resample1 = Resample2d()

        # Second Block (DispNetRes), input is 11 channels(img0, img1, img1->img0, flow, diff-mag)
        in_planes = 3 * 3 + 1 + 1
        self.dispnetres = DispNetRes(in_planes,
                                     self.batchNorm,
                                     lastRelu=self.lastRelu,
                                     maxdisp=self.maxdisp,
                                     input_channel=input_channel)

        self.relu = nn.ReLU(inplace=False)
コード例 #2
0
ファイル: DispNetCS.py プロジェクト: liuguoyou/FADNet-1
    def __init__(self,
                 batchNorm=False,
                 lastRelu=True,
                 resBlock=True,
                 maxdisp=-1,
                 input_channel=3):
        super(DispNetCS, self).__init__()
        self.input_channel = input_channel
        self.batchNorm = batchNorm
        self.resBlock = resBlock
        self.lastRelu = lastRelu
        self.maxdisp = maxdisp

        # First Block (DispNetC)
        self.dispnetc = DispNetC(batchNorm=self.batchNorm,
                                 resBlock=self.resBlock,
                                 maxdisp=self.maxdisp,
                                 input_channel=input_channel)
        # Second and third Block (DispNetS), input is 6+3+1+1=11
        self.dispnets1 = DispNetS(11,
                                  batchNorm=self.batchNorm,
                                  resBlock=self.resBlock,
                                  maxdisp=self.maxdisp,
                                  input_channel=3)

        # warp layer and channelnorm layer
        self.channelnorm = ChannelNorm()
        self.resample1 = Resample2d()

        self.relu = nn.ReLU(inplace=False)
コード例 #3
0
ファイル: MobileFADNet.py プロジェクト: zhaobinNF/FADNet
    def __init__(self,
                 batchNorm=True,
                 lastRelu=False,
                 resBlock=True,
                 maxdisp=-1,
                 input_channel=3,
                 input_img_shape=None,
                 warp_size=None):
        super(MobileFADNet, self).__init__()
        self.input_channel = input_channel
        self.batchNorm = batchNorm
        self.lastRelu = lastRelu
        self.maxdisp = maxdisp
        self.resBlock = resBlock
        self.warp_size = warp_size  #(1, 3, 576, 960)
        if warp_size is not None:
            B, C, H, W = warp_size
            xx = torch.arange(0, W).float().cuda()
            yy = torch.arange(0, H).float().cuda()
            xx = xx.view(1, -1).repeat(H, 1)
            yy = yy.view(-1, 1).repeat(1, W)
            xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
            yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
            yy = 2.0 * yy / max(H - 1, 1) - 1.0
            self.warp_grid = (xx, yy)
        else:
            self.warp_grid = None

        # First Block (DispNetC)
        self.dispnetc = MobileDispNetC(self.batchNorm,
                                       maxdisp=self.maxdisp,
                                       input_channel=input_channel,
                                       input_img_shape=input_img_shape)

        # warp layer and channelnorm layer
        self.channelnorm = ChannelNorm()
        self.resample1 = Resample2d()

        # Second Block (DispNetRes), input is 11 channels(img0, img1, img1->img0, flow, diff-mag)
        in_planes = 3 * 3 + 1 + 1
        self.dispnetres = MobileDispNetRes(in_planes,
                                           self.batchNorm,
                                           lastRelu=self.lastRelu,
                                           maxdisp=self.maxdisp,
                                           input_channel=input_channel)

        self.relu = nn.ReLU(inplace=False)