def get_synth_flow(occs,
                   unps,
                   summ_writer,
                   sometimes_zero=False,
                   do_vis=False):
    B, S, C, Z, Y, X = list(occs.shape)
    assert (S == 2, C == 1)

    # we do not sample any rotations here, to keep the distribution purely
    # uniform across all translations
    # (rotation ruins this, since the pivot point is at the camera)
    cam1_T_cam0 = [
        utils_geom.get_random_rt(B, r_amount=0.0,
                                 t_amount=1.0),  # large motion
        utils_geom.get_random_rt(
            B,
            r_amount=0.0,
            t_amount=0.1,  # small motion
            sometimes_zero=sometimes_zero)
    ]
    cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0]

    occ0 = occs[:, 0]
    unp0 = unps[:, 0]
    occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True)
    unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0)
    occs = [occ0, occ1]
    unps = [unp0, unp1]

    if do_vis:
        summ_writer.summ_occs('synth/occs', occs)
        summ_writer.summ_unps('synth/unps', unps, occs)

    mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X)
    cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem)
    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)
    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)
    flow = xyz_mem1 - xyz_mem0
    # this is B x Z x Y x X x 3
    flow = flow.permute(0, 4, 1, 2, 3)
    # this is B x 3 x Z x Y x X
    if do_vis:
        summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0)

    if do_vis:
        occ0_e = utils_samp.backwarp_using_3D_flow(occ1,
                                                   flow,
                                                   binary_feat=True)
        unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow)
        summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e])
        summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e],
                              [occ0, occ0_e])

    occs = torch.stack(occs, dim=1)
    unps = torch.stack(unps, dim=1)

    return occs, unps, flow, cam1_T_cam0
예제 #2
0
    def run_test(self, feed):
        results = dict()

        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist(
            self.lrt_camX0s)

        self.original_centroid = self.scene_centroid.clone()

        obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s)
        obj_length = obj_lengths[:, 0]
        for b in list(range(self.B)):
            if self.score_s[b, 0] < 1.0:
                # we need the template to exist
                print('returning early, since score_s[%d,0] = %.1f' %
                      (b, self.score_s[b, 0].cpu().numpy()))
                return total_loss, results, True
            # if torch.sum(self.score_s[b]) < (self.S/2):
            if not (torch.sum(self.score_s[b]) == self.S):
                # the full traj should be valid
                print(
                    'returning early, since sum(score_s) = %d, while S = %d' %
                    (torch.sum(self.score_s).cpu().numpy(), self.S))
                return total_loss, results, True

        if hyp.do_feat3D:

            feat_memX0_input = torch.cat([
                self.occ_memX0s[:, 0],
                self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0],
            ],
                                         dim=1)
            _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input)
            B, C, Z, Y, X = list(feat_memX0.shape)
            S = self.S

            obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist(
                self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1)
            # only take the occupied voxels
            occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y,
                                                   X)
            # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0
            obj_mask_memX0 = obj_mask_memX0s[:, 0]

            # discard the known freespace
            _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs[:, 0:1],
                self.xyz_camXs[:, 0:1],
                Z,
                Y,
                X,
                agg=True)
            free_memX0 = free_memX0_.squeeze(1)
            obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0)

            for b in list(range(self.B)):
                if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8:
                    print(
                        'returning early, since there are not enough valid object points'
                    )
                    return total_loss, results, True

            # for b in list(range(self.B)):
            #     sum_b = torch.sum(obj_mask_memX0[b])
            #     print('sum_b', sum_b.detach().cpu().numpy())
            #     if sum_b > 1000:
            #         obj_mask_memX0[b] *= occ_memX0[b]
            #         sum_b = torch.sum(obj_mask_memX0[b])
            #         print('reducing this to', sum_b.detach().cpu().numpy())

            feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1)
            # this is B x C x huge
            feat0_vec = feat0_vec.permute(0, 2, 1)
            # this is B x huge x C

            obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round()
            occ_mask0_vec = occ_memX0.reshape(B, -1).round()
            free_mask0_vec = free_memX0.reshape(B, -1).round()
            # these are B x huge

            orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X)
            # this is B x huge x 3

            obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(
                self.lrt_camX0s)
            obj_length = obj_lengths[:, 0]
            cam0_T_obj = cams_T_obj0[:, 0]
            # this is B x S x 4 x 4

            mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X)
            cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X)

            lrt_camIs_g = self.lrt_camX0s.clone()
            lrt_camIs_e = torch.zeros_like(self.lrt_camX0s)
            # we will fill this up

            ious = torch.zeros([B, S]).float().cuda()
            point_counts = np.zeros([B, S])
            inb_counts = np.zeros([B, S])

            feat_vis = []
            occ_vis = []

            for s in range(self.S):
                if not (s == 0):
                    # remake the vox util and all the mem data
                    self.scene_centroid = utils_geom.get_clist_from_lrtlist(
                        lrt_camIs_e[:, s - 1:s])[:, 0]
                    delta = self.scene_centroid - self.original_centroid
                    self.vox_util = vox_util.Vox_util(
                        self.Z,
                        self.Y,
                        self.X,
                        self.set_name,
                        scene_centroid=self.scene_centroid,
                        assert_cube=True)
                    self.occ_memXs = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z,
                                                   self.Y, self.X))
                    self.occ_memX0s = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camX0s),
                                                   self.Z, self.Y, self.X))

                    self.unp_memXs = __u(
                        self.vox_util.unproject_rgb_to_mem(
                            __p(self.rgb_camXs), self.Z, self.Y, self.X,
                            __p(self.pix_T_cams)))
                    self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs(
                        self.camX0s_T_camXs, self.unp_memXs)
                    self.summ_writer.summ_occ('track/reloc_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                else:
                    self.summ_writer.summ_occ('track/init_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                    delta = torch.zeros([B, 3]).float().cuda()
                # print('scene centroid:', self.scene_centroid.detach().cpu().numpy())
                occ_vis.append(
                    self.summ_writer.summ_occ('',
                                              self.occ_memX0s[:, s],
                                              only_return=True))

                # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False))
                inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s],
                                                 self.Z4,
                                                 self.Y4,
                                                 self.X,
                                                 already_mem=False)
                num_inb = torch.sum(inb.float(), axis=1)
                # print('num_inb', num_inb, num_inb.shape)
                inb_counts[:, s] = num_inb.cpu().numpy()

                feat_memI_input = torch.cat([
                    self.occ_memX0s[:, s],
                    self.unp_memX0s[:, s] * self.occ_memX0s[:, s],
                ],
                                            dim=1)
                _, feat_memI, valid_memI = self.featnet3D(feat_memI_input)

                self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s,
                                           feat_memI_input,
                                           pca=True)
                self.summ_writer.summ_feat('3D_feats/feat_%d' % s,
                                           feat_memI,
                                           pca=True)
                feat_vis.append(
                    self.summ_writer.summ_feat('',
                                               feat_memI,
                                               pca=True,
                                               only_return=True))

                # collect freespace here, to discard bad matches
                _, free_memI_, _, _ = self.vox_util.prep_occs_supervision(
                    self.camX0s_T_camXs[:, s:s + 1],
                    self.xyz_camXs[:, s:s + 1],
                    Z,
                    Y,
                    X,
                    agg=True)
                free_memI = free_memI_.squeeze(1)

                feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1)
                # this is B x C x huge
                feat_vec = feat_vec.permute(0, 2, 1)
                # this is B x huge x C

                memI_T_mem0 = utils_geom.eye_4x4(B)
                # we will fill this up

                # # put these on cpu, to save mem
                # feat0_vec = feat0_vec.detach().cpu()
                # feat_vec = feat_vec.detach().cpu()

                # to simplify the impl, we will iterate over the batch dim
                for b in list(range(B)):
                    feat_vec_b = feat_vec[b]
                    feat0_vec_b = feat0_vec[b]
                    obj_mask0_vec_b = obj_mask0_vec[b]
                    occ_mask0_vec_b = occ_mask0_vec[b]
                    free_mask0_vec_b = free_mask0_vec[b]
                    orig_xyz_b = orig_xyz[b]
                    # these are huge x C

                    careful = False
                    if careful:
                        # start with occ points, since these are definitely observed
                        obj_inds_b = torch.where(
                            (occ_mask0_vec_b * obj_mask0_vec_b) > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # also take random non-free non-occ points in the mask
                        ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * (
                            1.0 - free_mask0_vec_b)
                        alt_inds_b = torch.where(ok_mask > 0)
                        alt_vec_b = feat0_vec_b[alt_inds_b]
                        alt_xyz0 = orig_xyz_b[alt_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        num = len(alt_xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts))
                            perm = np.random.permutation(num)
                            alt_vec_b = alt_vec_b[perm[:max_pts]]
                            alt_xyz0 = alt_xyz0[perm[:max_pts]]

                        obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0)
                        xyz0 = torch.cat([xyz0, alt_xyz0], dim=0)
                        if s == 0:
                            print('have %d pts in total' % (len(xyz0)))
                    else:
                        # take any points within the mask
                        obj_inds_b = torch.where(obj_mask0_vec_b > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        # trim down to max_pts
                        num = len(xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            print(
                                'have %d pts; taking a random set of %d pts inside'
                                % (num, max_pts))
                            perm = np.random.permutation(num)
                            obj_vec_b = obj_vec_b[perm[:max_pts]]
                            xyz0 = xyz0[perm[:max_pts]]

                    obj_vec_b = obj_vec_b.permute(1, 0)
                    # this is is C x med

                    corr_b = torch.matmul(feat_vec_b, obj_vec_b)
                    # this is huge x med

                    heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X)
                    # this is med x 1 x Z4 x Y4 x X4

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b_min = (torch.min(heat_b_).values)
                    # free_b = free_memI[b:b+1]
                    # print('free_b', free_b.shape)
                    # print('heat_b', heat_b.shape)
                    # heat_b[free_b > 0.0] = heat_b_min

                    # make the min zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_min
                    # zero out the freespace
                    heat_b = heat_b * (1.0 - free_memI[b:b + 1])
                    # make the max zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_max
                    # scale up, for numerical stability
                    heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True)
                    # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True)
                    # this is med x 3

                    xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y,
                                                     X)
                    xyzI_cam += delta
                    xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1)

                    memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI)

                    # record #points, since ransac depends on this
                    point_counts[b, s] = len(xyz0)
                # done stepping through batch

                mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0)
                cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI,
                                                  mem_T_cam)

                # eval
                camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0,
                                                 mem_T_cam, cam0_T_obj)
                # this is B x 4 x 4
                lrt_camIs_e[:,
                            s] = utils_geom.merge_lrt(obj_length, camI_T_obj)
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:,
                                                         s:s + 1]).squeeze(1)
            results['ious'] = ious
            # if ious[0,-1] > 0.5:
            #     print('returning early, since acc is too high')
            #     return total_loss, results, True

            self.summ_writer.summ_rgbs('track/feats', feat_vis)
            self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False)

            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())

            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())
            self.summ_writer.summ_scalar('track/point_counts',
                                         np.mean(point_counts))
            # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item())
            self.summ_writer.summ_scalar('track/inb_counts',
                                         np.mean(inb_counts))

            lrt_camX0s_e = lrt_camIs_e.clone()
            lrt_camXs_e = utils_geom.apply_4x4s_to_lrts(
                self.camXs_T_camX0s, lrt_camX0s_e)

            if self.include_vis:
                visX_e = []
                for s in list(range(self.S)):
                    visX_e.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s,
                                                      self.rgb_camXs[:, s],
                                                      lrt_camXs_e[:, s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e)
                visX_g = []
                for s in list(range(self.S)):
                    visX_g.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s,
                                                      self.rgb_camXs[:, s],
                                                      self.lrt_camXs[:,
                                                                     s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g)

            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)

            dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2)
            # this is B x S
            mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s)
            median_dist = utils_basic.reduce_masked_median(dists, self.score_s)
            # this is []
            self.summ_writer.summ_scalar('track/centroid_dist_mean',
                                         mean_dist.cpu().item())
            self.summ_writer.summ_scalar('track/centroid_dist_median',
                                         median_dist.cpu().item())

            # if self.include_vis:
            if (True):
                self.summ_writer.summ_traj_on_occ('track/traj_e',
                                                  obj_clist_camX0_e,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
                self.summ_writer.summ_traj_on_occ('track/traj_g',
                                                  self.obj_clist_camX0,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
            total_loss += mean_dist  # we won't backprop, but it's nice to plot and print this anyway

        else:
            ious = torch.zeros([self.B, self.S]).float().cuda()
            for s in list(range(self.S)):
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    self.lrt_camX0s[:, 0:1],
                    self.lrt_camX0s[:, s:s + 1]).squeeze(1)
            results['ious'] = ious
            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())
            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())

            lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1)
            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)
            self.summ_writer.summ_traj_on_occ('track/traj_e',
                                              obj_clist_camX0_e,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)
            self.summ_writer.summ_traj_on_occ('track/traj_g',
                                              self.obj_clist_camX0,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)

        self.summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, False
def get_gt_flow(obj_lrtlist_camRs,
                obj_scorelist,
                camRs_T_camXs,
                Z,
                Y,
                X,
                K=2,
                mod='',
                vis=True,
                summ_writer=None):
    # this constructs the flow field according to the given
    # box trajectories (obj_lrtlist_camRs) (collected from a moving camR)
    # and egomotion (encoded in camRs_T_camXs)
    # (so they do not take into account egomotion)
    # so, we first generate the flow for all the objects,
    # then in the background, put the ego flow
    N, B, S, D = list(obj_lrtlist_camRs.shape)
    assert (S == 2)  # as a flow util, this expects S=2

    flows = []
    masks = []
    for k in list(range(K)):
        obj_masklistR0 = utils_vox.assemble_padded_obj_masklist(
            obj_lrtlist_camRs[k, :, 0:1],
            obj_scorelist[k, :, 0:1],
            Z,
            Y,
            X,
            coeff=1.0)
        # this is B x 1(N) x 1(C) x Z x Y x Z
        # obj_masklistR0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X
        obj_mask0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X

        camR_T_cam0 = camRs_T_camXs[:, 0]
        camR_T_cam1 = camRs_T_camXs[:, 1]
        cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
        cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)
        # camR0_T_camR1 = camR0_T_camRs[:,1]
        # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1)

        # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0)
        # if vis and (summ_writer is not None):
        #     summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0)
        #     summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1)

        if vis and (summ_writer is not None):
            # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0)
            summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod),
                                  torch.mean(obj_mask0, 3))

        _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k])
        # this is B x S x 4 x 4
        ref_T_obj0 = ref_T_objs_list[:, 0]
        ref_T_obj1 = ref_T_objs_list[:, 1]
        obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0)
        obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1)
        # these are B x 4 x 4

        mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
        ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

        ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref)
        cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0,
                                          camR_T_cam0)
        mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

        xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
        xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

        xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
        xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

        # only use these displaced points within the obj mask
        # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3)
        obj_mask0 = obj_mask0.view(B, Z, Y, X, 1)
        # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0
        # cond = (obj_mask03 < 1.0).float()
        cond = (obj_mask0 > 0.0).float()
        xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0

        flow = xyz_mem1 - xyz_mem0
        flow = flow.permute(0, 4, 1, 2, 3)
        obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3)

        # if vis and k==0:
        if vis:
            summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod),
                                     flow,
                                     clip=4.0)

        masks.append(obj_mask0)
        flows.append(flow)

    camR_T_cam0 = camRs_T_camXs[:, 0]
    camR_T_cam1 = camRs_T_camXs[:, 1]
    cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
    cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)

    mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
    ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

    cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

    flow = xyz_mem1 - xyz_mem0
    flow = flow.permute(0, 4, 1, 2, 3)

    bkg_flow = flow

    # allow zero motion in the bkg
    any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0]
    masks.append(1.0 - any_mask)
    flows.append(bkg_flow)

    flows = torch.stack(flows, axis=0)
    masks = torch.stack(masks, axis=0)
    masks = masks.repeat(1, 1, 3, 1, 1, 1)
    flow = utils_basic.reduce_masked_mean(flows, masks, dim=0)

    if vis:
        summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0)

    # flow is shaped B x 3 x D x H x W
    return flow