def __call__(self, x): # x: b*2*h*w #Part 1 # print("part 1") #TODO x_L1 = self.pool(x) b, c, h, w = x_L1.size() input_L1 = torch.cat((x_L1, torch.zeros(b, 2, h, w).cuda()), 1) optical_flow_L1 = self.RNN2(self.RNN1(input_L1)) # optical_flow_L1_upscaled = F.interpolate(optical_flow_L1, scale_factor=2, mode='bilinear', align_corners=False) * 2 # TODO: check, temporary fix, since the original interpolation was not producing the correct shape required in Part 2 # in optical_flow_warp, instead of shape torch.Size([2, 1, 66, 75]) like the image, it was producing torch.Size([2, 1, 66, 74]) # here I'm forcing it to be interpolated to exactly the size of the image image_shape = torch.unsqueeze(x[:, 0, :, :], 1).shape optical_flow_L1_upscaled = F.interpolate(optical_flow_L1, size=(image_shape[2],image_shape[3]), mode='bilinear', align_corners=False) * 2 # print(optical_flow_L1_upscaled.shape) # print(torch.unsqueeze(x[:, 0, :, :], 1).shape) #Part 2 # print("part 2") #TODO x_L2 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L1_upscaled) input_L2 = torch.cat((x_L2, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L1_upscaled), 1) optical_flow_L2 = self.RNN2(self.RNN1(input_L2)) + optical_flow_L1_upscaled #Part 3 # print("part 3") #TODO x_L3 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L2) input_L3 = torch.cat((x_L3, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L2), 1) # print(self.SR(self.RNN1(input_L3)).shape) # tmpL3 = self.RNN1(input_L3) # print("tmpL3", tmpL3.shape) # print("part SR") optical_flow_L3 = self.SR(self.RNN1(input_L3)) + \ F.interpolate(optical_flow_L2, scale_factor=self.scale, mode='bilinear', align_corners=False) * self.scale return optical_flow_L1, optical_flow_L2, optical_flow_L3
def forward(self, x): # x: b*n*c*h*w b, n_frames, c, h, w = x.size() idx_center = (n_frames - 1) // 2 # motion estimation flow_L1 = [] flow_L2 = [] flow_L3 = [] input = [] for idx_frame in range(n_frames): if idx_frame != idx_center: input.append(torch.cat((x[:,idx_frame,:,:,:], x[:,idx_center,:,:,:]), 1)) optical_flow_L1, optical_flow_L2, optical_flow_L3 = self.OFR(torch.cat(input, 0)) optical_flow_L1 = optical_flow_L1.view(-1, b, 2, h//2, w//2) optical_flow_L2 = optical_flow_L2.view(-1, b, 2, h, w) optical_flow_L3 = optical_flow_L3.view(-1, b, 2, h*self.scale, w*self.scale) # motion compensation draft_cube = [] draft_cube.append(x[:, idx_center, :, :, :]) for idx_frame in range(n_frames): if idx_frame == idx_center: flow_L1.append([]) flow_L2.append([]) flow_L3.append([]) else: # if idx_frame != idx_center: if idx_frame < idx_center: idx = idx_frame if idx_frame > idx_center: idx = idx_frame - 1 flow_L1.append(optical_flow_L1[idx, :, :, :, :]) flow_L2.append(optical_flow_L2[idx, :, :, :, :]) flow_L3.append(optical_flow_L3[idx, :, :, :, :]) # Generate the draft_cube by subsampling the SR flow optical_flow_L3 # according to the scale for i in range(self.scale): for j in range(self.scale): draft = optical_flow_warp(x[:, idx_frame, :, :, :], optical_flow_L3[idx, :, :, i::self.scale, j::self.scale] / self.scale) draft_cube.append(draft) draft_cube = torch.cat(draft_cube, 1) # print('draft_cube:', draft_cube.shape) #TODO # super-resolution SR = self.SR(draft_cube) return flow_L1, flow_L2, flow_L3, SR
def forward(self, x0, x1, optical_flow): warped = optical_flow_warp(x0, optical_flow) loss = torch.mean(torch.abs( x1 - warped)) + self.reg_weight * self.regularization(optical_flow) return loss