def get_synth_flow(occs, unps, summ_writer, sometimes_zero=False, do_vis=False): B, S, C, Z, Y, X = list(occs.shape) assert (S == 2, C == 1) # we do not sample any rotations here, to keep the distribution purely # uniform across all translations # (rotation ruins this, since the pivot point is at the camera) cam1_T_cam0 = [ utils_geom.get_random_rt(B, r_amount=0.0, t_amount=1.0), # large motion utils_geom.get_random_rt( B, r_amount=0.0, t_amount=0.1, # small motion sometimes_zero=sometimes_zero) ] cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0] occ0 = occs[:, 0] unp0 = unps[:, 0] occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True) unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0) occs = [occ0, occ1] unps = [unp0, unp1] if do_vis: summ_writer.summ_occs('synth/occs', occs) summ_writer.summ_unps('synth/unps', unps, occs) mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X) cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 # this is B x Z x Y x X x 3 flow = flow.permute(0, 4, 1, 2, 3) # this is B x 3 x Z x Y x X if do_vis: summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0) if do_vis: occ0_e = utils_samp.backwarp_using_3D_flow(occ1, flow, binary_feat=True) unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow) summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e]) summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e], [occ0, occ0_e]) occs = torch.stack(occs, dim=1) unps = torch.stack(unps, dim=1) return occs, unps, flow, cam1_T_cam0
def run_test(self, feed): results = dict() global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist( self.lrt_camX0s) self.original_centroid = self.scene_centroid.clone() obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s) obj_length = obj_lengths[:, 0] for b in list(range(self.B)): if self.score_s[b, 0] < 1.0: # we need the template to exist print('returning early, since score_s[%d,0] = %.1f' % (b, self.score_s[b, 0].cpu().numpy())) return total_loss, results, True # if torch.sum(self.score_s[b]) < (self.S/2): if not (torch.sum(self.score_s[b]) == self.S): # the full traj should be valid print( 'returning early, since sum(score_s) = %d, while S = %d' % (torch.sum(self.score_s).cpu().numpy(), self.S)) return total_loss, results, True if hyp.do_feat3D: feat_memX0_input = torch.cat([ self.occ_memX0s[:, 0], self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0], ], dim=1) _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input) B, C, Z, Y, X = list(feat_memX0.shape) S = self.S obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist( self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1) # only take the occupied voxels occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y, X) # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0 obj_mask_memX0 = obj_mask_memX0s[:, 0] # discard the known freespace _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, 0:1], self.xyz_camXs[:, 0:1], Z, Y, X, agg=True) free_memX0 = free_memX0_.squeeze(1) obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0) for b in list(range(self.B)): if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8: print( 'returning early, since there are not enough valid object points' ) return total_loss, results, True # for b in list(range(self.B)): # sum_b = torch.sum(obj_mask_memX0[b]) # print('sum_b', sum_b.detach().cpu().numpy()) # if sum_b > 1000: # obj_mask_memX0[b] *= occ_memX0[b] # sum_b = torch.sum(obj_mask_memX0[b]) # print('reducing this to', sum_b.detach().cpu().numpy()) feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat0_vec = feat0_vec.permute(0, 2, 1) # this is B x huge x C obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round() occ_mask0_vec = occ_memX0.reshape(B, -1).round() free_mask0_vec = free_memX0.reshape(B, -1).round() # these are B x huge orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X) # this is B x huge x 3 obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist( self.lrt_camX0s) obj_length = obj_lengths[:, 0] cam0_T_obj = cams_T_obj0[:, 0] # this is B x S x 4 x 4 mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X) cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X) lrt_camIs_g = self.lrt_camX0s.clone() lrt_camIs_e = torch.zeros_like(self.lrt_camX0s) # we will fill this up ious = torch.zeros([B, S]).float().cuda() point_counts = np.zeros([B, S]) inb_counts = np.zeros([B, S]) feat_vis = [] occ_vis = [] for s in range(self.S): if not (s == 0): # remake the vox util and all the mem data self.scene_centroid = utils_geom.get_clist_from_lrtlist( lrt_camIs_e[:, s - 1:s])[:, 0] delta = self.scene_centroid - self.original_centroid self.vox_util = vox_util.Vox_util( self.Z, self.Y, self.X, self.set_name, scene_centroid=self.scene_centroid, assert_cube=True) self.occ_memXs = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y, self.X)) self.occ_memX0s = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y, self.X)) self.unp_memXs = __u( self.vox_util.unproject_rgb_to_mem( __p(self.rgb_camXs), self.Z, self.Y, self.X, __p(self.pix_T_cams))) self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs( self.camX0s_T_camXs, self.unp_memXs) self.summ_writer.summ_occ('track/reloc_occ_%d' % s, self.occ_memX0s[:, s]) else: self.summ_writer.summ_occ('track/init_occ_%d' % s, self.occ_memX0s[:, s]) delta = torch.zeros([B, 3]).float().cuda() # print('scene centroid:', self.scene_centroid.detach().cpu().numpy()) occ_vis.append( self.summ_writer.summ_occ('', self.occ_memX0s[:, s], only_return=True)) # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False)) inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s], self.Z4, self.Y4, self.X, already_mem=False) num_inb = torch.sum(inb.float(), axis=1) # print('num_inb', num_inb, num_inb.shape) inb_counts[:, s] = num_inb.cpu().numpy() feat_memI_input = torch.cat([ self.occ_memX0s[:, s], self.unp_memX0s[:, s] * self.occ_memX0s[:, s], ], dim=1) _, feat_memI, valid_memI = self.featnet3D(feat_memI_input) self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s, feat_memI_input, pca=True) self.summ_writer.summ_feat('3D_feats/feat_%d' % s, feat_memI, pca=True) feat_vis.append( self.summ_writer.summ_feat('', feat_memI, pca=True, only_return=True)) # collect freespace here, to discard bad matches _, free_memI_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, s:s + 1], self.xyz_camXs[:, s:s + 1], Z, Y, X, agg=True) free_memI = free_memI_.squeeze(1) feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat_vec = feat_vec.permute(0, 2, 1) # this is B x huge x C memI_T_mem0 = utils_geom.eye_4x4(B) # we will fill this up # # put these on cpu, to save mem # feat0_vec = feat0_vec.detach().cpu() # feat_vec = feat_vec.detach().cpu() # to simplify the impl, we will iterate over the batch dim for b in list(range(B)): feat_vec_b = feat_vec[b] feat0_vec_b = feat0_vec[b] obj_mask0_vec_b = obj_mask0_vec[b] occ_mask0_vec_b = occ_mask0_vec[b] free_mask0_vec_b = free_mask0_vec[b] orig_xyz_b = orig_xyz[b] # these are huge x C careful = False if careful: # start with occ points, since these are definitely observed obj_inds_b = torch.where( (occ_mask0_vec_b * obj_mask0_vec_b) > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # also take random non-free non-occ points in the mask ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * ( 1.0 - free_mask0_vec_b) alt_inds_b = torch.where(ok_mask > 0) alt_vec_b = feat0_vec_b[alt_inds_b] alt_xyz0 = orig_xyz_b[alt_inds_b] # these are med x C # issues arise when "med" is too large num = len(alt_xyz0) max_pts = 2000 if num > max_pts: # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) alt_vec_b = alt_vec_b[perm[:max_pts]] alt_xyz0 = alt_xyz0[perm[:max_pts]] obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0) xyz0 = torch.cat([xyz0, alt_xyz0], dim=0) if s == 0: print('have %d pts in total' % (len(xyz0))) else: # take any points within the mask obj_inds_b = torch.where(obj_mask0_vec_b > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # issues arise when "med" is too large # trim down to max_pts num = len(xyz0) max_pts = 2000 if num > max_pts: print( 'have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) obj_vec_b = obj_vec_b[perm[:max_pts]] xyz0 = xyz0[perm[:max_pts]] obj_vec_b = obj_vec_b.permute(1, 0) # this is is C x med corr_b = torch.matmul(feat_vec_b, obj_vec_b) # this is huge x med heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X) # this is med x 1 x Z4 x Y4 x X4 # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # heat_b_ = heat_b.reshape(-1, Z*Y*X) # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b_min = (torch.min(heat_b_).values) # free_b = free_memI[b:b+1] # print('free_b', free_b.shape) # print('heat_b', heat_b.shape) # heat_b[free_b > 0.0] = heat_b_min # make the min zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_min = (torch.min(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_min # zero out the freespace heat_b = heat_b * (1.0 - free_memI[b:b + 1]) # make the max zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_max = (torch.max(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_max # scale up, for numerical stability heat_b = heat_b * float(len(heat_b[0].reshape(-1))) xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True) # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True) # this is med x 3 xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y, X) xyzI_cam += delta xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1) memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI) # record #points, since ransac depends on this point_counts[b, s] = len(xyz0) # done stepping through batch mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0) cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI, mem_T_cam) # eval camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0, mem_T_cam, cam0_T_obj) # this is B x 4 x 4 lrt_camIs_e[:, s] = utils_geom.merge_lrt(obj_length, camI_T_obj) ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:, s:s + 1]).squeeze(1) results['ious'] = ious # if ious[0,-1] > 0.5: # print('returning early, since acc is too high') # return total_loss, results, True self.summ_writer.summ_rgbs('track/feats', feat_vis) self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False) for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) self.summ_writer.summ_scalar('track/point_counts', np.mean(point_counts)) # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item()) self.summ_writer.summ_scalar('track/inb_counts', np.mean(inb_counts)) lrt_camX0s_e = lrt_camIs_e.clone() lrt_camXs_e = utils_geom.apply_4x4s_to_lrts( self.camXs_T_camX0s, lrt_camX0s_e) if self.include_vis: visX_e = [] for s in list(range(self.S)): visX_e.append( self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s, self.rgb_camXs[:, s], lrt_camXs_e[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e) visX_g = [] for s in list(range(self.S)): visX_g.append( self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s, self.rgb_camXs[:, s], self.lrt_camXs[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2) # this is B x S mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s) median_dist = utils_basic.reduce_masked_median(dists, self.score_s) # this is [] self.summ_writer.summ_scalar('track/centroid_dist_mean', mean_dist.cpu().item()) self.summ_writer.summ_scalar('track/centroid_dist_median', median_dist.cpu().item()) # if self.include_vis: if (True): self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) total_loss += mean_dist # we won't backprop, but it's nice to plot and print this anyway else: ious = torch.zeros([self.B, self.S]).float().cuda() for s in list(range(self.S)): ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( self.lrt_camX0s[:, 0:1], self.lrt_camX0s[:, s:s + 1]).squeeze(1) results['ious'] = ious for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, False
def get_gt_flow(obj_lrtlist_camRs, obj_scorelist, camRs_T_camXs, Z, Y, X, K=2, mod='', vis=True, summ_writer=None): # this constructs the flow field according to the given # box trajectories (obj_lrtlist_camRs) (collected from a moving camR) # and egomotion (encoded in camRs_T_camXs) # (so they do not take into account egomotion) # so, we first generate the flow for all the objects, # then in the background, put the ego flow N, B, S, D = list(obj_lrtlist_camRs.shape) assert (S == 2) # as a flow util, this expects S=2 flows = [] masks = [] for k in list(range(K)): obj_masklistR0 = utils_vox.assemble_padded_obj_masklist( obj_lrtlist_camRs[k, :, 0:1], obj_scorelist[k, :, 0:1], Z, Y, X, coeff=1.0) # this is B x 1(N) x 1(C) x Z x Y x Z # obj_masklistR0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X obj_mask0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) # camR0_T_camR1 = camR0_T_camRs[:,1] # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1) # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0) # if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0) # summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1) if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0) summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod), torch.mean(obj_mask0, 3)) _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k]) # this is B x S x 4 x 4 ref_T_obj0 = ref_T_objs_list[:, 0] ref_T_obj1 = ref_T_objs_list[:, 1] obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0) obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1) # these are B x 4 x 4 mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref) cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) # only use these displaced points within the obj mask # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3) obj_mask0 = obj_mask0.view(B, Z, Y, X, 1) # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0 # cond = (obj_mask03 < 1.0).float() cond = (obj_mask0 > 0.0).float() xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0 flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3) # if vis and k==0: if vis: summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod), flow, clip=4.0) masks.append(obj_mask0) flows.append(flow) camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) bkg_flow = flow # allow zero motion in the bkg any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0] masks.append(1.0 - any_mask) flows.append(bkg_flow) flows = torch.stack(flows, axis=0) masks = torch.stack(masks, axis=0) masks = masks.repeat(1, 1, 3, 1, 1, 1) flow = utils_basic.reduce_masked_mean(flows, masks, dim=0) if vis: summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0) # flow is shaped B x 3 x D x H x W return flow