def get_synth_flow(occs, unps, summ_writer, sometimes_zero=False, do_vis=False): B, S, C, Z, Y, X = list(occs.shape) assert (S == 2, C == 1) # we do not sample any rotations here, to keep the distribution purely # uniform across all translations # (rotation ruins this, since the pivot point is at the camera) cam1_T_cam0 = [ utils_geom.get_random_rt(B, r_amount=0.0, t_amount=1.0), # large motion utils_geom.get_random_rt( B, r_amount=0.0, t_amount=0.1, # small motion sometimes_zero=sometimes_zero) ] cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0] occ0 = occs[:, 0] unp0 = unps[:, 0] occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True) unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0) occs = [occ0, occ1] unps = [unp0, unp1] if do_vis: summ_writer.summ_occs('synth/occs', occs) summ_writer.summ_unps('synth/unps', unps, occs) mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X) cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 # this is B x Z x Y x X x 3 flow = flow.permute(0, 4, 1, 2, 3) # this is B x 3 x Z x Y x X if do_vis: summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0) if do_vis: occ0_e = utils_samp.backwarp_using_3D_flow(occ1, flow, binary_feat=True) unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow) summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e]) summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e], [occ0, occ0_e]) occs = torch.stack(occs, dim=1) unps = torch.stack(unps, dim=1) return occs, unps, flow, cam1_T_cam0
def forward(self, feat, summ_writer, mask=None,prefix=""): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat.shape) if not hyp.onlyocc: summ_writer.summ_feat(f'feat/{prefix}feat0_input', feat) if hyp.feat_do_rt: # apply a random rt to the feat # Y_T_X = utils_geom.get_random_rt(B, r_amount=5.0, t_amount=8.0).cuda() # Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=8.0).cuda() Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=4.0).cuda() feat = utils_vox.apply_4x4_to_vox(Y_T_X, feat) summ_writer.summ_feat(f'feat/{prefix}feat1_rt', feat) if hyp.feat_do_flip: # randomly flip the input flip0 = torch.rand(1) flip1 = torch.rand(1) flip2 = torch.rand(1) if flip0 > 0.5: # transpose width/depth (rotate 90deg) feat = feat.permute(0,1,4,3,2) if flip1 > 0.5: # flip depth feat = feat.flip(2) if flip2 > 0.5: # flip width feat = feat.flip(4) summ_writer.summ_feat(f'feat/{prefix}feat2_flip', feat) if hyp.feat_do_sb: feat = self.net(feat, mask) elif hyp.feat_do_sparse_invar: feat, mask = self.net(feat, mask) else: if hyp.feat_quantize: feat,feat_uq,loss,encodings,perplexity = self.net(feat) total_loss = utils_misc.add_loss('feat_loss',total_loss, loss,hyp.feat_coeff,summ_writer) summ_writer.summ_scalar('feat/perplexity',perplexity) summ_writer.summ_histogram('feat/encodings',encodings) ## Visualizing encodings will make training very slow. ## Use this only for debugging. # feat_uq = feat_uq[:1] # encodings = encodings[:1] # B,C,D2,H2,W2 = feat_uq.shape # feat_uq = feat_uq.permute(0,2,3,4,1) # [B,D,H,W,C] # feat_uq = feat_uq.reshape(B*D2*H2*W2,C) # encodings = encodings.flatten() # [B*D2*H2*W2] # summ_writer.summ_embeddings('feat/emb_before_vqvae',feat_uq,encodings) del feat_uq,encodings,perplexity # Cleanup. else: feat = self.net(feat) feat = l2_normalize(feat, dim=1) summ_writer.summ_feat(f'feat/{prefix}feat3_out', feat) if hyp.feat_do_flip: if flip2 > 0.5: # unflip width feat = feat.flip(4) if flip1 > 0.5: # unflip depth feat = feat.flip(2) if flip0 > 0.5: # untranspose width/depth feat = feat.permute(0,1,4,3,2) summ_writer.summ_feat(f'feat/{prefix}feat4_unflip', feat) if hyp.feat_do_rt: # undo the random rt X_T_Y = utils_geom.safe_inverse(Y_T_X) feat = utils_vox.apply_4x4_to_vox(X_T_Y, feat) summ_writer.summ_feat(f'feat/{prefix}feat5_unrt', feat) # valid_mask = 1.0 - (feat==0).all(dim=1, keepdim=True).float() # if hyp.feat_do_sparse_invar: # valid_mask = valid_mask * mask return feat, total_loss
def forward(self, feat, summ_writer=None, comp_mask=None): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat.shape) if summ_writer is not None: summ_writer.summ_feat('feat/feat0_input', feat, pca=False) if comp_mask is not None: if summ_writer is not None: summ_writer.summ_feat('feat/mask_input', comp_mask, pca=False) if hyp.feat_do_rt: # apply a random rt to the feat # Y_T_X = utils_geom.get_random_rt(B, r_amount=5.0, t_amount=8.0).cuda() # Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=8.0).cuda() Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=4.0).cuda() feat = utils_vox.apply_4x4_to_vox(Y_T_X, feat) if comp_mask is not None: comp_mask = utils_vox.apply_4x4_to_vox(Y_T_X, comp_mask) if summ_writer is not None: summ_writer.summ_feat('feat/feat1_rt', feat, pca=False) if hyp.feat_do_flip: # randomly flip the input flip0 = torch.rand(1) flip1 = torch.rand(1) flip2 = torch.rand(1) if flip0 > 0.5: # transpose width/depth (rotate 90deg) feat = feat.permute(0, 1, 4, 3, 2) if comp_mask is not None: comp_mask = comp_mask.permute(0, 1, 4, 3, 2) if flip1 > 0.5: # flip depth feat = feat.flip(2) if comp_mask is not None: comp_mask = comp_mask.flip(2) if flip2 > 0.5: # flip width feat = feat.flip(4) if comp_mask is not None: comp_mask = comp_mask.flip(4) if summ_writer is not None: summ_writer.summ_feat('feat/feat2_flip', feat, pca=False) if hyp.feat_do_sparse_conv: feat, comp_mask = self.net(feat, comp_mask) if summ_writer is not None: summ_writer.summ_feat('feat/mask_output', comp_mask, pca=False) elif hyp.feat_do_sparse_invar: feat, comp_mask = self.net(feat, comp_mask) else: feat = self.net(feat) # smooth loss dz, dy, dx = gradient3D(feat, absolute=True) smooth_vox = torch.mean(dz + dy + dx, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('feat/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('feat/smooth_loss', total_loss, smooth_loss, hyp.feat_smooth_coeff, summ_writer) # feat = l2_normalize(feat, dim=1) if summ_writer is not None: summ_writer.summ_feat('feat/feat3_out', feat) if hyp.feat_do_flip: if flip2 > 0.5: # unflip width feat = feat.flip(4) if flip1 > 0.5: # unflip depth feat = feat.flip(2) if flip0 > 0.5: # untranspose width/depth feat = feat.permute(0, 1, 4, 3, 2) if summ_writer is not None: summ_writer.summ_feat('feat/feat4_unflip', feat) if hyp.feat_do_rt: # undo the random rt X_T_Y = utils_geom.safe_inverse(Y_T_X) feat = utils_vox.apply_4x4_to_vox(X_T_Y, feat) if summ_writer is not None: summ_writer.summ_feat('feat/feat5_unrt', feat) valid_mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float() if hyp.feat_do_sparse_conv and (comp_mask is not None): valid_mask = valid_mask * comp_mask if summ_writer is not None: summ_writer.summ_feat('feat/valid_mask', valid_mask, pca=False) return feat, valid_mask, total_loss