Пример #1
0
    def forward(self, diff1, diff2, tensorInput1, tensorInput2):
        tensorJoin = torch.cat([diff1, diff2, tensorInput1, tensorInput2], 1)
        tensorConv1 = self.moduleConv1(tensorJoin)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)

        tensorCombine = tensorUpsample5 + tensorConv5

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)

        tensorCombine = tensorUpsample4 + tensorConv4

        # tensorDot1_a = sepconv.FunctionSepconv()(self.modulePad_a(func.upsample(tensorInput1,size=(tensorInput1.shape[2]//4,tensorInput1.shape[3]//4),mode='bilinear',align_corners=True)),
        #                                           self.mv1_a(tensorCombine),self.mh1_a(tensorCombine))
        # tensorDot2_a = sepconv.FunctionSepconv()(self.modulePad_a(func.upsample(tensorInput2, size=(tensorInput1.shape[2] // 4, tensorInput1.shape[3] // 4), mode='bilinear',
        #                   align_corners=True)), self.mv2_a(tensorCombine), self.mh2_a(tensorCombine))

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)

        tensorCombine = tensorUpsample3 + tensorConv3

        # tensorDot1_b = sepconv.FunctionSepconv()(self.modulePad_b(func.upsample(tensorInput1, size=(tensorInput1.shape[2] // 2, tensorInput1.shape[3] // 2), mode='bilinear',
        #                    align_corners=True)),self.mv1_b(tensorCombine), self.mh1_b(tensorCombine))
        # tensorDot2_b = sepconv.FunctionSepconv()(self.modulePad_b(func.upsample(tensorInput2, size=(tensorInput1.shape[2] // 2, tensorInput1.shape[3] // 2), mode='bilinear',
        #                   align_corners=True)), self.mv2_b(tensorCombine), self.mh2_b(tensorCombine))

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)

        tensorCombine = tensorUpsample2 + tensorConv2

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine),
            self.moduleHorizontal2(tensorCombine))

        return tensorDot1 + tensorDot2
    def forward(self, tensorInput1, tensorInput2):
        '''
        tensorInput1/2 : [bcz, 3, height, width]
        diff:            [bcz, 2, height, width]
        '''

        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)

        tensorConv1 = self.moduleConv1(tensorJoin)  #[32, 128, 128]
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)  #[64, 64, 64]
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)  #[128, 32, 32]
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)  #[256, 16, 16]
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)  #[512, 8, 8]
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)  #[512, 4, 4]
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)  #[512, 8, 8]

        tensorCombine = tensorUpsample5 + tensorConv5  #[512, 8, 8]

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)  #[256, 8, 8]
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)  #[256, 16, 16]

        tensorCombine = tensorUpsample4 + tensorConv4  #[256, 16, 16]

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)  #[128, 16, 16]
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)  #[128, 32, 32]

        tensorCombine = tensorUpsample3 + tensorConv3  #[128, 32, 32]

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)  #[64, 32, 32]
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)  #[64, 64, 64]

        tensorCombine = tensorUpsample2 + tensorConv2  #[64, 64, 64]

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine),
            self.moduleHorizontal2(tensorCombine))

        return tensorDot1 + tensorDot2
Пример #3
0
    def forward(self, tensorInput1, tensorInput2):
        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)
        tensorConv1 = self.moduleConv1(tensorJoin)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)

        tensorCombine = tensorUpsample5 + tensorConv5

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)

        tensorCombine = tensorUpsample4 + tensorConv4

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)

        tensorCombine = tensorUpsample3 + tensorConv3

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)

        tensorCombine = tensorUpsample2 + tensorConv2

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine),
            self.moduleHorizontal2(tensorCombine))

        return tensorDot1 + tensorDot2
Пример #4
0
    def forward(self, frames):

        assert np.all([f.shape == frames[0].shape for f in frames])

        _, _, h, w = frames[0].shape

        if len(frames) == 4:
            frames = [frames[1], frames[2], frames[0], frames[3]]

        h_padded = False
        w_padded = False
        padded_frames = []

        for frame in frames:
            if h % 32 != 0:
                pad_h = 32 - (h % 32)
                frame = F.pad(frame, (0, 0, 0, pad_h))
                h_padded = True

            if w % 32 != 0:
                pad_w = 32 - (w % 32)
                frame = F.pad(frame, (0, pad_w, 0, 0))
                w_padded = True

            padded_frames.append(frame)

        Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel(
            padded_frames)

        tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(
            padded_frames[0]), Vertical1, Horizontal1)
        tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(
            padded_frames[1]), Vertical2, Horizontal2)

        frame1 = tensorDot1 + tensorDot2

        if h_padded:
            frame1 = frame1[:, :, 0:h, :]
        if w_padded:
            frame1 = frame1[:, :, :, 0:w]

        return frame1
Пример #5
0
	def forward(self, tenFirst, tenSecond):
		tenConv1 = self.netConv1(torch.cat([ tenFirst, tenSecond ], 1))
		tenConv2 = self.netConv2(torch.nn.functional.avg_pool2d(input=tenConv1, kernel_size=2, stride=2, count_include_pad=False))
		tenConv3 = self.netConv3(torch.nn.functional.avg_pool2d(input=tenConv2, kernel_size=2, stride=2, count_include_pad=False))
		tenConv4 = self.netConv4(torch.nn.functional.avg_pool2d(input=tenConv3, kernel_size=2, stride=2, count_include_pad=False))
		tenConv5 = self.netConv5(torch.nn.functional.avg_pool2d(input=tenConv4, kernel_size=2, stride=2, count_include_pad=False))

		tenDeconv5 = self.netUpsample5(self.netDeconv5(torch.nn.functional.avg_pool2d(input=tenConv5, kernel_size=2, stride=2, count_include_pad=False)))
		tenDeconv4 = self.netUpsample4(self.netDeconv4(tenDeconv5 + tenConv5))
		tenDeconv3 = self.netUpsample3(self.netDeconv3(tenDeconv4 + tenConv4))
		tenDeconv2 = self.netUpsample2(self.netDeconv2(tenDeconv3 + tenConv3))

		tenCombine = tenDeconv2 + tenConv2

		tenFirst = torch.nn.functional.pad(input=tenFirst, pad=[ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ], mode='replicate')
		tenSecond = torch.nn.functional.pad(input=tenSecond, pad=[ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ], mode='replicate')

		tenDot1 = sepconv.FunctionSepconv(tenInput=tenFirst, tenVertical=self.netVertical1(tenCombine), tenHorizontal=self.netHorizontal1(tenCombine))
		tenDot2 = sepconv.FunctionSepconv(tenInput=tenSecond, tenVertical=self.netVertical2(tenCombine), tenHorizontal=self.netHorizontal2(tenCombine))

		return tenDot1 + tenDot2
Пример #6
0
    def forward(self, frame0, frame2):
        h0 = int(list(frame0.size())[2])
        w0 = int(list(frame0.size())[3])
        h2 = int(list(frame2.size())[2])
        w2 = int(list(frame2.size())[3])
        if h0 != h2 or w0 != w2:
            sys.exit('Frame sizes do not match')

        h_padded = False
        w_padded = False
        if h0 % 32 != 0:
            pad_h = 32 - (h0 % 32)
            frame0 = F.pad(frame0, (0, 0, 0, pad_h))
            frame2 = F.pad(frame2, (0, 0, 0, pad_h))
            h_padded = True

        if w0 % 32 != 0:
            pad_w = 32 - (w0 % 32)
            frame0 = F.pad(frame0, (0, pad_w, 0, 0))
            frame2 = F.pad(frame2, (0, pad_w, 0, 0))
            w_padded = True

        Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel(
            frame0, frame2)

        tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(frame0),
                                               Vertical1, Horizontal1)
        tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(frame2),
                                               Vertical2, Horizontal2)

        frame1 = tensorDot1 + tensorDot2

        if h_padded:
            frame1 = frame1[:, :, 0:h0, :]
        if w_padded:
            frame1 = frame1[:, :, :, 0:w0]

        return frame1
Пример #7
0
    def forward(self, Frame1, Frame3):
        h_1 = int(list(Frame1.size())[2])
        w_1 = int(list(Frame1.size())[3])
        h_3 = int(list(Frame3.size())[2])
        w_3 = int(list(Frame3.size())[3])

        # Make sure frame size is same
        if h_1 != h_3 or w_1 != w_3:
            sys.exit('Size mismatch')

        h_pad = False
        w_pad = False

        if w_1 % 32 != 0:
            pad_w = 32 - (w_1 % 32)
            Frame1 = F.pad(Frame1, (0, pad_w, 0, 0))
            Frame3 = F.pad(Frame3, (0, pad_w, 0, 0))
            w_pad = True

        if h_1 % 32 != 0:
            pad_h = 32 - (h_1 % 32)
            Frame1 = F.pad(Frame1, (0, 0, 0, pad_h))
            Frame3 = F.pad(Frame3, (0, 0, 0, pad_h))
            h_pad = True

        Ver1, Hor1, Ver2, Hor2 = self.estimate_kernel(Frame1, Frame3)

        tenDot1 = sepconv.FunctionSepconv()(self.modulePad(Frame1), Ver1, Hor1)
        tenDot2 = sepconv.FunctionSepconv()(self.modulePad(Frame3), Ver2, Hor2)

        Frame2 = tenDot1 + tenDot2

        if h_pad:
            Frame2 = Frame2[:, :, 0:h_1, :]
        if w_pad:
            Frame2 = Frame2[:, :, :, 0:w_1]

        return Frame2, Ver1, Hor1, Ver2, Hor2
Пример #8
0
    def forward(self, frames):

        _, f, _, h, w = frames.shape

        h_padded = False
        w_padded = False
        padded_frames = frames.clone()

        if h % 32 != 0:
            pad_h = 32 - (h % 32)
            padded_frames = F.pad(padded_frames, (0, 0, 0, pad_h))
            h_padded = True

        if w % 32 != 0:
            pad_w = 32 - (w % 32)
            padded_frames = F.pad(padded_frames, (0, pad_w, 0, 0))
            w_padded = True

        Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel(
            padded_frames)

        frame_before = int(0 + f / 4)
        frame_after = int(1 + f / 4)

        tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(
            padded_frames[:, frame_before]), Vertical1, Horizontal1)
        tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(
            padded_frames[:, frame_after]), Vertical2, Horizontal2)

        frame1 = tensorDot1 + tensorDot2

        if h_padded:
            frame1 = frame1[:, :, 0:h, :]
        if w_padded:
            frame1 = frame1[:, :, :, 0:w]

        return frame1
Пример #9
0
    def forward(self, tensorFirst, tensorSecond):
        tensorJoin = torch.cat([tensorFirst, tensorSecond], 1)

        tensorConv1 = self.moduleConv1(tensorJoin)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)

        tensorCombine = tensorUpsample5 + tensorConv5

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)

        tensorCombine = tensorUpsample4 + tensorConv4

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)

        tensorCombine = tensorUpsample3 + tensorConv3

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)

        tensorCombine = tensorUpsample2 + tensorConv2
        print(self.moduleVertical1(tensorCombine).size())
        print(self.moduleHorizontal1(tensorCombine).size())
        #print(tensorCombine.size(0), tensorCombine.size(1),tensorCombine.size(2),tensorCombine.size(3))
        sepconv.FunctionSepconv().forward(
            self.modulePad(tensorFirst), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        #tensorDot2 = sepconv.FunctionSepconv().forward(self.modulePad(tensorSecond), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine))
        return torch.zeros(1, 1, 388 + 124, 584 + 56)  #1,1,512,640
Пример #10
0
    def forward(self, diff, tensorInput1, tensorInput2):
        '''
        tensorInput1/2 : [bcz, 3, height, width]
        diff:            [bcz, 2, height, width]
        '''
        diff *= 2.0  # @ I multiply it by 2 just for favor of warping

        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)

        # ---------------- Predict the back-forward optical flow and warp the inputFrame2                         Part1
        tensorOptConv1 = self.optConv1(diff)  #[32, 128, 128]
        tensorOptPool1 = self.optPool1(tensorOptConv1)

        tensorOptConv2 = self.optConv2(tensorOptPool1)  #[64, 64, 64]
        tensorOptPool2 = self.optPool2(tensorOptConv2)

        tensorOptConv3 = self.optConv3(tensorOptPool2)  #[128, 32, 32]
        tensorOptPool3 = self.optPool3(tensorOptConv3)

        tensorOptConv4 = self.optConv4(tensorOptPool3)  #[256, 16, 16]
        tensorOptPool4 = self.optPool4(tensorOptConv4)

        tensorOptConv5 = self.optConv5(tensorOptPool4)  #[512, 8, 8]
        tensorOptPool5 = self.optPool5(tensorOptConv5)

        tensorOptDeconv5 = self.optDeconv5(tensorOptPool5)
        tensorOptUpsample5 = self.optUpsample5(tensorOptDeconv5)
        tensorCombine = tensorOptUpsample5 + tensorOptConv5

        tensorOptDeconv4 = self.optDeconv4(tensorCombine)
        tensorOptUpsample4 = self.optUpsample4(tensorOptDeconv4)
        tensorCombine = tensorOptUpsample4 + tensorOptConv4

        tensorOptDeconv3 = self.optDeconv3(tensorCombine)
        tensorOptUpsample3 = self.optUpsample3(tensorOptDeconv3)
        tensorCombine = tensorOptUpsample3 + tensorOptConv3

        tensorOptDeconv2 = self.optDeconv2(tensorCombine)
        tensorOptUpsample2 = self.optUpsample2(tensorOptDeconv2)
        tensorCombine = tensorOptUpsample2 + tensorOptConv2

        tensorOptDeconv1 = self.optDeconv1(tensorCombine)
        tensorOptUpsample1 = self.optUpsample1(tensorOptDeconv1)
        tensorCombine = tensorOptUpsample1 + tensorOptConv1

        tensorOptPred1 = self.optPred(tensorCombine)

        # Warp the raw image
        tensorWarp1 = self.opt.warp(tensorOptPred1, tensorInput2)
        # ---------------- Predict the back-forward optical flow and warp the inputFrame2                         Part1

        tensorConv1 = self.moduleConv1(tensorJoin)  #[32, 128, 128]
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)  #[64, 64, 64]
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)  #[128, 32, 32]
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)  #[256, 16, 16]
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)  #[512, 8, 8]
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)  #[512, 4, 4]
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)  #[512, 8, 8]

        tensorCombine = tensorUpsample5 + tensorConv5  #[512, 8, 8]

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)  #[256, 8, 8]
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)  #[256, 16, 16]

        tensorCombine = tensorUpsample4 + tensorConv4  #[256, 16, 16]

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)  #[128, 16, 16]
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)  #[128, 32, 32]

        tensorCombine = tensorUpsample3 + tensorConv3  #[128, 32, 32]

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)  #[64, 32, 32]
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)  #[64, 64, 64]

        tensorCombine = tensorUpsample2 + tensorConv2  #[64, 64, 64]

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine),
            self.moduleHorizontal2(tensorCombine))

        tensorDot = tensorDot1 + tensorDot2
        tensorRet = self.fuse(torch.cat([tensorDot, tensorWarp1], 1))
        # tensorRet = self.fuse(tensorDot)

        return tensorDot, tensorWarp1, tensorRet
Пример #11
0
    def forward(self, tensorInput1, tensorInput2):
        '''
        tensorInput1/2 : [bcz, 3, height, width]
        diff:            [bcz, 2, height, width]
        '''

        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)
        x = tensorJoin
        x = self.conv1(x)

        # x = self.conv1_bn(x)
        conv1 = self.relu(x)

        x = self.pool(conv1)

        x = self.conv2(x)
        # x = self.conv2_bn(x)
        conv2 = self.relu(x)

        x = self.pool(conv2)

        x = self.conv3(x)
        # x = self.conv3_bn(x)
        conv3 = self.relu(x)

        x = self.pool(conv3)

        x = self.bottleneck(x)
        # x = self.bottleneck_bn(x)
        x = self.relu(x)

        x = nn.functional.upsample(x,
                                   scale_factor=2,
                                   mode='bilinear',
                                   align_corners=False)

        x = torch.cat([x, conv3], dim=1)
        x = self.deconv1(x)
        # x = self.deconv1_bn(x)
        x = self.relu(x)

        x = nn.functional.upsample(x,
                                   scale_factor=2,
                                   mode='bilinear',
                                   align_corners=False)

        x = torch.cat([x, conv2], dim=1)
        x = self.deconv2(x)
        # x = self.deconv2_bn(x)
        x = self.relu(x)

        x = nn.functional.upsample(x,
                                   scale_factor=2,
                                   mode='bilinear',
                                   align_corners=False)

        x = torch.cat([x, conv1], dim=1)
        x = self.deconv3(x)
        # x = self.deconv3_bn(x)
        x = self.relu(x)

        x = self.conv4(x)
        mask = nn.functional.tanh(x)
        # ---------------- Predict the back-forward optical flow and warp the inputFrame2                         Part1

        # ---------------- Predict the back-forward optical flow and warp the inputFrame2                         Part1

        tensorConv1 = self.moduleConv1(tensorJoin)  #[32, 128, 128]
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)  #[64, 64, 64]
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)  #[128, 32, 32]
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)  #[256, 16, 16]
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)  #[512, 8, 8]
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)  #[512, 4, 4]
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)  #[512, 8, 8]

        tensorCombine = tensorUpsample5 + tensorConv5  #[512, 8, 8]

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)  #[256, 8, 8]
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)  #[256, 16, 16]

        tensorCombine = tensorUpsample4 + tensorConv4  #[256, 16, 16]

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)  #[128, 16, 16]
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)  #[128, 32, 32]

        tensorCombine = tensorUpsample3 + tensorConv3  #[128, 32, 32]

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)  #[64, 32, 32]
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)  #[64, 64, 64]

        tensorCombine = tensorUpsample2 + tensorConv2  #[64, 64, 64]

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine),
            self.moduleHorizontal2(tensorCombine))
        mask = 0.5 * (1.0 + mask)
        mask = mask.repeat([1, 3, 1, 1])
        x = mask * tensorDot1 + (1.0 - mask) * tensorDot2

        return x
Пример #12
0
    def forward(self,
                tensorInput1,
                tensorInput2,
                tensorResidual=None,
                tensorHidden=None):
        '''
        tensorInput1/2 : [bcz, 3, height, width]
        tensorResidual:  [bcz, 3, height, width]
        tensorHidden:(tuple or None) ([bcz, hidden_dim, height, width])
        When the LSTM_state is Noe, it means that its the first time step
        '''
        batch_size = tensorInput1.size(0)
        # ------------------- LSTM Part --------------------
        if tensorResidual is None:
            tensorResidual = var(
                torch.zeros(batch_size, tensorInput1.size(1),
                            tensorInput1.size(2),
                            tensorInput1.size(3))).cuda()
            tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual))
            tensorH_next, tensorC_next = self.moduleLSTM(
                tensorEncRes)  # Hence we also don't have the tensorHidden
        else:
            tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual))
            tensorH_next, tensorC_next = self.moduleLSTM(
                tensorEncRes, tensorHidden)

    # ------------------- Encoder Part -----------------
        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)

        tensorConv1 = self.moduleConv1(tensorJoin)  #[32, 128, 128]
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)  #[64, 64, 64]
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)  #[128, 32, 32]
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)  #[256, 16, 16]
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)  #[512, 8, 8]
        tensorPool5 = self.modulePool5(tensorConv5)

        # ------------------- Doceder Part -----------------
        tensorDeconv5 = self.moduleDeconv5(tensorPool5)  #[512, 4, 4]
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)  #[512, 8, 8]

        tensorCombine = tensorUpsample5 + tensorConv5  #[512, 8, 8]

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)  #[256, 8, 8]
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)  #[256, 16, 16]

        tensorCombine = tensorUpsample4 + tensorConv4  #[256, 16, 16]

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)  #[128, 16, 16]
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)  #[128, 32, 32]

        tensorCombine = tensorUpsample3 + tensorConv3  #[128, 32, 32]

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)  #[64, 32, 32]
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)  #[64, 64, 64]

        tensorCombine1 = tensorUpsample2 + tensorConv2  #[64, 64, 64]

        tensorCombine = torch.cat([tensorCombine1, tensorH_next], 1)

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical11(tensorCombine),
            self.moduleHorizontal11(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2), self.moduleVertical22(tensorCombine),
            self.moduleHorizontal22(tensorCombine))

        # Return the predictd tensor and the next state of convLSTM
        return tensorDot1 + tensorDot2, (tensorH_next, tensorC_next)
Пример #13
0
        moduleNetwork.load_state_dict(torch.load("../models/" + predir))

    for epoch in range(100):
        for n in range(train):
            #making train data
            image3b = torch.ones((1, 3, 178, 178)).cuda(
            )  ##25 pixel wider picture for each direction to synthesis kernel and I1, I,2
            image1b = torch.ones((1, 3, 178, 178)).cuda()
            image1 = torch.ones(1, 3, 128, 128).cuda()
            image2 = torch.ones(1, 3, 128, 128).cuda()
            image3 = torch.ones(1, 3, 128, 128).cuda()

            #forward caluclation
            Kernel = moduleNetwork.forward(image1, image3)
            kernelDiv = torch.chunk(Kernel, 4, dim=3)
            tensorDot1 = sepconv.FunctionSepconv().forward(
                image1b, kernelDiv[0], kernelDiv[1]).detach()
            tensorDot2 = sepconv.FunctionSepconv().forward(
                image3b, kernelDiv[2], kernelDiv[3]).detach()
            tensorDot1.requires_grad = True
            tensorDot2.requires_grad = True
            tensorCombine = tensorDot1 + tensorDot2

            #backward caluclation
            loss = loss_fn(tensorCombine, image2)
            value_loss = loss.item()
            loss.backward()
            kgrad1 = sepconv.FunctionSepconv().backward(
                tensorDot1.grad,
                (tensorCombine, image1b, kernelDiv[0], kernelDiv[1]))
            kgrad2 = sepconv.FunctionSepconv().backward(
                tensorDot2.grad,
    def forward(self,
                tensorInput1,
                tensorInput2,
                tensorResidual=None,
                tensorHidden=None):
        '''
        tensorInput1/2 : [bcz, 3, height, width]
        tensorResidual:  [bcz, 3, height, width]
        tensorHidden:(tuple or None) ([bcz, hidden_dim, height, width])
        When the LSTM_state is Noe, it means that its the first time step
        '''
        batch_size = tensorInput1.size(0)
        # ------------------- LSTM Part --------------------
        if tensorResidual is None:
            tensorResidual = var(
                torch.zeros(batch_size, tensorInput1.size(1),
                            tensorInput1.size(2),
                            tensorInput1.size(3))).cuda()
            tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual))
            tensorH_next, tensorC_next = self.moduleLSTM(
                tensorEncRes)  # Hence we also don't have the tensorHidden
        else:
            tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual))
            tensorH_next, tensorC_next = self.moduleLSTM(
                tensorEncRes, tensorHidden)

        # ------------------------- I use the convolution with stride of 2 to work as a downsample function~, which accords to the resolution of [128, 128], [64, 64], [32, 32]
        tensorL0 = tensorH_next
        tensorL1 = self.moduleDownLSTM1(tensorL0)
        tensorL2 = self.moduleDownLSTM2(tensorL1)
        # ------------------------- I use the convolution with stride of 2 to work as a downsample function~, which accords to the resolution of [128, 128], [64, 64], [32, 32]

        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)
        tensorConv1 = self.moduleConv1(tensorJoin)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)

        tensorCombine = tensorUpsample5 + tensorConv5

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)

        tensorCombine = tensorUpsample4 + tensorConv4

        # ------- LSTM combine ------------
        tensorCombineL2 = torch.cat([tensorCombine, tensorL2],
                                    1)  # This channel is 256 + 128 = 384
        # ------- LSTM combine ------------

        tensorDot1_a = sepconv.FunctionSepconv()(self.modulePad_a(
            func.upsample(tensorInput1,
                          size=(tensorInput1.shape[2] // 4,
                                tensorInput1.shape[3] // 4),
                          mode='bilinear',
                          align_corners=True)), self.mv1_a_(tensorCombineL2),
                                                 self.mh1_a_(tensorCombineL2))
        tensorDot2_a = sepconv.FunctionSepconv()(self.modulePad_a(
            func.upsample(tensorInput2,
                          size=(tensorInput1.shape[2] // 4,
                                tensorInput1.shape[3] // 4),
                          mode='bilinear',
                          align_corners=True)), self.mv2_a_(tensorCombineL2),
                                                 self.mh2_a_(tensorCombineL2))

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)

        tensorCombine = tensorUpsample3 + tensorConv3

        # ------- LSTM combine ------------
        tensorCombineL1 = torch.cat([tensorCombine, tensorL1],
                                    1)  # This channel is 128 + 64 = 192
        # ------- LSTM combine ------------
        tensorDot1_b = sepconv.FunctionSepconv()(self.modulePad_b(
            func.upsample(tensorInput1,
                          size=(tensorInput1.shape[2] // 2,
                                tensorInput1.shape[3] // 2),
                          mode='bilinear',
                          align_corners=True)), self.mv1_b_(tensorCombineL1),
                                                 self.mh1_b_(tensorCombineL1))
        tensorDot2_b = sepconv.FunctionSepconv()(self.modulePad_b(
            func.upsample(tensorInput2,
                          size=(tensorInput1.shape[2] // 2,
                                tensorInput1.shape[3] // 2),
                          mode='bilinear',
                          align_corners=True)), self.mv2_b_(tensorCombineL1),
                                                 self.mh2_b_(tensorCombineL1))

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)

        tensorCombine = tensorUpsample2 + tensorConv2
        # ------- LSTM combine ------------
        tensorCombineL0 = torch.cat([tensorCombine, tensorL0],
                                    1)  # This channel is 64 + 32 = 96
        # ------- LSTM combine ------------
        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1),
            self.moduleVertical1_(tensorCombineL0),
            self.moduleHorizontal1_(tensorCombineL0))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput2),
            self.moduleVertical2_(tensorCombineL0),
            self.moduleHorizontal2_(tensorCombineL0))

        return tensorDot1 + tensorDot2, tensorDot1_a + tensorDot2_a, tensorDot1_b + tensorDot2_b, (
            tensorH_next, tensorC_next)
Пример #15
0
    def forward(self, tensorInput1, tensorInput2):
        '''
        tensorInput1/2 : [bcz, 3, height, width]
        diff:            [bcz, 2, height, width]
        '''

        tensorJoin = torch.cat([tensorInput1, tensorInput2], 1)

        # ---------------- Predict the back-forward optical flow and warp the inputFrame2                         Part1
        tensorOptConv1 = self.optConv1(tensorJoin)  #[32, 128, 128]
        tensorOptPool1 = self.optPool1(tensorOptConv1)

        tensorOptConv2 = self.optConv2(tensorOptPool1)  #[64, 64, 64]
        tensorOptPool2 = self.optPool2(tensorOptConv2)

        tensorOptConv3 = self.optConv3(tensorOptPool2)  #[128, 32, 32]
        tensorOptPool3 = self.optPool3(tensorOptConv3)

        tensorOptConv4 = self.optConv4(tensorOptPool3)  #[256, 16, 16]
        tensorOptPool4 = self.optPool4(tensorOptConv4)

        tensorOptConv5 = self.optConv5(tensorOptPool4)  #[512, 8, 8]
        tensorOptPool5 = self.optPool5(tensorOptConv5)

        tensorOptDeconv5 = self.optDeconv5(tensorOptPool5)
        tensorOptUpsample5 = self.optUpsample5(tensorOptDeconv5)
        tensorCombine = tensorOptUpsample5 + tensorOptConv5

        tensorOptDeconv4 = self.optDeconv4(tensorCombine)
        tensorOptUpsample4 = self.optUpsample4(tensorOptDeconv4)
        tensorCombine = tensorOptUpsample4 + tensorOptConv4

        tensorOptDeconv3 = self.optDeconv3(tensorCombine)
        tensorOptUpsample3 = self.optUpsample3(tensorOptDeconv3)
        tensorCombine = tensorOptUpsample3 + tensorOptConv3

        tensorOptDeconv2 = self.optDeconv2(tensorCombine)
        tensorOptUpsample2 = self.optUpsample2(tensorOptDeconv2)
        tensorCombine = tensorOptUpsample2 + tensorOptConv2

        tensorOptDeconv1 = self.optDeconv1(tensorCombine)
        tensorOptUpsample1 = self.optUpsample1(tensorOptDeconv1)
        tensorCombine = tensorOptUpsample1 + tensorOptConv1

        tensorOptPred = self.optPred(tensorCombine)

        # tensorOptPred1 = tensorOptPred[:,:2,:,:]
        # tensorOptPred2 = tensorOptPred[:,2:,:,:]
        # # Warp the raw image
        # tensorWarp1 = self.opt.warp(tensorOptPred1, tensorInput1)
        # tensorWarp2 = self.opt.warp(tensorOptPred2, tensorInput2)
        tensorWarp = self.opt.warp(tensorOptPred, tensorInput2)
        # ---------------- Predict the back-forward optical flow and warp the inputFrame2                         Part1

        tensorConv1 = self.moduleConv1(tensorJoin)  #[32, 128, 128]
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)  #[64, 64, 64]
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)  #[128, 32, 32]
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)  #[256, 16, 16]
        tensorPool4 = self.modulePool4(tensorConv4)

        tensorConv5 = self.moduleConv5(tensorPool4)  #[512, 8, 8]
        tensorPool5 = self.modulePool5(tensorConv5)

        tensorDeconv5 = self.moduleDeconv5(tensorPool5)  #[512, 4, 4]
        tensorUpsample5 = self.moduleUpsample5(tensorDeconv5)  #[512, 8, 8]

        tensorCombine = tensorUpsample5 + tensorConv5  #[512, 8, 8]

        tensorDeconv4 = self.moduleDeconv4(tensorCombine)  #[256, 8, 8]
        tensorUpsample4 = self.moduleUpsample4(tensorDeconv4)  #[256, 16, 16]

        tensorCombine = tensorUpsample4 + tensorConv4  #[256, 16, 16]

        tensorDeconv3 = self.moduleDeconv3(tensorCombine)  #[128, 16, 16]
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)  #[128, 32, 32]

        tensorCombine = tensorUpsample3 + tensorConv3  #[128, 32, 32]

        tensorDeconv2 = self.moduleDeconv2(tensorCombine)  #[64, 32, 32]
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)  #[64, 64, 64]

        tensorCombine = tensorUpsample2 + tensorConv2  #[64, 64, 64]

        tensorDot1 = sepconv.FunctionSepconv()(
            self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine),
            self.moduleHorizontal1(tensorCombine))
        tensorDot2 = sepconv.FunctionSepconv()(
            self.modulePad(tensorWarp), self.moduleVertical2(tensorCombine),
            self.moduleHorizontal2(tensorCombine))
        return tensorDot1 + tensorDot2
Пример #16
0
    def forward(self, frames):

        _, f, _, h, w = frames.shape

        h_padded = False
        w_padded = False
        padded_frames = frames.clone()

        if h % 32 != 0:
            pad_h = 32 - (h % 32)
            padded_frames = F.pad(padded_frames, (0, 0, 0, pad_h))
            h_padded = True

        if w % 32 != 0:
            pad_w = 32 - (w % 32)
            padded_frames = F.pad(padded_frames, (0, pad_w, 0, 0))
            w_padded = True

        # get kernels from subnets
        V1, H1, V2, H2, VQ1, HQ1, VQ2, HQ2 = self.interpolation_kernels = self.get_kernel(
            padded_frames)

        frame_before = int(0 + f / 4)
        frame_after = int(1 + f / 4)

        tensorDotL = sepconv.FunctionSepconv()(self.modulePad_l(
            padded_frames[:, frame_before]), V1[0], H1[0])
        tensorDotR = sepconv.FunctionSepconv()(self.modulePad_l(
            padded_frames[:, frame_after]), V2[0], H2[0])

        if self.kl_d_size != None:
            # downscale input frames
            im1d = self.down_l(padded_frames[:, frame_before])
            im2d = self.down_l(padded_frames[:, frame_after])

            # convolve and upscale back to original size
            tensorDotL += self.up_l(sepconv.FunctionSepconv()(
                self.modulePad_ld(im1d), V1[1], H1[1]))
            tensorDotR += self.up_l(sepconv.FunctionSepconv()(
                self.modulePad_ld(im2d), V2[1], H2[1]))

        if self.kq_size != None:
            tensorDotLL = sepconv.FunctionSepconv()(self.modulePad_q(
                padded_frames[:, 0]), VQ1[0], HQ1[0])
            tensorDotRR = sepconv.FunctionSepconv()(self.modulePad_q(
                padded_frames[:, 3]), VQ2[0], HQ2[0])
        else:
            tensorDotLL = tensorDotRR = 0

        if self.kq_d_size != None:
            im1qd = self.down_q(padded_frames[:, 0])
            im2qd = self.down_q(padded_frames[:, 3])

            tensorDotLL += self.up_q(sepconv.FunctionSepconv()(
                self.modulePad_qd(im1qd), VQ1[1], HQ1[1]))
            tensorDotRR += self.up_q(sepconv.FunctionSepconv()(
                self.modulePad_qd(im2qd), VQ2[1], HQ2[1]))

        frame_out = tensorDotL + tensorDotR + tensorDotLL + tensorDotRR

        if h_padded:
            frame_out = frame_out[:, :, 0:h, :]
        if w_padded:
            frame_out = frame_out[:, :, :, 0:w]

        return frame_out