def crop_zoom_from_mem(mem, lrt, Z2, Y2, X2, additive_pad=0.1): # mem is B x C x Z x Y x X # lrt is B x 19 B, C, Z, Y, X = list(mem.shape) B2, E = list(lrt.shape) assert (E == 19) assert (B == B2) # for each voxel in the zoom grid, i want to # sample a voxel from the mem # this puts each C-dim pixel in the image # along a ray in the zoomed voxelgrid xyz_zoom = utils_basic.gridcloud3D(B, Z2, Y2, X2, norm=False) # these represent the zoom grid coordinates # we need to convert these to mem coordinates xyz_ref = Zoom2Ref(xyz_zoom, lrt, Z2, Y2, X2, additive_pad=additive_pad) xyz_mem = Ref2Mem(xyz_ref, Z, Y, X) zoom = utils_samp.sample3D(mem, xyz_mem, Z2, Y2, X2) zoom = torch.reshape(zoom, [B, C, Z2, Y2, X2]) return zoom
def get_synth_flow(occs, unps, summ_writer, sometimes_zero=False, do_vis=False): B, S, C, Z, Y, X = list(occs.shape) assert (S == 2, C == 1) # we do not sample any rotations here, to keep the distribution purely # uniform across all translations # (rotation ruins this, since the pivot point is at the camera) cam1_T_cam0 = [ utils_geom.get_random_rt(B, r_amount=0.0, t_amount=1.0), # large motion utils_geom.get_random_rt( B, r_amount=0.0, t_amount=0.1, # small motion sometimes_zero=sometimes_zero) ] cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0] occ0 = occs[:, 0] unp0 = unps[:, 0] occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True) unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0) occs = [occ0, occ1] unps = [unp0, unp1] if do_vis: summ_writer.summ_occs('synth/occs', occs) summ_writer.summ_unps('synth/unps', unps, occs) mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X) cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 # this is B x Z x Y x X x 3 flow = flow.permute(0, 4, 1, 2, 3) # this is B x 3 x Z x Y x X if do_vis: summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0) if do_vis: occ0_e = utils_samp.backwarp_using_3D_flow(occ1, flow, binary_feat=True) unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow) summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e]) summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e], [occ0, occ0_e]) occs = torch.stack(occs, dim=1) unps = torch.stack(unps, dim=1) return occs, unps, flow, cam1_T_cam0
def compute_mem1_T_mem0_from_object_flow(flow_mem, mask_mem, occ_mem): B, C, Z, Y, X = list(flow_mem.shape) assert(C==3) mem1_T_mem0 = utils_geom.eye_4x4(B) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X, norm=False) for b in list(range(B)): # i think there is a way to parallelize the where/gather but it is beyond me right now occ = occ_mem[b] mask = mask_mem[b] flow = flow_mem[b] xyz0 = xyz_mem0[b] # cam_T_obj = camR_T_obj[b] # mem_T_cam = mem_T_ref[b] flow = flow.reshape(3, -1).permute(1, 0) # flow is -1 x 3 inds = torch.where((occ*mask).reshape(-1) > 0.5) # inds is ? flow = flow[inds] xyz0 = xyz0[inds] xyz1 = xyz0 + flow mem1_T_mem0_ = rigid_transform_3D(xyz0, xyz1) # this is 4 x 4 mem1_T_mem0[b] = mem1_T_mem0_ return mem1_T_mem0
def forward(self, template_feat, search_feat, template_mask, template_lrt, search_lrt, vox_util, lrt_cam0s, summ_writer=None): # template_feat is the thing we are searching for; it is B x C x ZZ x ZY x ZX # search_feat is the featuremap where we are searching; it is B x C x Z x Y x X total_loss = torch.tensor(0.0).cuda() B, C, ZZ, ZY, ZX = list(template_feat.shape) _, _, Z, Y, X = list(search_feat.shape) xyz0_template = utils_basic.gridcloud3D(B, ZZ, ZY, ZX) # this is B x med x 3 xyz0_cam = vox_util.Zoom2Ref(xyz0_template, template_lrt, ZZ, ZY, ZX) # ok, next, after i relocate the object in search coords, # i need to transform those coords into cam, and then do svd on that # print('template_feat', template_feat.shape) # print('search_feat', search_feat.shape) search_feat = search_feat.view(B, C, -1) # this is B x C x huge template_feat = template_feat.view(B, C, -1) # this is B x C x med template_mask = template_mask.view(B, -1) # this is B x med # next i need to sample # i would like to take N random samples within the mask cam1_T_cam0_e = utils_geom.eye_4x4(B) # to simplify the impl, we will iterate over the batch dim for b in list(range(B)): template_feat_b = template_feat[b] template_mask_b = template_mask[b] search_feat_b = search_feat[b] xyz0_cam_b = xyz0_cam[b] # print('xyz0_cam_b', xyz0_cam_b.shape) # print('template_mask_b', template_mask_b.shape) # print('template_mask_b sum', torch.sum(template_mask_b).cpu().numpy()) # take any points within the mask inds = torch.where(template_mask_b > 0) # gather up template_feat_b = template_feat_b.permute(1, 0) # this is C x med template_feat_b = template_feat_b[inds] xyz0_cam_b = xyz0_cam_b[inds] # these are self.num_pts x C # print('inds', inds) # not sure why this is a tuple # inds = inds[0] # trim down to self.num_pts # inds = inds.squeeze() assert (len(xyz0_cam_b) > 8) # otw we should have returned early # i want to have self.num_pts pts every time if len(xyz0_cam_b) < self.num_pts: reps = int(self.num_pts / len(xyz0_cam_b)) + 1 print('only have %d pts; repeating %d times...' % (len(xyz0_cam_b), reps)) xyz0_cam_b = xyz0_cam_b.repeat(reps, 1) template_feat_b = template_feat_b.repeat(reps, 1) assert (len(xyz0_cam_b) >= self.num_pts) # now trim down perm = np.random.permutation(len(xyz0_cam_b)) # print('perm', perm[:10]) xyz0_cam_b = xyz0_cam_b[perm[:self.num_pts]] template_feat_b = template_feat_b[perm[:self.num_pts]] heat_b = torch.matmul(template_feat_b, search_feat_b) # this is self.num_pts x huge # it represents each point's heatmap in the search region # make the min zero heat_b = heat_b - (torch.min(heat_b, dim=1).values).unsqueeze(1) # scale up, for numerical stability heat_b = heat_b * float(len(heat_b[0].reshape(-1))) heat_b = heat_b.reshape(self.num_pts, 1, Z, Y, X) xyz1_search_b = utils_basic.argmax3D(heat_b, hard=False, stack=True) # this is self.num_pts x 3 # i need to get to cam coords xyz1_cam_b = vox_util.Zoom2Ref(xyz1_search_b.unsqueeze(0), search_lrt[b:b + 1], Z, Y, X).squeeze(0) # print('xyz0, xyz1', xyz0_cam_b.shape, xyz1_cam_b.shape) # cam1_T_cam0_e[b] = utils_track.rigid_transform_3D(xyz0_cam_b, xyz1_cam_b) # cam1_T_cam0_e[b] = utils_track.differentiable_rigid_transform_3D(xyz0_cam_b, xyz1_cam_b) cam1_T_cam0_e[b] = utils_track.rigid_transform_3D( xyz0_cam_b, xyz1_cam_b) _, rt_cam0_g = utils_geom.split_lrt(lrt_cam0s[:, 0]) _, rt_cam1_g = utils_geom.split_lrt(lrt_cam0s[:, 1]) # these represent ref_T_obj cam1_T_cam0_g = torch.matmul(rt_cam1_g, rt_cam0_g.inverse()) # cam1_T_cam0_e = cam1_T_cam0_g lrt_cam1_e = utils_geom.apply_4x4_to_lrtlist(cam1_T_cam0_e, lrt_cam0s[:, 0:1]).squeeze(1) # lrt_cam1_g = lrt_cam0s[:,1] # _, rt_cam1_e = utils_geom.split_lrt(lrt_cam1_e) # _, rt_cam1_g = utils_geom.split_lrt(lrt_cam1_g) # let's try the cube loss lx, ly, lz = 1.0, 1.0, 1.0 x = np.array([ lx / 2., lx / 2., -lx / 2., -lx / 2., lx / 2., lx / 2., -lx / 2., -lx / 2. ]) y = np.array([ ly / 2., ly / 2., ly / 2., ly / 2., -ly / 2., -ly / 2., -ly / 2., -ly / 2. ]) z = np.array([ lz / 2., -lz / 2., -lz / 2., lz / 2., lz / 2., -lz / 2., -lz / 2., lz / 2. ]) xyz = np.stack([x, y, z], axis=1) # this is 8 x 3 xyz = torch.from_numpy(xyz).float().cuda() xyz = xyz.reshape(1, 8, 3) # this is B x 8 x 3 # xyz_e = utils_geom.apply_4x4(rt_cam1_e, xyz) # xyz_g = utils_geom.apply_4x4(rt_cam1_g, xyz) xyz_e = utils_geom.apply_4x4(cam1_T_cam0_e, xyz) xyz_g = utils_geom.apply_4x4(cam1_T_cam0_g, xyz) # print('xyz_e', xyz_e.detach().cpu().numpy()) # print('xyz_g', xyz_g.detach().cpu().numpy()) corner_loss = self.smoothl1(xyz_e, xyz_g) total_loss = utils_misc.add_loss('robust/corner_loss', total_loss, corner_loss, hyp.robust_corner_coeff, summ_writer) # rot_e, t_e = utils_geom.split_rt(rt_cam1_e) # rot_g, t_g = utils_geom.split_rt(rt_cam1_g) rot_e, t_e = utils_geom.split_rt(cam1_T_cam0_e) rot_g, t_g = utils_geom.split_rt(cam1_T_cam0_g) rx_e, ry_e, rz_e = utils_geom.rotm2eul(rot_e) rx_g, ry_g, rz_g = utils_geom.rotm2eul(rot_g) rad_e = torch.stack([rx_e, ry_e, rz_e], dim=1) rad_g = torch.stack([rx_g, ry_g, rz_g], dim=1) deg_e = utils_geom.rad2deg(rad_e) deg_g = utils_geom.rad2deg(rad_g) r_loss = self.smoothl1(deg_e, deg_g) t_loss = self.smoothl1(t_e, t_g) total_loss = utils_misc.add_loss('robust/r_loss', total_loss, r_loss, hyp.robust_r_coeff, summ_writer) total_loss = utils_misc.add_loss('robust/t_loss', total_loss, t_loss, hyp.robust_t_coeff, summ_writer) # print('r_loss', r_loss.detach().cpu().numpy()) # print('t_loss', t_loss.detach().cpu().numpy()) return lrt_cam1_e, total_loss
def forward(self, feat_cam0, feat_cam1, mask_mem0, pix_T_cam0, pix_T_cam1, cam1_T_cam0, vox_util, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(mask_mem0.shape) assert (C == 1) B2, C, H, W = list(feat_cam0.shape) assert (B == B2) go_slow = True go_slow = False if go_slow: xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) mask_mem0 = mask_mem0.reshape(B, Z * Y * X) vec0_list = [] vec1_list = [] for b in list(range(B)): xyz_mem0_b = xyz_mem0[b] mask_mem0_b = mask_mem0[b] xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)] # this is N x 3 N, D = list(xyz_mem0_b.shape) if N > self.num_samples: # to not waste time, i will subsample right here perm = np.random.permutation(N) xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]] # this is num_samples x 3 (smaller than before) xyz_cam0_b = vox_util.Mem2Ref(xyz_mem0_b.unsqueeze(0), Z, Y, X) xyz_cam1_b = utils_geom.apply_4x4(cam1_T_cam0[b:b + 1], xyz_cam0_b) # these are N x 3 # now, i need to project both of these, and sample from the feats xy_cam0_b = utils_geom.apply_pix_T_cam(pix_T_cam0[b:b + 1], xyz_cam0_b).squeeze(0) xy_cam1_b = utils_geom.apply_pix_T_cam(pix_T_cam1[b:b + 1], xyz_cam1_b).squeeze(0) # these are N x 2 vec0 = utils_samp.bilinear_sample_single( feat_cam0[b], xy_cam0_b[:, 0], xy_cam0_b[:, 1]) vec1 = utils_samp.bilinear_sample_single( feat_cam1[b], xy_cam1_b[:, 0], xy_cam1_b[:, 1]) # these are C x N x_pix0 = xy_cam0_b[:, 0] y_pix0 = xy_cam0_b[:, 1] x_pix1 = xy_cam1_b[:, 0] y_pix1 = xy_cam1_b[:, 1] y_pix0, x_pix0 = utils_basic.normalize_grid2D( y_pix0, x_pix0, H, W) y_pix1, x_pix1 = utils_basic.normalize_grid2D( y_pix1, x_pix1, H, W) xy_pix0 = torch.stack([x_pix0, y_pix0], axis=1).unsqueeze(0) xy_pix1 = torch.stack([x_pix1, y_pix1], axis=1).unsqueeze(0) # these are 1 x N x 2 print('xy_pix0', xy_pix0.shape) vec0 = F.grid_sample(feat_cam0[b:b + 1], xy_pix0) vec1 = F.grid_sample(feat_cam1[b:b + 1], xy_pix1) print('vec0', vec0.shape) vec0_list.append(vec0) vec1_list.append(vec1) vec0 = torch.cat(vec0_list, dim=1).permute(1, 0) vec1 = torch.cat(vec1_list, dim=1).permute(1, 0) else: xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) mask_mem0 = mask_mem0.reshape(B, Z * Y * X) valid_batches = 0 sampling_coords_mem0 = torch.zeros(B, self.num_samples, 3).float().cuda() valid_feat_cam0 = torch.zeros_like(feat_cam0) valid_feat_cam1 = torch.zeros_like(feat_cam1) valid_pix_T_cam0 = torch.zeros_like(pix_T_cam0) valid_pix_T_cam1 = torch.zeros_like(pix_T_cam1) valid_cam1_T_cam0 = torch.zeros_like(cam1_T_cam0) # sampling_coords_mem1 = torch.zeros(B, self.num_samples, 3).float().cuda() for b in list(range(B)): xyz_mem0_b = xyz_mem0[b] mask_mem0_b = mask_mem0[b] xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)] # this is N x 3 N, D = list(xyz_mem0_b.shape) if N >= self.num_samples: perm = np.random.permutation(N) xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]] # this is num_samples x 3 (smaller than before) valid_batches += 1 # sampling_coords_mem0[valid_batches] = xyz_mem0_b sampling_coords_mem0[b] = xyz_mem0_b valid_feat_cam0[b] = feat_cam0[b] valid_feat_cam1[b] = feat_cam1[b] valid_pix_T_cam0[b] = pix_T_cam0[b] valid_pix_T_cam1[b] = pix_T_cam1[b] valid_cam1_T_cam0[b] = cam1_T_cam0[b] print('valid_batches:', valid_batches) if valid_batches == 0: # return early return total_loss # trim down sampling_coords_mem0 = sampling_coords_mem0[:valid_batches] feat_cam0 = valid_feat_cam0[:valid_batches] feat_cam1 = valid_feat_cam1[:valid_batches] pix_T_cam0 = valid_pix_T_cam0[:valid_batches] pix_T_cam1 = valid_pix_T_cam1[:valid_batches] cam1_T_cam0 = valid_cam1_T_cam0[:valid_batches] xyz_cam0 = vox_util.Mem2Ref(sampling_coords_mem0, Z, Y, X) xyz_cam1 = utils_geom.apply_4x4(cam1_T_cam0, xyz_cam0) # these are B x N x 3 # now, i need to project both of these, and sample from the feats xy_cam0 = utils_geom.apply_pix_T_cam(pix_T_cam0, xyz_cam0) xy_cam1 = utils_geom.apply_pix_T_cam(pix_T_cam1, xyz_cam1) # these are B x N x 2 vec0 = utils_samp.bilinear_sample2D(feat_cam0, xy_cam0[:, :, 0], xy_cam0[:, :, 1]) vec1 = utils_samp.bilinear_sample2D(feat_cam1, xy_cam1[:, :, 0], xy_cam1[:, :, 1]) # these are B x C x N vec0 = vec0.permute(0, 2, 1).view(valid_batches * self.num_samples, C) vec1 = vec1.permute(0, 2, 1).view(valid_batches * self.num_samples, C) print('vec0', vec0.shape) print('vec1', vec1.shape) # these are N x C # # where g is valid, we use it as reference and pull up e # margin_loss = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer) # l2_loss = reduce_masked_mean(sql2_on_axis(emb_e-emb_g.detach(), 1, keepdim=True), vis_g) # total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer) # total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer) ce_loss = self.compute_ce_loss(vec0, vec1.detach()) total_loss = utils_misc.add_loss('tri2D/emb_ce_loss', total_loss, ce_loss, hyp.tri_2D_ce_coeff, summ_writer) # l2_loss_im = torch.mean(sql2_on_axis(emb_e-emb_g, 1, keepdim=True), dim=3) # if summ_writer is not None: # summ_writer.summ_oned('emb3D/emb_3D_l2_loss', l2_loss_im) # summ_writer.summ_feats('emb3D/embs_3D', [emb_e, emb_g], pca=True) return total_loss
def assemble(bkg_feat0, obj_feat0, origin_T_camRs, camRs_T_zoom): # let's first assemble the seq of background tensors # this should effectively CREATE egomotion # i fully expect we can do this all in one shot # note it makes sense to create egomotion here, because # we want to predict each view B, C, Z, Y, X = list(bkg_feat0.shape) B2, C2, Z2, Y2, X2 = list(obj_feat0.shape) assert (B == B2) assert (C == C2) B, S, _, _ = list(origin_T_camRs.shape) # ok, we have everything we need # for each timestep, we want to warp the bkg to this timestep # utils for packing/unpacking along seq dim __p = lambda x: pack_seqdim(x, B) __u = lambda x: unpack_seqdim(x, B) # we in fact have utils for this already cam0s_T_camRs = utils_geom.get_camM_T_camXs(origin_T_camRs, ind=0) camRs_T_cam0s = __u(utils_geom.safe_inverse(__p(cam0s_T_camRs))) bkg_feat0s = bkg_feat0.unsqueeze(1).repeat(1, S, 1, 1, 1, 1) bkg_featRs = apply_4x4s_to_voxs(camRs_T_cam0s, bkg_feat0s) # now for the objects # we want to sample for each location in the bird grid xyz_mems_ = utils_basic.gridcloud3D(B * S, Z, Y, X, norm=False) # this is B*S x Z*Y*X x 3 xyz_camRs_ = Mem2Ref(xyz_mems_, Z, Y, X) camRs_T_zoom_ = __p(camRs_T_zoom) zoom_T_camRs_ = camRs_T_zoom_.inverse( ) # note this is not a rigid transform xyz_zooms_ = utils_geom.apply_4x4(zoom_T_camRs_, xyz_camRs_) # we will do the whole traj at once (per obj) # note we just have one feat for the whole traj, so we tile up obj_feats = obj_feat0.unsqueeze(1).repeat(1, S, 1, 1, 1, 1) obj_feats_ = __p(obj_feats) # this is B*S x Z x Y x X x C # to sample, we need feats_ in ZYX order obj_featRs_ = utils_samp.sample3D(obj_feats_, xyz_zooms_, Z, Y, X) obj_featRs = __u(obj_featRs_) # overweigh objects, so that we essentially overwrite # featRs = 0.05*bkg_featRs + 0.95*obj_featRs # overwrite the bkg at the object obj_mask = (bkg_featRs > 0).float() featRs = obj_featRs + (1.0 - obj_mask) * bkg_featRs # note the normalization (next) will restore magnitudes for the bkg # # featRs = bkg_featRs # featRs = obj_featRs # l2 normalize on chans featRs = l2_normalize(featRs, dim=2) validRs = 1.0 - (featRs == 0).all(dim=2, keepdim=True).float().cuda() return featRs, validRs, bkg_featRs, obj_featRs
def run_test(self, feed): results = dict() global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist( self.lrt_camX0s) self.original_centroid = self.scene_centroid.clone() obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s) obj_length = obj_lengths[:, 0] for b in list(range(self.B)): if self.score_s[b, 0] < 1.0: # we need the template to exist print('returning early, since score_s[%d,0] = %.1f' % (b, self.score_s[b, 0].cpu().numpy())) return total_loss, results, True # if torch.sum(self.score_s[b]) < (self.S/2): if not (torch.sum(self.score_s[b]) == self.S): # the full traj should be valid print( 'returning early, since sum(score_s) = %d, while S = %d' % (torch.sum(self.score_s).cpu().numpy(), self.S)) return total_loss, results, True if hyp.do_feat3D: feat_memX0_input = torch.cat([ self.occ_memX0s[:, 0], self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0], ], dim=1) _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input) B, C, Z, Y, X = list(feat_memX0.shape) S = self.S obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist( self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1) # only take the occupied voxels occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y, X) # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0 obj_mask_memX0 = obj_mask_memX0s[:, 0] # discard the known freespace _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, 0:1], self.xyz_camXs[:, 0:1], Z, Y, X, agg=True) free_memX0 = free_memX0_.squeeze(1) obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0) for b in list(range(self.B)): if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8: print( 'returning early, since there are not enough valid object points' ) return total_loss, results, True # for b in list(range(self.B)): # sum_b = torch.sum(obj_mask_memX0[b]) # print('sum_b', sum_b.detach().cpu().numpy()) # if sum_b > 1000: # obj_mask_memX0[b] *= occ_memX0[b] # sum_b = torch.sum(obj_mask_memX0[b]) # print('reducing this to', sum_b.detach().cpu().numpy()) feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat0_vec = feat0_vec.permute(0, 2, 1) # this is B x huge x C obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round() occ_mask0_vec = occ_memX0.reshape(B, -1).round() free_mask0_vec = free_memX0.reshape(B, -1).round() # these are B x huge orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X) # this is B x huge x 3 obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist( self.lrt_camX0s) obj_length = obj_lengths[:, 0] cam0_T_obj = cams_T_obj0[:, 0] # this is B x S x 4 x 4 mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X) cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X) lrt_camIs_g = self.lrt_camX0s.clone() lrt_camIs_e = torch.zeros_like(self.lrt_camX0s) # we will fill this up ious = torch.zeros([B, S]).float().cuda() point_counts = np.zeros([B, S]) inb_counts = np.zeros([B, S]) feat_vis = [] occ_vis = [] for s in range(self.S): if not (s == 0): # remake the vox util and all the mem data self.scene_centroid = utils_geom.get_clist_from_lrtlist( lrt_camIs_e[:, s - 1:s])[:, 0] delta = self.scene_centroid - self.original_centroid self.vox_util = vox_util.Vox_util( self.Z, self.Y, self.X, self.set_name, scene_centroid=self.scene_centroid, assert_cube=True) self.occ_memXs = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y, self.X)) self.occ_memX0s = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y, self.X)) self.unp_memXs = __u( self.vox_util.unproject_rgb_to_mem( __p(self.rgb_camXs), self.Z, self.Y, self.X, __p(self.pix_T_cams))) self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs( self.camX0s_T_camXs, self.unp_memXs) self.summ_writer.summ_occ('track/reloc_occ_%d' % s, self.occ_memX0s[:, s]) else: self.summ_writer.summ_occ('track/init_occ_%d' % s, self.occ_memX0s[:, s]) delta = torch.zeros([B, 3]).float().cuda() # print('scene centroid:', self.scene_centroid.detach().cpu().numpy()) occ_vis.append( self.summ_writer.summ_occ('', self.occ_memX0s[:, s], only_return=True)) # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False)) inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s], self.Z4, self.Y4, self.X, already_mem=False) num_inb = torch.sum(inb.float(), axis=1) # print('num_inb', num_inb, num_inb.shape) inb_counts[:, s] = num_inb.cpu().numpy() feat_memI_input = torch.cat([ self.occ_memX0s[:, s], self.unp_memX0s[:, s] * self.occ_memX0s[:, s], ], dim=1) _, feat_memI, valid_memI = self.featnet3D(feat_memI_input) self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s, feat_memI_input, pca=True) self.summ_writer.summ_feat('3D_feats/feat_%d' % s, feat_memI, pca=True) feat_vis.append( self.summ_writer.summ_feat('', feat_memI, pca=True, only_return=True)) # collect freespace here, to discard bad matches _, free_memI_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, s:s + 1], self.xyz_camXs[:, s:s + 1], Z, Y, X, agg=True) free_memI = free_memI_.squeeze(1) feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat_vec = feat_vec.permute(0, 2, 1) # this is B x huge x C memI_T_mem0 = utils_geom.eye_4x4(B) # we will fill this up # # put these on cpu, to save mem # feat0_vec = feat0_vec.detach().cpu() # feat_vec = feat_vec.detach().cpu() # to simplify the impl, we will iterate over the batch dim for b in list(range(B)): feat_vec_b = feat_vec[b] feat0_vec_b = feat0_vec[b] obj_mask0_vec_b = obj_mask0_vec[b] occ_mask0_vec_b = occ_mask0_vec[b] free_mask0_vec_b = free_mask0_vec[b] orig_xyz_b = orig_xyz[b] # these are huge x C careful = False if careful: # start with occ points, since these are definitely observed obj_inds_b = torch.where( (occ_mask0_vec_b * obj_mask0_vec_b) > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # also take random non-free non-occ points in the mask ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * ( 1.0 - free_mask0_vec_b) alt_inds_b = torch.where(ok_mask > 0) alt_vec_b = feat0_vec_b[alt_inds_b] alt_xyz0 = orig_xyz_b[alt_inds_b] # these are med x C # issues arise when "med" is too large num = len(alt_xyz0) max_pts = 2000 if num > max_pts: # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) alt_vec_b = alt_vec_b[perm[:max_pts]] alt_xyz0 = alt_xyz0[perm[:max_pts]] obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0) xyz0 = torch.cat([xyz0, alt_xyz0], dim=0) if s == 0: print('have %d pts in total' % (len(xyz0))) else: # take any points within the mask obj_inds_b = torch.where(obj_mask0_vec_b > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # issues arise when "med" is too large # trim down to max_pts num = len(xyz0) max_pts = 2000 if num > max_pts: print( 'have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) obj_vec_b = obj_vec_b[perm[:max_pts]] xyz0 = xyz0[perm[:max_pts]] obj_vec_b = obj_vec_b.permute(1, 0) # this is is C x med corr_b = torch.matmul(feat_vec_b, obj_vec_b) # this is huge x med heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X) # this is med x 1 x Z4 x Y4 x X4 # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # heat_b_ = heat_b.reshape(-1, Z*Y*X) # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b_min = (torch.min(heat_b_).values) # free_b = free_memI[b:b+1] # print('free_b', free_b.shape) # print('heat_b', heat_b.shape) # heat_b[free_b > 0.0] = heat_b_min # make the min zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_min = (torch.min(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_min # zero out the freespace heat_b = heat_b * (1.0 - free_memI[b:b + 1]) # make the max zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_max = (torch.max(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_max # scale up, for numerical stability heat_b = heat_b * float(len(heat_b[0].reshape(-1))) xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True) # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True) # this is med x 3 xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y, X) xyzI_cam += delta xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1) memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI) # record #points, since ransac depends on this point_counts[b, s] = len(xyz0) # done stepping through batch mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0) cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI, mem_T_cam) # eval camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0, mem_T_cam, cam0_T_obj) # this is B x 4 x 4 lrt_camIs_e[:, s] = utils_geom.merge_lrt(obj_length, camI_T_obj) ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:, s:s + 1]).squeeze(1) results['ious'] = ious # if ious[0,-1] > 0.5: # print('returning early, since acc is too high') # return total_loss, results, True self.summ_writer.summ_rgbs('track/feats', feat_vis) self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False) for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) self.summ_writer.summ_scalar('track/point_counts', np.mean(point_counts)) # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item()) self.summ_writer.summ_scalar('track/inb_counts', np.mean(inb_counts)) lrt_camX0s_e = lrt_camIs_e.clone() lrt_camXs_e = utils_geom.apply_4x4s_to_lrts( self.camXs_T_camX0s, lrt_camX0s_e) if self.include_vis: visX_e = [] for s in list(range(self.S)): visX_e.append( self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s, self.rgb_camXs[:, s], lrt_camXs_e[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e) visX_g = [] for s in list(range(self.S)): visX_g.append( self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s, self.rgb_camXs[:, s], self.lrt_camXs[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2) # this is B x S mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s) median_dist = utils_basic.reduce_masked_median(dists, self.score_s) # this is [] self.summ_writer.summ_scalar('track/centroid_dist_mean', mean_dist.cpu().item()) self.summ_writer.summ_scalar('track/centroid_dist_median', median_dist.cpu().item()) # if self.include_vis: if (True): self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) total_loss += mean_dist # we won't backprop, but it's nice to plot and print this anyway else: ious = torch.zeros([self.B, self.S]).float().cuda() for s in list(range(self.S)): ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( self.lrt_camX0s[:, 0:1], self.lrt_camX0s[:, s:s + 1]).squeeze(1) results['ious'] = ious for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, False
def get_gt_flow(obj_lrtlist_camRs, obj_scorelist, camRs_T_camXs, Z, Y, X, K=2, mod='', vis=True, summ_writer=None): # this constructs the flow field according to the given # box trajectories (obj_lrtlist_camRs) (collected from a moving camR) # and egomotion (encoded in camRs_T_camXs) # (so they do not take into account egomotion) # so, we first generate the flow for all the objects, # then in the background, put the ego flow N, B, S, D = list(obj_lrtlist_camRs.shape) assert (S == 2) # as a flow util, this expects S=2 flows = [] masks = [] for k in list(range(K)): obj_masklistR0 = utils_vox.assemble_padded_obj_masklist( obj_lrtlist_camRs[k, :, 0:1], obj_scorelist[k, :, 0:1], Z, Y, X, coeff=1.0) # this is B x 1(N) x 1(C) x Z x Y x Z # obj_masklistR0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X obj_mask0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) # camR0_T_camR1 = camR0_T_camRs[:,1] # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1) # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0) # if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0) # summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1) if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0) summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod), torch.mean(obj_mask0, 3)) _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k]) # this is B x S x 4 x 4 ref_T_obj0 = ref_T_objs_list[:, 0] ref_T_obj1 = ref_T_objs_list[:, 1] obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0) obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1) # these are B x 4 x 4 mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref) cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) # only use these displaced points within the obj mask # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3) obj_mask0 = obj_mask0.view(B, Z, Y, X, 1) # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0 # cond = (obj_mask03 < 1.0).float() cond = (obj_mask0 > 0.0).float() xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0 flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3) # if vis and k==0: if vis: summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod), flow, clip=4.0) masks.append(obj_mask0) flows.append(flow) camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) bkg_flow = flow # allow zero motion in the bkg any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0] masks.append(1.0 - any_mask) flows.append(bkg_flow) flows = torch.stack(flows, axis=0) masks = torch.stack(masks, axis=0) masks = masks.repeat(1, 1, 3, 1, 1, 1) flow = utils_basic.reduce_masked_mean(flows, masks, dim=0) if vis: summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0) # flow is shaped B x 3 x D x H x W return flow