예제 #1
0
    def summ_lrtlist(self,
                     name,
                     rgbR,
                     lrtlist,
                     scorelist,
                     tidlist,
                     pix_T_cam,
                     only_return=False):
        # rgb is B x H x W x C
        # lrtlist is B x N x 17
        # scorelist is B x N
        # tidlist is B x N
        # pix_T_cam is B x 4 x 4

        B, C, H, W = list(rgbR.shape)
        B, N, D = list(lrtlist.shape)
        lenlist = lrtlist[:, :, :3].reshape(B, N, 3)
        rtlist = lrtlist[:, :, 3:].reshape(B, N, 4, 4)

        xyzlist_obj = utils_geom.get_xyzlist_from_lenlist(lenlist)
        # this is B x N x 8 x 3

        rtlist_ = rtlist.reshape(B * N, 4, 4)
        xyzlist_obj_ = xyzlist_obj.reshape(B * N, 8, 3)
        xyzlist_cam_ = utils_geom.apply_4x4(rtlist_, xyzlist_obj_)
        xyzlist_cam = xyzlist_cam_.reshape(B, N, 8, 3)

        boxes_vis = self.draw_corners_on_image(rgbR, xyzlist_cam, scorelist,
                                               tidlist, pix_T_cam)
        if not only_return:
            self.summ_rgb(name, boxes_vis)
        return boxes_vis
예제 #2
0
def Ref2Mem(xyz, mem_coord):
    # xyz is B x N x 3, in ref coordinates
    # transforms camR coordinates into mem coordinates
    # (0, 0, 0) corresponds to (Xmin, Ymin, Zmin)
    B, N, C = list(xyz.shape)
    mem_T_ref = get_mem_T_ref(B, mem_coord)
    xyz = utils_geom.apply_4x4(mem_T_ref, xyz)
    return xyz
예제 #3
0
def Mem2Ref(xyz_mem, mem_coord):
    # xyz is B x N x 3, in mem coordinates
    # transforms mem coordinates into ref coordinates
    B, N, C = list(xyz_mem.shape)
    #ref_T_mem = tf.tile(mem_coord.cam_T_vox, [B, 1, 1])
    ref_T_mem = get_ref_T_mem(B, mem_coord)
    xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
    return xyz_ref
def Ref2Zoom(xyz_ref, lrt_ref, ZY, ZX, ZZ):
    # xyz_ref is B x N x 3, in ref coordinates
    # lrt_ref is B x 9, specifying the box in ref coordinates
    # this transforms ref coordinates into zoom coordinates
    B, N, _ = list(xyz_ref.shape)
    zoom_T_ref = get_zoom_T_ref(lrt_ref, ZY, ZX, ZZ)
    xyz_zoom = utils_geom.apply_4x4(zoom_T_ref, xyz_ref)
    return xyz_zoom
def unproject_rgb_to_mem(rgb_camB, pixB_T_camA, mem_coord, device=None):
    # rgb_camB is B x C x H x W
    # pixB_T_camA is B x 4 x 4 (pix_T_camR)

    # rgb lives in B pixel coords
    # we want everything in A memory coords

    # this puts each C-dim pixel in the rgb_camB
    # along a ray in the voxelgrid
    B, C, H, W = list(rgb_camB.shape)

    Y, X, Z = mem_coord.proto.shape

    xyz_memA = utils_basic.gridcloud3D(B, Z, Y, X, norm=False, device=device)
    # grid_z, grid_y, grid_x = meshgrid3D(B, Z, Y, X)
    # # these are B x Z x Y x X
    # # these represent the mem grid coordinates

    # # we need to convert these to pixel coordinates
    # x = torch.reshape(grid_x, [B, -1])
    # y = torch.reshape(grid_y, [B, -1])
    # z = torch.reshape(grid_z, [B, -1])
    # # these are B x N
    # xyz_mem = torch.stack([x, y, z], dim=2)

    # not specifically related to Ref, I am just
    # converting grid to points here, irrespective
    # of the which cam it is associated to.
    xyz_camA = Mem2Ref(xyz_memA, mem_coord)

    xyz_pixB = utils_geom.apply_4x4(pixB_T_camA, xyz_camA)
    # this is just getting the z coordinate to divide x/Z, y/Z
    normalizer = torch.unsqueeze(xyz_pixB[:, :, 2], 2)
    EPS = 1e-6
    xy_pixB = xyz_pixB[:, :, :2] / (EPS + normalizer)
    # this is B x N x 2
    # this is the (floating point) pixel coordinate of each voxel
    x_pixB, y_pixB = xy_pixB[:, :, 0], xy_pixB[:, :, 1]
    # these are B x N

    if (0):
        # handwritten version
        values = torch.zeros([B, C, Z * Y * X], dtype=torch.float32)
        for b in range(B):
            values[b] = utils_samp.bilinear_sample_single(
                rgb_camB[b], x_pixB[b], y_pixB[b])
    else:
        # native pytorch version, this makes the pixel between -1 to 1
        y_pixB, x_pixB = utils_basic.normalize_grid2D(y_pixB, x_pixB, H, W)
        # since we want a 3d output, we need 5d tensors
        z_pixB = torch.zeros_like(x_pixB)
        xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)
        rgb_camB = rgb_camB.unsqueeze(2)
        xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
        values = F.grid_sample(rgb_camB, xyz_pixB, mode='nearest')

    values = torch.reshape(values, (B, C, Z, Y, X))
    return values
def apply_4x4_to_vox(B_T_A,
                     feat_A,
                     mem_coord_As=None,
                     mem_coord_Bs=None,
                     already_mem=False,
                     binary_feat=False,
                     rigid=True):
    # B_T_A is B x 4 x 4
    # if already_mem=False, it is a transformation between cam systems
    # if already_mem=True, it is a transformation between mem systems

    # feat_A is B x C x Z x Y x X
    # it represents some scene features in reference/canonical coordinates
    # we want to go from these coords to some target coords

    # since this is a backwarp,
    # the question to ask is:
    # "WHERE in the tensor do you want to sample,
    # to replace each voxel's current value?"

    # the inverse of B_T_A represents this "where";
    # it transforms each coordinate in B
    # to the location we want to sample in A

    B, C, Z, Y, X = list(feat_A.shape)

    # we have B_T_A in input, since this follows the other utils_geom.apply_4x4
    # for an apply_4x4 func, but really we need A_T_B
    if rigid:
        A_T_B = utils_geom.safe_inverse(B_T_A)
    else:
        # this op is slower but more powerful
        A_T_B = B_T_A.inverse()

    if not already_mem:
        cam_T_mem = mem_coord_Bs.cam_T_vox.repeat(B, 1, 1)
        mem_T_cam = mem_coord_As.vox_T_cam.repeat(B, 1, 1)
        A_T_B = utils_basic.matmul3(mem_T_cam, A_T_B, cam_T_mem)

    # we want to sample for each location in the bird grid
    xyz_B = utils_basic.gridcloud3D(B, Z, Y, X)
    # this is B x N x 3

    # transform
    xyz_A = utils_geom.apply_4x4(A_T_B, xyz_B)
    # we want each voxel to take its value
    # from whatever is at these A coordinates
    # i.e., we are back-warping from the "A" coords

    # feat_B = F.grid_sample(feat_A, normalize_grid(xyz_A, Z, Y, X))
    feat_B = utils_samp.resample3D(feat_A, xyz_A, binary_feat=binary_feat)

    # feat_B, valid = utils_samp.resample3D(feat_A, xyz_A, binary_feat=binary_feat)
    # return feat_B, valid
    return feat_B
def assemble_padded_obj_masklist(lrtlist, scorelist, Z, Y, X, coeff=1.0):
    # compute a binary mask in 3D for each object
    # we use this when computing the center-surround objectness score
    # lrtlist is B x N x 19
    # scorelist is B x N

    # returns masklist shaped B x N x 1 x Z x Y x Z

    B, N, D = list(lrtlist.shape)
    assert (D == 19)
    masks = torch.zeros(B, N, Z, Y, X)

    lenlist, ref_T_objlist = utils_geom.split_lrtlist(lrtlist)
    # lenlist is B x N x 3
    # ref_T_objlist is B x N x 4 x 4

    lenlist_ = lenlist.reshape(B * N, 3)
    ref_T_objlist_ = ref_T_objlist.reshape(B * N, 4, 4)
    obj_T_reflist_ = utils_geom.safe_inverse(ref_T_objlist_)

    # we want a value for each location in the mem grid
    xyz_mem_ = utils_basic.gridcloud3D(B * N, Z, Y, X)
    # this is B*N x V x 3, where V = Z*Y*X
    xyz_ref_ = Mem2Ref(xyz_mem_, Z, Y, X)
    # this is B*N x V x 3

    lx, ly, lz = torch.unbind(lenlist_, dim=1)
    # these are B*N

    # ref_T_obj = convert_box_to_ref_T_obj(boxes3D)
    # obj_T_ref = ref_T_obj.inverse()

    xyz_obj_ = utils_geom.apply_4x4(obj_T_reflist_, xyz_ref_)
    x, y, z = torch.unbind(xyz_obj_, dim=2)
    # these are B*N x V

    lx = lx.unsqueeze(1) * coeff
    ly = ly.unsqueeze(1) * coeff
    lz = lz.unsqueeze(1) * coeff
    # these are B*N x 1

    x_valid = (x > -lx / 2.0).byte() & (x < lx / 2.0).byte()
    y_valid = (y > -ly / 2.0).byte() & (y < ly / 2.0).byte()
    z_valid = (z > -lz / 2.0).byte() & (z < lz / 2.0).byte()
    inbounds = x_valid.byte() & y_valid.byte() & z_valid.byte()
    masklist = inbounds.float()
    # print(masklist.shape)
    masklist = masklist.reshape(B, N, 1, Z, Y, X)
    # print(masklist.shape)
    # print(scorelist.shape)
    masklist = masklist * scorelist.view(B, N, 1, 1, 1, 1)
    return masklist
def resample_to_view(feats, new_T_old, multi=False):
    # feats is B x S x c x Y x X x Z
    # it represents some scene features in reference/canonical coordinates
    # we want to go from these coords to some target coords

    # new_T_old is B x 4 x 4
    # it represents a transformation between two "mem" systems
    # or if multi=True, it's B x S x 4 x 4

    B, S, C, Z, Y, X = list(feats.shape)

    # we want to sample for each location in the bird grid
    # xyz_mem = gridcloud3D(B, Z, Y, X)
    grid_y, grid_x, grid_z = meshgrid3D(B, Z, Y, X)
    # these are B x BY x BX x BZ
    # these represent the mem grid coordinates

    # we need to convert these to pixel coordinates
    x = torch.reshape(grid_x, [B, -1])
    y = torch.reshape(grid_y, [B, -1])
    z = torch.reshape(grid_z, [B, -1])
    # these are B x N

    xyz_mem = torch.stack([x, y, z], dim=2)
    # this is B x N x 3

    xyz_mems = xyz_mem.unsqueeze(1).repeat(1, S, 1, 1)
    # this is B x S x N x 3

    xyz_mems_ = xyz_mems.view(B * S, Y * X * Z, 3)

    feats_ = feats.view(B * S, C, Z, Y, X)

    if multi:
        new_T_olds = new_T_old.clone()
    else:
        new_T_olds = new_T_old.unsqueeze(1).repeat(1, S, 1, 1)
    new_T_olds_ = new_T_olds.view(B * S, 4, 4)

    xyz_new_ = utils_geom.apply_4x4(new_T_olds_, xyz_mems_)
    # we want each voxel to replace its value
    # with whatever is at these new coordinates

    # i.e., we are back-warping from the "new" coords

    feats_, valid_ = utils_samp.resample3D(feats_, xyz_new_)
    feats = feats_.view(B, S, C, Z, Y, X)
    valid = valid_.view(B, S, 1, Z, Y, X)
    return feats, valid
def Zoom2Ref(xyz_zoom, lrt_ref, ZY, ZX, ZZ, sensor_camR_T_camXs=None):
    # xyz_zoom is B x N x 3, in zoom coordinates
    # lrt_ref is B x 9, converts from box to cam coordinates
    # sensor_camR_T_camXs standard transformation matrix to
    # convert from cam coords to ref_coords.
    B, N, _ = list(xyz_zoom.shape)
    ref_T_zoom = get_ref_T_zoom(lrt_ref, ZY, ZX, ZZ)
    ref_T_zoom = ref_T_zoom.to(xyz_zoom.device)
    # the zero zero zero of xyz_zoom should be mapped to (0, 0, 0.07)
    # and it does I checked it
    if sensor_camR_T_camXs is not None:
        # this takes from grid_coordinates to box_coordinates to cam_coordinates
        # to ref_coordinates
        ref_T_zoom = utils_basic.matmul2(sensor_camR_T_camXs, ref_T_zoom)
    # remember this are coordinates in ref_cam and not in memory
    xyz_ref = utils_geom.apply_4x4(ref_T_zoom, xyz_zoom)
    return xyz_ref
def assemble_static_seq(feats, ref0_T_refXs):
    # feats is B x S x C x Y x X x Z
    # it is in mem coords

    # ref0_T_refXs is B x S x 4 x 4
    # it tells us how to warp the static scene

    # ref0 represents a reference frame, not necessarily frame0
    # refXs represents the frames where feats were observed

    B, S, C, Z, Y, X = list(feats.shape)

    # each feat is in its own little coord system
    # we need to get from 0 coords to these coords
    # and sample

    # we want to sample for each location in the bird grid
    # xyz_mem = gridcloud3D(B, Z, Y, X)
    grid_y, grid_x, grid_z = meshgrid3D(B, Z, Y, X)
    # these are B x BY x BX x BZ
    # these represent the mem grid coordinates

    # we need to convert these to pixel coordinates
    x = torch.reshape(grid_x, [B, -1])
    y = torch.reshape(grid_y, [B, -1])
    z = torch.reshape(grid_z, [B, -1])
    # these are B x N

    xyz_mem = torch.stack([x, y, z], dim=2)
    # this is B x N x 3
    xyz_ref = Mem2Ref(xyz_mem, Z, Y, X)
    # this is B x N x 3
    xyz_refs = xyz_ref.unsqueeze(1).repeat(1, S, 1, 1)
    # this is B x S x N x 3
    xyz_refs_ = torch.reshape(xyz_refs, (B * S, Y * X * Z, 3))

    feats_ = torch.reshape(feats, (B * S, C, Z, Y, X))

    ref0_T_refXs_ = torch.reshape(ref0_T_refXs, (B * S, 4, 4))
    refXs_T_ref0_ = utils_geom.safe_inverse(ref0_T_refXs_)

    xyz_refXs_ = utils_geom.apply_4x4(refXs_T_ref0_, xyz_refs_)
    xyz_memXs_ = Ref2Mem(xyz_refXs_, Z, Y, X)
    feats_, _ = utils_samp.resample3D(feats_, xyz_memXs_)
    feats = torch.reshape(feats_, (B, S, C, Z, Y, X))
    return feats
def apply_pixX_T_memR_to_voxR(pix_T_camX, camX_T_camR, mem_coord, voxR, D, H,
                              W):
    # mats are B x 4 x 4
    # voxR is B x C x Z x Y x X
    # H, W, D indicates how big to make the output
    # returns B x C x D x H x W
    vox_coord = mem_coord.coord
    ZMIN = vox_coord.ZMIN
    ZMAX = vox_coord.ZMAX

    B, C, Z, Y, X = list(voxR.shape)
    z_near = ZMIN
    z_far = ZMAX

    grid_z = torch.linspace(z_near,
                            z_far,
                            steps=D,
                            dtype=torch.float32,
                            device=DEVICE)
    grid_z = torch.reshape(grid_z, [1, 1, D, 1, 1])
    grid_z = grid_z.repeat([B, 1, 1, H, W])
    grid_z = torch.reshape(grid_z, [B * D, 1, H, W])

    pix_T_camX__ = torch.unsqueeze(pix_T_camX, axis=1).repeat([1, D, 1, 1])
    pix_T_camX = torch.reshape(pix_T_camX__, [B * D, 4, 4])
    xyz_camX = utils_geom.depth2pointcloud(grid_z, pix_T_camX)

    camR_T_camX = utils_geom.safe_inverse(camX_T_camR)
    camR_T_camX_ = torch.unsqueeze(camR_T_camX, dim=1).repeat([1, D, 1, 1])
    camR_T_camX = torch.reshape(camR_T_camX_, [B * D, 4, 4])

    mem_T_cam = get_mem_T_ref(B * D, mem_coord)
    memR_T_camX = utils_basic.matmul2(mem_T_cam, camR_T_camX)

    xyz_memR = utils_geom.apply_4x4(memR_T_camX, xyz_camX)
    xyz_memR = torch.reshape(xyz_memR, [B, D * H * W, 3])

    samp = utils_samp.sample3D(voxR, xyz_memR, D, H, W)
    # samp is B x H x W x D x C
    return samp
예제 #12
0
def get_synth_flow(unpRs,
                   occRs,
                   obj_lrtlist_camX0s,
                   obj_scorelist_s,
                   occXs,
                   set_name,
                   K,
                   summ_writer,
                   sometimes_zero=False,
                   sometimes_real=False,
                   do_vis=False):
    B, S, _, Z, Y, X = list(occXs.shape)
    assert (S == 2)
    flowX0 = get_gt_flow(
        obj_lrtlist_camX0s,
        obj_scorelist_s,
        utils_geom.eye_4x4s(B, S),
        occXs[:, 0],
        K=K,
        occ_only=False,  # get the dense flow
        mod='X0',
        summ_writer=summ_writer)

    # we do not sample any rotations here, to keep the distribution purely
    # uniform across all translation angles.
    # (rotation ruins this, since the pivot point is at the camera)
    cam1_T_cam0 = [
        utils_geom.get_random_rt(B, r_amount=0.0,
                                 t_amount=1.0),  # large motion
        utils_geom.get_random_rt(
            B,
            r_amount=0.0,
            t_amount=0.1,  # small motion
            sometimes_zero=sometimes_zero)
    ]
    cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0]

    occ0 = occRs[:, 0]
    unp0 = unpRs[:, 0]
    occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0)
    unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0)
    occs = [occ0, occ1]
    unps = [unp0, unp1]

    # occ1 should be a binary thing, so let's restore that property
    occ1 = torch.round(occ1)
    if do_vis:
        summ_writer.summ_occs('synth/occs', occs)
        summ_writer.summ_unps('synth/unps', unps, occs)

    mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X)
    cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem)
    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)
    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)
    flow = xyz_mem1 - xyz_mem0
    # this is B x Z x Y x X x 3
    flow = flow.permute(0, 4, 1, 2, 3)
    # this is B x 3 x Z x Y x X
    if do_vis:
        summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0)

    occ0_e = utils_samp.backwarp_using_3D_flow(occ1, flow, binary_feat=True)
    unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow)
    if do_vis:
        summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e])
        summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e],
                              [occ0, occ0_e])

    occs = torch.stack(occs, dim=1)
    unps = torch.stack(unps, dim=1)

    is_synth = 1
    if sometimes_real and set_name == 'train':
        is_synth = random.randint(0, 1)
        occs = [occRs, occs][is_synth]
        unps = [unpRs, unps][is_synth]
        flow = [flowX0, flow][is_synth]
        cam1_T_cam0 = [utils_geom.eye_4x4(B), cam1_T_cam0][is_synth]

    return occs, unps, flow, cam1_T_cam0, is_synth
예제 #13
0
def get_gt_flow(obj_lrtlist_camRs,
                obj_scorelist,
                camRs_T_camXs,
                occR,
                K=2,
                occ_only=True,
                mod='',
                vis=True,
                summ_writer=None):
    # this constructs the flow field according to the given
    # box trajectories (obj_lrtlist_camRs) (collected from a moving camR)
    # and egomotion (encoded in camRs_T_camXs)
    # (so they do not take into account egomotion)
    # so, we first generate the flow for all the objects,
    # then in the background, put the ego flow

    N, B, S, D = list(obj_lrtlist_camRs.shape)
    assert (S == 2)  # as a flow util, this expects S=2
    B, _, Z, Y, X = list(occR.shape)

    flows = []
    masks = []
    for k in range(K):
        obj_masklistR0 = utils_vox.assemble_padded_obj_masklist(
            obj_lrtlist_camRs[k, :, 0:1],
            obj_scorelist[k, :, 0:1],
            Z,
            Y,
            X,
            coeff=1.0)
        # this is B x 1(N) x 1(C) x Z x Y x Z
        # obj_masklistR0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X
        obj_mask0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X

        camR_T_cam0 = camRs_T_camXs[:, 0]
        camR_T_cam1 = camRs_T_camXs[:, 1]
        cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
        cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)
        # camR0_T_camR1 = camR0_T_camRs[:,1]
        # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1)

        # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0)
        # if vis and (summ_writer is not None):
        #     summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0)
        #     summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1)

        if vis and (summ_writer is not None):
            # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0)
            summ_writer.summ_oned('flow/obj%d_mask0' % k,
                                  torch.mean(obj_mask0, 3))

        _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k])
        # this is B x S x 4 x 4
        ref_T_obj0 = ref_T_objs_list[:, 0]
        ref_T_obj1 = ref_T_objs_list[:, 1]
        obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0)
        obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1)
        # these are B x 4 x 4

        mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
        ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

        ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref)
        cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0,
                                          camR_T_cam0)
        mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

        xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
        xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

        xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
        xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

        # only use these displaced points within the obj mask
        # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3)
        obj_mask0 = obj_mask0.view(B, Z, Y, X, 1)
        # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0
        # cond = (obj_mask03 < 1.0).float()
        cond = (obj_mask0 > 0.0).float()
        xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0

        flow = xyz_mem1 - xyz_mem0
        flow = flow.permute(0, 4, 1, 2, 3)
        obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3)

        # flow is centered on frame0, so we use occ0 to mask it
        if occ_only:
            flow = flow * occR

        # if vis and k==0:
        if vis:
            summ_writer.summ_3D_flow('flow/gt_%d' % k, flow, clip=4.0)

        masks.append(obj_mask0)
        flows.append(flow)

    camR_T_cam0 = camRs_T_camXs[:, 0]
    camR_T_cam1 = camRs_T_camXs[:, 1]
    cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
    cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)

    mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
    ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

    cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

    flow = xyz_mem1 - xyz_mem0
    flow = flow.permute(0, 4, 1, 2, 3)
    if occ_only:
        flow = flow * occXs[:, 0]

    bkg_flow = flow

    # allow zero motion in the bkg
    any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0]
    masks.append(1.0 - any_mask)
    flows.append(bkg_flow)

    flows = torch.stack(flows, axis=0)
    masks = torch.stack(masks, axis=0)
    masks = masks.repeat(1, 1, 3, 1, 1, 1)
    flow = utils_basic.reduce_masked_mean(flows, masks, dim=0)

    if vis:
        summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0)

    # flow is shaped B x 3 x D x H x W
    return flow