def Mem2Ref(self, xyz_mem, Z, Y, X): # xyz is B x N x 3, in mem coordinates # transforms mem coordinates into ref coordinates B, N, C = list(xyz_mem.shape) ref_T_mem = self.get_ref_T_mem(B, Z, Y, X) xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem) return xyz_ref
def get_synth_flow(occs, unps, summ_writer, sometimes_zero=False, do_vis=False): B, S, C, Z, Y, X = list(occs.shape) assert (S == 2, C == 1) # we do not sample any rotations here, to keep the distribution purely # uniform across all translations # (rotation ruins this, since the pivot point is at the camera) cam1_T_cam0 = [ utils_geom.get_random_rt(B, r_amount=0.0, t_amount=1.0), # large motion utils_geom.get_random_rt( B, r_amount=0.0, t_amount=0.1, # small motion sometimes_zero=sometimes_zero) ] cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0] occ0 = occs[:, 0] unp0 = unps[:, 0] occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True) unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0) occs = [occ0, occ1] unps = [unp0, unp1] if do_vis: summ_writer.summ_occs('synth/occs', occs) summ_writer.summ_unps('synth/unps', unps, occs) mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X) cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 # this is B x Z x Y x X x 3 flow = flow.permute(0, 4, 1, 2, 3) # this is B x 3 x Z x Y x X if do_vis: summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0) if do_vis: occ0_e = utils_samp.backwarp_using_3D_flow(occ1, flow, binary_feat=True) unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow) summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e]) summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e], [occ0, occ0_e]) occs = torch.stack(occs, dim=1) unps = torch.stack(unps, dim=1) return occs, unps, flow, cam1_T_cam0
def rescore_boxlist_with_pointcloud(camX_T_camR, boxlist_camR, xyz_camX, scorelist, tidlist, thresh=1.0): # boxlist_camR is B x N x 9 B, N, D = list(boxlist_camR.shape) assert (D == 9) xyzlist = boxlist_camR[:, :, :3] # this is B x N x 3 lenlist = boxlist_camR[:, :, 3:7] # this is B x N x 3 xyzlist = utils_geom.apply_4x4(camX_T_camR, xyzlist) # xyz_camX is B x V x 3 xyz_camX = xyz_camX[:, ::10] xyz_camX = xyz_camX.unsqueeze(1) # xyz_camX is B x 1 x V x 3 xyzlist = xyzlist.unsqueeze(2) # xyzlist is B x N x 1 x 3 dists = torch.norm(xyz_camX - xyzlist, dim=3) # this is B x N x V mindists = torch.min(dists, 2)[0] ok = (mindists < thresh).float() scorelist = scorelist * ok return scorelist
def assemble_padded_obj_masklist(self, lrtlist, scorelist, Z, Y, X, coeff=1.0, additive_coeff=0.0): # compute a binary mask in 3d for each object # we use this when computing the center-surround objectness score # lrtlist is B x N x 19 # scorelist is B x N # returns masklist shaped B x N x 1 x Z x Y x X B, N, D = list(lrtlist.shape) assert (D == 19) masks = torch.zeros(B, N, Z, Y, X) lenlist, ref_T_objlist = utils_geom.split_lrtlist(lrtlist) # lenlist is B x N x 3 # ref_T_objlist is B x N x 4 x 4 lenlist_ = lenlist.reshape(B * N, 3) ref_T_objlist_ = ref_T_objlist.reshape(B * N, 4, 4) obj_T_reflist_ = utils_geom.safe_inverse(ref_T_objlist_) # we want a value for each location in the mem grid xyz_mem_ = gridcloud3d(B * N, Z, Y, X) # this is B*N x V x 3, where V = Z*Y*X xyz_ref_ = self.Mem2Ref(xyz_mem_, Z, Y, X) # this is B*N x V x 3 lx, ly, lz = torch.unbind(lenlist_, dim=1) # these are B*N # ref_T_obj = convert_box_to_ref_T_obj(boxes3d) # obj_T_ref = ref_T_obj.inverse() xyz_obj_ = utils_geom.apply_4x4(obj_T_reflist_, xyz_ref_) x, y, z = torch.unbind(xyz_obj_, dim=2) # these are B*N x V lx = lx.unsqueeze(1) * coeff + additive_coeff ly = ly.unsqueeze(1) * coeff + additive_coeff lz = lz.unsqueeze(1) * coeff + additive_coeff # these are B*N x 1 x_valid = (x > -lx / 2.0).byte() & (x < lx / 2.0).byte() y_valid = (y > -ly / 2.0).byte() & (y < ly / 2.0).byte() z_valid = (z > -lz / 2.0).byte() & (z < lz / 2.0).byte() inbounds = x_valid.byte() & y_valid.byte() & z_valid.byte() masklist = inbounds.float() # print(masklist.shape) masklist = masklist.reshape(B, N, 1, Z, Y, X) # print(masklist.shape) # print(scorelist.shape) masklist = masklist * scorelist.view(B, N, 1, 1, 1, 1) return masklist
def Ref2Mem(self, xyz, Z, Y, X, assert_cube=True): # xyz is B x N x 3, in ref coordinates # transforms ref coordinates into mem coordinates B, N, C = list(xyz.shape) assert (C == 3) mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube) xyz = utils_geom.apply_4x4(mem_T_ref, xyz) return xyz
def apply_4x4_to_vox(self, B_T_A, feat_A, already_mem=False, binary_feat=False, rigid=True): # B_T_A is B x 4 x 4 # if already_mem=False, it is a transformation between cam systems # if already_mem=True, it is a transformation between mem systems # feat_A is B x C x Z x Y x X # it represents some scene features in reference/canonical coordinates # we want to go from these coords to some target coords # since this is a backwarp, # the question to ask is: # "WHERE in the tensor do you want to sample, # to replace each voxel's current value?" # the inverse of B_T_A represents this "where"; # it transforms each coordinate in B # to the location we want to sample in A B, C, Z, Y, X = list(feat_A.shape) # we have B_T_A in input, since this follows the other utils_geom.apply_4x4 # for an apply_4x4 func, but really we need A_T_B if rigid: A_T_B = utils_geom.safe_inverse(B_T_A) else: # this op is slower but more powerful A_T_B = B_T_A.inverse() if not already_mem: cam_T_mem = self.get_ref_T_mem(B, Z, Y, X) mem_T_cam = self.get_mem_T_ref(B, Z, Y, X) A_T_B = matmul3(mem_T_cam, A_T_B, cam_T_mem) # we want to sample for each location in the bird grid xyz_B = gridcloud3d(B, Z, Y, X) # this is B x N x 3 # transform xyz_A = utils_geom.apply_4x4(A_T_B, xyz_B) # we want each voxel to take its value # from whatever is at these A coordinates # i.e., we are back-warping from the "A" coords # feat_B = F.grid_sample(feat_A, normalize_grid(xyz_A, Z, Y, X)) feat_B = utils_samp.resample3d(feat_A, xyz_A, binary_feat=binary_feat) # feat_B, valid = utils_samp.resample3d(feat_A, xyz_A, binary_feat=binary_feat) # return feat_B, valid return feat_B
def Zoom2Ref(self, xyz_zoom, lrt_ref, Z, Y, X, additive_pad=0.0): # xyz_zoom is B x N x 3, in zoom coordinates # lrt_ref is B x 9, specifying the box in ref coordinates B, N, _ = list(xyz_zoom.shape) ref_T_zoom = self.get_ref_T_zoom(lrt_ref, Z, Y, X, additive_pad=additive_pad) xyz_ref = utils_geom.apply_4x4(ref_T_zoom, xyz_zoom) return xyz_ref
def rescore_boxlist_with_inbound(camX_T_camR, boxlist_camR, tidlist, Z, Y, X, only_cars=True, pad=2.0): # boxlist_camR is B x N x 9 B, N, D = list(boxlist_camR.shape) assert (D == 9) xyzlist = boxlist_camR[:, :, :3] # this is B x N x 3 lenlist = boxlist_camR[:, :, 3:7] # this is B x N x 3 xyzlist = utils_geom.apply_4x4(camX_T_camR, xyzlist) validlist = 1.0 - (torch.eq(tidlist, -1 * torch.ones_like(tidlist))).float() # this is B x N if only_cars: biglist = (torch.norm(lenlist, dim=2) > 2.0).float() validlist = validlist * biglist xlist, ylist, zlist = torch.unbind(xyzlist, dim=2) inboundlist_0 = utils_vox.get_inbounds(torch.stack( [xlist + pad, ylist, zlist], dim=2), Z, Y, X, already_mem=False).float() inboundlist_1 = utils_vox.get_inbounds(torch.stack( [xlist - pad, ylist, zlist], dim=2), Z, Y, X, already_mem=False).float() inboundlist_2 = utils_vox.get_inbounds(torch.stack( [xlist, ylist, zlist + pad], dim=2), Z, Y, X, already_mem=False).float() inboundlist_3 = utils_vox.get_inbounds(torch.stack( [xlist, ylist, zlist - pad], dim=2), Z, Y, X, already_mem=False).float() inboundlist = inboundlist_0 * inboundlist_1 * inboundlist_2 * inboundlist_3 scorelist = validlist * inboundlist return scorelist
def Ref2Zoom(self, xyz_ref, lrt_ref, Z, Y, X, additive_pad=0.0): # xyz_ref is B x N x 3, in ref coordinates # lrt_ref is B x 19, specifying the box in ref coordinates # this transforms ref coordinates into zoom coordinates B, N, _ = list(xyz_ref.shape) zoom_T_ref = self.get_zoom_T_ref(lrt_ref, Z, Y, X, additive_pad=additive_pad) xyz_zoom = utils_geom.apply_4x4(zoom_T_ref, xyz_ref) return xyz_zoom
def unproject_rgb_to_mem(self, rgb_camB, Z, Y, X, pixB_T_camA): # rgb_camB is B x C x H x W # pixB_T_camA is B x 4 x 4 # rgb lives in B pixel coords # we want everything in A memory coords # this puts each C-dim pixel in the rgb_camB # along a ray in the voxelgrid B, C, H, W = list(rgb_camB.shape) xyz_memA = gridcloud3d(B, Z, Y, X, norm=False) # grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X) # # these are B x Z x Y x X # # these represent the mem grid coordinates # # we need to convert these to pixel coordinates # x = torch.reshape(grid_x, [B, -1]) # y = torch.reshape(grid_y, [B, -1]) # z = torch.reshape(grid_z, [B, -1]) # # these are B x N # xyz_mem = torch.stack([x, y, z], dim=2) xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X) xyz_pixB = utils_geom.apply_4x4(pixB_T_camA, xyz_camA) normalizer = torch.unsqueeze(xyz_pixB[:, :, 2], 2) EPS = 1e-6 xy_pixB = xyz_pixB[:, :, :2] / torch.clamp(normalizer, min=EPS) # this is B x N x 2 # this is the (floating point) pixel coordinate of each voxel x_pixB, y_pixB = xy_pixB[:, :, 0], xy_pixB[:, :, 1] # these are B x N if (0): # handwritten version values = torch.zeros([B, C, Z * Y * X], dtype=torch.float32) for b in list(range(B)): values[b] = utils_samp.bilinear_sample_single( rgb_camB[b], x_pixB[b], y_pixB[b]) else: # native pytorch version y_pixB, x_pixB = normalize_grid2d(y_pixB, x_pixB, H, W) # since we want a 3d output, we need 5d tensors z_pixB = torch.zeros_like(x_pixB) xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2) rgb_camB = rgb_camB.unsqueeze(2) xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3]) values = F.grid_sample(rgb_camB, xyz_pixB) values = torch.reshape(values, (B, C, Z, Y, X)) return values
def resample_to_view(feats, new_T_old, multi=False): # feats is B x S x c x Y x X x Z # it represents some scene features in reference/canonical coordinates # we want to go from these coords to some target coords # new_T_old is B x 4 x 4 # it represents a transformation between two "mem" systems # or if multi=True, it's B x S x 4 x 4 B, S, C, Z, Y, X = list(feats.shape) # we want to sample for each location in the bird grid # xyz_mem = gridcloud3d(B, Z, Y, X) grid_y, grid_x, grid_z = meshgrid3d(B, Z, Y, X) # these are B x BY x BX x BZ # these represent the mem grid coordinates # we need to convert these to pixel coordinates x = torch.reshape(grid_x, [B, -1]) y = torch.reshape(grid_y, [B, -1]) z = torch.reshape(grid_z, [B, -1]) # these are B x N xyz_mem = torch.stack([x, y, z], dim=2) # this is B x N x 3 xyz_mems = xyz_mem.unsqueeze(1).repeat(1, S, 1, 1) # this is B x S x N x 3 xyz_mems_ = xyz_mems.view(B * S, Y * X * Z, 3) feats_ = feats.view(B * S, C, Z, Y, X) if multi: new_T_olds = new_T_old.clone() else: new_T_olds = new_T_old.unsqueeze(1).repeat(1, S, 1, 1) new_T_olds_ = new_T_olds.view(B * S, 4, 4) xyz_new_ = utils_geom.apply_4x4(new_T_olds_, xyz_mems_) # we want each voxel to replace its value # with whatever is at these new coordinates # i.e., we are back-warping from the "new" coords feats_, valid_ = utils_samp.resample3d(feats_, xyz_new_) feats = feats_.view(B, S, C, Z, Y, X) valid = valid_.view(B, S, 1, Z, Y, X) return feats, valid
def prep_occs_supervision(self, camRs_T_camXs, xyz_camXs, Z, Y, X, agg=False): B, S, N, D = list(xyz_camXs.size()) assert (D == 3) # occRs_half = __u(utils.vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2)) # utils for packing/unpacking along seq dim __p = lambda x: pack_seqdim(x, B) __u = lambda x: unpack_seqdim(x, B) camRs_T_camXs_ = __p(camRs_T_camXs) xyz_camXs_ = __p(xyz_camXs) xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, xyz_camXs_) occXs_ = self.voxelize_xyz(xyz_camXs_, Z, Y, X) occRs_ = self.voxelize_xyz(xyz_camRs_, Z, Y, X) # note we must compute freespace in the given view, # then warp to the target view freeXs_ = self.get_freespace(xyz_camXs_, occXs_) freeRs_ = self.apply_4x4_to_vox(camRs_T_camXs_, freeXs_) occXs = __u(occXs_) occRs = __u(occRs_) freeXs = __u(freeXs_) freeRs = __u(freeRs_) # these are B x S x 1 x Z x Y x X if agg: # note we should only agg if we are in STATIC mode (time frozen) freeR = torch.max(freeRs, dim=1)[0] occR = torch.max(occRs, dim=1)[0] # these are B x 1 x Z x Y x X occR = (occR > 0.5).float() freeR = (freeR > 0.5).float() return occR, freeR, occXs, freeXs else: occRs = (occRs > 0.5).float() freeRs = (freeRs > 0.5).float() return occRs, freeRs, occXs, freeXs
def apply_mem_T_ref_to_lrtlist(self, lrtlist_cam, Z, Y, X, assert_cube=True): # lrtlist is B x N x 19, in cam coordinates # transforms them into mem coordinates, including a scale change for the lengths B, N, C = list(lrtlist_cam.shape) assert (C == 19) mem_T_cam = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube) # apply_4x4 will work for the t part lenlist_cam, rtlist_cam = utils_geom.split_lrtlist(lrtlist_cam) __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) rlist_cam_, tlist_cam_ = utils_geom.split_rt(__p(rtlist_cam)) # rlist_cam_ is B*N x 3 x 3 # tlist_cam_ is B*N x 3 # tlist_cam = __u(tlist_cam_) tlist_mem_ = __p(utils_geom.apply_4x4(mem_T_cam, __u(tlist_cam_))) # rlist does not need to change, since cam is aligned with mem rlist_mem_ = rlist_cam_.clone() rtlist_mem = __u(utils_geom.merge_rt(rlist_mem_, tlist_mem_)) # this is B x N x 4 x 4 # next we need to scale the lengths lenlist_cam, _ = utils_geom.split_lrtlist(lrtlist_cam) # this is B x N x 3 xlist, ylist, zlist = lenlist_cam.chunk(3, dim=2) vox_size_X = (self.XMAX - self.XMIN) / float(X) vox_size_Y = (self.YMAX - self.YMIN) / float(Y) vox_size_Z = (self.ZMAX - self.ZMIN) / float(Z) lenlist_mem = torch.cat( [xlist / vox_size_X, ylist / vox_size_Y, zlist / vox_size_Z], dim=2) # merge up lrtlist_mem = utils_geom.merge_lrtlist(lenlist_mem, rtlist_mem) return lrtlist_mem
def apply_pixX_T_memR_to_voxR(self, pix_T_camX, camX_T_camR, voxR, D, H, W, z_far=None, noise_amount=0.0, grid_z_vec=None, logspace_slices=False): # mats are B x 4 x 4 # voxR is B x C x Z x Y x X # H, W, D indicates how big to make the output # returns B x C x D x H x W B, C, Z, Y, X = list(voxR.shape) # z_near = np.maximum(self.ZMIN, 0.1) # z_far = self.ZMAX z_near = 0.1 if z_far is None: z_far = self.ZMAX print(z_far) if grid_z_vec is None: if logspace_slices: grid_z_vec = torch.exp( torch.linspace(np.log(z_near), np.log(z_far), steps=D, dtype=torch.float32, device=torch.device('cuda'))) if noise_amount > 0.: print('cannot add noise to logspace sampling yet') else: grid_z_vec = torch.linspace(z_near, z_far, steps=D, dtype=torch.float32, device=torch.device('cuda')) if noise_amount > 0.: diff = grid_z_vec[1] - grid_z_vec[0] noise = torch.rand(grid_z_vec.shape).float().cuda( ) * diff * 0.5 * noise_amount # noise = torch.randn(grid_z_vec.shape).float().cuda() * noise_std # noise = torch.randn(grid_z_vec.shape).float().cuda() * noise_std grid_z_vec = grid_z_vec + noise grid_z_vec = grid_z_vec.clamp(min=z_near) grid_z = torch.reshape(grid_z_vec, [1, 1, D, 1, 1]) grid_z = grid_z.repeat([B, 1, 1, H, W]) grid_z = torch.reshape(grid_z, [B * D, 1, H, W]) pix_T_camX__ = torch.unsqueeze(pix_T_camX, axis=1).repeat([1, D, 1, 1]) pix_T_camX = torch.reshape(pix_T_camX__, [B * D, 4, 4]) xyz_camX = utils_geom.depth2pointcloud(grid_z, pix_T_camX) camR_T_camX = utils_geom.safe_inverse(camX_T_camR) camR_T_camX_ = torch.unsqueeze(camR_T_camX, dim=1).repeat([1, D, 1, 1]) camR_T_camX = torch.reshape(camR_T_camX_, [B * D, 4, 4]) mem_T_cam = self.get_mem_T_ref(B * D, Z, Y, X) memR_T_camX = matmul2(mem_T_cam, camR_T_camX) xyz_memR = utils_geom.apply_4x4(memR_T_camX, xyz_camX) xyz_memR = torch.reshape(xyz_memR, [B, D * H * W, 3]) samp = utils_samp.sample3d(voxR, xyz_memR, D, H, W) # samp is B x H x W x D x C return samp, grid_z_vec
def get_gt_flow(obj_lrtlist_camRs, obj_scorelist, camRs_T_camXs, Z, Y, X, K=2, mod='', vis=True, summ_writer=None): # this constructs the flow field according to the given # box trajectories (obj_lrtlist_camRs) (collected from a moving camR) # and egomotion (encoded in camRs_T_camXs) # (so they do not take into account egomotion) # so, we first generate the flow for all the objects, # then in the background, put the ego flow N, B, S, D = list(obj_lrtlist_camRs.shape) assert (S == 2) # as a flow util, this expects S=2 flows = [] masks = [] for k in list(range(K)): obj_masklistR0 = utils_vox.assemble_padded_obj_masklist( obj_lrtlist_camRs[k, :, 0:1], obj_scorelist[k, :, 0:1], Z, Y, X, coeff=1.0) # this is B x 1(N) x 1(C) x Z x Y x Z # obj_masklistR0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X obj_mask0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) # camR0_T_camR1 = camR0_T_camRs[:,1] # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1) # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0) # if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0) # summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1) if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0) summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod), torch.mean(obj_mask0, 3)) _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k]) # this is B x S x 4 x 4 ref_T_obj0 = ref_T_objs_list[:, 0] ref_T_obj1 = ref_T_objs_list[:, 1] obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0) obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1) # these are B x 4 x 4 mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref) cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) # only use these displaced points within the obj mask # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3) obj_mask0 = obj_mask0.view(B, Z, Y, X, 1) # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0 # cond = (obj_mask03 < 1.0).float() cond = (obj_mask0 > 0.0).float() xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0 flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3) # if vis and k==0: if vis: summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod), flow, clip=4.0) masks.append(obj_mask0) flows.append(flow) camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) bkg_flow = flow # allow zero motion in the bkg any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0] masks.append(1.0 - any_mask) flows.append(bkg_flow) flows = torch.stack(flows, axis=0) masks = torch.stack(masks, axis=0) masks = masks.repeat(1, 1, 3, 1, 1, 1) flow = utils_basic.reduce_masked_mean(flows, masks, dim=0) if vis: summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0) # flow is shaped B x 3 x D x H x W return flow