def get_synth_flow(occs,
                   unps,
                   summ_writer,
                   sometimes_zero=False,
                   do_vis=False):
    B, S, C, Z, Y, X = list(occs.shape)
    assert (S == 2, C == 1)

    # we do not sample any rotations here, to keep the distribution purely
    # uniform across all translations
    # (rotation ruins this, since the pivot point is at the camera)
    cam1_T_cam0 = [
        utils_geom.get_random_rt(B, r_amount=0.0,
                                 t_amount=1.0),  # large motion
        utils_geom.get_random_rt(
            B,
            r_amount=0.0,
            t_amount=0.1,  # small motion
            sometimes_zero=sometimes_zero)
    ]
    cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0]

    occ0 = occs[:, 0]
    unp0 = unps[:, 0]
    occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True)
    unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0)
    occs = [occ0, occ1]
    unps = [unp0, unp1]

    if do_vis:
        summ_writer.summ_occs('synth/occs', occs)
        summ_writer.summ_unps('synth/unps', unps, occs)

    mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X)
    cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem)
    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)
    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)
    flow = xyz_mem1 - xyz_mem0
    # this is B x Z x Y x X x 3
    flow = flow.permute(0, 4, 1, 2, 3)
    # this is B x 3 x Z x Y x X
    if do_vis:
        summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0)

    if do_vis:
        occ0_e = utils_samp.backwarp_using_3D_flow(occ1,
                                                   flow,
                                                   binary_feat=True)
        unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow)
        summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e])
        summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e],
                              [occ0, occ0_e])

    occs = torch.stack(occs, dim=1)
    unps = torch.stack(unps, dim=1)

    return occs, unps, flow, cam1_T_cam0
    def forward(self, feat, summ_writer, mask=None,prefix=""):
        total_loss = torch.tensor(0.0).cuda()
        B, C, D, H, W = list(feat.shape)
        if not hyp.onlyocc:
            summ_writer.summ_feat(f'feat/{prefix}feat0_input', feat)
        
        if hyp.feat_do_rt:
            # apply a random rt to the feat
            # Y_T_X = utils_geom.get_random_rt(B, r_amount=5.0, t_amount=8.0).cuda()
            # Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=8.0).cuda()
            Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=4.0).cuda()
            feat = utils_vox.apply_4x4_to_vox(Y_T_X, feat)
            summ_writer.summ_feat(f'feat/{prefix}feat1_rt', feat)

        if hyp.feat_do_flip:
            # randomly flip the input
            flip0 = torch.rand(1)
            flip1 = torch.rand(1)
            flip2 = torch.rand(1)
            if flip0 > 0.5:
                # transpose width/depth (rotate 90deg)
                feat = feat.permute(0,1,4,3,2)
            if flip1 > 0.5:
                # flip depth
                feat = feat.flip(2)
            if flip2 > 0.5:
                # flip width
                feat = feat.flip(4)
            summ_writer.summ_feat(f'feat/{prefix}feat2_flip', feat)
        
        if hyp.feat_do_sb:
            feat = self.net(feat, mask)
        elif hyp.feat_do_sparse_invar:
            feat, mask = self.net(feat, mask)
        else:
            if hyp.feat_quantize:
                feat,feat_uq,loss,encodings,perplexity = self.net(feat)
                total_loss = utils_misc.add_loss('feat_loss',total_loss,
                                                 loss,hyp.feat_coeff,summ_writer)
                summ_writer.summ_scalar('feat/perplexity',perplexity)
                summ_writer.summ_histogram('feat/encodings',encodings)
                ## Visualizing encodings will make training very slow.
                ## Use this only for debugging.
                # feat_uq = feat_uq[:1]
                # encodings = encodings[:1]
                # B,C,D2,H2,W2 = feat_uq.shape
                # feat_uq = feat_uq.permute(0,2,3,4,1) # [B,D,H,W,C]
                # feat_uq = feat_uq.reshape(B*D2*H2*W2,C)
                # encodings = encodings.flatten() # [B*D2*H2*W2]
                # summ_writer.summ_embeddings('feat/emb_before_vqvae',feat_uq,encodings)
                del feat_uq,encodings,perplexity # Cleanup.
            else:
                feat = self.net(feat)
        feat = l2_normalize(feat, dim=1)
        summ_writer.summ_feat(f'feat/{prefix}feat3_out', feat)
        
        if hyp.feat_do_flip:
            if flip2 > 0.5:
                # unflip width
                feat = feat.flip(4)
            if flip1 > 0.5:
                # unflip depth
                feat = feat.flip(2)
            if flip0 > 0.5:
                # untranspose width/depth
                feat = feat.permute(0,1,4,3,2)
            summ_writer.summ_feat(f'feat/{prefix}feat4_unflip', feat)
                
        if hyp.feat_do_rt:
            # undo the random rt
            X_T_Y = utils_geom.safe_inverse(Y_T_X)
            feat = utils_vox.apply_4x4_to_vox(X_T_Y, feat)
            summ_writer.summ_feat(f'feat/{prefix}feat5_unrt', feat)

        # valid_mask = 1.0 - (feat==0).all(dim=1, keepdim=True).float()
        # if hyp.feat_do_sparse_invar:
        #     valid_mask = valid_mask * mask
        return feat,  total_loss
Example #3
0
    def forward(self, feat, summ_writer=None, comp_mask=None):
        total_loss = torch.tensor(0.0).cuda()
        B, C, D, H, W = list(feat.shape)

        if summ_writer is not None:
            summ_writer.summ_feat('feat/feat0_input', feat, pca=False)
        if comp_mask is not None:
            if summ_writer is not None:
                summ_writer.summ_feat('feat/mask_input', comp_mask, pca=False)

        if hyp.feat_do_rt:
            # apply a random rt to the feat
            # Y_T_X = utils_geom.get_random_rt(B, r_amount=5.0, t_amount=8.0).cuda()
            # Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0, t_amount=8.0).cuda()
            Y_T_X = utils_geom.get_random_rt(B, r_amount=1.0,
                                             t_amount=4.0).cuda()
            feat = utils_vox.apply_4x4_to_vox(Y_T_X, feat)
            if comp_mask is not None:
                comp_mask = utils_vox.apply_4x4_to_vox(Y_T_X, comp_mask)
            if summ_writer is not None:
                summ_writer.summ_feat('feat/feat1_rt', feat, pca=False)

        if hyp.feat_do_flip:
            # randomly flip the input
            flip0 = torch.rand(1)
            flip1 = torch.rand(1)
            flip2 = torch.rand(1)
            if flip0 > 0.5:
                # transpose width/depth (rotate 90deg)
                feat = feat.permute(0, 1, 4, 3, 2)
                if comp_mask is not None:
                    comp_mask = comp_mask.permute(0, 1, 4, 3, 2)
            if flip1 > 0.5:
                # flip depth
                feat = feat.flip(2)
                if comp_mask is not None:
                    comp_mask = comp_mask.flip(2)
            if flip2 > 0.5:
                # flip width
                feat = feat.flip(4)
                if comp_mask is not None:
                    comp_mask = comp_mask.flip(4)
            if summ_writer is not None:
                summ_writer.summ_feat('feat/feat2_flip', feat, pca=False)

        if hyp.feat_do_sparse_conv:
            feat, comp_mask = self.net(feat, comp_mask)
            if summ_writer is not None:
                summ_writer.summ_feat('feat/mask_output', comp_mask, pca=False)
        elif hyp.feat_do_sparse_invar:
            feat, comp_mask = self.net(feat, comp_mask)
        else:
            feat = self.net(feat)

        # smooth loss
        dz, dy, dx = gradient3D(feat, absolute=True)
        smooth_vox = torch.mean(dz + dy + dx, dim=1, keepdims=True)
        if summ_writer is not None:
            summ_writer.summ_oned('feat/smooth_loss',
                                  torch.mean(smooth_vox, dim=3))
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('feat/smooth_loss', total_loss,
                                         smooth_loss, hyp.feat_smooth_coeff,
                                         summ_writer)

        # feat = l2_normalize(feat, dim=1)
        if summ_writer is not None:
            summ_writer.summ_feat('feat/feat3_out', feat)

        if hyp.feat_do_flip:
            if flip2 > 0.5:
                # unflip width
                feat = feat.flip(4)
            if flip1 > 0.5:
                # unflip depth
                feat = feat.flip(2)
            if flip0 > 0.5:
                # untranspose width/depth
                feat = feat.permute(0, 1, 4, 3, 2)
            if summ_writer is not None:
                summ_writer.summ_feat('feat/feat4_unflip', feat)

        if hyp.feat_do_rt:
            # undo the random rt
            X_T_Y = utils_geom.safe_inverse(Y_T_X)
            feat = utils_vox.apply_4x4_to_vox(X_T_Y, feat)
            if summ_writer is not None:
                summ_writer.summ_feat('feat/feat5_unrt', feat)

        valid_mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float()
        if hyp.feat_do_sparse_conv and (comp_mask is not None):
            valid_mask = valid_mask * comp_mask
        if summ_writer is not None:
            summ_writer.summ_feat('feat/valid_mask', valid_mask, pca=False)
        return feat, valid_mask, total_loss