def forward(self, frames, F_kprime_to_k, F_n_to_k_s, F_k_to_n_s): # frames: base frame also included h0 = int(list(frames[0].size())[2]) w0 = int(list(frames[0].size())[3]) h6 = int(list(frames[-1].size())[2]) w6 = int(list(frames[-1].size())[3]) if h0 != h6 or w0 != w6: sys.exit('Frame sizes do not match') GAUSSIAN_FILTER_KSIZE = len(frames) gaussian_filter = cv2.getGaussianKernel(GAUSSIAN_FILTER_KSIZE, -1) if self.args.inference_with_frame_selection > 0: import numpy as np var_lap_values = [] flow_values = [] clear_frame_indicator = [] def variance_of_laplacian(image): gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # compute the Laplacian of the image and then return the focus # measure, which is simply the variance of the Laplacian return cv2.Laplacian(gray, cv2.CV_32F).var() for iii, tmp_img in enumerate(frames): var_lap_values.append( variance_of_laplacian( np.transpose(tmp_img.cpu().numpy()[0], (1, 2, 0)))) flow_values.append( np.sum((F_n_to_k_s[iii].cpu().numpy()[0])**2)) mid_value = sorted(var_lap_values)[len(frames) // 2] mid_flow_value = sorted(flow_values)[len(frames) // 2] for iii, value in enumerate(var_lap_values): if value < mid_value and flow_values[iii] < mid_flow_value: clear_frame_indicator.append(False) else: clear_frame_indicator.append(True) clear_frames = [] clear_gaussian_filter = [] clear_F_n_to_k_s = [] clear_F_k_to_n_s = [] for idx, indicator in enumerate(clear_frame_indicator): if indicator: clear_frames.append(frames[idx]) clear_gaussian_filter.append(gaussian_filter[idx, 0]) clear_F_n_to_k_s.append(F_n_to_k_s[idx]) clear_F_k_to_n_s.append(F_k_to_n_s[idx]) frames = clear_frames F_n_to_k_s = clear_F_n_to_k_s F_k_to_n_s = clear_F_k_to_n_s gaussian_filter = np.copy(np.stack(clear_gaussian_filter, 0)) gaussian_filter = np.expand_dims(gaussian_filter, -1) gaussian_filter = gaussian_filter / np.sum(gaussian_filter) h_padded = False w_padded = False minimum_size = 4 if h0 % minimum_size != 0: pad_h = minimum_size - (h0 % minimum_size) F_kprime_to_k = F.pad(F_kprime_to_k, (0, 0, 0, pad_h), mode='replicate') for i in range(len(frames)): frames[i] = F.pad(frames[i], (0, 0, 0, pad_h), mode='replicate') F_n_to_k_s[i] = F.pad(F_n_to_k_s[i], (0, 0, 0, pad_h), mode='replicate') h_padded = True if w0 % minimum_size != 0: pad_w = minimum_size - (w0 % minimum_size) F_kprime_to_k = F.pad(F_kprime_to_k, (0, pad_w, 0, 0), mode='replicate') for i in range(len(frames)): frames[i] = F.pad(frames[i], (0, pad_w, 0, 0), mode='replicate') F_n_to_k_s[i] = F.pad(F_n_to_k_s[i], (0, pad_w, 0, 0), mode='replicate') w_padded = True if self.args.noDL_CNNAggregation > 0: """no learnable params""" """except final aggregation function""" W = 256 H = 256 tenOnes = torch.ones_like(frames[0])[:, 0:1, :, :] tenOnes = torch.nn.ZeroPad2d((W, W, H, H))(tenOnes).detach() F_kprime_to_k_pad = torch.nn.ZeroPad2d((W, W, H, H))(F_kprime_to_k) tenWarpedFeat = [] tenWarpedMask = [] tenWarpedFlow = [] for idx, feat in enumerate(frames): """padding for forward warping""" ref_frame_flow = torch.nn.ReplicationPad2d( (W, W, H, H))(F_n_to_k_s[idx]) tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(feat) """first forward warping""" tenWarpedFirst = softsplat.FunctionSoftsplat( tenInput=tenRef, tenFlow=ref_frame_flow, tenMetric=None, strType='average') tenMaskFirst = softsplat.FunctionSoftsplat( tenInput=tenOnes, tenFlow=ref_frame_flow, tenMetric=None, strType='average') tenFlowFirst = softsplat.FunctionSoftsplat( tenInput=ref_frame_flow, tenFlow=ref_frame_flow, tenMetric=None, strType='average') """second backward warping""" tenWarpedSecond = backwarp(tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad) tenMaskSecond = backwarp(tenInput=tenMaskFirst, tenFlow=F_kprime_to_k_pad) tenFlowSecond = backwarp(tenInput=tenFlowFirst, tenFlow=F_kprime_to_k_pad) """back to original resolution""" tenWarped = tenWarpedSecond[:, :, H:-H, W:-W] tenMask = tenMaskSecond[:, :, H:-H, W:-W] tenFlow = tenFlowSecond[:, :, H:-H, W:-W] tenWarpedFeat.append(tenWarped) tenWarpedMask.append(tenMask) tenWarpedFlow.append(tenFlow) color_tensor = [] for i in range(len(tenWarpedFeat)): color_tensor.append(tenWarpedFeat[i]) color_tensor = torch.stack(color_tensor, 0) weight_tensor = [] for i in range(len(tenWarpedFeat)): weight_tensor.append( self.aggregationWeightingNetwork( torch.cat([ tenWarpedFeat[i], tenWarpedMask[i], tenWarpedFlow[i], tenWarpedFeat[len(frames) // 2], tenWarpedMask[len(frames) // 2], torch.abs(tenWarpedFeat[i] - tenWarpedFeat[len(frames) // 2]) ], 1))) weight_tensor = torch.stack(weight_tensor, 0) if self.args.gumbel > 0: weight_tensor = gumbel_softmax(weight_tensor, hard=True, dim=0) else: weight_tensor = torch.softmax(weight_tensor, 0) total_mask = torch.sum(torch.stack(tenWarpedMask, 0), 0) global_average_pooled_feature = torch.sum(color_tensor * weight_tensor, dim=0) global_average_pooled_feature = global_average_pooled_feature * torch.where( total_mask > 0, torch.ones_like(total_mask), torch.zeros_like(total_mask)) if h_padded: global_average_pooled_feature = global_average_pooled_feature[:, :, 0: h0, :] if w_padded: global_average_pooled_feature = global_average_pooled_feature[:, :, :, 0: w0] return global_average_pooled_feature features = [] for i in range(len(frames)): features.append(self.spatialFeatureNetwork(frames[i])) # import numpy as np # nnn = feature0.cpu().numpy() # print(np.mean(nnn)) # print(np.std(nnn)) # print(np.min(nnn)) # print(np.max(nnn)) # forward warping + backward warping # W = list(features[0].size())[3]//2 # H = list(features[0].size())[2]//2 W = 256 H = 256 tenOnes = torch.ones_like(features[0])[:, 0:1, :, :] tenOnes = torch.nn.ZeroPad2d((W, W, H, H))(tenOnes).detach() if self.args.FOV_expansion > 0: F_kprime_to_k_pad = torch.nn.ReplicationPad2d( (W, W, H, H))(F_kprime_to_k) else: F_kprime_to_k_pad = torch.nn.ZeroPad2d((W, W, H, H))(F_kprime_to_k) tenWarpedFeat = [] tenWarpedMask = [] for idx, feat in enumerate(features): if self.args.all_backward > 0: ref_frame_flow = torch.nn.ReplicationPad2d( (W, W, H, H))(F_k_to_n_s[idx]) tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(feat) tenWarpedFirst = backwarp(tenInput=tenRef, tenFlow=ref_frame_flow) tenMaskFirst = backwarp(tenInput=tenOnes, tenFlow=ref_frame_flow) else: """padding for forward warping""" ref_frame_flow = torch.nn.ReplicationPad2d( (W, W, H, H))(F_n_to_k_s[idx]) tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(feat) """first forward warping""" tenWarpedFirst = softsplat.FunctionSoftsplat( tenInput=tenRef, tenFlow=ref_frame_flow, tenMetric=None, strType='average') tenMaskFirst = softsplat.FunctionSoftsplat( tenInput=tenOnes, tenFlow=ref_frame_flow, tenMetric=None, strType='average') """second backward warping""" if self.args.bundle_forward_flow > 0: tenWarpedSecond = softsplat.FunctionSoftsplat( tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad, tenMetric=None, strType='average') tenMaskSecond = softsplat.FunctionSoftsplat( tenInput=tenMaskFirst, tenFlow=F_kprime_to_k_pad, tenMetric=None, strType='average') else: tenWarpedSecond = backwarp(tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad) tenMaskSecond = backwarp(tenInput=tenMaskFirst, tenFlow=F_kprime_to_k_pad) """back to original resolution""" if self.args.FOV_expansion <= 0: tenWarped = tenWarpedSecond[:, :, H:-H, W:-W] tenMask = tenMaskSecond[:, :, H:-H, W:-W] else: tenWarped = tenWarpedSecond tenMask = tenMaskSecond tenWarpedFeat.append(tenWarped) tenWarpedMask.append(tenMask) # tenMetrics = [] # tenSoftmaxs = [] # for i in range(len(frames)): # tenMetrics.append(torch.nn.functional.l1_loss(input=frames[i], target=backwarp(tenInput=target_frame, tenFlow=flows[i]), reduction='none').mean(1, True)) # tenSoftmaxs.append(softsplat.FunctionSoftsplat(tenInput=features[i], tenFlow=flows[i], tenMetric=self.beta * tenMetrics[i], strType=self.splatting_type)) # # backward warping # tenSoftmaxs = [] # for i in range(len(frames)): # # tenSoftmaxs.append(backwarp(tenInput=features[i], tenFlow=flows[i])) # tenSoftmaxs.append(tenWarped) # # TODO: mask of backward warping # if self.args.pooling_with_mask > 0 or self.args.decoder_with_mask > 0 or self.args.softargmax_with_mask > 0: # def length_sq(x): # return torch.sum(x**2, dim=1, keepdim=True) # tenMasks = [] # for i in range(len(frames)): # mag_sq = length_sq(forward_flows[i]) + length_sq(flows[i]) # flow_fw_warped = backwarp(forward_flows[i], flows[i]) # flow_diff_bw = flows[i] + flow_fw_warped # occ_thresh = 0.01 * mag_sq + 0.5 # fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh).float() # if self.args.mask_with_proxy_mask > 0: # tenMasks.append(((1.0 - fb_occ_bw)*torch.mean(proxy_mask, dim=1, keepdim=True)).detach()) # else: # tenMasks.append((1.0-fb_occ_bw).detach()) # # imwrite(fb_occ_bw, 'mask'+str(i)+'.png', range=(0, 1)) # # imwrite(frames[i], 'frame'+str(i)+'.png', range=(0, 1)) # # mask of forward warping # if self.args.pooling_with_mask > 0 or self.args.decoder_with_mask > 0 or self.args.softargmax_with_mask > 0: # tenOnes = torch.ones([self.args.batch_size, 1, int(list(frames[0].size())[2]), int(list(frames[0].size())[3])]).cuda().detach() # tenMasks = [] # for i in range(len(frames)): # tenMasks.append(softsplat.FunctionSoftsplat(tenInput=tenOnes, tenFlow=flows[i], tenMetric=self.beta * tenMetrics[i], strType=self.splatting_type)) """Pooling""" if self.args.pooling_with_mask <= 0: global_average_pooled_feature = torch.stack(tenWarpedFeat, -1).mean(-1) else: if self.args.pooling_with_center_bias <= 0: tmp_list = [] for i in range(len(frames)): tmp_list.append(tenWarpedFeat[i] * tenWarpedMask[i]) global_summed_feature = torch.stack(tmp_list, -1).sum(-1) global_summed_weights = torch.stack(tenWarpedMask, -1).sum(-1) global_average_pooled_feature = global_summed_feature / torch.clamp( global_summed_weights, min=1e-6) else: if self.args.pooling_type == 'gaussian': # gaussian weights for center bias pooling # GAUSSIAN_FILTER_KSIZE = len(frames) # gaussian_filter = cv2.getGaussianKernel(GAUSSIAN_FILTER_KSIZE, -1) color_tensor = [] weight_tensor = [] for i in range(len(frames)): color_tensor.append(tenWarpedFeat[i]) weight_tensor.append(tenWarpedMask[i] * gaussian_filter[i, 0]) color_tensor = torch.stack(color_tensor, 0) weight_tensor = torch.stack(weight_tensor, 0) output_mask = torch.sum(weight_tensor, dim=0) global_average_pooled_feature = torch.sum( color_tensor * weight_tensor, dim=0) / torch.clamp( torch.sum(weight_tensor, dim=0), min=1e-6) elif self.args.pooling_type == 'max': def score_max(x, dim, score): _tmp = [1] * len(x.size()) _tmp[dim] = x.size(dim) _tmp[2] = x.size(2) # channel dim return torch.gather( x, dim, score.max(dim)[1].unsqueeze(dim).repeat( tuple(_tmp))).select(dim, 0) # GAUSSIAN_FILTER_KSIZE = len(frames) # gaussian_filter = cv2.getGaussianKernel(GAUSSIAN_FILTER_KSIZE, -1) color_tensor = [] weight_tensor = [] for i in range(len(frames)): color_tensor.append(tenWarpedFeat[i]) weight_tensor.append(tenWarpedMask[i] * gaussian_filter[i, 0]) color_tensor = torch.stack(color_tensor, 0) weight_tensor = torch.stack(weight_tensor, 0) output_mask = torch.sum(weight_tensor, dim=0) global_average_pooled_feature = score_max( color_tensor, 0, weight_tensor) elif self.args.pooling_type == 'CNN': # TODO: deep blend color_tensor = [] for i in range(len(frames)): color_tensor.append(tenWarpedFeat[i]) color_tensor = torch.stack(color_tensor, 0) weight_tensor = [] if self.args.inference_with_frame_selection > 0: middle_idx = np.argmax(np.array(gaussian_filter)) else: middle_idx = len(frames) // 2 for i in range(len(frames)): weight_tensor.append( self.aggregationWeightingNetwork( torch.cat([ tenWarpedFeat[i], tenWarpedMask[i], tenWarpedFeat[middle_idx], tenWarpedMask[middle_idx] ], 1))) weight_tensor = torch.stack(weight_tensor, 0) if self.args.gumbel > 0: weight_tensor = gumbel_softmax(weight_tensor, hard=True, dim=0) else: weight_tensor = torch.softmax(weight_tensor, 0) global_average_pooled_feature = torch.sum(color_tensor * weight_tensor, dim=0) elif self.args.pooling_type == 'CNN_flowError': # consistency def length_sq(x): return torch.sum(x**2, dim=1, keepdim=True) flow_errors = [] for i in range(len(frames)): flow_fw_warped = backwarp(F_n_to_k_s[i], F_k_to_n_s[i]) flow_diff_bw = F_k_to_n_s[i] + flow_fw_warped if self.args.bundle_forward_flow > 0: flow_errors.append( softsplat.FunctionSoftsplat( tenInput=length_sq(flow_diff_bw), tenFlow=F_kprime_to_k, tenMetric=None, strType='average')) else: flow_errors.append( backwarp(tenInput=length_sq(flow_diff_bw), tenFlow=F_kprime_to_k)) # flow_errors.append(length_sq(flow_diff_bw)) # TODO: deep blend color_tensor = [] for i in range(len(frames)): color_tensor.append(tenWarpedFeat[i]) color_tensor = torch.stack(color_tensor, 0) weight_tensor = [] if self.args.inference_with_frame_selection > 0: middle_idx = np.argmax(np.array(gaussian_filter)) else: middle_idx = len(frames) // 2 for i in range(len(frames)): if self.args.FOV_expansion > 0: flow_errors[i] = torch.nn.ReplicationPad2d( (W, W, H, H))(flow_errors[i]) weight_tensor.append( self.aggregationWeightingNetwork( torch.cat([ tenWarpedFeat[i], tenWarpedMask[i], flow_errors[i], tenWarpedFeat[middle_idx], tenWarpedMask[middle_idx] ], 1))) weight_tensor = torch.stack(weight_tensor, 0) if self.args.gumbel > 0: weight_tensor = gumbel_softmax(weight_tensor, hard=True, dim=0) else: weight_tensor = torch.softmax(weight_tensor, 0) global_average_pooled_feature = torch.sum(color_tensor * weight_tensor, dim=0) """decoder""" if self.args.no_pooling > 0 and self.args.single_decoder <= 0: I_preds = [] Cs = [] for i in range(len(frames)): I_pred, C = self.refinementNetwork( torch.cat([tenWarpedFeat[i], tenWarpedMask[i]], 1)) I_preds.append(I_pred) Cs.append(C) elif self.args.no_pooling <= 0 and self.args.single_decoder > 0: I_pred, C = self.refinementNetwork(global_average_pooled_feature) if h_padded: I_pred = I_pred[:, :, 0:h0, :] if w_padded: I_pred = I_pred[:, :, :, 0:w0] return I_pred else: if self.args.decoder_with_mask <= 0: I_preds = [] Cs = [] for i in range(len(frames)): I_pred, C = self.refinementNetwork( torch.cat( [tenWarpedFeat[i], global_average_pooled_feature], 1)) I_preds.append(I_pred) Cs.append(C) else: I_preds = [] Cs = [] for i in range(len(frames)): I_pred, C = self.refinementNetwork( torch.cat([ tenWarpedFeat[i], global_average_pooled_feature, tenWarpedMask[i] ], 1)) I_preds.append(I_pred) Cs.append(C) # TODO: residual detail transfer # if self.args.residual_detail_transfer > 0: # # print('with residual detail transfer') # tenOnes = torch.ones([self.args.batch_size, 1, int(list(frames[0].size())[2]), int(list(frames[0].size())[3])]).cuda().detach() # for i in range(len(frames)): # reconstructed_frame, _ = self.refinementNetwork(torch.cat([features[i], features[i], tenOnes], 1)) # delta_frame = frames[i] - reconstructed_frame # warped_delta_frame = backwarp(delta_frame, flows[i]) # # warped_delta_frame = softsplat.FunctionSoftsplat(tenInput=delta_frame, tenFlow=flows[i], tenMetric=self.beta * tenMetrics[i], strType=self.splatting_type) # if self.args.residual_detail_transfer_with_mask > 0: # I_preds[i] = I_preds[i] + warped_delta_frame * tenMasks[i] # else: # I_preds[i] = I_preds[i] + warped_delta_frame # center residual detail transfer (wierd, worse PNSR due to shared weights decoder) if self.args.center_residual_detail_transfer > 0: tenOnes = torch.ones([ self.args.batch_size, 1, int(list(frames[0].size())[2]), int(list(frames[0].size())[3]) ]).cuda().detach() center_idx = len(frames) // 2 reconstructed_frame, _ = self.refinementNetwork( torch.cat( [features[center_idx], features[center_idx], tenOnes], 1)) delta_frame = frames[center_idx] - reconstructed_frame warped_delta_frame = backwarp(delta_frame, F_kprime_to_k) # center residual detail transfer with mask I_preds[center_idx] = I_preds[ center_idx] + warped_delta_frame * tenWarpedMask[center_idx] # all frames residual detail transfer else: if self.args.residual_detail_transfer > 0: tenOnes = torch.ones([ self.args.batch_size, 1, int(list(frames[0].size())[2]), int(list(frames[0].size())[3]) ]).cuda().detach() for i in range(len(frames)): reconstructed_frame, _ = self.refinementNetwork( torch.cat([features[i], features[i], tenOnes], 1)) delta_frame = frames[i] - reconstructed_frame # warped_delta_frame = backwarp(delta_frame, F_kprime_to_k) if self.args.all_backward > 0: ref_frame_flow = torch.nn.ReplicationPad2d( (W, W, H, H))(F_k_to_n_s[idx]) tenRef = torch.nn.ReplicationPad2d( (W, W, H, H))(delta_frame) tenWarpedFirst = backwarp(tenInput=tenRef, tenFlow=ref_frame_flow) else: """padding for forward warping""" ref_frame_flow = torch.nn.ReplicationPad2d( (W, W, H, H))(F_n_to_k_s[i]) tenRef = torch.nn.ReplicationPad2d( (W, W, H, H))(delta_frame) """first forward warping""" tenWarpedFirst = softsplat.FunctionSoftsplat( tenInput=tenRef, tenFlow=ref_frame_flow, tenMetric=None, strType='average') """second backward warping""" if self.args.bundle_forward_flow > 0: tenWarpedSecond = softsplat.FunctionSoftsplat( tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad, tenMetric=None, strType='average') else: tenWarpedSecond = backwarp(tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad) """back to original resolution""" if self.args.FOV_expansion <= 0: tenWarpedResidual = tenWarpedSecond[:, :, H:-H, W:-W] else: tenWarpedResidual = tenWarpedSecond # center residual detail transfer with mask # if self.args.masked_residual_detail_transfer > 0 and i != len(frames)//2: # I_preds[i] = I_preds[i] + tenWarpedResidual * torch.clamp((tenWarpedMask[i] - tenWarpedMask[len(frames)//2]), min=0.0) # else: I_preds[ i] = I_preds[i] + tenWarpedResidual * tenWarpedMask[i] # blending # if self.training: if not self.args.softargmax_with_mask: softmax_conf = torch.softmax(torch.stack(Cs, -1), -1) else: tmp_list = [] for i in range(len(frames)): tmp_list.append(Cs[i] * tenWarpedMask[i]) softmax_conf = torch.softmax(torch.stack(tmp_list, -1), -1) I_pred = (I_preds[0] * softmax_conf[..., 0]) for i in range(1, len(frames)): I_pred += (I_preds[i] * softmax_conf[..., i]) if self.args.seamless > 0: import numpy as np # print(gaussian_filter) # print(np.argmax(np.array(gaussian_filter))) # print(len(gaussian_filter)) # print(len(frames)) # import pdb # pdb.set_trace() if self.args.inference_with_frame_selection > 0: middle_idx = np.argmax(np.array(gaussian_filter)) else: middle_idx = len(frames) // 2 ref_frame_flow = torch.nn.ReplicationPad2d( (W, W, H, H))(F_n_to_k_s[middle_idx]) tenRef = torch.nn.ReplicationPad2d( (W, W, H, H))(frames[middle_idx]) """first forward warping""" tenWarpedFirst = softsplat.FunctionSoftsplat( tenInput=tenRef, tenFlow=ref_frame_flow, tenMetric=None, strType='average') """second backward warping""" if self.args.bundle_forward_flow > 0: tenWarpedSecond = softsplat.FunctionSoftsplat( tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad, tenMetric=None, strType='average') else: tenWarpedSecond = backwarp(tenInput=tenWarpedFirst, tenFlow=F_kprime_to_k_pad) """back to original resolution""" if self.args.FOV_expansion <= 0: tenWarpedKey = tenWarpedSecond[:, :, H:-H, W:-W] else: tenWarpedKey = tenWarpedSecond # """warped key frame""" # tenRef = torch.nn.ReplicationPad2d((W, W, H, H))(frames[len(frames)//2]) # tenWarpedSecond = backwarp(tenInput=tenRef, tenFlow=F_kprime_to_k_pad) # """back to original resolution""" # if self.args.FOV_expansion <= 0: # tenWarpedKey = tenWarpedSecond[:, :, H:-H, W:-W] # else: # tenWarpedKey = tenWarpedSecond npWarpedKey = np.transpose( np.clip(tenWarpedKey.cpu().numpy(), 0.0, 1.0)[0, ::-1], (1, 2, 0)) np_I_pred = np.transpose( np.clip(I_pred.cpu().numpy(), 0.0, 1.0)[0, ::-1], (1, 2, 0)) np_mask_A = np.clip(tenWarpedMask[middle_idx].cpu().numpy(), 0.0, 1.0)[0, 0, :, :] np_mask_B = np.clip(1.0 - tenWarpedMask[middle_idx].cpu().numpy(), 0.0, 1.0)[0, 0, :, :] # """np_mask[0] = 0 # np_mask[-1] = 0 # np_mask[:, 0] = 0 # np_mask[:, -1] = 0""" # # """erosion_size = 10 # element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * erosion_size + 1, 2 * erosion_size + 1), # (erosion_size, erosion_size)) # # np_mask_dst = cv2.erode(np_mask, element) # cv2.imwrite('np_mask_dst.png', np_mask_dst)""" # # cv2.imwrite('np_mask.png', np_mask) # cv2.imwrite('np_I_pred.png', np_I_pred) # cv2.imwrite('npWarpedKey.png', npWarpedKey) # # normal_clone = cv2.seamlessClone(np_I_pred, npWarpedKey, np_mask, (np_mask.shape[1] // 2, np_mask.shape[0] // 2), cv2.NORMAL_CLONE) # I_pred = torch.from_numpy(np.expand_dims(np.transpose(normal_clone[:, :, ::-1], (2, 0, 1)), 0).astype(np.float32) / 255.0).cuda() def GaussianPyramid(img, leveln): GP = [img] for i in range(leveln - 1): GP.append(cv2.pyrDown(GP[i])) return GP def LaplacianPyramid(img, leveln): LP = [] for i in range(leveln - 1): next_img = cv2.pyrDown(img) LP.append(img - cv2.pyrUp(next_img, img.shape[1::-1])) img = next_img LP.append(img) return LP def blend_pyramid(LPA, LPB, MA, MB): blended = [] for i, M in enumerate(MA): if len(list(MA[i].shape)) < 3 and len(list( MB[i].shape)) < 3: blended.append(LPA[i] * np.expand_dims(MA[i], -1) + LPB[i] * np.expand_dims(MB[i], -1)) else: blended.append(LPA[i] * MA[i] + LPB[i] * MB[i]) return blended def reconstruct(LS): img = LS[-1] for lev_img in LS[-2::-1]: img = cv2.pyrUp(img, lev_img.shape[1::-1]) img += lev_img return img minimum_multi_band = 32 pad_h_multi_band = 0 pad_b_multi_band = 0 if np_mask_A.shape[0] % minimum_multi_band != 0: pad_h_multi_band = ( minimum_multi_band - (np_mask_A.shape[0] % minimum_multi_band)) // 2 pad_b_multi_band = minimum_multi_band - ( np_mask_A.shape[0] % minimum_multi_band) - pad_h_multi_band print(pad_h_multi_band) print(pad_b_multi_band) np_mask_A = cv2.copyMakeBorder(np_mask_A, pad_h_multi_band, pad_b_multi_band, 0, 0, cv2.BORDER_REPLICATE) np_mask_B = cv2.copyMakeBorder(np_mask_B, pad_h_multi_band, pad_b_multi_band, 0, 0, cv2.BORDER_REPLICATE) npWarpedKey = cv2.copyMakeBorder(npWarpedKey, pad_h_multi_band, pad_b_multi_band, 0, 0, cv2.BORDER_REPLICATE) np_I_pred = cv2.copyMakeBorder(np_I_pred, pad_h_multi_band, pad_b_multi_band, 0, 0, cv2.BORDER_REPLICATE) pad_l_multi_band = 0 pad_r_multi_band = 0 if np_mask_A.shape[1] % minimum_multi_band != 0: pad_l_multi_band = ( minimum_multi_band - (np_mask_A.shape[1] % minimum_multi_band)) // 2 pad_r_multi_band = minimum_multi_band - ( np_mask_A.shape[1] % minimum_multi_band) - pad_l_multi_band np_mask_A = cv2.copyMakeBorder(np_mask_A, 0, 0, pad_l_multi_band, pad_r_multi_band, cv2.BORDER_REPLICATE) np_mask_B = cv2.copyMakeBorder(np_mask_B, 0, 0, pad_l_multi_band, pad_r_multi_band, cv2.BORDER_REPLICATE) npWarpedKey = cv2.copyMakeBorder(npWarpedKey, 0, 0, pad_l_multi_band, pad_r_multi_band, cv2.BORDER_REPLICATE) np_I_pred = cv2.copyMakeBorder(np_I_pred, 0, 0, pad_l_multi_band, pad_r_multi_band, cv2.BORDER_REPLICATE) MA = GaussianPyramid(np_mask_A, 5) MB = GaussianPyramid(np_mask_B, 5) LPA = LaplacianPyramid(npWarpedKey, 5) LPB = LaplacianPyramid(np_I_pred, 5) # Blend two Laplacian pyramidspass blended = blend_pyramid(LPA, LPB, MA, MB) # Reconstruction process frame_A = reconstruct(blended) frame_A[frame_A > 1.0] = 1.0 frame_A[frame_A < 0.0] = 0.0 # print(frame_A.shape) if pad_h_multi_band != 0 or pad_b_multi_band != 0: frame_A = frame_A[pad_h_multi_band:-pad_b_multi_band] if pad_l_multi_band != 0 or pad_r_multi_band != 0: frame_A = frame_A[:, pad_l_multi_band:-pad_r_multi_band] # print(frame_A.shape) I_pred = torch.from_numpy( np.expand_dims(np.transpose(frame_A[:, :, ::-1], (2, 0, 1)), 0).astype(np.float32)).cuda() """cv2.imwrite('_np_mask_A.png', np.round(np_mask_A*255).astype(np.uint8)) cv2.imwrite('_np_mask_B.png', np.round(np_mask_B*255).astype(np.uint8)) cv2.imwrite('_frame_A.png', np.round(frame_A*255).astype(np.uint8)) cv2.imwrite('_npWarpedKey.png', np.round(npWarpedKey*255).astype(np.uint8)) cv2.imwrite('_np_I_pred.png', np.round(np_I_pred*255).astype(np.uint8))""" if h_padded: I_pred = I_pred[:, :, 0:h0, :] if w_padded: I_pred = I_pred[:, :, :, 0:w0] return I_pred
cv2.imread(filename='./images/second.png', flags=-1).transpose( 2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() tenFlow = torch.FloatTensor( numpy.ascontiguousarray( read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda() tenMetric = torch.nn.functional.l1_loss(input=tenFirst, target=backwarp(tenInput=tenSecond, tenFlow=tenFlow), reduction='none').mean(1, True) for intTime, fltTime in enumerate(numpy.linspace(0.0, 1.0, 11).tolist()): tenSummation = softsplat.FunctionSoftsplat(tenInput=tenFirst, tenFlow=tenFlow * fltTime, tenMetric=None, strType='summation') tenAverage = softsplat.FunctionSoftsplat(tenInput=tenFirst, tenFlow=tenFlow * fltTime, tenMetric=None, strType='average') tenLinear = softsplat.FunctionSoftsplat( tenInput=tenFirst, tenFlow=tenFlow * fltTime, tenMetric=(0.3 - tenMetric).clamp(0.0000001, 1.0), strType='linear' ) # finding a good linearly metric is difficult, and it is not invariant to translations tenSoftmax = softsplat.FunctionSoftsplat( tenInput=tenFirst, tenFlow=tenFlow * fltTime, tenMetric=-20.0 * tenMetric,
def splat_rgb_img(ret, ratio, R_w2t, t_w2t, j, H, W, focal, fwd_flow): import softsplat raw_rgba_s = torch.cat([ret['raw_rgb'], ret['raw_alpha'].unsqueeze(-1)], dim=-1) raw_rgba = raw_rgba_s[:, :, j, :].permute(2, 0, 1).unsqueeze(0).contiguous().cuda() pts_ref = ret['pts_ref'][:, :, j, :3] pts_ref_e_G = NDC2Euclidean(pts_ref, H, W, focal) if fwd_flow: pts_post = pts_ref + ret['raw_sf_ref2post'][:, :, j, :] else: pts_post = pts_ref + ret['raw_sf_ref2prev'][:, :, j, :] pts_post_e_G = NDC2Euclidean(pts_post, H, W, focal) pts_mid_e_G = (pts_post_e_G - pts_ref_e_G) * ratio + pts_ref_e_G pts_mid_e_local = se3_transform_points(pts_mid_e_G, R_w2t.unsqueeze(0).unsqueeze(0), t_w2t.unsqueeze(0).unsqueeze(0)) pts_2d_mid = perspective_projection(pts_mid_e_local, H, W, focal) xx, yy = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)) xx = xx.t() yy = yy.t() pts_2d_original = torch.stack([xx, yy], -1) flow_2d = pts_2d_mid - pts_2d_original flow_2d = flow_2d.permute(2, 0, 1).unsqueeze(0).contiguous().cuda() splat_raw_rgba_dy = softsplat.FunctionSoftsplat(tenInput=raw_rgba, tenFlow=flow_2d, tenMetric=None, strType='average') # splatting for static nerf pts_rig_e_local = se3_transform_points(pts_ref_e_G, R_w2t.unsqueeze(0).unsqueeze(0), t_w2t.unsqueeze(0).unsqueeze(0)) pts_2d_rig = perspective_projection(pts_rig_e_local, H, W, focal) flow_2d_rig = pts_2d_rig - pts_2d_original flow_2d_rig = flow_2d_rig.permute(2, 0, 1).unsqueeze(0).contiguous().cuda() raw_rgba_rig = torch.cat( [ret['raw_rgb_rigid'], ret['raw_alpha_rigid'].unsqueeze(-1)], dim=-1) raw_rgba_rig = raw_rgba_rig[:, :, j, :].permute( 2, 0, 1).unsqueeze(0).contiguous().cuda() splat_raw_rgba_rig = softsplat.FunctionSoftsplat(tenInput=raw_rgba_rig, tenFlow=flow_2d_rig, tenMetric=None, strType='average') splat_alpha_dy = splat_raw_rgba_dy[0, 3:4, :, :] splat_rgb_dy = splat_raw_rgba_dy[0, 0:3, :, :] splat_alpha_rig = splat_raw_rgba_rig[0, 3:4, :, :] splat_rgb_rig = splat_raw_rgba_rig[0, 0:3, :, :] return splat_alpha_dy, splat_rgb_dy, splat_alpha_rig, splat_rgb_rig
def img2tex_forwardwarp(img_tensor, iuv_tensor, class_num=25, tex_patch_size=256, img_space_size=512, inpaint_func=None, warp_mask=True, use_prob=False): # pdb.set_trace() b, c, h, w = img_tensor.shape assert iuv_tensor.shape[ 1] == class_num + 2, "input I channel should be onehot encoded" assert iuv_tensor[:, -2:].min() >= 0 and iuv_tensor[:, -2:].max( ) <= 1, "input UV channels should be normalized into [0,1]" TextureIm = torch.zeros([b, 24, c, tex_patch_size, tex_patch_size], dtype=torch.float64).to(iuv_tensor.device) if warp_mask: TextureMask = torch.zeros([b, 24, 1, tex_patch_size, tex_patch_size], dtype=torch.float64).to(iuv_tensor.device) for PartInd in range( 1, class_num): ## Set to xrange(1,23) to ignore the face part. prob = iuv_tensor[:, PartInd:PartInd + 1] if use_prob: ## create flow part_mask = (prob > 1. / class_num).float() # pdb.set_trace() uv = part_mask * iuv_tensor[:, -2:] uv = uv * (tex_patch_size - 1) if isinstance(img_space_size, tuple): coords = make_meshgrid(b, img_space_size[0], img_space_size[1], norm=False).to(iuv_tensor.device) else: coords = make_meshgrid(b, img_space_size, img_space_size, norm=False).to(iuv_tensor.device) flow = uv - coords # pdb.set_trace() ## create img img_masked = img_tensor * prob ## forward warp tex_patch = softsplat.FunctionSoftsplat( tenInput=img_masked.float(), tenFlow=flow.float(), tenMetric=None, strType='average') # pdb.set_trace() # imageio.imwrite("../tmp/tmp_tex_patch{}.png".format(PartInd), tex_patch[0].permute([1,2,0]).detach().cpu().numpy()) tex_patch = tex_patch[:, :, :tex_patch_size, :tex_patch_size] if inpaint_func is not None: tex_patch = inpaint_func(tex_patch) # pdb.set_trace() TextureIm[:, PartInd - 1] += tex_patch # pdb.set_trace() if warp_mask: prob = softsplat.FunctionSoftsplat(tenInput=prob, tenFlow=flow.float(), tenMetric=None, strType='average') prob = prob[:, :, :tex_patch_size, :tex_patch_size] if inpaint_func is not None: prob = inpaint_func(prob) TextureMask[:, PartInd - 1] += prob #(prob>0).float() else: # assert prob ## create flow part_mask = (prob > 0).float() uv = part_mask * iuv_tensor[:, -2:] uv = uv * (tex_patch_size - 1) if isinstance(img_space_size, tuple): coords = make_meshgrid(b, img_space_size[0], img_space_size[1], norm=False).to(iuv_tensor.device) else: coords = make_meshgrid(b, img_space_size, img_space_size, norm=False).to(iuv_tensor.device) flow = uv - coords # pdb.set_trace() ## create img img_masked = img_tensor * prob ## forward warp tex_patch = softsplat.FunctionSoftsplat( tenInput=img_masked.float(), tenFlow=flow.float(), tenMetric=None, strType='average') # pdb.set_trace() # imageio.imwrite("../tmp/tmp_tex_patch{}.png".format(PartInd), tex_patch[0].permute([1,2,0]).detach().cpu().numpy()) tex_patch = tex_patch[:, :, :tex_patch_size, :tex_patch_size] if inpaint_func is not None: tex_patch = inpaint_func(tex_patch) # pdb.set_trace() TextureIm[:, PartInd - 1] += tex_patch # pdb.set_trace() if warp_mask: prob = softsplat.FunctionSoftsplat(tenInput=prob.float(), tenFlow=flow.float(), tenMetric=None, strType='average') prob = prob[:, :, :tex_patch_size, :tex_patch_size] if inpaint_func is not None: prob = inpaint_func(prob) TextureMask[:, PartInd - 1] += prob #(prob>0).float() ## vis # flow_img = flow_viz.flow_to_image(flow[0].permute([1,2,0]).detach().cpu().numpy()) # imageio.imwrite('flow_img.png', flow_img) # imageio.imwrite('u_masked.png', uv[0,0].detach().cpu().numpy()) # imageio.imwrite('part_mask.png', part_mask[0,0].detach().cpu().numpy()) # imageio.imwrite('img_tensor.png', img_tensor[0].permute([1,2,0]).detach().cpu().numpy()) # imageio.imwrite('img_masked.png', img_masked[0].permute([1,2,0]).detach().cpu().numpy()) if warp_mask: return TextureIm.float(), TextureMask.float() else: return TextureIm.float()
HHH = 256 tenOnes = torch.ones_like(input_frames[0])[:, 0:1, :, :] tenOnes = torch.nn.ZeroPad2d( (WWW, WWW, HHH, HHH))(tenOnes).detach() F_kprime_to_k_pad = torch.nn.ReplicationPad2d( (WWW, WWW, HHH, HHH))(F_kprime_to_k) tenWarpedFeat = [] tenWarpedMask = [] for iii, feat in enumerate(input_frames): """padding for forward warping""" ref_frame_flow = torch.nn.ReplicationPad2d( (WWW, WWW, HHH, HHH))(forward_flows[iii]) """first forward warping""" tenMaskFirst = softsplat.FunctionSoftsplat( tenInput=tenOnes, tenFlow=ref_frame_flow, tenMetric=None, strType='average') """second backward warping""" tenMaskSecond = backwarp(tenInput=tenMaskFirst, tenFlow=F_kprime_to_k_pad) """back to original resolution""" tenMask = tenMaskSecond tenWarpedMask.append(tenMask) weight_tensor = torch.stack(tenWarpedMask, 0) output_mask = torch.sum(weight_tensor, dim=0) output_mask = torch.clamp(output_mask, max=1.0) # imwrite(output_mask, str(idx-GAUSSIAN_FILTER_KSIZE//2).zfill(5)+'_mask.png', range=(0, 1)) large_mask_chain.append(output_mask.detach().cpu())
pic_B.transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() tenFlow = torch.FloatTensor( numpy.ascontiguousarray( read_flo(arg_Flow).transpose(2, 0, 1)[None, :, :, :])).cuda() tenMetric = torch.nn.functional.l1_loss( input=tenFirst, target=backwarp(tenInput=tenSecond, tenFlow=tenFlow), reduction='none').mean(1, True) img_array.append(pic_A) for t in time: tenSoftmax = softsplat.FunctionSoftsplat(tenInput=tenFirst, tenFlow=tenFlow * t, tenMetric=-20.0 * tenMetric, strType='softmax') img = tenSoftmax[0, :, :, :].cpu().numpy().transpose(1, 2, 0) img = (img * 255).astype(numpy.uint8) img_array.append(img) # end print('-- processing %d / %d' % (filecounts, filecounts)) print('Creating video...') out = cv2.VideoWriter('project.mp4', cv2.VideoWriter_fourcc(*'mp4v'), arg_FPS * FPS, (arg_width, arg_height)) print('Total images : ' + str(len(img_array))) for i in range(len(img_array)):