def forward_3_frames(self, x0_pyramid, x1_pyramid, x2_pyramid): # outputs flows = [] # init b_size, _, h_x1, w_x1, = x1_pyramid[0].size() init_dtype = x1_pyramid[0].dtype init_device = x1_pyramid[0].device flow = torch.zeros(b_size, 4, h_x1, w_x1, dtype=init_dtype, device=init_device).float() for l, (x0, x1, x2) in enumerate(zip(x0_pyramid, x1_pyramid, x2_pyramid)): # warping if l == 0: x0_warp = x0 x2_warp = x2 else: flow = F.interpolate(flow * 2, scale_factor=2, mode='bilinear', align_corners=True) x0_warp = flow_warp(x0, flow[:, :2]) x2_warp = flow_warp(x2, flow[:, 2:]) # correlation corr_10, corr_12 = self.corr(x1, x0_warp), self.corr(x1, x2_warp) corr_relu_10, corr_relu_12 = self.leakyRELU(corr_10), self.leakyRELU(corr_12) # concat and estimate flow x1_1by1 = self.conv_1x1[l](x1) feat_10 = [x1_1by1, corr_relu_10, corr_relu_12, flow[:, :2], -flow[:, 2:]] feat_12 = [x1_1by1, corr_relu_12, corr_relu_10, flow[:, 2:], -flow[:, :2]] x_intm_10, flow_res_10 = self.flow_estimators(torch.cat(feat_10, dim=1)) x_intm_12, flow_res_12 = self.flow_estimators(torch.cat(feat_12, dim=1)) flow_res = torch.cat([flow_res_10, flow_res_12], dim=1) flow = flow + flow_res feat_10 = [x_intm_10, x_intm_12, flow[:, :2], -flow[:, 2:]] feat_12 = [x_intm_12, x_intm_10, flow[:, 2:], -flow[:, :2]] flow_res_10 = self.context_networks(torch.cat(feat_10, dim=1)) flow_res_12 = self.context_networks(torch.cat(feat_12, dim=1)) flow_res = torch.cat([flow_res_10, flow_res_12], dim=1) flow = flow + flow_res flows.append(flow) if l == self.output_level: break if self.upsample: flows = [F.interpolate(flow * 4, scale_factor=4, mode='bilinear', align_corners=True) for flow in flows] flows_10 = [flo[:, :2] for flo in flows[::-1]] flows_12 = [flo[:, 2:] for flo in flows[::-1]] return flows_10, flows_12
def _forward(self, x1_pyramid, x2_pyramid, neg=False): flows = [] for i, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): if i == 0: corr = self.corr(x1, x2) feat, flow = self.flow_estimators[i](corr) if neg: flow = -F.relu(-flow) else: flow = F.relu(flow) else: # predict the normalized disparity to keep consistent with MonoDepth # for reusing the hyper-parameters up_flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) zeros = torch.zeros_like(up_flow) x2_warp = flow_warp( x2, torch.cat([up_flow, zeros], dim=1), ) corr = self.corr(x1, x2_warp) F.leaky_relu_(corr) feat, flow = self.flow_estimators[i](torch.cat( [corr, x1, up_flow], dim=1)) flow = flow + up_flow if neg: flow = -F.relu(-flow) else: flow = F.relu(flow) if self.context_networks[i]: flow_fine = self.context_networks[i](torch.cat( [flow, feat], dim=1)) flow = flow + flow_fine if neg: flow = -F.relu(-flow) else: flow = F.relu(flow) if neg: flows.append(-flow) else: flows.append(flow) if len(flows) == self.n_out: break flows = [ F.interpolate(flow * 4, scale_factor=4, mode='bilinear', align_corners=True) for flow in flows ] return flows[::-1]
def get_occu_mask_bidirection(flow12, flow21, scale=0.1, bias=0.5): flow21_warped = flow_warp(flow21, flow12, pad='zeros') flow12_diff = flow12 + flow21_warped mag = (flow12 * flow12).sum(1, keepdim=True) + \ (flow21_warped * flow21_warped).sum(1, keepdim=True) occ_thresh = scale * mag + bias occ = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh return occ.float()
def forward_2_frames(self, x1_pyramid, x2_pyramid): # outputs flows = [] # init b_size, _, h_x1, w_x1, = x1_pyramid[0].size() init_dtype = x1_pyramid[0].dtype init_device = x1_pyramid[0].device flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): # warping if l == 0: x2_warp = x2 else: flow = F.interpolate(flow * 2, scale_factor=2, mode='bilinear', align_corners=True) x2_warp = flow_warp(x2, flow) # correlation out_corr = self.corr(x1, x2_warp) out_corr_relu = self.leakyRELU(out_corr) # concat and estimate flow x1_1by1 = self.conv_1x1[l](x1) x_intm, flow_res = self.flow_estimators( torch.cat([out_corr_relu, x1_1by1, flow], dim=1)) flow = flow + flow_res flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1)) flow = flow + flow_fine flows.append(flow) # upsampling or post-processing if l == self.output_level: break if self.upsample: flows = [ F.interpolate(flow * 4, scale_factor=4, mode='bilinear', align_corners=True) for flow in flows ] return flows[::-1]
def _forward(self, x1_pyramid, x2_pyramid): flows = [] for i, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): if i == 0: corr = self.corr(x1, x2) feat, flow = self.flow_estimators[i](corr) else: up_flow = F.interpolate(flow * 2, scale_factor=2, mode='bilinear', align_corners=True) x2_warp = flow_warp(x2, up_flow) corr = self.corr(x1, x2_warp) F.leaky_relu_(corr) flow_feat = [corr, x1, up_flow] feat, flow = self.flow_estimators[i](torch.cat(flow_feat, dim=1)) flow = flow + up_flow if self.context_networks[i]: flow_fine = self.context_networks[i](torch.cat( [flow, feat], dim=1)) flow = flow + flow_fine flows.append(flow) if len(flows) == self.n_out: break flows = [ F.interpolate(flow * 4, scale_factor=4, mode='bilinear', align_corners=True) for flow in flows ] return flows[::-1]
def forward(self, pyramid_disp, fl_bl, pyramid_K, pyramid_K_inv, raw_W, pyramid_flow, images): """ :param pyramid_depths: Multi-scale disparities n * [B x h x w] :param fl_bl: focal length * baseline [B] :param pyramid_K: Multi-scale intrinsics n * [B, 3, 3] :param pyramid_K_inv: Multi-scale inverse of intrinsics n * [B, 3, 3] :param raw_W: Original width of images [B] :param pyramid_flows: Multi-scale forward/backward flows n * [B x 4 x h x w] :param target: image pairs Nx6xHxW :return: """ B = images.size(0) im1_origin = images[:, :3] im2_origin = images[:, 3:] pyramid_l_photomatric = [] pyramid_l_smooth = [] pyramid_l_consistancy = [] pyramid_l_photomatric_rigid = [] pyramid_rigid_mask = [] for i, (disp, flow, K, K_inv, md) in enumerate( zip(pyramid_disp, pyramid_flow, pyramid_K, pyramid_K_inv, self.cfg.pyramid_md)): # only the first n scales compute loss. if i >= self.cfg.valid_s: break _, _, h, w = flow.size() if i == 0 and self.cfg.norm_smooth: s = min(h, w) disp = F.interpolate( disp.unsqueeze(1), (h, w), mode='bilinear', align_corners=True).squeeze(1) * raw_W.reshape(-1, 1, 1) depth = fl_bl.reshape(-1, 1, 1) / disp.clamp(min=1e-3) # [B, h ,w] # use the largest depth and flow to predict pose if i == 0: pose_mat, _, inlier_ratio = depth_flow2pose_pt( depth, flow[:, :2], K, K_inv, gs=16, th=2., method=self.cfg.PnP_method) rigid_flow = depth_pose2flow_pt(depth, pose_mat, K, K_inv) # resize images to match the size of layer im1_scaled = F.interpolate(im1_origin, (h, w), mode='area') im2_scaled = F.interpolate(im2_origin, (h, w), mode='area') im1_recons, occu_mask1 = flow_warp(im2_scaled, flow[:, :2], flow[:, 2:]) im1_recons_rigid = flow_warp(im2_scaled, rigid_flow) th_mask = EPE(flow[:, :2], rigid_flow) < self.cfg.mask_th / 2**i flow_e = F.pad(SSIM(im1_scaled, im1_recons, md=md), [md] * 4).mean(1, keepdim=True) # [B, 1, h ,w] rigid_e = F.pad(SSIM(im1_scaled, im1_recons_rigid, md=md), [md] * 4).mean(1, keepdim=True) dist_e = rigid_e - flow_e dist_e = gaussianblur_pt(dist_e, (11, 11), 5) delta = percentile_pt(dist_e, th=self.cfg.recons_p).reshape(-1, 1, 1, 1) rigid_mask = dist_e < delta # [B, 1, h ,w] rigid_mask = rigid_mask & th_mask # mask out the failure depth region rigid_mask = rigid_mask & (depth.unsqueeze(1) < 80) # for the failure pose estimation, rigid_mask should be all false valid_poses = (inlier_ratio > 0.2).type_as(rigid_mask) rigid_mask = rigid_mask & valid_poses.reshape(-1, 1, 1, 1) rigid_mask = rigid_mask.float() # for the occlusion region, rigid_mask should be true or false if self.cfg.mask_with_occu: # the original tf implementation: rigid_mask = (rigid_mask + (occu_mask1 < 0.2).float()).clamp( 0., 1.) if self.cfg.smooth_mask_by == 'th': sm_mask = 1 - (th_mask & (depth.unsqueeze(1) < 80)).float() else: sm_mask = 1 - rigid_mask # same as paper l_photomatric = self.loss_photomatric(im1_scaled, im1_recons, occu_mask1) l_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled, sm_mask) l_consistancy = self.loss_consistancy(flow[:, :2], rigid_flow.detach(), rigid_mask) # occlusion mask? l_photomatric_rigid = self.loss_photomatric( im1_scaled, im1_recons_rigid, rigid_mask) pyramid_l_photomatric.append(l_photomatric * self.cfg.w_scales[i]) pyramid_l_smooth.append(l_smooth * self.cfg.w_sm_scales[i]) pyramid_l_consistancy.append(l_consistancy * self.cfg.w_cons_scales[i]) pyramid_l_photomatric_rigid.append(l_photomatric_rigid * self.cfg.w_rigid_scales[i]) pyramid_rigid_mask.append(rigid_mask.mean() * B / (valid_poses.sum() + 1e-6)) w_l_pohotometric = sum(pyramid_l_photomatric) w_l_pohotometric_rigid = sum(pyramid_l_photomatric_rigid) w_l_smooth = sum(pyramid_l_smooth) w_l_consistancy = sum(pyramid_l_consistancy) final_loss = w_l_pohotometric + \ self.cfg.w_rigid_warp * w_l_pohotometric_rigid + \ self.cfg.w_smooth * w_l_smooth + \ self.cfg.w_cons * w_l_consistancy return final_loss, w_l_pohotometric, w_l_pohotometric_rigid, \ 1000 * w_l_smooth, w_l_consistancy, \ sum(pyramid_rigid_mask) / len(pyramid_disp), \ inlier_ratio.mean()
imgs = [imageio.imread(img).astype(np.float32) for img in args.img_list] h, w = imgs[0].shape[:2] flow_12 = ts.run(imgs)['flows_fw'][0] flow_12 = resize_flow(flow_12, (h, w)) np_flow_12 = flow_12[0].detach().cpu().numpy().transpose([1, 2, 0]) vis_flow = flowpy.flow_to_rgb(np_flow_12) cv2.imwrite(t0[0]+ " " +t1[0]+".png", vis_flow) im1=cv2.imread("/content/drive/MyDrive/data1/NATL_AN_2007-01-03.png") im2=cv2.imread("/content/drive/MyDrive/data1/NATL_AN_2007-01-04.png ") im1=flow_warp(im2, flow12, pad='border', mode='bilinear'): cv2.imwrite("warped" +t0[0]+ " " +t1[0]+".png", warped) def PSNR(original, compressed): mse = np.mean((original - compressed) ** 2) if(mse == 0): # MSE is zero means no noise is present in the signal . # Therefore PSNR have no importance. return 100 max_pixel = 255.0 psnr = 20 * log10(max_pixel / sqrt(mse)) return psnr # 5. Compute the Structural Similarity Index (SSIM) between the two # images, ensuring that the difference image is returned
def forward(self, output, target): """ :param output: Multi-scale forward/backward flows n * [B x 4 x h x w] :param target: image pairs Nx6xHxW :return: """ pyramid_flows = output im1_origin = target[:, :3] im2_origin = target[:, 3:] pyramid_smooth_losses = [] pyramid_warp_losses = [] self.pyramid_occu_mask1 = [] self.pyramid_occu_mask2 = [] s = 1. for i, flow in enumerate(pyramid_flows): if self.cfg.w_scales[i] == 0: pyramid_warp_losses.append(0) pyramid_smooth_losses.append(0) continue b, _, h, w = flow.size() # resize images to match the size of layer im1_scaled = F.interpolate(im1_origin, (h, w), mode='area') im2_scaled = F.interpolate(im2_origin, (h, w), mode='area') im1_recons = flow_warp(im2_scaled, flow[:, :2], pad=self.cfg.warp_pad) im2_recons = flow_warp(im1_scaled, flow[:, 2:], pad=self.cfg.warp_pad) if i == 0: if self.cfg.occ_from_back: occu_mask1 = 1 - get_occu_mask_backward(flow[:, 2:], th=0.2) occu_mask2 = 1 - get_occu_mask_backward(flow[:, :2], th=0.2) else: occu_mask1 = 1 - get_occu_mask_bidirection( flow[:, :2], flow[:, 2:]) occu_mask2 = 1 - get_occu_mask_bidirection( flow[:, 2:], flow[:, :2]) else: occu_mask1 = F.interpolate(self.pyramid_occu_mask1[0], (h, w), mode='nearest') occu_mask2 = F.interpolate(self.pyramid_occu_mask2[0], (h, w), mode='nearest') self.pyramid_occu_mask1.append(occu_mask1) self.pyramid_occu_mask2.append(occu_mask2) loss_warp = self.loss_photomatric(im1_scaled, im1_recons, occu_mask1) if i == 0: s = min(h, w) loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled) if self.cfg.with_bk: loss_warp += self.loss_photomatric(im2_scaled, im2_recons, occu_mask2) loss_smooth += self.loss_smooth(flow[:, 2:] / s, im2_scaled) loss_warp /= 2. loss_smooth /= 2. pyramid_warp_losses.append(loss_warp) pyramid_smooth_losses.append(loss_smooth) pyramid_warp_losses = [ l * w for l, w in zip(pyramid_warp_losses, self.cfg.w_scales) ] pyramid_smooth_losses = [ l * w for l, w in zip(pyramid_smooth_losses, self.cfg.w_sm_scales) ] warp_loss = sum(pyramid_warp_losses) smooth_loss = self.cfg.w_smooth * sum(pyramid_smooth_losses) total_loss = warp_loss + smooth_loss return total_loss, warp_loss, smooth_loss, pyramid_flows[0].abs().mean( )
h, w = imgs[0].shape[:2] res_dict = ts.run(imgs) flow_12 = res_dict['flows_fw'][0] flow_21 = res_dict['flows_bw'][0] flow_12 = resize_flow(flow_12, (h, w)) # [1, 2, H, W] flow_21 = resize_flow(flow_21, (h, w)) # [1, 2, H, W] occu_mask1 = 1 - get_occu_mask_bidirection(flow_12, flow_21) # [1, 1, H, W] occu_mask2 = 1 - get_occu_mask_bidirection(flow_21, flow_12) back_occu_mask1 = get_occu_mask_backward(flow_21) back_occu_mask2 = get_occu_mask_backward(flow_21) warped_image_12 = flow_warp(torch.from_numpy( np.transpose(imgs[1], [2, 0, 1])).unsqueeze(0).cuda(), flow_12, pad='border') warped_image_21 = flow_warp(torch.from_numpy( np.transpose(imgs[0], [2, 0, 1])).unsqueeze(0).cuda(), flow_21, pad='border') np_warped_image12 = warped_image_12[0].detach().cpu().numpy( ).transpose([1, 2, 0]) np_warped_image21 = warped_image_21[0].detach().cpu().numpy( ).transpose([1, 2, 0]) np_flow_12 = flow_12[0].detach().cpu().numpy().transpose([1, 2, 0]) np_flow_21 = flow_21[0].detach().cpu().numpy().transpose([1, 2, 0]) # vx = np_flow_12[:, :, 0] # vy = np_flow_12[:, :, 1] # f = open(os.path.join(r'G:\ARFlow-master\data\flow_dataset\ceshi_tmp', name + '_vx.bin'), 'wb')
def forward(self, output, target): """ :param output: Multi-scale forward/backward flows n * [B x 4 x h x w] :param target: image pairs Nx6xHxW :return: """ pyramid_flows = output im1_origin = target[:, :3] im2_origin = target[:, 3:] pyramid_smooth_losses = [] pyramid_warp_losses = [] self.pyramid_occu_mask1 = [] self.pyramid_occu_mask2 = [] s = 1. for i, flow in enumerate(pyramid_flows): b, _, h, w = flow.size() # resize images to match the size of layer im1_scaled = F.interpolate(im1_origin, (h, w), mode='area') im2_scaled = F.interpolate(im2_origin, (h, w), mode='area') im1_recons, occu_mask1 = flow_warp(im2_scaled, flow[:, :2], flow[:, 2:]) im2_recons, occu_mask2 = flow_warp(im1_scaled, flow[:, 2:], flow[:, :2]) self.pyramid_occu_mask1.append(occu_mask1) self.pyramid_occu_mask2.append(occu_mask2) if self.cfg.hard_occu: occu_mask1 = (occu_mask1 > self.cfg.hard_occu_th).float() occu_mask2 = (occu_mask2 > self.cfg.hard_occu_th).float() loss_photomatric = self.loss_photomatric(im1_scaled, im1_recons, occu_mask1) if i == 0 and self.cfg.norm_smooth: s = min(h, w) if self.cfg.s_mask: loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled, occu_mask1) else: loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled, None) if self.cfg.with_bk: loss_photomatric += self.loss_photomatric( im2_scaled, im2_recons, occu_mask2) if self.cfg.s_mask: loss_smooth += self.loss_smooth(flow[:, 2:] / s, im2_scaled, occu_mask2) else: loss_smooth += self.loss_smooth(flow[:, 2:] / s, im2_scaled, None) loss_photomatric /= 2. loss_smooth /= 2. pyramid_warp_losses.append(loss_photomatric) pyramid_smooth_losses.append(loss_smooth) pyramid_warp_losses = [ l * w for l, w in zip(pyramid_warp_losses, self.cfg.w_scales) ] pyramid_smooth_losses = [ l * w for l, w in zip(pyramid_smooth_losses, self.cfg.w_sm_scales) ] return sum(pyramid_warp_losses) + self.cfg.w_smooth * sum(pyramid_smooth_losses), \ sum(pyramid_warp_losses), self.cfg.w_smooth * sum(pyramid_smooth_losses), \ pyramid_flows[0].abs().mean()
def _validate_with_gt2(self): import cv2 import torch.nn.functional as F from utils.warp_utils import flow_warp from utils.misc_utils import plot_imgs batch_time = AverageMeter() error_names = ['EPE', 'E_noc', 'E_occ', 'F1_all'] error_meters = AverageMeter(i=len(error_names)) self.model.eval() self.model = self.model.float() end = time.time() for i_step, data in enumerate(self.valid_loader): img1, img2 = data['img1'], data['img2'] img_pair = torch.cat([img1, img2], 1).to(self.device) # compute output flow = self.model(img_pair, with_bk=True)[0] _, _, h, w = flow.size() im1_origin = img_pair[:, :3] _, occu_mask1 = flow_warp(im1_origin, flow[:, :2], flow[:, 2:]) res = list(map(load_flow, data['flow_occ'])) gt_flows, occ_masks = [r[0] for r in res], [r[1] for r in res] res = list(map(load_flow, data['flow_noc'])) _, noc_masks = [r[0] for r in res], [r[1] for r in res] gt_flows = [np.concatenate([flow, occ_mask, noc_mask], axis=2) for flow, occ_mask, noc_mask in zip(gt_flows, occ_masks, noc_masks)] pred_flows = flow[:, :2].detach().cpu().numpy().transpose([0, 2, 3, 1]) es = evaluate_kitti_flow(gt_flows, pred_flows) error_meters.update([l.item() for l in es], img_pair.size(0)) plot_list = [] occu_mask1 = (occu_mask1 < 0.2).detach().cpu().numpy()[0, 0] * 255 plot_list.append({'im': occu_mask1, 'title': 'occu mask 1'}) gt_occu_mask1 = (noc_masks[0] - occ_masks[0])[:, :, 0].astype( np.float32) * 255 plot_list.append({'im': gt_occu_mask1, 'title': 'gt occu mask 1'}) plot_imgs(plot_list, save_path='./tmp/occu_soft_hard/occu_hard_{:03d}.jpg'.format( i_step)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i_step % self.cfg.print_freq == 0: self._log.info('Test: [{0}/{1}]\t Time {2}\t '.format( i_step, self.cfg.valid_size, batch_time) + ' '.join( map('{:.2f}'.format, error_meters.avg))) if i_step > self.cfg.valid_size: break # write error to tf board. for value, name in zip(error_meters.avg, error_names): self.summary_writer.add_scalar('Valid_' + name, value, self.i_epoch) # In order to reduce the space occupied during debugging, # only the model with more than cfg.save_iter iterations will be saved. if self.i_iter > self.cfg.save_iter: self.save_model(error_meters.avg[0], 'KITTI_flow') return error_meters.avg, error_names
def forward_2_frames(self, x1_pyramid, x2_pyramid): # outputs flows = [] # init b_size, _, h_x1, w_x1, = x1_pyramid[0].size() init_dtype = x1_pyramid[0].dtype init_device = x1_pyramid[0].device flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float() for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): #print(l) # print(x1.shape) # Output level is 4 # 0 # torch.Size([2, 192, 6, 13]) # 1 # torch.Size([2, 128, 12, 26]) # 2 # torch.Size([2, 96, 24, 52]) # 3 # torch.Size([2, 64, 48, 104]) # 4 # torch.Size([2, 32, 96, 208]) # warping if l == 0: x2_warp = x2 else: flow = F.interpolate(flow * 2, scale_factor=2, mode='bilinear', align_corners=True) x2_warp = flow_warp(x2, flow) # correlation - checks the x1 against x2_warped == x1 #print(x1.shape) #print(x2_warp.shape) out_corr = self.corr(x1, x2_warp) out_corr_relu = self.leakyRELU(out_corr) #print(out_corr_relu.shape) #print("--") # 0 # torch.Size([2, 192, 6, 13]) in # torch.Size([2, 192, 6, 13]) in # torch.Size([2, 81, 6, 13]) out - seems to be 81 for corr # -- # 1 # torch.Size([2, 128, 12, 26]) # torch.Size([2, 128, 12, 26]) # torch.Size([2, 81, 12, 26]) # -- # 2 # torch.Size([2, 96, 24, 52]) # torch.Size([2, 96, 24, 52]) # torch.Size([2, 81, 24, 52]) # -- # 3 # torch.Size([2, 64, 48, 104]) # torch.Size([2, 64, 48, 104]) # torch.Size([2, 81, 48, 104]) # concat and estimate flow x1_1by1 = self.conv_1x1[l](x1) # Compresses Channels to 32 x_intm, flow_res = self.flow_estimators( torch.cat([out_corr_relu, x1_1by1, flow], dim=1)) flow = flow + flow_res flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1)) flow = flow + flow_fine #print(flow.shape) flows.append(flow) # upsampling or post-processing if l == self.output_level: break if self.upsample: flows = [F.interpolate(flow * 4, scale_factor=4, mode='bilinear', align_corners=True) for flow in flows] return flows[::-1]