Exemplo n.º 1
0
    def __call__(self, x):
        # x: b*2*h*w
        #Part 1
        # print("part 1") #TODO
        x_L1 = self.pool(x)
        b, c, h, w = x_L1.size()
        input_L1 = torch.cat((x_L1, torch.zeros(b, 2, h, w).cuda()), 1)
        optical_flow_L1 = self.RNN2(self.RNN1(input_L1))
        # optical_flow_L1_upscaled = F.interpolate(optical_flow_L1, scale_factor=2, mode='bilinear', align_corners=False) * 2
        
        # TODO: check, temporary fix, since the original interpolation was not producing the correct shape required in Part 2
        # in optical_flow_warp, instead of shape torch.Size([2, 1, 66, 75]) like the image, it was producing torch.Size([2, 1, 66, 74])
        # here I'm forcing it to be interpolated to exactly the size of the image
        image_shape = torch.unsqueeze(x[:, 0, :, :], 1).shape
        optical_flow_L1_upscaled = F.interpolate(optical_flow_L1, size=(image_shape[2],image_shape[3]), mode='bilinear', align_corners=False) * 2
        # print(optical_flow_L1_upscaled.shape)
        # print(torch.unsqueeze(x[:, 0, :, :], 1).shape)

        #Part 2
        # print("part 2") #TODO
        x_L2 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L1_upscaled)
        input_L2 = torch.cat((x_L2, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L1_upscaled), 1)
        optical_flow_L2 = self.RNN2(self.RNN1(input_L2)) + optical_flow_L1_upscaled

        #Part 3
        # print("part 3") #TODO
        x_L3 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L2)
        input_L3 = torch.cat((x_L3, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L2), 1)
        # print(self.SR(self.RNN1(input_L3)).shape)
        # tmpL3 = self.RNN1(input_L3)
        # print("tmpL3", tmpL3.shape)
        # print("part SR")
        optical_flow_L3 = self.SR(self.RNN1(input_L3)) + \
                          F.interpolate(optical_flow_L2, scale_factor=self.scale, mode='bilinear', align_corners=False) * self.scale
        return optical_flow_L1, optical_flow_L2, optical_flow_L3
Exemplo n.º 2
0
    def forward(self, x):
        # x: b*n*c*h*w
        b, n_frames, c, h, w = x.size()
        idx_center = (n_frames - 1) // 2

        # motion estimation
        flow_L1 = []
        flow_L2 = []
        flow_L3 = []
        input = []

        for idx_frame in range(n_frames):
            if idx_frame != idx_center:
                input.append(torch.cat((x[:,idx_frame,:,:,:], x[:,idx_center,:,:,:]), 1))
        optical_flow_L1, optical_flow_L2, optical_flow_L3 = self.OFR(torch.cat(input, 0))

        optical_flow_L1 = optical_flow_L1.view(-1, b, 2, h//2, w//2)
        optical_flow_L2 = optical_flow_L2.view(-1, b, 2, h, w)
        optical_flow_L3 = optical_flow_L3.view(-1, b, 2, h*self.scale, w*self.scale)

        # motion compensation
        draft_cube = []
        draft_cube.append(x[:, idx_center, :, :, :])

        for idx_frame in range(n_frames):
            if idx_frame == idx_center:
                flow_L1.append([])
                flow_L2.append([])
                flow_L3.append([])
            else: # if idx_frame != idx_center:
                if idx_frame < idx_center:
                    idx = idx_frame
                if idx_frame > idx_center:
                    idx = idx_frame - 1

                flow_L1.append(optical_flow_L1[idx, :, :, :, :])
                flow_L2.append(optical_flow_L2[idx, :, :, :, :])
                flow_L3.append(optical_flow_L3[idx, :, :, :, :])

                # Generate the draft_cube by subsampling the SR flow optical_flow_L3
                # according to the scale
                for i in range(self.scale):
                    for j in range(self.scale):
                        draft = optical_flow_warp(x[:, idx_frame, :, :, :],
                                                  optical_flow_L3[idx, :, :, i::self.scale, j::self.scale] / self.scale)
                        draft_cube.append(draft)
        draft_cube = torch.cat(draft_cube, 1)
        # print('draft_cube:', draft_cube.shape) #TODO

        # super-resolution
        SR = self.SR(draft_cube)

        return flow_L1, flow_L2, flow_L3, SR
Exemplo n.º 3
0
 def forward(self, x0, x1, optical_flow):
     warped = optical_flow_warp(x0, optical_flow)
     loss = torch.mean(torch.abs(
         x1 - warped)) + self.reg_weight * self.regularization(optical_flow)
     return loss