コード例 #1
0
def Ref2Mem(xyz, Z, Y, X, bounds='default'):
    # xyz is B x N x 3, in ref coordinates
    # transforms velo coordinates into mem coordinates
    B, N, C = list(xyz.shape)
    mem_T_ref = get_mem_T_ref(B, Z, Y, X, bounds=bounds)
    xyz = utils_geom.apply_4x4(mem_T_ref, xyz)
    return xyz
コード例 #2
0
def Mem2Ref(xyz_mem, Z, Y, X, bounds='default'):
    # xyz is B x N x 3, in mem coordinates
    # transforms mem coordinates into ref coordinates
    B, N, C = list(xyz_mem.shape)
    ref_T_mem = get_ref_T_mem(B, Z, Y, X, bounds=bounds)
    xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
    return xyz_ref
コード例 #3
0
    def summ_lrtlist(self,
                     name,
                     rgbR,
                     lrtlist,
                     scorelist,
                     tidlist,
                     pix_T_cam,
                     only_return=False):
        # rgb is B x H x W x C
        # lrtlist is B x N x 17
        # scorelist is B x N
        # tidlist is B x N
        # pix_T_cam is B x 4 x 4

        B, C, H, W = list(rgbR.shape)
        B, N, D = list(lrtlist.shape)
        lenlist = lrtlist[:, :, :3].reshape(B, N, 3)
        rtlist = lrtlist[:, :, 3:].reshape(B, N, 4, 4)

        xyzlist_obj = utils_geom.get_xyzlist_from_lenlist(lenlist)
        # this is B x N x 8 x 3

        rtlist_ = rtlist.reshape(B * N, 4, 4)
        xyzlist_obj_ = xyzlist_obj.reshape(B * N, 8, 3)
        xyzlist_cam_ = utils_geom.apply_4x4(rtlist_, xyzlist_obj_)
        xyzlist_cam = xyzlist_cam_.reshape(B, N, 8, 3)

        boxes_vis = self.draw_corners_on_image(rgbR, xyzlist_cam, scorelist,
                                               tidlist, pix_T_cam)
        if not only_return:
            self.summ_rgb(name, boxes_vis)
        return boxes_vis
コード例 #4
0
def Zoom2Ref(xyz_zoom, lrt_ref, Z, Y, X, additive_pad=0.1):
    # xyz_zoom is B x N x 3, in zoom coordinates
    # lrt_ref is B x 9, specifying the box in ref coordinates
    B, N, _ = list(xyz_zoom.shape)
    ref_T_zoom = get_ref_T_zoom(lrt_ref, Z, Y, X, additive_pad=additive_pad)
    xyz_ref = utils_geom.apply_4x4(ref_T_zoom, xyz_zoom)
    return xyz_ref
コード例 #5
0
ファイル: locnet.py プロジェクト: shamitlal/CoCoNets
    def convert_params_to_lrt(self, obj_len, obj_xyz_sce, obj_rot, cam_T_sce):
        # this borrows from utils_geom.convert_box_to_ref_T_obj
        B = list(obj_xyz_sce.shape)[0]

        obj_xyz_cam = utils_geom.apply_4x4(cam_T_sce,
                                           obj_xyz_sce.unsqueeze(1)).squeeze(1)
        # # compose with lengths
        # lrt = utils_geom.merge_lrt(obj_len, cam_T_obj)
        # return lrt

        rot0 = utils_geom.eye_3x3(B)
        # tra = torch.stack([x, y, z], axis=1)
        center_T_ref = utils_geom.merge_rt(rot0, -obj_xyz_cam)
        # center_T_ref is B x 4 x 4

        t0 = torch.zeros([B, 3])
        obj_T_center = utils_geom.merge_rt(obj_rot, t0)
        # this is B x 4 x 4

        # we want obj_T_ref
        # first we to translate to center,
        # and then rotate around the origin
        obj_T_ref = utils_basic.matmul2(obj_T_center, center_T_ref)

        # return the inverse of this, so that we can transform obj corners into cam coords
        ref_T_obj = utils_geom.safe_inverse(obj_T_ref)
        # return ref_T_obj

        # compose with lengths
        lrt = utils_geom.merge_lrt(obj_len, ref_T_obj)
        return lrt
コード例 #6
0
def get_synth_flow(occs,
                   unps,
                   summ_writer,
                   sometimes_zero=False,
                   do_vis=False):
    B, S, C, Z, Y, X = list(occs.shape)
    assert (S == 2, C == 1)

    # we do not sample any rotations here, to keep the distribution purely
    # uniform across all translations
    # (rotation ruins this, since the pivot point is at the camera)
    cam1_T_cam0 = [
        utils_geom.get_random_rt(B, r_amount=0.0,
                                 t_amount=1.0),  # large motion
        utils_geom.get_random_rt(
            B,
            r_amount=0.0,
            t_amount=0.1,  # small motion
            sometimes_zero=sometimes_zero)
    ]
    cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0]

    occ0 = occs[:, 0]
    unp0 = unps[:, 0]
    occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True)
    unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0)
    occs = [occ0, occ1]
    unps = [unp0, unp1]

    if do_vis:
        summ_writer.summ_occs('synth/occs', occs)
        summ_writer.summ_unps('synth/unps', unps, occs)

    mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X)
    cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem)
    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)
    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)
    flow = xyz_mem1 - xyz_mem0
    # this is B x Z x Y x X x 3
    flow = flow.permute(0, 4, 1, 2, 3)
    # this is B x 3 x Z x Y x X
    if do_vis:
        summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0)

    if do_vis:
        occ0_e = utils_samp.backwarp_using_3D_flow(occ1,
                                                   flow,
                                                   binary_feat=True)
        unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow)
        summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e])
        summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e],
                              [occ0, occ0_e])

    occs = torch.stack(occs, dim=1)
    unps = torch.stack(unps, dim=1)

    return occs, unps, flow, cam1_T_cam0
コード例 #7
0
def rescore_boxlist_with_pointcloud(camX_T_camR,
                                    boxlist_camR,
                                    xyz_camX,
                                    scorelist,
                                    tidlist,
                                    thresh=1.0):
    # boxlist_camR is B x N x 9
    B, N, D = list(boxlist_camR.shape)
    assert (D == 9)
    xyzlist = boxlist_camR[:, :, :3]
    # this is B x N x 3
    lenlist = boxlist_camR[:, :, 3:7]
    # this is B x N x 3

    xyzlist = utils_geom.apply_4x4(camX_T_camR, xyzlist)

    # xyz_camX is B x V x 3
    xyz_camX = xyz_camX[:, ::10]
    xyz_camX = xyz_camX.unsqueeze(1)
    # xyz_camX is B x 1 x V x 3
    xyzlist = xyzlist.unsqueeze(2)
    # xyzlist is B x N x 1 x 3

    dists = torch.norm(xyz_camX - xyzlist, dim=3)
    # this is B x N x V

    mindists = torch.min(dists, 2)[0]
    ok = (mindists < thresh).float()
    scorelist = scorelist * ok
    return scorelist
コード例 #8
0
def Ref2Zoom(xyz_ref, lrt_ref, Z, Y, X, additive_pad=0.1):
    # xyz_ref is B x N x 3, in ref coordinates
    # lrt_ref is B x 19, specifying the box in ref coordinates
    # this transforms ref coordinates into zoom coordinates
    B, N, _ = list(xyz_ref.shape)
    zoom_T_ref = get_zoom_T_ref(lrt_ref, Z, Y, X, additive_pad=additive_pad)
    xyz_zoom = utils_geom.apply_4x4(zoom_T_ref, xyz_ref)
    return xyz_zoom
コード例 #9
0
def gen_list_of_bboxes(tree, boxes=[], ref_T_mem=None):
    for i in range(0, tree.num_children):
        updated_tree, boxes = gen_list_of_bboxes(tree.children[i],
                                                 boxes=boxes,
                                                 ref_T_mem=ref_T_mem)
        tree.children[i] = updated_tree
    if tree.function == "describe":
        coordinates_M = get_coordinates(tree)
        coordinates_M = np.expand_dims(coordinates_M, 0).astype(np.float32)
        coordinates_R = utils_geom.apply_4x4(torch.tensor(ref_T_mem),
                                             torch.tensor(coordinates_M))
        camR_T_origin = get_camRTorigin()
        coordinates_R = utils_geom.apply_4x4(
            torch.tensor(camR_T_origin),
            torch.tensor(coordinates_R, dtype=torch.float32))
        # coords = np.squeeze(coords.cpu().numpy(),axis=0)
        cube_coordinates, tree_box = gen_cube_coordinates(coordinates_R)
        tree.bbox_origin = tree_box
        boxes.append(cube_coordinates)
    return tree, boxes
コード例 #10
0
def rescore_boxlist_with_inbound(camX_T_camR,
                                 boxlist_camR,
                                 tidlist,
                                 Z,
                                 Y,
                                 X,
                                 only_cars=True,
                                 pad=2.0):
    # boxlist_camR is B x N x 9
    B, N, D = list(boxlist_camR.shape)
    assert (D == 9)
    xyzlist = boxlist_camR[:, :, :3]
    # this is B x N x 3
    lenlist = boxlist_camR[:, :, 3:7]
    # this is B x N x 3

    xyzlist = utils_geom.apply_4x4(camX_T_camR, xyzlist)

    validlist = 1.0 - (torch.eq(tidlist,
                                -1 * torch.ones_like(tidlist))).float()
    # this is B x N

    if only_cars:
        biglist = (torch.norm(lenlist, dim=2) > 2.0).float()
        validlist = validlist * biglist

    xlist, ylist, zlist = torch.unbind(xyzlist, dim=2)
    inboundlist_0 = utils_vox.get_inbounds(torch.stack(
        [xlist + pad, ylist, zlist], dim=2),
                                           Z,
                                           Y,
                                           X,
                                           already_mem=False).float()
    inboundlist_1 = utils_vox.get_inbounds(torch.stack(
        [xlist - pad, ylist, zlist], dim=2),
                                           Z,
                                           Y,
                                           X,
                                           already_mem=False).float()
    inboundlist_2 = utils_vox.get_inbounds(torch.stack(
        [xlist, ylist, zlist + pad], dim=2),
                                           Z,
                                           Y,
                                           X,
                                           already_mem=False).float()
    inboundlist_3 = utils_vox.get_inbounds(torch.stack(
        [xlist, ylist, zlist - pad], dim=2),
                                           Z,
                                           Y,
                                           X,
                                           already_mem=False).float()
    inboundlist = inboundlist_0 * inboundlist_1 * inboundlist_2 * inboundlist_3
    scorelist = validlist * inboundlist
    return scorelist
コード例 #11
0
def apply_4x4_to_vox(B_T_A,
                     feat_A,
                     already_mem=False,
                     binary_feat=False,
                     rigid=True):
    # B_T_A is B x 4 x 4
    # if already_mem=False, it is a transformation between cam systems
    # if already_mem=True, it is a transformation between mem systems

    # feat_A is B x C x Z x Y x X
    # it represents some scene features in reference/canonical coordinates
    # we want to go from these coords to some target coords

    # since this is a backwarp,
    # the question to ask is:
    # "WHERE in the tensor do you want to sample,
    # to replace each voxel's current value?"

    # the inverse of B_T_A represents this "where";
    # it transforms each coordinate in B
    # to the location we want to sample in A

    B, C, Z, Y, X = list(feat_A.shape)

    # we have B_T_A in input, since this follows the other utils_geom.apply_4x4
    # for an apply_4x4 func, but really we need A_T_B
    if rigid:
        A_T_B = utils_geom.safe_inverse(B_T_A)
    else:
        # this op is slower but more powerful
        A_T_B = B_T_A.inverse()

    if not already_mem:
        cam_T_mem = get_ref_T_mem(B, Z, Y, X)
        mem_T_cam = get_mem_T_ref(B, Z, Y, X)
        A_T_B = matmul3(mem_T_cam, A_T_B, cam_T_mem)

    # we want to sample for each location in the bird grid
    xyz_B = gridcloud3D(B, Z, Y, X)
    # this is B x N x 3

    # transform
    xyz_A = utils_geom.apply_4x4(A_T_B, xyz_B)
    # we want each voxel to take its value
    # from whatever is at these A coordinates
    # i.e., we are back-warping from the "A" coords

    # feat_B = F.grid_sample(feat_A, normalize_grid(xyz_A, Z, Y, X))
    feat_B = utils_samp.resample3D(feat_A, xyz_A, binary_feat=binary_feat)

    # feat_B, valid = utils_samp.resample3D(feat_A, xyz_A, binary_feat=binary_feat)
    # return feat_B, valid
    return feat_B
コード例 #12
0
def unproject_rgb_to_mem(rgb_camB, Z, Y, X, pixB_T_camA):
    # rgb_camB is B x C x H x W
    # pixB_T_camA is B x 4 x 4

    # rgb lives in B pixel coords
    # we want everything in A memory coords

    # this puts each C-dim pixel in the rgb_camB
    # along a ray in the voxelgrid
    B, C, H, W = list(rgb_camB.shape)

    xyz_memA = gridcloud3D(B, Z, Y, X, norm=False)
    # grid_z, grid_y, grid_x = meshgrid3D(B, Z, Y, X)
    # # these are B x Z x Y x X
    # # these represent the mem grid coordinates

    # # we need to convert these to pixel coordinates
    # x = torch.reshape(grid_x, [B, -1])
    # y = torch.reshape(grid_y, [B, -1])
    # z = torch.reshape(grid_z, [B, -1])
    # # these are B x N
    # xyz_mem = torch.stack([x, y, z], dim=2)

    xyz_camA = Mem2Ref(xyz_memA, Z, Y, X)

    xyz_pixB = utils_geom.apply_4x4(pixB_T_camA, xyz_camA)
    normalizer = torch.unsqueeze(xyz_pixB[:, :, 2], 2)
    EPS = 1e-6
    xy_pixB = xyz_pixB[:, :, :2] / torch.clamp(normalizer, min=EPS)
    # this is B x N x 2
    # this is the (floating point) pixel coordinate of each voxel
    x_pixB, y_pixB = xy_pixB[:, :, 0], xy_pixB[:, :, 1]
    # these are B x N

    if (0):
        # handwritten version
        values = torch.zeros([B, C, Z * Y * X], dtype=torch.float32)
        for b in list(range(B)):
            values[b] = utils_samp.bilinear_sample_single(
                rgb_camB[b], x_pixB[b], y_pixB[b])
    else:
        # native pytorch version
        y_pixB, x_pixB = normalize_grid2D(y_pixB, x_pixB, H, W)
        # since we want a 3d output, we need 5d tensors
        z_pixB = torch.zeros_like(x_pixB)
        xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)
        rgb_camB = rgb_camB.unsqueeze(2)
        xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
        values = F.grid_sample(rgb_camB, xyz_pixB)

    values = torch.reshape(values, (B, C, Z, Y, X))
    return values
コード例 #13
0
def assemble_padded_obj_masklist(lrtlist, scorelist, Z, Y, X, coeff=1.0):
    # compute a binary mask in 3D for each object
    # we use this when computing the center-surround objectness score
    # lrtlist is B x N x 19
    # scorelist is B x N

    # returns masklist shaped B x N x 1 x Z x Y x X

    B, N, D = list(lrtlist.shape)
    assert (D == 19)
    masks = torch.zeros(B, N, Z, Y, X)

    lenlist, ref_T_objlist = utils_geom.split_lrtlist(lrtlist)
    # lenlist is B x N x 3
    # ref_T_objlist is B x N x 4 x 4

    lenlist_ = lenlist.reshape(B * N, 3)
    ref_T_objlist_ = ref_T_objlist.reshape(B * N, 4, 4)
    obj_T_reflist_ = utils_geom.safe_inverse(ref_T_objlist_)

    # we want a value for each location in the mem grid
    xyz_mem_ = gridcloud3D(B * N, Z, Y, X)
    # this is B*N x V x 3, where V = Z*Y*X
    xyz_ref_ = Mem2Ref(xyz_mem_, Z, Y, X)
    # this is B*N x V x 3

    lx, ly, lz = torch.unbind(lenlist_, dim=1)
    # these are B*N

    # ref_T_obj = convert_box_to_ref_T_obj(boxes3D)
    # obj_T_ref = ref_T_obj.inverse()

    xyz_obj_ = utils_geom.apply_4x4(obj_T_reflist_, xyz_ref_)
    x, y, z = torch.unbind(xyz_obj_, dim=2)
    # these are B*N x V

    lx = lx.unsqueeze(1) * coeff
    ly = ly.unsqueeze(1) * coeff
    lz = lz.unsqueeze(1) * coeff
    # these are B*N x 1

    x_valid = (x > -lx / 2.0).byte() & (x < lx / 2.0).byte()
    y_valid = (y > -ly / 2.0).byte() & (y < ly / 2.0).byte()
    z_valid = (z > -lz / 2.0).byte() & (z < lz / 2.0).byte()
    inbounds = x_valid.byte() & y_valid.byte() & z_valid.byte()
    masklist = inbounds.float()
    # print(masklist.shape)
    masklist = masklist.reshape(B, N, 1, Z, Y, X)
    # print(masklist.shape)
    # print(scorelist.shape)
    masklist = masklist * scorelist.view(B, N, 1, 1, 1, 1)
    return masklist
コード例 #14
0
def resample_to_view(feats, new_T_old, multi=False):
    # feats is B x S x c x Y x X x Z
    # it represents some scene features in reference/canonical coordinates
    # we want to go from these coords to some target coords

    # new_T_old is B x 4 x 4
    # it represents a transformation between two "mem" systems
    # or if multi=True, it's B x S x 4 x 4

    B, S, C, Z, Y, X = list(feats.shape)

    # we want to sample for each location in the bird grid
    # xyz_mem = gridcloud3D(B, Z, Y, X)
    grid_y, grid_x, grid_z = meshgrid3D(B, Z, Y, X)
    # these are B x BY x BX x BZ
    # these represent the mem grid coordinates

    # we need to convert these to pixel coordinates
    x = torch.reshape(grid_x, [B, -1])
    y = torch.reshape(grid_y, [B, -1])
    z = torch.reshape(grid_z, [B, -1])
    # these are B x N

    xyz_mem = torch.stack([x, y, z], dim=2)
    # this is B x N x 3

    xyz_mems = xyz_mem.unsqueeze(1).repeat(1, S, 1, 1)
    # this is B x S x N x 3

    xyz_mems_ = xyz_mems.view(B * S, Y * X * Z, 3)

    feats_ = feats.view(B * S, C, Z, Y, X)

    if multi:
        new_T_olds = new_T_old.clone()
    else:
        new_T_olds = new_T_old.unsqueeze(1).repeat(1, S, 1, 1)
    new_T_olds_ = new_T_olds.view(B * S, 4, 4)

    xyz_new_ = utils_geom.apply_4x4(new_T_olds_, xyz_mems_)
    # we want each voxel to replace its value
    # with whatever is at these new coordinates

    # i.e., we are back-warping from the "new" coords

    feats_, valid_ = utils_samp.resample3D(feats_, xyz_new_)
    feats = feats_.view(B, S, C, Z, Y, X)
    valid = valid_.view(B, S, 1, Z, Y, X)
    return feats, valid
コード例 #15
0
    def test_bbox_projection(self, xyz_camXs_origin_agg, origin_T_camXs,
                             pix_T_camXs, rgb_camX, xyz_camXs, f):
        rgb_camX = rgb_camX.astype(np.float32)

        objs_info = f['objects_info']
        for obj_info in objs_info:
            if obj_info['category_name'] == "chair":

                bbox_center = obj_info['bbox_center']
                bbox_size = obj_info['bbox_size']
                print("bbox center and size are: ", bbox_center, bbox_size)

        xmin, xmax = bbox_center[0] - bbox_size[0] / 2., bbox_center[
            0] + bbox_size[0] / 2.
        ymin, ymax = bbox_center[1] - bbox_size[1] / 2., bbox_center[
            1] + bbox_size[1] / 2.
        zmin, zmax = bbox_center[2] - bbox_size[2] / 2., bbox_center[
            2] + bbox_size[2] / 2.

        bbox_origin_ends = np.array([xmin, ymin, zmin, xmax, ymax, zmax])
        bbox_origin_theta = nlu.get_alignedboxes2thetaformat(
            torch.tensor(bbox_origin_ends).reshape(1, 1, 2, 3).float())
        bbox_origin_corners = utils_geom.transform_boxes_to_corners(
            bbox_origin_theta)

        nlu.only_visualize(nlu.make_pcd(xyz_camXs_origin_agg.numpy()),
                           bbox_origin_ends.reshape(1, -1))

        print("Ends of bbox in origin are: ", bbox_origin_ends)

        camX_T_origin = utils_geom.safe_inverse(
            torch.tensor(origin_T_camXs).unsqueeze(0)).float()
        bbox_corners_camX = utils_geom.apply_4x4(
            camX_T_origin,
            bbox_origin_corners.squeeze(0).float())
        bbox_ends_camX = nlu.get_ends_of_corner(
            bbox_corners_camX.permute(0, 2, 1)).permute(0, 2, 1)

        ends_camX = bbox_ends_camX.reshape(1, -1).numpy()
        print("ends in camX are: ", ends_camX)

        nlu.only_visualize(nlu.make_pcd(xyz_camXs), ends_camX)
        plt.imshow(rgb_camX)
        plt.show(block=True)
        utils_pointcloud.draw_boxes_on_rgb(rgb_camX,
                                           pix_T_camXs,
                                           ends_camX,
                                           visualize=True)
コード例 #16
0
def assemble_static_seq(feats, ref0_T_refXs):
    # feats is B x S x C x Y x X x Z
    # it is in mem coords

    # ref0_T_refXs is B x S x 4 x 4
    # it tells us how to warp the static scene

    # ref0 represents a reference frame, not necessarily frame0
    # refXs represents the frames where feats were observed

    B, S, C, Z, Y, X = list(feats.shape)

    # each feat is in its own little coord system
    # we need to get from 0 coords to these coords
    # and sample

    # we want to sample for each location in the bird grid
    # xyz_mem = gridcloud3D(B, Z, Y, X)
    grid_y, grid_x, grid_z = meshgrid3D(B, Z, Y, X)
    # these are B x BY x BX x BZ
    # these represent the mem grid coordinates

    # we need to convert these to pixel coordinates
    x = torch.reshape(grid_x, [B, -1])
    y = torch.reshape(grid_y, [B, -1])
    z = torch.reshape(grid_z, [B, -1])
    # these are B x N

    xyz_mem = torch.stack([x, y, z], dim=2)
    # this is B x N x 3
    xyz_ref = Mem2Ref(xyz_mem, Z, Y, X)
    # this is B x N x 3
    xyz_refs = xyz_ref.unsqueeze(1).repeat(1, S, 1, 1)
    # this is B x S x N x 3
    xyz_refs_ = torch.reshape(xyz_refs, (B * S, Y * X * Z, 3))

    feats_ = torch.reshape(feats, (B * S, C, Z, Y, X))

    ref0_T_refXs_ = torch.reshape(ref0_T_refXs, (B * S, 4, 4))
    refXs_T_ref0_ = utils_geom.safe_inverse(ref0_T_refXs_)

    xyz_refXs_ = utils_geom.apply_4x4(refXs_T_ref0_, xyz_refs_)
    xyz_memXs_ = Ref2Mem(xyz_refXs_, Z, Y, X)
    feats_, _ = utils_samp.resample3D(feats_, xyz_memXs_)
    feats = torch.reshape(feats_, (B, S, C, Z, Y, X))
    return feats
コード例 #17
0
def prep_occs_supervision(camRs_T_camXs, xyz_camXs, Z, Y, X, agg=False):
    B, S, N, D = list(xyz_camXs.size())
    assert (D == 3)
    # occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2))

    # utils for packing/unpacking along seq dim
    __p = lambda x: pack_seqdim(x, B)
    __u = lambda x: unpack_seqdim(x, B)

    camRs_T_camXs_ = __p(camRs_T_camXs)
    xyz_camXs_ = __p(xyz_camXs)
    xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, xyz_camXs_)
    occXs_ = voxelize_xyz(xyz_camXs_, Z, Y, X)
    occRs_ = voxelize_xyz(xyz_camRs_, Z, Y, X)

    # note we must compute freespace in the given view,
    # then warp to the target view
    freeXs_ = get_freespace(xyz_camXs_, occXs_)
    freeRs_ = apply_4x4_to_vox(camRs_T_camXs_, freeXs_)

    occXs = __u(occXs_)
    occRs = __u(occRs_)
    freeXs = __u(freeXs_)
    freeRs = __u(freeRs_)
    # these are B x S x 1 x Z x Y x X

    if agg:
        # note we should only agg if we are in STATIC mode (time frozen)
        freeR = torch.max(freeRs, dim=1)[0]
        occR = torch.max(occRs, dim=1)[0]
        # these are B x 1 x Z x Y x X
        occR = (occR > 0.5).float()
        freeR = (freeR > 0.5).float()
        return occR, freeR, occXs, freeXs
    else:
        occRs = (occRs > 0.5).float()
        freeRs = (freeRs > 0.5).float()
        return occRs, freeRs, occXs, freeXs
コード例 #18
0
def apply_pixX_T_memR_to_voxR(pix_T_camX, camX_T_camR, voxR, D, H, W):
    # mats are B x 4 x 4
    # voxR is B x C x Z x Y x X
    # H, W, D indicates how big to make the output
    # returns B x C x D x H x W

    B, C, Z, Y, X = list(voxR.shape)
    z_near = hyp.ZMIN
    z_far = hyp.ZMAX

    grid_z = torch.linspace(z_near,
                            z_far,
                            steps=D,
                            dtype=torch.float32,
                            device=torch.device('cuda'))
    # grid_z = torch.exp(torch.linspace(np.log(z_near), np.log(z_far), steps=D, dtype=torch.float32, device=torch.device('cuda')))
    grid_z = torch.reshape(grid_z, [1, 1, D, 1, 1])
    grid_z = grid_z.repeat([B, 1, 1, H, W])
    grid_z = torch.reshape(grid_z, [B * D, 1, H, W])

    pix_T_camX__ = torch.unsqueeze(pix_T_camX, axis=1).repeat([1, D, 1, 1])
    pix_T_camX = torch.reshape(pix_T_camX__, [B * D, 4, 4])
    xyz_camX = utils_geom.depth2pointcloud(grid_z, pix_T_camX)

    camR_T_camX = utils_geom.safe_inverse(camX_T_camR)
    camR_T_camX_ = torch.unsqueeze(camR_T_camX, dim=1).repeat([1, D, 1, 1])
    camR_T_camX = torch.reshape(camR_T_camX_, [B * D, 4, 4])

    mem_T_cam = get_mem_T_ref(B * D, Z, Y, X)
    memR_T_camX = matmul2(mem_T_cam, camR_T_camX)

    xyz_memR = utils_geom.apply_4x4(memR_T_camX, xyz_camX)
    xyz_memR = torch.reshape(xyz_memR, [B, D * H * W, 3])

    samp = utils_samp.sample3D(voxR, xyz_memR, D, H, W)
    # samp is B x H x W x D x C
    return samp
コード例 #19
0
    def forward(self,
                template_feat,
                search_feat,
                template_mask,
                template_lrt,
                search_lrt,
                vox_util,
                lrt_cam0s,
                summ_writer=None):
        # template_feat is the thing we are searching for; it is B x C x ZZ x ZY x ZX
        # search_feat is the featuremap where we are searching; it is B x C x Z x Y x X
        total_loss = torch.tensor(0.0).cuda()

        B, C, ZZ, ZY, ZX = list(template_feat.shape)
        _, _, Z, Y, X = list(search_feat.shape)

        xyz0_template = utils_basic.gridcloud3D(B, ZZ, ZY, ZX)
        # this is B x med x 3
        xyz0_cam = vox_util.Zoom2Ref(xyz0_template, template_lrt, ZZ, ZY, ZX)
        # ok, next, after i relocate the object in search coords,
        # i need to transform those coords into cam, and then do svd on that

        # print('template_feat', template_feat.shape)
        # print('search_feat', search_feat.shape)

        search_feat = search_feat.view(B, C, -1)
        # this is B x C x huge
        template_feat = template_feat.view(B, C, -1)
        # this is B x C x med
        template_mask = template_mask.view(B, -1)
        # this is B x med

        # next i need to sample
        # i would like to take N random samples within the mask

        cam1_T_cam0_e = utils_geom.eye_4x4(B)

        # to simplify the impl, we will iterate over the batch dim
        for b in list(range(B)):
            template_feat_b = template_feat[b]
            template_mask_b = template_mask[b]
            search_feat_b = search_feat[b]
            xyz0_cam_b = xyz0_cam[b]

            # print('xyz0_cam_b', xyz0_cam_b.shape)
            # print('template_mask_b', template_mask_b.shape)
            # print('template_mask_b sum', torch.sum(template_mask_b).cpu().numpy())

            # take any points within the mask
            inds = torch.where(template_mask_b > 0)

            # gather up
            template_feat_b = template_feat_b.permute(1, 0)
            # this is C x med
            template_feat_b = template_feat_b[inds]
            xyz0_cam_b = xyz0_cam_b[inds]
            # these are self.num_pts x C

            # print('inds', inds)
            # not sure why this is a tuple
            # inds = inds[0]

            # trim down to self.num_pts
            # inds = inds.squeeze()
            assert (len(xyz0_cam_b) > 8)  # otw we should have returned early

            # i want to have self.num_pts pts every time
            if len(xyz0_cam_b) < self.num_pts:
                reps = int(self.num_pts / len(xyz0_cam_b)) + 1
                print('only have %d pts; repeating %d times...' %
                      (len(xyz0_cam_b), reps))
                xyz0_cam_b = xyz0_cam_b.repeat(reps, 1)
                template_feat_b = template_feat_b.repeat(reps, 1)
            assert (len(xyz0_cam_b) >= self.num_pts)
            # now trim down
            perm = np.random.permutation(len(xyz0_cam_b))
            # print('perm', perm[:10])
            xyz0_cam_b = xyz0_cam_b[perm[:self.num_pts]]
            template_feat_b = template_feat_b[perm[:self.num_pts]]

            heat_b = torch.matmul(template_feat_b, search_feat_b)
            # this is self.num_pts x huge
            # it represents each point's heatmap in the search region

            # make the min zero
            heat_b = heat_b - (torch.min(heat_b, dim=1).values).unsqueeze(1)
            # scale up, for numerical stability
            heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

            heat_b = heat_b.reshape(self.num_pts, 1, Z, Y, X)
            xyz1_search_b = utils_basic.argmax3D(heat_b,
                                                 hard=False,
                                                 stack=True)
            # this is self.num_pts x 3

            # i need to get to cam coords
            xyz1_cam_b = vox_util.Zoom2Ref(xyz1_search_b.unsqueeze(0),
                                           search_lrt[b:b + 1], Z, Y,
                                           X).squeeze(0)

            # print('xyz0, xyz1', xyz0_cam_b.shape, xyz1_cam_b.shape)
            # cam1_T_cam0_e[b] = utils_track.rigid_transform_3D(xyz0_cam_b, xyz1_cam_b)

            # cam1_T_cam0_e[b] = utils_track.differentiable_rigid_transform_3D(xyz0_cam_b, xyz1_cam_b)
            cam1_T_cam0_e[b] = utils_track.rigid_transform_3D(
                xyz0_cam_b, xyz1_cam_b)

        _, rt_cam0_g = utils_geom.split_lrt(lrt_cam0s[:, 0])
        _, rt_cam1_g = utils_geom.split_lrt(lrt_cam0s[:, 1])
        # these represent ref_T_obj
        cam1_T_cam0_g = torch.matmul(rt_cam1_g, rt_cam0_g.inverse())

        # cam1_T_cam0_e = cam1_T_cam0_g
        lrt_cam1_e = utils_geom.apply_4x4_to_lrtlist(cam1_T_cam0_e,
                                                     lrt_cam0s[:,
                                                               0:1]).squeeze(1)
        # lrt_cam1_g = lrt_cam0s[:,1]

        # _, rt_cam1_e = utils_geom.split_lrt(lrt_cam1_e)
        # _, rt_cam1_g = utils_geom.split_lrt(lrt_cam1_g)

        # let's try the cube loss
        lx, ly, lz = 1.0, 1.0, 1.0
        x = np.array([
            lx / 2., lx / 2., -lx / 2., -lx / 2., lx / 2., lx / 2., -lx / 2.,
            -lx / 2.
        ])
        y = np.array([
            ly / 2., ly / 2., ly / 2., ly / 2., -ly / 2., -ly / 2., -ly / 2.,
            -ly / 2.
        ])
        z = np.array([
            lz / 2., -lz / 2., -lz / 2., lz / 2., lz / 2., -lz / 2., -lz / 2.,
            lz / 2.
        ])
        xyz = np.stack([x, y, z], axis=1)
        # this is 8 x 3
        xyz = torch.from_numpy(xyz).float().cuda()
        xyz = xyz.reshape(1, 8, 3)
        # this is B x 8 x 3

        # xyz_e = utils_geom.apply_4x4(rt_cam1_e, xyz)
        # xyz_g = utils_geom.apply_4x4(rt_cam1_g, xyz)
        xyz_e = utils_geom.apply_4x4(cam1_T_cam0_e, xyz)
        xyz_g = utils_geom.apply_4x4(cam1_T_cam0_g, xyz)

        # print('xyz_e', xyz_e.detach().cpu().numpy())
        # print('xyz_g', xyz_g.detach().cpu().numpy())

        corner_loss = self.smoothl1(xyz_e, xyz_g)
        total_loss = utils_misc.add_loss('robust/corner_loss', total_loss,
                                         corner_loss, hyp.robust_corner_coeff,
                                         summ_writer)

        # rot_e, t_e = utils_geom.split_rt(rt_cam1_e)
        # rot_g, t_g = utils_geom.split_rt(rt_cam1_g)
        rot_e, t_e = utils_geom.split_rt(cam1_T_cam0_e)
        rot_g, t_g = utils_geom.split_rt(cam1_T_cam0_g)
        rx_e, ry_e, rz_e = utils_geom.rotm2eul(rot_e)
        rx_g, ry_g, rz_g = utils_geom.rotm2eul(rot_g)

        rad_e = torch.stack([rx_e, ry_e, rz_e], dim=1)
        rad_g = torch.stack([rx_g, ry_g, rz_g], dim=1)
        deg_e = utils_geom.rad2deg(rad_e)
        deg_g = utils_geom.rad2deg(rad_g)

        r_loss = self.smoothl1(deg_e, deg_g)
        t_loss = self.smoothl1(t_e, t_g)

        total_loss = utils_misc.add_loss('robust/r_loss', total_loss, r_loss,
                                         hyp.robust_r_coeff, summ_writer)
        total_loss = utils_misc.add_loss('robust/t_loss', total_loss, t_loss,
                                         hyp.robust_t_coeff, summ_writer)
        # print('r_loss', r_loss.detach().cpu().numpy())
        # print('t_loss', t_loss.detach().cpu().numpy())

        return lrt_cam1_e, total_loss
コード例 #20
0
def job(data_dir):
    _, image_folder = data_dir
    current_dir = image_folder.split("/")[-1]
    print(current_dir)

    for FOV in [49.5]:
        print("%d FOV" % FOV)
        focal_length = get_focal(FOV)
        focal_length = 2.1875
        intrinsics = get_intrinsic_matrix_np(focal_length, H, W)
        pix_T_cams_ = []
        rgb_camXs_ = []
        xyz_camXs_ = []
        depths = []

        empty_rgb_camXs_ = []
        empty_xyz_camXs_ = []
        empty_depths = []
        camR_T_origin_ = []
        origin_T_camXs_ = []
        # tree_data_dir_save
        for cam_name in all_cameras:
            out_fn = current_dir
            out_fn += '.p'
            out_fn = '%s/%s' % (out_dir, out_fn)
            rgb_file = join(image_folder,
                            "%s_%s.png" % (current_dir, cam_name))
            rgb = imread(rgb_file)
            rgb_camXs_.append(process_rgbs(rgb))
            if empty_table:
                empty_rgb_file = join(empty_data_dir, "%s" % empty_dir,
                                      "%s_%s.png" % (empty_dir, cam_name))
                empty_rgb = imread(empty_rgb_file)
                empty_rgb_camXs_.append(process_rgbs(empty_rgb))
            depth_file = rgb_file.replace("images",
                                          "depth").replace("png", "exr")
            if baxter:
                depth = convert_exr_to_numpy(depth_file)
            else:
                depth = np.array(imageio.imread(depth_file,
                                                format='EXR-FI'))[:, :, 0]

            depth = process_depths(depth)
            depths.append(depth)
            if empty_table:
                empty_depth_file = empty_rgb_file.replace("images",
                                                          "depth").replace(
                                                              "png", "exr")
                if baxter:
                    empty_depth = convert_exr_to_numpy(empty_depth_file)
                else:
                    empty_depth = np.array(
                        imageio.imread(empty_depth_file,
                                       format='EXR-FI'))[:, :, 0]
                empty_depth = process_depths(empty_depth)
                empty_depths.append(empty_depth)
            xyz_camXs_.append(process_xyz(depth, intrinsics).cpu().numpy())
            if empty_table:
                empty_xyz_camXs_.append(
                    process_xyz(empty_depth, intrinsics).cpu().numpy())
            theta, phi = cam_name.split("_")
            theta = float(theta)
            phi = float(phi)
            origin_T_X = np.expand_dims(get_extrensic_np(theta, phi, RADIUS),
                                        0)
            origin_T_camXs_.append(origin_T_X)
            pix_T_cams_.append(intrinsics)
            camR_T_origin_.append(get_camRTorigin())
        if DO_TREE:
            tree_file = "/".join(
                rgb_file.replace("images", "trees").split("/")[:-1]) + ".tree"
            tree = pickle.load(open(tree_file, "rb"))
            ref_T_mem = get_ref_T_mem()

            tree_filename = out_fn.split("/")[-1].replace(".p", ".tree")
            tree_updated_file = join(tree_data_dir, tree_filename)
            # tree_folder_info = join(tree_data_dir_save,"train",tree_filename)

            updated_tree, boxes = gen_list_of_bboxes(tree,
                                                     boxes=[],
                                                     ref_T_mem=ref_T_mem)
            # tree_updated_file = tree_file.replace("trees","trees_updated")
            # tree_updated_file = join(out_base_dir,tree_folder_info)
            pickle.dump(updated_tree, open(tree_updated_file, "wb"))
            cube_coordinates = np.stack(boxes)
        else:
            tree_updated_file = "invalid_tree"
        pix_T_cams = np.stack(pix_T_cams_, axis=0)
        rgb_camXs = np.stack(rgb_camXs_, axis=0)
        xyz_camXs = np.stack(xyz_camXs_, axis=0)
        camR_T_origin = np.stack(camR_T_origin_, axis=0)
        origin_T_camXs = np.squeeze(np.stack(origin_T_camXs_, axis=0),
                                    axis=1).astype(np.float32)
        depths = np.stack(depths, axis=0)
        # st()
        if empty_table:
            empty_rgb_camXs = np.stack(empty_rgb_camXs_, axis=0)
            empty_xyz_camXs = np.stack(empty_xyz_camXs_, axis=0)
            empty_depths = np.stack(empty_depths, axis=0)
        assert origin_T_camXs.shape == (NUM_VIEWS, 4, 4)
        assert rgb_camXs.dtype == np.uint8
        assert pix_T_cams.shape == (NUM_VIEWS, 4, 4)
        xyz_camXs = np.squeeze(xyz_camXs, 1)
        rgb_camXs_raw = rgb_camXs
        xyz_camXs_raw = xyz_camXs
        depths_raw = depths
        origin_T_camXs_raw = origin_T_camXs
        camR_T_origin_raw = camR_T_origin
        pix_T_cams_raw = pix_T_cams
        if empty_table:
            empty_rgb_camXs_raw = empty_rgb_camXs
            empty_xyz_camXs_raw = empty_xyz_camXs
            empty_depths_raw = empty_depths
            empty_xyz_camXs_raw = np.squeeze(empty_xyz_camXs_raw, 1)

        comptype = "GZIP"
        if VISUALIZE:
            mkdir("preprocess_vis/dump_npy_vis")
            ax_points = None
            origin_T_camXs_selected = []
            for cam_num in random.sample(list(range(0, NUM_VIEWS)), 5):
                pix_to_cam_current = pix_T_cams[cam_num]
                rgb_current = rgb_camXs[cam_num]
                xyz_camX_current = xyz_camXs[cam_num]
                depth_current = depths[cam_num]
                origin_T_camXs_current = origin_T_camXs[cam_num]
                origin_T_camXs_selected.append(origin_T_camXs_current)
                xyz_origin = utils_geom.apply_4x4(
                    torch.tensor(origin_T_camXs_current, dtype=torch.float32),
                    torch.tensor(np.expand_dims(xyz_camX_current, axis=0)))
                camR_T_origin = get_camRTorigin()
                xyz_R = utils_geom.apply_4x4(
                    torch.tensor(camR_T_origin, dtype=torch.float32),
                    torch.tensor(xyz_origin))

                xyz_origin = xyz_origin.cpu().numpy()
                xyz_R = xyz_R.cpu().numpy()

                if not os.path.exists(
                        'preprocess_vis/dump_npy_vis/%s_rgb_cam%d.png' %
                    (current_dir, cam_num)):
                    scipy.misc.imsave(
                        'preprocess_vis/dump_npy_vis/%s_rgb_cam%d.png' %
                        (current_dir, cam_num), rgb_current)
                    scipy.misc.imsave(
                        'preprocess_vis/dump_npy_vis/%s_depth_cam%d.png' %
                        (current_dir, cam_num), depth_current)
                fig, ax_points = utils_pyplot_vis.plot_pointcloud(
                    xyz_R[0, ::10],
                    fig_id=3,
                    ax=ax_points,
                    xlims=[-8.0, 8.0],
                    ylims=[-8.0, 8.0],
                    zlims=[5, 21.0],
                    coord="xright-ydown")
                if DO_TREE:
                    fig, ax_points = utils_pyplot_vis.plot_cube(
                        cube_coordinates, fig=fig, ax=ax_points)
                # fig, ax_points = utils_pyplot_vis.plot_pointcloud(xyz_R[0], fig_id=3, ax=ax_points, xlims = [-10.0, 10.0], ylims = [-10.0, 10.0], zlims=[5, 21.0],coord="xright-ydown")
                # if DO_TREE:
                # 	fig, ax_points = utils_pyplot_vis.plot_cube(cube_coordinates,fig=fig,ax=ax_points)
            # utils.pyplot_vis.plot_cam(tf.concat(origin_T_camXs_selected, 0), fig_id=2, xlims = [-13.0, 13.0], ylims = [-13.0, 13.0], zlims=[-13, 13.0], length=2.0)
            print(cam_num)
            pyplot.show()

        # folder_info
        if empty_table:
            feature = {
                'tree_seq_filename': tree_updated_file,
                'pix_T_cams_raw': pix_T_cams_raw,
                'origin_T_camXs_raw': origin_T_camXs_raw,
                'rgb_camXs_raw': rgb_camXs_raw,
                "camR_T_origin_raw": camR_T_origin_raw,
                # 'depth_camXs_raw': depths_raw,
                'xyz_camXs_raw': xyz_camXs_raw,
                'empty_rgb_camXs_raw': empty_rgb_camXs_raw,
                'empty_xyz_camXs_raw': empty_xyz_camXs_raw,
            }
        else:
            feature = {
                'tree_seq_filename': tree_updated_file,
                'pix_T_cams_raw': pix_T_cams_raw,
                'origin_T_camXs_raw': origin_T_camXs_raw,
                'rgb_camXs_raw': rgb_camXs_raw,
                "camR_T_origin_raw": camR_T_origin_raw,
                'xyz_camXs_raw': xyz_camXs_raw,
            }
        feature_np = feature
        shape_dict = print_feature_shapes(feature)
        pickle.dump(feature_np, open(out_fn, "wb"))
        sys.stdout.write('.')
        sys.stdout.flush()
コード例 #21
0
    def forward(self, feed):
        results = dict()

        if 'log_freq' not in feed.keys():
            feed['log_freq'] = None
        start_time = time.time()

        summ_writer = utils_improc.Summ_writer(writer=feed['writer'],
                                               global_step=feed['global_step'],
                                               set_name=feed['set_name'],
                                               log_freq=feed['log_freq'],
                                               fps=8)
        writer = feed['writer']
        global_step = feed['global_step']

        total_loss = torch.tensor(0.0).cuda()
        __p = lambda x: utils_basic.pack_seqdim(x, B)
        __u = lambda x: utils_basic.unpack_seqdim(x, B)

        __pb = lambda x: utils_basic.pack_boxdim(x, hyp.N)
        __ub = lambda x: utils_basic.unpack_boxdim(x, hyp.N)
        if hyp.aug_object_ent_dis:
            __pb_a = lambda x: utils_basic.pack_boxdim(
                x, hyp.max_obj_aug + hyp.max_obj_aug_dis)
            __ub_a = lambda x: utils_basic.unpack_boxdim(
                x, hyp.max_obj_aug + hyp.max_obj_aug_dis)
        else:
            __pb_a = lambda x: utils_basic.pack_boxdim(x, hyp.max_obj_aug)
            __ub_a = lambda x: utils_basic.unpack_boxdim(x, hyp.max_obj_aug)

        B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
        PH, PW = hyp.PH, hyp.PW
        K = hyp.K
        BOX_SIZE = hyp.BOX_SIZE
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
        Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
        D = 9

        tids = torch.from_numpy(np.reshape(np.arange(B * N), [B, N]))

        rgb_camXs = feed["rgb_camXs_raw"]
        pix_T_cams = feed["pix_T_cams_raw"]
        camRs_T_origin = feed["camR_T_origin_raw"]
        origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
        origin_T_camXs = feed["origin_T_camXs_raw"]
        camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
        camRs_T_camXs = __u(
            torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                         __p(origin_T_camXs)))
        camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
        camX0_T_camRs = camXs_T_camRs[:, 0]
        camX1_T_camRs = camXs_T_camRs[:, 1]

        camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)

        xyz_camXs = feed["xyz_camXs_raw"]
        depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
            __p(pix_T_cams), __p(xyz_camXs), H, W)
        dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_,
                                                       __p(pix_T_cams))

        xyz_camRs = __u(
            utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
        xyz_camX0s = __u(
            utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs)))

        occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))

        occXs_to_Rs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, occXs)
        occXs_to_Rs_45 = cross_corr.rotate_tensor_along_y_axis(occXs_to_Rs, 45)
        occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2))
        occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2))
        occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2))

        unpXs = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X,
                                           __p(pix_T_cams)))

        unpXs_half = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z2, Y2, X2,
                                           __p(pix_T_cams)))

        unpX0s_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camX0_T_camXs)))))

        unpRs = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z, Y, X,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camRs_T_camXs)))))

        unpRs_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camRs_T_camXs)))))

        dense_xyz_camRs_ = utils_geom.apply_4x4(__p(camRs_T_camXs),
                                                dense_xyz_camXs_)
        inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y,
                                                X).float()
        inbound_camXs_ = torch.reshape(inbound_camXs_, [B * S, 1, H, W])

        depth_camXs = __u(depth_camXs_)
        valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_)

        summ_writer.summ_oneds('2D_inputs/depth_camXs',
                               torch.unbind(depth_camXs, dim=1),
                               maxdepth=21.0)
        summ_writer.summ_oneds('2D_inputs/valid_camXs',
                               torch.unbind(valid_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs',
                              torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1))
        summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1),
                              torch.unbind(occXs, dim=1))

        occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X))

        if hyp.do_eval_boxes:
            if hyp.dataset_name == "clevr_vqa":
                gt_boxes_origin_corners = feed['gt_box']
                gt_scores_origin = feed['gt_scores'].detach().cpu().numpy()
                classes = feed['classes']
                scores = gt_scores_origin
                tree_seq_filename = feed['tree_seq_filename']
                gt_boxes_origin = nlu.get_ends_of_corner(
                    gt_boxes_origin_corners)
                gt_boxes_origin_end = torch.reshape(gt_boxes_origin,
                                                    [hyp.B, hyp.N, 2, 3])
                gt_boxes_origin_theta = nlu.get_alignedboxes2thetaformat(
                    gt_boxes_origin_end)
                gt_boxes_origin_corners = utils_geom.transform_boxes_to_corners(
                    gt_boxes_origin_theta)
                gt_boxesR_corners = __ub(
                    utils_geom.apply_4x4(camRs_T_origin[:, 0],
                                         __pb(gt_boxes_origin_corners)))
                gt_boxesR_theta = utils_geom.transform_corners_to_boxes(
                    gt_boxesR_corners)
                gt_boxesR_end = nlu.get_ends_of_corner(gt_boxesR_corners)

            else:
                tree_seq_filename = feed['tree_seq_filename']
                tree_filenames = [
                    join(hyp.root_dataset, i) for i in tree_seq_filename
                    if i != "invalid_tree"
                ]
                invalid_tree_filenames = [
                    join(hyp.root_dataset, i) for i in tree_seq_filename
                    if i == "invalid_tree"
                ]
                num_empty = len(invalid_tree_filenames)
                trees = [pickle.load(open(i, "rb")) for i in tree_filenames]

                len_valid = len(trees)
                if len_valid > 0:
                    gt_boxesR, scores, classes = nlu.trees_rearrange(trees)

                if num_empty > 0:
                    gt_boxesR = np.concatenate([
                        gt_boxesR, empty_gt_boxesR
                    ]) if len_valid > 0 else empty_gt_boxesR
                    scores = np.concatenate([
                        scores, empty_scores
                    ]) if len_valid > 0 else empty_scores
                    classes = np.concatenate([
                        classes, empty_classes
                    ]) if len_valid > 0 else empty_classes

                gt_boxesR = torch.from_numpy(
                    gt_boxesR).cuda().float()  # torch.Size([2, 3, 6])
                gt_boxesR_end = torch.reshape(gt_boxesR, [hyp.B, hyp.N, 2, 3])
                gt_boxesR_theta = nlu.get_alignedboxes2thetaformat(
                    gt_boxesR_end)  #torch.Size([2, 3, 9])
                gt_boxesR_corners = utils_geom.transform_boxes_to_corners(
                    gt_boxesR_theta)

            class_names_ex_1 = "_".join(classes[0])
            summ_writer.summ_text('eval_boxes/class_names', class_names_ex_1)

            gt_boxesRMem_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z2, Y2, X2))
            gt_boxesRMem_end = nlu.get_ends_of_corner(gt_boxesRMem_corners)

            gt_boxesRMem_theta = utils_geom.transform_corners_to_boxes(
                gt_boxesRMem_corners)
            gt_boxesRUnp_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z, Y, X))
            gt_boxesRUnp_end = nlu.get_ends_of_corner(gt_boxesRUnp_corners)

            gt_boxesX0_corners = __ub(
                utils_geom.apply_4x4(camX0_T_camRs, __pb(gt_boxesR_corners)))
            gt_boxesX0Mem_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesX0_corners), Z2, Y2, X2))

            gt_boxesX0Mem_theta = utils_geom.transform_corners_to_boxes(
                gt_boxesX0Mem_corners)

            gt_boxesX0Mem_end = nlu.get_ends_of_corner(gt_boxesX0Mem_corners)
            gt_boxesX0_end = nlu.get_ends_of_corner(gt_boxesX0_corners)

            gt_cornersX0_pix = __ub(
                utils_geom.apply_pix_T_cam(pix_T_cams[:, 0],
                                           __pb(gt_boxesX0_corners)))

            rgb_camX0 = rgb_camXs[:, 0]
            rgb_camX1 = rgb_camXs[:, 1]

            summ_writer.summ_box_by_corners('eval_boxes/gt_boxescamX0',
                                            rgb_camX0, gt_boxesX0_corners,
                                            torch.from_numpy(scores), tids,
                                            pix_T_cams[:, 0])
            unps_vis = utils_improc.get_unps_vis(unpX0s_half, occX0s_half)
            unp_vis = torch.mean(unps_vis, dim=1)
            unps_visRs = utils_improc.get_unps_vis(unpRs_half, occRs_half)
            unp_visRs = torch.mean(unps_visRs, dim=1)
            unps_visRs_full = utils_improc.get_unps_vis(unpRs, occRs)
            unp_visRs_full = torch.mean(unps_visRs_full, dim=1)
            summ_writer.summ_box_mem_on_unp('eval_boxes/gt_boxesR_mem',
                                            unp_visRs, gt_boxesRMem_end,
                                            scores, tids)

            unpX0s_half = torch.mean(unpX0s_half, dim=1)
            unpX0s_half = nlu.zero_out(unpX0s_half, gt_boxesX0Mem_end, scores)

            occX0s_half = torch.mean(occX0s_half, dim=1)
            occX0s_half = nlu.zero_out(occX0s_half, gt_boxesX0Mem_end, scores)

            summ_writer.summ_unp('3D_inputs/unpX0s', unpX0s_half, occX0s_half)

        if hyp.do_feat:
            featXs_input = torch.cat([occXs, occXs * unpXs], dim=2)
            featXs_input_ = __p(featXs_input)

            freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), __p(occXs_half))
            freeXs = __u(freeXs_)
            visXs = torch.clamp(occXs_half + freeXs, 0.0, 1.0)
            mask_ = None

            if (type(mask_) != type(None)):
                assert (list(mask_.shape)[2:5] == list(
                    featXs_input_.shape)[2:5])

            featXs_, feat_loss = self.featnet(featXs_input_,
                                              summ_writer,
                                              mask=__p(occXs))  #mask_)
            total_loss += feat_loss

            validXs = torch.ones_like(visXs)
            _validX00 = validXs[:, 0:1]
            _validX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:],
                                                     validXs[:, 1:])
            validX0s = torch.cat([_validX00, _validX01], dim=1)
            validRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, validXs)
            visRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, visXs)

            featXs = __u(featXs_)
            _featX00 = featXs[:, 0:1]
            _featX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:],
                                                    featXs[:, 1:])
            featX0s = torch.cat([_featX00, _featX01], dim=1)

            emb3D_e = torch.mean(featX0s[:, 1:], dim=1)
            vis3D_e_R = torch.max(visRs[:, 1:], dim=1)[0]
            emb3D_g = featX0s[:, 0]
            vis3D_g_R = visRs[:, 0]
            validR_combo = torch.min(validRs, dim=1).values

            summ_writer.summ_feats('3D_feats/featXs_input',
                                   torch.unbind(featXs_input, dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/featXs_output',
                                   torch.unbind(featXs, dim=1),
                                   valids=torch.unbind(validXs, dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/featX0s_output',
                                   torch.unbind(featX0s, dim=1),
                                   valids=torch.unbind(
                                       torch.ones_like(validRs), dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/validRs',
                                   torch.unbind(validRs, dim=1),
                                   pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_e_R', vis3D_e_R, pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_g_R', vis3D_g_R, pca=False)

        if hyp.do_munit:
            object_classes, filenames = nlu.create_object_classes(
                classes, [tree_seq_filename, tree_seq_filename], scores)
            if hyp.do_munit_fewshot:
                emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e)
                emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g)
                emb3D_R = emb3D_e_R
                emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors(
                    [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end,
                    scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE])
                emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2
                content, style = self.munitnet.net.gen_a.encode(emb3D_R_object)
                objects_taken, _ = self.munitnet.net.gen_a.decode(
                    content, style)
                styles = style
                contents = content
            elif hyp.do_3d_style_munit:
                emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e)
                emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g)
                emb3D_R = emb3D_e_R
                # st()
                emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors(
                    [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end,
                    scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE])
                emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2

                camX1_T_R = camXs_T_camRs[:, 1]
                camX0_T_R = camXs_T_camRs[:, 0]
                assert hyp.B == 2
                assert emb3D_e_R_object.shape[0] == 2
                munit_loss, sudo_input_0, sudo_input_1, recon_input_0, recon_input_1, sudo_input_0_cycle, sudo_input_1_cycle, styles, contents, adin = self.munitnet(
                    emb3D_R_object[0:1], emb3D_R_object[1:2])

                if hyp.store_content_style_range:
                    if self.max_content == None:
                        self.max_content = torch.zeros_like(
                            contents[0][0]).cuda() - 100000000
                    if self.min_content == None:
                        self.min_content = torch.zeros_like(
                            contents[0][0]).cuda() + 100000000
                    if self.max_style == None:
                        self.max_style = torch.zeros_like(
                            styles[0][0]).cuda() - 100000000
                    if self.min_style == None:
                        self.min_style = torch.zeros_like(
                            styles[0][0]).cuda() + 100000000
                    self.max_content = torch.max(
                        torch.max(self.max_content, contents[0][0]),
                        contents[1][0])
                    self.min_content = torch.min(
                        torch.min(self.min_content, contents[0][0]),
                        contents[1][0])
                    self.max_style = torch.max(
                        torch.max(self.max_style, styles[0][0]), styles[1][0])
                    self.min_style = torch.min(
                        torch.min(self.min_style, styles[0][0]), styles[1][0])

                    data_to_save = {
                        'max_content': self.max_content.cpu().numpy(),
                        'min_content': self.min_content.cpu().numpy(),
                        'max_style': self.max_style.cpu().numpy(),
                        'min_style': self.min_style.cpu().numpy()
                    }
                    with open('content_style_range.p', 'wb') as f:
                        pickle.dump(data_to_save, f)
                elif hyp.is_contrastive_examples:
                    if hyp.normalize_contrast:
                        content0 = (contents[0] - self.min_content) / (
                            self.max_content - self.min_content + 1e-5)
                        content1 = (contents[1] - self.min_content) / (
                            self.max_content - self.min_content + 1e-5)
                        style0 = (styles[0] - self.min_style) / (
                            self.max_style - self.min_style + 1e-5)
                        style1 = (styles[1] - self.min_style) / (
                            self.max_style - self.min_style + 1e-5)
                    else:
                        content0 = contents[0]
                        content1 = contents[1]
                        style0 = styles[0]
                        style1 = styles[1]

                    # euclid_dist_content = torch.sum(torch.sqrt((content0 - content1)**2))/torch.prod(torch.tensor(content0.shape))
                    # euclid_dist_style = torch.sum(torch.sqrt((style0-style1)**2))/torch.prod(torch.tensor(style0.shape))
                    euclid_dist_content = (content0 - content1).norm(2) / (
                        content0.numel())
                    euclid_dist_style = (style0 -
                                         style1).norm(2) / (style0.numel())

                    content_0_pooled = torch.mean(
                        content0.reshape(list(content0.shape[:2]) + [-1]),
                        dim=-1)
                    content_1_pooled = torch.mean(
                        content1.reshape(list(content1.shape[:2]) + [-1]),
                        dim=-1)

                    euclid_dist_content_pooled = (content_0_pooled -
                                                  content_1_pooled).norm(2) / (
                                                      content_0_pooled.numel())

                    content_0_normalized = content0 / content0.norm()
                    content_1_normalized = content1 / content1.norm()

                    style_0_normalized = style0 / style0.norm()
                    style_1_normalized = style1 / style1.norm()

                    content_0_pooled_normalized = content_0_pooled / content_0_pooled.norm(
                    )
                    content_1_pooled_normalized = content_1_pooled / content_1_pooled.norm(
                    )

                    cosine_dist_content = torch.sum(content_0_normalized *
                                                    content_1_normalized)
                    cosine_dist_style = torch.sum(style_0_normalized *
                                                  style_1_normalized)
                    cosine_dist_content_pooled = torch.sum(
                        content_0_pooled_normalized *
                        content_1_pooled_normalized)

                    print("euclid dist [content, pooled-content, style]: ",
                          euclid_dist_content, euclid_dist_content_pooled,
                          euclid_dist_style)
                    print("cosine sim [content, pooled-content, style]: ",
                          cosine_dist_content, cosine_dist_content_pooled,
                          cosine_dist_style)

            if hyp.run_few_shot_on_munit:
                if (global_step % 300) == 1 or (global_step % 300) == 0:
                    wrong = False
                    try:
                        precision_style = float(self.tp_style) / self.all_style
                        precision_content = float(
                            self.tp_content) / self.all_content
                    except ZeroDivisionError:
                        wrong = True

                    if not wrong:
                        summ_writer.summ_scalar(
                            'precision/unsupervised_precision_style',
                            precision_style)
                        summ_writer.summ_scalar(
                            'precision/unsupervised_precision_content',
                            precision_content)
                        # st()
                    self.embed_list_style = defaultdict(lambda: [])
                    self.embed_list_content = defaultdict(lambda: [])
                    self.tp_style = 0
                    self.all_style = 0
                    self.tp_content = 0
                    self.all_content = 0
                    self.check = False
                elif not self.check and not nlu.check_fill_dict(
                        self.embed_list_content, self.embed_list_style):
                    print("Filling \n")
                    for index, class_val in enumerate(object_classes):

                        if hyp.dataset_name == "clevr_vqa":
                            class_val_content, class_val_style = class_val.split(
                                "/")
                        else:
                            class_val_content, class_val_style = [
                                class_val.split("/")[0],
                                class_val.split("/")[0]
                            ]

                        print(len(self.embed_list_style.keys()), "style class",
                              len(self.embed_list_content), "content class",
                              self.embed_list_content.keys())
                        if len(self.embed_list_style[class_val_style]
                               ) < hyp.few_shot_nums:
                            self.embed_list_style[class_val_style].append(
                                styles[index].squeeze())
                        if len(self.embed_list_content[class_val_content]
                               ) < hyp.few_shot_nums:
                            if hyp.avg_3d:
                                content_val = contents[index]
                                content_val = torch.mean(content_val.reshape(
                                    [content_val.shape[1], -1]),
                                                         dim=-1)
                                # st()
                                self.embed_list_content[
                                    class_val_content].append(content_val)
                            else:
                                self.embed_list_content[
                                    class_val_content].append(
                                        contents[index].reshape([-1]))
                else:
                    self.check = True
                    try:
                        print(float(self.tp_content) / self.all_content)
                        print(float(self.tp_style) / self.all_style)
                    except Exception as e:
                        pass
                    average = True
                    if average:
                        for key, val in self.embed_list_style.items():
                            if isinstance(val, type([])):
                                self.embed_list_style[key] = torch.mean(
                                    torch.stack(val, dim=0), dim=0)

                        for key, val in self.embed_list_content.items():
                            if isinstance(val, type([])):
                                self.embed_list_content[key] = torch.mean(
                                    torch.stack(val, dim=0), dim=0)
                    else:
                        for key, val in self.embed_list_style.items():
                            if isinstance(val, type([])):
                                self.embed_list_style[key] = torch.stack(val,
                                                                         dim=0)

                        for key, val in self.embed_list_content.items():
                            if isinstance(val, type([])):
                                self.embed_list_content[key] = torch.stack(
                                    val, dim=0)
                    for index, class_val in enumerate(object_classes):
                        class_val = class_val
                        if hyp.dataset_name == "clevr_vqa":
                            class_val_content, class_val_style = class_val.split(
                                "/")
                        else:
                            class_val_content, class_val_style = [
                                class_val.split("/")[0],
                                class_val.split("/")[0]
                            ]

                        style_val = styles[index].squeeze().unsqueeze(0)
                        if not average:
                            embed_list_val_style = torch.cat(list(
                                self.embed_list_style.values()),
                                                             dim=0)
                            embed_list_key_style = list(
                                np.repeat(
                                    np.expand_dims(
                                        list(self.embed_list_style.keys()), 1),
                                    hyp.few_shot_nums, 1).reshape([-1]))
                        else:
                            embed_list_val_style = torch.stack(list(
                                self.embed_list_style.values()),
                                                               dim=0)
                            embed_list_key_style = list(
                                self.embed_list_style.keys())
                        embed_list_val_style = utils_basic.l2_normalize(
                            embed_list_val_style, dim=1).permute(1, 0)
                        style_val = utils_basic.l2_normalize(style_val, dim=1)
                        scores_styles = torch.matmul(style_val,
                                                     embed_list_val_style)
                        index_key = torch.argmax(scores_styles,
                                                 dim=1).squeeze()
                        selected_class_style = embed_list_key_style[index_key]
                        self.styles_prediction[class_val_style].append(
                            selected_class_style)
                        if class_val_style == selected_class_style:
                            self.tp_style += 1
                        self.all_style += 1

                        if hyp.avg_3d:
                            content_val = contents[index]
                            content_val = torch.mean(content_val.reshape(
                                [content_val.shape[1], -1]),
                                                     dim=-1).unsqueeze(0)
                        else:
                            content_val = contents[index].reshape(
                                [-1]).unsqueeze(0)
                        if not average:
                            embed_list_val_content = torch.cat(list(
                                self.embed_list_content.values()),
                                                               dim=0)
                            embed_list_key_content = list(
                                np.repeat(
                                    np.expand_dims(
                                        list(self.embed_list_content.keys()),
                                        1), hyp.few_shot_nums,
                                    1).reshape([-1]))
                        else:
                            embed_list_val_content = torch.stack(list(
                                self.embed_list_content.values()),
                                                                 dim=0)
                            embed_list_key_content = list(
                                self.embed_list_content.keys())
                        embed_list_val_content = utils_basic.l2_normalize(
                            embed_list_val_content, dim=1).permute(1, 0)
                        content_val = utils_basic.l2_normalize(content_val,
                                                               dim=1)
                        scores_content = torch.matmul(content_val,
                                                      embed_list_val_content)
                        index_key = torch.argmax(scores_content,
                                                 dim=1).squeeze()
                        selected_class_content = embed_list_key_content[
                            index_key]
                        self.content_prediction[class_val_content].append(
                            selected_class_content)
                        if class_val_content == selected_class_content:
                            self.tp_content += 1

                        self.all_content += 1
            # st()
            munit_loss = hyp.munit_loss_weight * munit_loss

            recon_input_obj = torch.cat([recon_input_0, recon_input_1], dim=0)
            recon_emb3D_R = nlu.update_scene_with_objects(
                emb3D_R, recon_input_obj, gt_boxesRMem_end, scores)

            sudo_input_obj = torch.cat([sudo_input_0, sudo_input_1], dim=0)
            styled_emb3D_R = nlu.update_scene_with_objects(
                emb3D_R, sudo_input_obj, gt_boxesRMem_end, scores)

            styled_emb3D_e_X1 = utils_vox.apply_4x4_to_vox(
                camX1_T_R, styled_emb3D_R)
            styled_emb3D_e_X0 = utils_vox.apply_4x4_to_vox(
                camX0_T_R, styled_emb3D_R)

            emb3D_e_X1 = utils_vox.apply_4x4_to_vox(camX1_T_R, recon_emb3D_R)
            emb3D_e_X0 = utils_vox.apply_4x4_to_vox(camX0_T_R, recon_emb3D_R)

            emb3D_e_X1_og = utils_vox.apply_4x4_to_vox(camX1_T_R, emb3D_R)
            emb3D_e_X0_og = utils_vox.apply_4x4_to_vox(camX0_T_R, emb3D_R)

            emb3D_R_aug_diff = torch.abs(emb3D_R - recon_emb3D_R)

            summ_writer.summ_feat(f'aug_feat/og', emb3D_R)
            summ_writer.summ_feat(f'aug_feat/og_gen', recon_emb3D_R)
            summ_writer.summ_feat(f'aug_feat/og_aug_diff', emb3D_R_aug_diff)

            if hyp.cycle_style_view_loss:
                sudo_input_obj_cycle = torch.cat(
                    [sudo_input_0_cycle, sudo_input_1_cycle], dim=0)
                styled_emb3D_R_cycle = nlu.update_scene_with_objects(
                    emb3D_R, sudo_input_obj_cycle, gt_boxesRMem_end, scores)

                styled_emb3D_e_X0_cycle = utils_vox.apply_4x4_to_vox(
                    camX0_T_R, styled_emb3D_R_cycle)
                styled_emb3D_e_X1_cycle = utils_vox.apply_4x4_to_vox(
                    camX1_T_R, styled_emb3D_R_cycle)
            summ_writer.summ_scalar('munit_loss', munit_loss.cpu().item())
            total_loss += munit_loss

        if hyp.do_occ and hyp.occ_do_cheap:
            occX0_sup, freeX0_sup, _, freeXs = utils_vox.prep_occs_supervision(
                camX0_T_camXs, xyz_camXs, Z2, Y2, X2, agg=True)

            summ_writer.summ_occ('occ_sup/occ_sup', occX0_sup)
            summ_writer.summ_occ('occ_sup/free_sup', freeX0_sup)
            summ_writer.summ_occs('occ_sup/freeXs_sup',
                                  torch.unbind(freeXs, dim=1))
            summ_writer.summ_occs('occ_sup/occXs_sup',
                                  torch.unbind(occXs_half, dim=1))

            occ_loss, occX0s_pred_ = self.occnet(
                torch.mean(featX0s[:, 1:], dim=1), occX0_sup, freeX0_sup,
                torch.max(validX0s[:, 1:], dim=1)[0], summ_writer)
            occX0s_pred = __u(occX0s_pred_)
            total_loss += occ_loss

        if hyp.do_view:
            assert (hyp.do_feat)
            PH, PW = hyp.PH, hyp.PW
            sy = float(PH) / float(hyp.H)
            sx = float(PW) / float(hyp.W)
            assert (sx == 0.5)  # else we need a fancier downsampler
            assert (sy == 0.5)
            projpix_T_cams = __u(
                utils_geom.scale_intrinsics(__p(pix_T_cams), sx, sy))
            # st()

            if hyp.do_munit:
                feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    emb3D_e_X1,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                feat_projX00_og = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    emb3D_e_X1_og,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                # only for checking the style
                styled_feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    styled_emb3D_e_X1,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                if hyp.cycle_style_view_loss:
                    styled_feat_projX00_cycle = utils_vox.apply_pixX_T_memR_to_voxR(
                        projpix_T_cams[:, 0],
                        camX0_T_camXs[:, 1],
                        styled_emb3D_e_X1_cycle,  # use feat1 to predict rgb0
                        hyp.view_depth,
                        PH,
                        PW)

            else:
                feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    featXs[:, 1],  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)
            rgb_X00 = utils_basic.downsample(rgb_camXs[:, 0], 2)
            rgb_X01 = utils_basic.downsample(rgb_camXs[:, 1], 2)
            valid_X00 = utils_basic.downsample(valid_camXs[:, 0], 2)

            view_loss, rgb_e, emb2D_e = self.viewnet(feat_projX00, rgb_X00,
                                                     valid_X00, summ_writer,
                                                     "rgb")

            if hyp.do_munit:
                _, rgb_e, emb2D_e = self.viewnet(feat_projX00_og, rgb_X00,
                                                 valid_X00, summ_writer,
                                                 "rgb_og")
            if hyp.do_munit:
                styled_view_loss, styled_rgb_e, styled_emb2D_e = self.viewnet(
                    styled_feat_projX00, rgb_X00, valid_X00, summ_writer,
                    "recon_style")
                if hyp.cycle_style_view_loss:
                    styled_view_loss_cycle, styled_rgb_e_cycle, styled_emb2D_e_cycle = self.viewnet(
                        styled_feat_projX00_cycle, rgb_X00, valid_X00,
                        summ_writer, "recon_style_cycle")

                rgb_input_1 = torch.cat(
                    [rgb_X01[1], rgb_X01[0], styled_rgb_e[0]], dim=2)
                rgb_input_2 = torch.cat(
                    [rgb_X01[0], rgb_X01[1], styled_rgb_e[1]], dim=2)
                complete_vis = torch.cat([rgb_input_1, rgb_input_2], dim=1)
                summ_writer.summ_rgb('munit/munit_recons_vis',
                                     complete_vis.unsqueeze(0))

            if not hyp.do_munit:
                total_loss += view_loss
            else:
                if hyp.basic_view_loss:
                    total_loss += view_loss
                if hyp.style_view_loss:
                    total_loss += styled_view_loss
                if hyp.cycle_style_view_loss:
                    total_loss += styled_view_loss_cycle

        summ_writer.summ_scalar('loss', total_loss.cpu().item())

        if hyp.save_embed_tsne:
            for index, class_val in enumerate(object_classes):
                class_val_content, class_val_style = class_val.split("/")
                style_val = styles[index].squeeze().unsqueeze(0)
                self.cluster_pool.update(style_val, [class_val_style])
                print(self.cluster_pool.num)

            if self.cluster_pool.is_full():
                embeds, classes = self.cluster_pool.fetch()
                with open("offline_cluster" + '/%st.txt' % 'classes',
                          'w') as f:
                    for index, embed in enumerate(classes):
                        class_val = classes[index]
                        f.write("%s\n" % class_val)
                f.close()
                with open("offline_cluster" + '/%st.txt' % 'embeddings',
                          'w') as f:
                    for index, embed in enumerate(embeds):
                        # embed = utils_basic.l2_normalize(embed,dim=0)
                        print("writing {} embed".format(index))
                        embed_l_s = [str(i) for i in embed.tolist()]
                        embed_str = '\t'.join(embed_l_s)
                        f.write("%s\n" % embed_str)
                f.close()
                st()

        return total_loss, results
コード例 #22
0
    def __getitem__(self, index):
        if hyp.dataset_name == 'kitti' or hyp.dataset_name == 'clevr' or hyp.dataset_name == 'real' or hyp.dataset_name == "bigbird" or hyp.dataset_name == "carla" or hyp.dataset_name == "carla_mix" or hyp.dataset_name == "replica" or hyp.dataset_name == "clevr_vqa" or hyp.dataset_name == "carla_det":
            # print(index)
            # st()
            filename = self.records[index]
            d = pickle.load(open(filename, "rb"))
            d = dict(d)

            d_empty = pickle.load(open(self.empty_scene, "rb"))
            d_empty = dict(d_empty)
            # st()
        # elif hyp.dataset_name=="carla":
        #     filename = self.records[index]
        #     d = np.load(filename)
        #     d = dict(d)

        #     d['rgb_camXs_raw'] = d['rgb_camXs']
        #     d['pix_T_cams_raw'] = d['pix_T_cams']
        #     d['tree_seq_filename'] = "dummy_tree_filename"
        #     d['origin_T_camXs_raw'] = d['origin_T_camXs']
        #     d['camR_T_origin_raw'] = utils_geom.safe_inverse(torch.from_numpy(d['origin_T_camRs'])).numpy()
        #     d['xyz_camXs_raw'] = d['xyz_camXs']

        else:
            assert (False)  # reader not ready yet

        if hyp.do_empty:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
                'empty_rgb_camXs_raw',
                'empty_xyz_camXs_raw',
            ]
        else:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
            ]

        if hyp.use_gt_occs:
            __p = lambda x: utils_basic.pack_seqdim(x, 1)
            __u = lambda x: utils_basic.unpack_seqdim(x, 1)

            B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
            PH, PW = hyp.PH, hyp.PW
            K = hyp.K
            BOX_SIZE = hyp.BOX_SIZE
            Z, Y, X = hyp.Z, hyp.Y, hyp.X
            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
            D = 9
            pix_T_cams = torch.from_numpy(
                d["pix_T_cams_raw"]).unsqueeze(0).cuda().to(torch.float)
            camRs_T_origin = torch.from_numpy(
                d["camR_T_origin_raw"]).unsqueeze(0).cuda().to(torch.float)
            origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
            origin_T_camXs = torch.from_numpy(
                d["origin_T_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
            camRs_T_camXs = __u(
                torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                             __p(origin_T_camXs)))
            camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
            camX0_T_camRs = camXs_T_camRs[:, 0]
            camX1_T_camRs = camXs_T_camRs[:, 1]
            camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)
            xyz_camXs = torch.from_numpy(
                d["xyz_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            xyz_camRs = __u(
                utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
            depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
                __p(pix_T_cams), __p(xyz_camXs), H, W)
            dense_xyz_camXs_ = utils_geom.depth2pointcloud(
                depth_camXs_, __p(pix_T_cams))
            occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))
            occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2,
                                                    X2))
            occRs_half = torch.max(occRs_half, dim=1).values.squeeze(0)
            occ_complete = occRs_half.cpu().numpy()

        # if hyp.do_time_flip:
        #     d = random_time_flip_single(d,item_names)
        # if the sequence length > 2, select S frames
        # filename = d['raw_seq_filename']
        original_filename = filename
        original_filename_empty = self.empty_scene

        # st()
        if hyp.dataset_name == "clevr_vqa":
            d['tree_seq_filename'] = "temp"
            pix_T_cams = d['pix_T_cams_raw']
            num_cams = pix_T_cams.shape[0]
            # padding_1 = torch.zeros([num_cams,1,3])
            # padding_2 = torch.zeros([num_cams,4,1])
            # padding_2[:,3] = 1.0
            # st()
            # pix_T_cams = torch.cat([pix_T_cams,padding_1],dim=1)
            # pix_T_cams = torch.cat([pix_T_cams,padding_2],dim=2)
            # st()
            shape_name = d['shape_list']
            color_name = d['color_list']
            material_name = d['material_list']
            all_name = []
            all_style = []
            for index in range(len(shape_name)):
                name = shape_name[index] + "/" + color_name[
                    index] + "_" + material_name[index]
                style_name = color_name[index] + "_" + material_name[index]
                all_name.append(name)
                all_style.append(style_name)

            # st()

            if hyp.do_shape:
                class_name = shape_name
            elif hyp.do_color:
                class_name = color_name
            elif hyp.do_material:
                class_name = material_name
            elif hyp.do_style:
                class_name = all_style
            else:
                class_name = all_name

            object_category = class_name
            bbox_origin = d['bbox_origin']
            # bbox_origin = torch.cat([bbox_origin],dim=0)
            # object_category = object_category
            bbox_origin_empty = np.zeros_like(bbox_origin)
            object_category_empty = ['0']
        # st()
        if not hyp.dataset_name == "clevr_vqa":
            filename = d['tree_seq_filename']
            filename_empty = d_empty['tree_seq_filename']
        if hyp.fixed_view:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)
            d_empty, indexes_empty = specific_select_single_empty(
                d_empty,
                item_names,
                d['origin_T_camXs_raw'],
                num_samples=hyp.S)

        filename_g = "/".join([original_filename, str(indexes[0])])
        filename_e = "/".join([original_filename, str(indexes[1])])

        filename_g_empty = "/".join([original_filename_empty, str(indexes[0])])
        filename_e_empty = "/".join([original_filename_empty, str(indexes[1])])

        rgb_camXs = d['rgb_camXs_raw']
        rgb_camXs_empty = d_empty['rgb_camXs_raw']
        # move channel dim inward, like pytorch wants
        # rgb_camRs = np.transpose(rgb_camRs, axes=[0, 3, 1, 2])
        rgb_camXs = np.transpose(rgb_camXs, axes=[0, 3, 1, 2])
        rgb_camXs = rgb_camXs[:, :3]
        rgb_camXs = utils_improc.preprocess_color(rgb_camXs)

        rgb_camXs_empty = np.transpose(rgb_camXs_empty, axes=[0, 3, 1, 2])
        rgb_camXs_empty = rgb_camXs_empty[:, :3]
        rgb_camXs_empty = utils_improc.preprocess_color(rgb_camXs_empty)

        if hyp.dataset_name == "clevr_vqa":
            num_boxes = bbox_origin.shape[0]
            bbox_origin = np.array(bbox_origin)
            score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
            bbox_origin = np.pad(bbox_origin,
                                 [[0, hyp.N - num_boxes], [0, 0], [0, 0]])
            object_category = np.pad(object_category, [[0, hyp.N - num_boxes]],
                                     lambda x, y, z, m: "0")
            object_category_empty = np.pad(object_category_empty,
                                           [[0, hyp.N - 1]],
                                           lambda x, y, z, m: "0")

            # st()
            score_empty = np.zeros_like(score)
            bbox_origin_empty = np.zeros_like(bbox_origin)
            d['gt_box'] = np.stack(
                [bbox_origin.astype(np.float32), bbox_origin_empty])
            d['gt_scores'] = np.stack([score.astype(np.float32), score_empty])
            try:
                d['classes'] = np.stack(
                    [object_category, object_category_empty]).tolist()
            except Exception as e:
                st()

        d['rgb_camXs_raw'] = np.stack([rgb_camXs, rgb_camXs_empty])
        d['pix_T_cams_raw'] = np.stack(
            [d["pix_T_cams_raw"], d_empty["pix_T_cams_raw"]])
        d['origin_T_camXs_raw'] = np.stack(
            [d["origin_T_camXs_raw"], d_empty["origin_T_camXs_raw"]])
        d['camR_T_origin_raw'] = np.stack(
            [d["camR_T_origin_raw"], d_empty["camR_T_origin_raw"]])
        d['xyz_camXs_raw'] = np.stack(
            [d["xyz_camXs_raw"], d_empty["xyz_camXs_raw"]])
        # d['rgb_camXs_raw'] = rgb_camXs
        # d['tree_seq_filename'] = filename
        if not hyp.dataset_name == "clevr_vqa":
            d['tree_seq_filename'] = [filename, "invalid_tree"]
        else:
            d['tree_seq_filename'] = ["temp"]
        # st()
        d['filename_e'] = ["temp"]
        d['filename_g'] = ["temp"]
        if hyp.use_gt_occs:
            d['occR_complete'] = np.expand_dims(occ_complete, axis=0)
        return d
コード例 #23
0
    def __getitem__(self, index):
        if hyp.dataset_name == 'kitti' or hyp.dataset_name == 'clevr' or hyp.dataset_name == 'real' or hyp.dataset_name == "bigbird" or hyp.dataset_name == "carla" or hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det" or hyp.dataset_name == "replica" or hyp.dataset_name == "clevr_vqa":
            # print(index)
            filename = self.records[index]
            d = pickle.load(open(filename, "rb"))
            d = dict(d)
        # elif hyp.dataset_name=="carla":
        #     filename = self.records[index]
        #     d = np.load(filename)
        #     d = dict(d)

        #     d['rgb_camXs_raw'] = d['rgb_camXs']
        #     d['pix_T_cams_raw'] = d['pix_T_cams']
        #     d['tree_seq_filename'] = "dummy_tree_filename"
        #     d['origin_T_camXs_raw'] = d['origin_T_camXs']
        #     d['camR_T_origin_raw'] = utils_geom.safe_inverse(torch.from_numpy(d['origin_T_camRs'])).numpy()
        #     d['xyz_camXs_raw'] = d['xyz_camXs']

        else:
            assert (False)  # reader not ready yet

        # st()
        # if hyp.save_gt_occs:
        # pickle.dump(d,open(filename, "wb"))
        # st()
        # st()
        if hyp.use_gt_occs:
            __p = lambda x: utils_basic.pack_seqdim(x, 1)
            __u = lambda x: utils_basic.unpack_seqdim(x, 1)

            B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
            PH, PW = hyp.PH, hyp.PW
            K = hyp.K
            BOX_SIZE = hyp.BOX_SIZE
            Z, Y, X = hyp.Z, hyp.Y, hyp.X
            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
            D = 9
            pix_T_cams = torch.from_numpy(
                d["pix_T_cams_raw"]).unsqueeze(0).cuda().to(torch.float)
            camRs_T_origin = torch.from_numpy(
                d["camR_T_origin_raw"]).unsqueeze(0).cuda().to(torch.float)
            origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
            origin_T_camXs = torch.from_numpy(
                d["origin_T_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
            camRs_T_camXs = __u(
                torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                             __p(origin_T_camXs)))
            camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
            camX0_T_camRs = camXs_T_camRs[:, 0]
            camX1_T_camRs = camXs_T_camRs[:, 1]
            camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)
            xyz_camXs = torch.from_numpy(
                d["xyz_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            xyz_camRs = __u(
                utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
            depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
                __p(pix_T_cams), __p(xyz_camXs), H, W)
            dense_xyz_camXs_ = utils_geom.depth2pointcloud(
                depth_camXs_, __p(pix_T_cams))
            occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))
            occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2,
                                                    X2))
            occRs_half = torch.max(occRs_half, dim=1).values.squeeze(0)
            occ_complete = occRs_half.cpu().numpy()

            # st()

        if hyp.do_empty:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
                'empty_rgb_camXs_raw',
                'empty_xyz_camXs_raw',
            ]
        else:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
            ]

        # if hyp.do_time_flip:
        #     d = random_time_flip_single(d,item_names)
        # if the sequence length > 2, select S frames
        # filename = d['raw_seq_filename']
        original_filename = filename
        if hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det":
            bbox_origin_gt = d['bbox_origin']
            if 'bbox_origin_predicted' in d:
                bbox_origin_predicted = d['bbox_origin_predicted']
            else:
                bbox_origin_predicted = []
            classes = d['obj_name']

            if isinstance(classes, str):
                classes = [classes]
            # st()

            d['tree_seq_filename'] = "temp"
        if hyp.dataset_name == "replica":
            d['tree_seq_filename'] = "temp"
            object_category = d['object_category_names']
            bbox_origin = d['bbox_origin']

        if hyp.dataset_name == "clevr_vqa":
            d['tree_seq_filename'] = "temp"
            pix_T_cams = d['pix_T_cams_raw']
            num_cams = pix_T_cams.shape[0]
            # padding_1 = torch.zeros([num_cams,1,3])
            # padding_2 = torch.zeros([num_cams,4,1])
            # padding_2[:,3] = 1.0
            # st()
            # pix_T_cams = torch.cat([pix_T_cams,padding_1],dim=1)
            # pix_T_cams = torch.cat([pix_T_cams,padding_2],dim=2)
            # st()
            shape_name = d['shape_list']
            color_name = d['color_list']
            material_name = d['material_list']
            all_name = []
            all_style = []
            for index in range(len(shape_name)):
                name = shape_name[index] + "/" + color_name[
                    index] + "_" + material_name[index]
                style_name = color_name[index] + "_" + material_name[index]
                all_name.append(name)
                all_style.append(style_name)

            # st()

            if hyp.do_shape:
                class_name = shape_name
            elif hyp.do_color:
                class_name = color_name
            elif hyp.do_material:
                class_name = material_name
            elif hyp.do_style:
                class_name = all_style
            else:
                class_name = all_name

            object_category = class_name
            bbox_origin = d['bbox_origin']
            # st()

        if hyp.dataset_name == "carla":
            camR_index = d['camR_index']
            rgb_camtop = d['rgb_camXs_raw'][camR_index:camR_index + 1]
            origin_T_camXs_top = d['origin_T_camXs_raw'][
                camR_index:camR_index + 1]
            # predicted_box  = d['bbox_origin_predicted']
            predicted_box = []
        filename = d['tree_seq_filename']
        if hyp.do_2d_style_munit:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)

        # st()
        if hyp.fixed_view:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)
        elif self.shuffle or hyp.randomly_select_views:
            d, indexes = random_select_single(d, item_names, num_samples=hyp.S)
        else:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)

        filename_g = "/".join([original_filename, str(indexes[0])])
        filename_e = "/".join([original_filename, str(indexes[1])])

        rgb_camXs = d['rgb_camXs_raw']
        # move channel dim inward, like pytorch wants
        # rgb_camRs = np.transpose(rgb_camRs, axes=[0, 3, 1, 2])

        rgb_camXs = np.transpose(rgb_camXs, axes=[0, 3, 1, 2])
        rgb_camXs = rgb_camXs[:, :3]
        rgb_camXs = utils_improc.preprocess_color(rgb_camXs)

        if hyp.dataset_name == "carla":
            rgb_camtop = np.transpose(rgb_camtop, axes=[0, 3, 1, 2])
            rgb_camtop = rgb_camtop[:, :3]
            rgb_camtop = utils_improc.preprocess_color(rgb_camtop)
            d['rgb_camtop'] = rgb_camtop
            d['origin_T_camXs_top'] = origin_T_camXs_top
            if len(predicted_box) == 0:
                predicted_box = np.zeros([hyp.N, 6])
                score = np.zeros([hyp.N]).astype(np.float32)
            else:
                num_boxes = predicted_box.shape[0]
                score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
                predicted_box = np.pad(predicted_box,
                                       [[0, hyp.N - num_boxes], [0, 0]])
            d['predicted_box'] = predicted_box.astype(np.float32)
            d['predicted_scores'] = score.astype(np.float32)
        if hyp.dataset_name == "clevr_vqa":
            num_boxes = bbox_origin.shape[0]
            bbox_origin = np.array(bbox_origin)
            score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
            bbox_origin = np.pad(bbox_origin,
                                 [[0, hyp.N - num_boxes], [0, 0], [0, 0]])
            object_category = np.pad(object_category, [[0, hyp.N - num_boxes]],
                                     lambda x, y, z, m: "0")

            d['gt_box'] = bbox_origin.astype(np.float32)
            d['gt_scores'] = score.astype(np.float32)
            d['classes'] = list(object_category)

        if hyp.dataset_name == "replica":
            if len(bbox_origin) == 0:
                score = np.zeros([hyp.N])
                bbox_origin = np.zeros([hyp.N, 6])
                object_category = ["0"] * hyp.N
                object_category = np.array(object_category)
            else:
                num_boxes = len(bbox_origin)
                bbox_origin = torch.stack(bbox_origin).numpy().squeeze(
                    1).squeeze(1).reshape([num_boxes, 6])
                bbox_origin = np.array(bbox_origin)
                score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
                bbox_origin = np.pad(bbox_origin,
                                     [[0, hyp.N - num_boxes], [0, 0]])
                object_category = np.pad(object_category,
                                         [[0, hyp.N - num_boxes]],
                                         lambda x, y, z, m: "0")
            d['gt_box'] = bbox_origin.astype(np.float32)
            d['gt_scores'] = score.astype(np.float32)
            d['classes'] = list(object_category)
            # st()

        if hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det":
            bbox_origin_predicted = bbox_origin_predicted[:3]
            if len(bbox_origin_gt.shape) == 1:
                bbox_origin_gt = np.expand_dims(bbox_origin_gt, 0)
            num_boxes = bbox_origin_gt.shape[0]
            # st()
            score_gt = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
            bbox_origin_gt = np.pad(bbox_origin_gt,
                                    [[0, hyp.N - num_boxes], [0, 0]])
            # st()
            classes = np.pad(classes, [[0, hyp.N - num_boxes]],
                             lambda x, y, z, m: "0")

            if len(bbox_origin_predicted) == 0:
                bbox_origin_predicted = np.zeros([hyp.N, 6])
                score_pred = np.zeros([hyp.N]).astype(np.float32)
            else:
                num_boxes = bbox_origin_predicted.shape[0]
                score_pred = np.pad(np.ones([num_boxes]),
                                    [0, hyp.N - num_boxes])
                bbox_origin_predicted = np.pad(
                    bbox_origin_predicted, [[0, hyp.N - num_boxes], [0, 0]])

            d['predicted_box'] = bbox_origin_predicted.astype(np.float32)
            d['predicted_scores'] = score_pred.astype(np.float32)
            d['gt_box'] = bbox_origin_gt.astype(np.float32)
            d['gt_scores'] = score_gt.astype(np.float32)
            d['classes'] = list(classes)

        d['rgb_camXs_raw'] = rgb_camXs

        if hyp.dataset_name != "carla" and hyp.do_empty:
            empty_rgb_camXs = d['empty_rgb_camXs_raw']
            # move channel dim inward, like pytorch wants
            empty_rgb_camXs = np.transpose(empty_rgb_camXs, axes=[0, 3, 1, 2])
            empty_rgb_camXs = empty_rgb_camXs[:, :3]
            empty_rgb_camXs = utils_improc.preprocess_color(empty_rgb_camXs)
            d['empty_rgb_camXs_raw'] = empty_rgb_camXs
        # st()
        if hyp.use_gt_occs:
            d['occR_complete'] = occ_complete
        d['tree_seq_filename'] = filename
        d['filename_e'] = filename_e
        d['filename_g'] = filename_g
        return d
コード例 #24
0
    def forward(self, feed):
        results = dict()
        summ_writer = utils_improc.Summ_writer(writer=feed['writer'],
                                               global_step=feed['global_step'],
                                               set_name=feed['set_name'],
                                               fps=8)

        writer = feed['writer']
        global_step = feed['global_step']

        total_loss = torch.tensor(0.0)

        __p = lambda x: pack_seqdim(x, B)
        __u = lambda x: unpack_seqdim(x, B)

        B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
        PH, PW = hyp.PH, hyp.PW
        K = hyp.K
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
        D = 9

        rgb_camRs = feed["rgb_camRs"]
        rgb_camXs = feed["rgb_camXs"]
        pix_T_cams = feed["pix_T_cams"]
        cam_T_velos = feed["cam_T_velos"]
        boxlist_camRs = feed["boxes3D"]
        tidlist_s = feed["tids"]  # coordinate-less and plural
        scorelist_s = feed["scores"]  # coordinate-less and plural
        # # postproc the boxes:
        # scorelist_s = __u(utils_misc.rescore_boxlist_with_inbound(__p(boxlist_camRs), __p(tidlist_s), Z, Y, X))
        boxlist_camRs_, tidlist_s_, scorelist_s_ = __p(boxlist_camRs), __p(
            tidlist_s), __p(scorelist_s)
        boxlist_camRs_, tidlist_s_, scorelist_s_ = utils_misc.shuffle_valid_and_sink_invalid_boxes(
            boxlist_camRs_, tidlist_s_, scorelist_s_)
        boxlist_camRs = __u(boxlist_camRs_)
        tidlist_s = __u(tidlist_s_)
        scorelist_s = __u(scorelist_s_)

        origin_T_camRs = feed["origin_T_camRs"]
        origin_T_camRs_ = __p(origin_T_camRs)
        origin_T_camXs = feed["origin_T_camXs"]
        origin_T_camXs_ = __p(origin_T_camXs)

        camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
        camX0_T_camXs_ = __p(camX0_T_camXs)
        camRs_T_camXs_ = torch.matmul(origin_T_camRs_.inverse(),
                                      origin_T_camXs_)
        camXs_T_camRs_ = camRs_T_camXs_.inverse()
        camRs_T_camXs = __u(camRs_T_camXs_)
        camXs_T_camRs = __u(camXs_T_camRs_)

        xyz_veloXs = feed["xyz_veloXs"]
        xyz_camXs = __u(utils_geom.apply_4x4(__p(cam_T_velos),
                                             __p(xyz_veloXs)))
        xyz_camRs = __u(
            utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
        xyz_camX0s = __u(
            utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs)))

        occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X))
        occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))
        occX0s = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z, Y, X))

        occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2))
        occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2))
        occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2))

        unpRs = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z, Y, X,
                __p(torch.matmul(pix_T_cams, camXs_T_camRs))))
        unpXs = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X,
                                           __p(pix_T_cams)))
        unpX0s = utils_vox.apply_4x4_to_voxs(camX0_T_camXs, unpXs)

        unpRs_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                __p(torch.matmul(pix_T_cams, camXs_T_camRs))))

        #####################
        ## visualize what we got
        #####################
        summ_writer.summ_rgbs('2D_inputs/rgb_camRs',
                              torch.unbind(rgb_camRs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs',
                              torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_occs('3D_inputs/occRs', torch.unbind(occRs, dim=1))
        summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1))
        summ_writer.summ_unps('3D_inputs/unpRs', torch.unbind(unpRs, dim=1),
                              torch.unbind(occRs, dim=1))
        summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1),
                              torch.unbind(occXs, dim=1))
        summ_writer.summ_unps('3D_inputs/unpX0s', torch.unbind(unpX0s, dim=1),
                              torch.unbind(occX0s, dim=1))

        lrtlist_camRs = __u(
            utils_geom.convert_boxlist_to_lrtlist(boxlist_camRs_)).reshape(
                B, S, N, 19)
        lrtlist_camXs = __u(
            utils_geom.apply_4x4_to_lrtlist(__p(camXs_T_camRs),
                                            __p(lrtlist_camRs)))
        # stabilize boxes for ego/cam motion
        lrtlist_camX0s = __u(
            utils_geom.apply_4x4_to_lrtlist(__p(camX0_T_camXs),
                                            __p(lrtlist_camXs)))
        # these are is B x S x N x 19

        summ_writer.summ_lrtlist('lrtlist_camR0', rgb_camRs[:, 0],
                                 lrtlist_camRs[:, 0], scorelist_s[:, 0],
                                 tidlist_s[:, 0], pix_T_cams[:, 0])
        summ_writer.summ_lrtlist('lrtlist_camR1', rgb_camRs[:, 1],
                                 lrtlist_camRs[:, 1], scorelist_s[:, 1],
                                 tidlist_s[:, 1], pix_T_cams[:, 1])
        summ_writer.summ_lrtlist('lrtlist_camX0', rgb_camXs[:, 0],
                                 lrtlist_camXs[:, 0], scorelist_s[:, 0],
                                 tidlist_s[:, 0], pix_T_cams[:, 0])
        summ_writer.summ_lrtlist('lrtlist_camX1', rgb_camXs[:, 1],
                                 lrtlist_camXs[:, 1], scorelist_s[:, 1],
                                 tidlist_s[:, 1], pix_T_cams[:, 1])
        (
            obj_lrtlist_camXs,
            obj_scorelist_s,
        ) = utils_misc.collect_object_info(lrtlist_camXs,
                                           tidlist_s,
                                           scorelist_s,
                                           pix_T_cams,
                                           K,
                                           mod='X',
                                           do_vis=True,
                                           summ_writer=summ_writer)
        (
            obj_lrtlist_camRs,
            obj_scorelist_s,
        ) = utils_misc.collect_object_info(lrtlist_camRs,
                                           tidlist_s,
                                           scorelist_s,
                                           pix_T_cams,
                                           K,
                                           mod='R',
                                           do_vis=True,
                                           summ_writer=summ_writer)
        (
            obj_lrtlist_camX0s,
            obj_scorelist_s,
        ) = utils_misc.collect_object_info(lrtlist_camX0s,
                                           tidlist_s,
                                           scorelist_s,
                                           pix_T_cams,
                                           K,
                                           mod='X0',
                                           do_vis=False)

        masklist_memR = utils_vox.assemble_padded_obj_masklist(
            lrtlist_camRs[:, 0], scorelist_s[:, 0], Z, Y, X, coeff=1.0)
        masklist_memX = utils_vox.assemble_padded_obj_masklist(
            lrtlist_camXs[:, 0], scorelist_s[:, 0], Z, Y, X, coeff=1.0)
        # obj_mask_memR is B x N x 1 x Z x Y x X
        summ_writer.summ_occ('obj/masklist_memR',
                             torch.sum(masklist_memR, dim=1))
        summ_writer.summ_occ('obj/masklist_memX',
                             torch.sum(masklist_memX, dim=1))

        # to do tracking or whatever, i need to be able to extract a 3d object crop
        cropX0_obj0 = utils_vox.crop_zoom_from_mem(occXs[:, 0],
                                                   lrtlist_camXs[:, 0, 0], Z2,
                                                   Y2, X2)
        cropX0_obj1 = utils_vox.crop_zoom_from_mem(occXs[:, 0],
                                                   lrtlist_camXs[:, 0, 1], Z2,
                                                   Y2, X2)
        cropR0_obj0 = utils_vox.crop_zoom_from_mem(occRs[:, 0],
                                                   lrtlist_camRs[:, 0, 0], Z2,
                                                   Y2, X2)
        cropR0_obj1 = utils_vox.crop_zoom_from_mem(occRs[:, 0],
                                                   lrtlist_camRs[:, 0, 1], Z2,
                                                   Y2, X2)
        # print('got it:')
        # print(cropX00.shape)
        # summ_writer.summ_occ('crops/cropX0_obj0', cropX0_obj0)
        # summ_writer.summ_occ('crops/cropX0_obj1', cropX0_obj1)
        summ_writer.summ_feat('crops/cropX0_obj0', cropX0_obj0, pca=False)
        summ_writer.summ_feat('crops/cropX0_obj1', cropX0_obj1, pca=False)
        summ_writer.summ_feat('crops/cropR0_obj0', cropR0_obj0, pca=False)
        summ_writer.summ_feat('crops/cropR0_obj1', cropR0_obj1, pca=False)

        if hyp.do_feat:
            if hyp.flow_do_synth_rt:
                result = utils_misc.get_synth_flow(unpRs_half,
                                                   occRs_half,
                                                   obj_lrtlist_camX0s,
                                                   obj_scorelist_s,
                                                   occXs_half,
                                                   feed['set_name'],
                                                   K=K,
                                                   summ_writer=summ_writer,
                                                   sometimes_zero=True,
                                                   sometimes_real=False)
                occXs, unpXs, flowX0, camX1_T_camX0, is_synth = result
            else:
                # ego-stabilized flow from X00 to X01
                flowX0 = utils_misc.get_gt_flow(
                    obj_lrtlist_camX0s,
                    obj_scorelist_s,
                    utils_geom.eye_4x4s(B, S),
                    occXs_half[:, 0],
                    K=K,
                    occ_only=False,  # get the dense flow
                    mod='X0',
                    summ_writer=summ_writer)

            # occXs is B x S x 1 x H x W x D
            # unpXs is B x S x 3 x H x W x D
            # featXs_input = torch.cat([occXs, occXs*unpXs], dim=2)
            featX0s_input = torch.cat([occX0s, occX0s * unpX0s], dim=2)
            featX0s_input_ = __p(featX0s_input)
            featX0s_, validX0s_, feat_loss = self.featnet(
                featX0s_input_, summ_writer)
            total_loss += feat_loss
            featX0s = __u(featX0s_)
            # _featX00 = featXs[:,0:1]
            # _featX01 = utils_vox.apply_4x4_to_voxs(camX0_T_camXs[:,1:], featXs[:,1:])
            # featX0s = torch.cat([_featX00, _featX01], dim=1)

            validX0s = 1.0 - (featX0s == 0).all(
                dim=2,
                keepdim=True).float()  #this shall be B x S x 1 x H x W x D

            summ_writer.summ_feats('3D_feats/featX0s_input',
                                   torch.unbind(featX0s_input, dim=1),
                                   pca=True)
            # summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featX0s_output',
                                   torch.unbind(featX0s, dim=1),
                                   pca=True)

        if hyp.do_flow:
            # total flow from X0 to X1
            flowX = utils_misc.get_gt_flow(
                obj_lrtlist_camXs,
                obj_scorelist_s,
                camX0_T_camXs,
                occXs_half[:, 0],
                K=K,
                occ_only=False,  # get the dense flow
                mod='X',
                vis=False,
                summ_writer=None)

            # # vis this to confirm it's ok (it is)
            # unpX0_e = utils_samp.backwarp_using_3D_flow(unpXs[:,1], flowX)
            # occX0_e = utils_samp.backwarp_using_3D_flow(occXs[:,1], flowX)
            # summ_writer.summ_unps('flow/backwarpX', [unpX0s[:,0], unpX0_e], [occXs[:,0], occX0_e])

            # unpX0_e = utils_samp.backwarp_using_3D_flow(unpX0s[:,1], flowX0)
            # occX0_e = utils_samp.backwarp_using_3D_flow(occX0s[:,1], flowX0, binary_feat=True)
            # summ_writer.summ_unps('flow/backwarpX0', [unpX0s[:,0], unpX0_e], [occXs[:,0], occX0_e])

            flow_loss, flowX0_pred = self.flownet(
                featX0s[:, 0],
                featX0s[:, 1],
                flowX0,  # gt flow
                torch.max(validX0s[:, 1:], dim=1)[0],
                is_synth,
                summ_writer)
            total_loss += flow_loss

            # g = flowX.reshape(-1)
            # summ_writer.summ_histogram('flowX_g_nonzero_hist', g[torch.abs(g)>0.01])

            # g = flowX0.reshape(-1)
            # e = flowX0_pred.reshape(-1)
            # summ_writer.summ_histogram('flowX0_g_nonzero_hist', g[torch.abs(g)>0.01])
            # summ_writer.summ_histogram('flowX0_e_nonzero_hist', e[torch.abs(g)>0.01])

        summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results
コード例 #25
0
ファイル: trinet2D.py プロジェクト: shamitlal/CoCoNets
    def forward(self,
                feat_cam0,
                feat_cam1,
                mask_mem0,
                pix_T_cam0,
                pix_T_cam1,
                cam1_T_cam0,
                vox_util,
                summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(mask_mem0.shape)
        assert (C == 1)

        B2, C, H, W = list(feat_cam0.shape)
        assert (B == B2)

        go_slow = True
        go_slow = False
        if go_slow:
            xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
            mask_mem0 = mask_mem0.reshape(B, Z * Y * X)
            vec0_list = []
            vec1_list = []
            for b in list(range(B)):
                xyz_mem0_b = xyz_mem0[b]
                mask_mem0_b = mask_mem0[b]
                xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)]
                # this is N x 3

                N, D = list(xyz_mem0_b.shape)
                if N > self.num_samples:
                    # to not waste time, i will subsample right here
                    perm = np.random.permutation(N)
                    xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]]
                    # this is num_samples x 3 (smaller than before)

                xyz_cam0_b = vox_util.Mem2Ref(xyz_mem0_b.unsqueeze(0), Z, Y, X)
                xyz_cam1_b = utils_geom.apply_4x4(cam1_T_cam0[b:b + 1],
                                                  xyz_cam0_b)
                # these are N x 3
                # now, i need to project both of these, and sample from the feats

                xy_cam0_b = utils_geom.apply_pix_T_cam(pix_T_cam0[b:b + 1],
                                                       xyz_cam0_b).squeeze(0)
                xy_cam1_b = utils_geom.apply_pix_T_cam(pix_T_cam1[b:b + 1],
                                                       xyz_cam1_b).squeeze(0)
                # these are N x 2

                vec0 = utils_samp.bilinear_sample_single(
                    feat_cam0[b], xy_cam0_b[:, 0], xy_cam0_b[:, 1])
                vec1 = utils_samp.bilinear_sample_single(
                    feat_cam1[b], xy_cam1_b[:, 0], xy_cam1_b[:, 1])
                # these are C x N

                x_pix0 = xy_cam0_b[:, 0]
                y_pix0 = xy_cam0_b[:, 1]
                x_pix1 = xy_cam1_b[:, 0]
                y_pix1 = xy_cam1_b[:, 1]
                y_pix0, x_pix0 = utils_basic.normalize_grid2D(
                    y_pix0, x_pix0, H, W)
                y_pix1, x_pix1 = utils_basic.normalize_grid2D(
                    y_pix1, x_pix1, H, W)
                xy_pix0 = torch.stack([x_pix0, y_pix0], axis=1).unsqueeze(0)
                xy_pix1 = torch.stack([x_pix1, y_pix1], axis=1).unsqueeze(0)
                # these are 1 x N x 2
                print('xy_pix0', xy_pix0.shape)

                vec0 = F.grid_sample(feat_cam0[b:b + 1], xy_pix0)
                vec1 = F.grid_sample(feat_cam1[b:b + 1], xy_pix1)
                print('vec0', vec0.shape)

                vec0_list.append(vec0)
                vec1_list.append(vec1)

            vec0 = torch.cat(vec0_list, dim=1).permute(1, 0)
            vec1 = torch.cat(vec1_list, dim=1).permute(1, 0)
        else:
            xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
            mask_mem0 = mask_mem0.reshape(B, Z * Y * X)

            valid_batches = 0
            sampling_coords_mem0 = torch.zeros(B, self.num_samples,
                                               3).float().cuda()
            valid_feat_cam0 = torch.zeros_like(feat_cam0)
            valid_feat_cam1 = torch.zeros_like(feat_cam1)
            valid_pix_T_cam0 = torch.zeros_like(pix_T_cam0)
            valid_pix_T_cam1 = torch.zeros_like(pix_T_cam1)
            valid_cam1_T_cam0 = torch.zeros_like(cam1_T_cam0)

            # sampling_coords_mem1 = torch.zeros(B, self.num_samples, 3).float().cuda()
            for b in list(range(B)):
                xyz_mem0_b = xyz_mem0[b]
                mask_mem0_b = mask_mem0[b]
                xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)]
                # this is N x 3

                N, D = list(xyz_mem0_b.shape)
                if N >= self.num_samples:
                    perm = np.random.permutation(N)
                    xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]]
                    # this is num_samples x 3 (smaller than before)

                    valid_batches += 1
                    # sampling_coords_mem0[valid_batches] = xyz_mem0_b

                    sampling_coords_mem0[b] = xyz_mem0_b
                    valid_feat_cam0[b] = feat_cam0[b]
                    valid_feat_cam1[b] = feat_cam1[b]
                    valid_pix_T_cam0[b] = pix_T_cam0[b]
                    valid_pix_T_cam1[b] = pix_T_cam1[b]
                    valid_cam1_T_cam0[b] = cam1_T_cam0[b]

            print('valid_batches:', valid_batches)
            if valid_batches == 0:
                # return early
                return total_loss

            # trim down
            sampling_coords_mem0 = sampling_coords_mem0[:valid_batches]
            feat_cam0 = valid_feat_cam0[:valid_batches]
            feat_cam1 = valid_feat_cam1[:valid_batches]
            pix_T_cam0 = valid_pix_T_cam0[:valid_batches]
            pix_T_cam1 = valid_pix_T_cam1[:valid_batches]
            cam1_T_cam0 = valid_cam1_T_cam0[:valid_batches]

            xyz_cam0 = vox_util.Mem2Ref(sampling_coords_mem0, Z, Y, X)
            xyz_cam1 = utils_geom.apply_4x4(cam1_T_cam0, xyz_cam0)
            # these are B x N x 3
            # now, i need to project both of these, and sample from the feats

            xy_cam0 = utils_geom.apply_pix_T_cam(pix_T_cam0, xyz_cam0)
            xy_cam1 = utils_geom.apply_pix_T_cam(pix_T_cam1, xyz_cam1)
            # these are B x N x 2

            vec0 = utils_samp.bilinear_sample2D(feat_cam0, xy_cam0[:, :, 0],
                                                xy_cam0[:, :, 1])
            vec1 = utils_samp.bilinear_sample2D(feat_cam1, xy_cam1[:, :, 0],
                                                xy_cam1[:, :, 1])
            # these are B x C x N

            vec0 = vec0.permute(0, 2, 1).view(valid_batches * self.num_samples,
                                              C)
            vec1 = vec1.permute(0, 2, 1).view(valid_batches * self.num_samples,
                                              C)

        print('vec0', vec0.shape)
        print('vec1', vec1.shape)
        # these are N x C

        # # where g is valid, we use it as reference and pull up e
        # margin_loss = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer)
        # l2_loss = reduce_masked_mean(sql2_on_axis(emb_e-emb_g.detach(), 1, keepdim=True), vis_g)
        # total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer)
        # total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer)

        ce_loss = self.compute_ce_loss(vec0, vec1.detach())
        total_loss = utils_misc.add_loss('tri2D/emb_ce_loss', total_loss,
                                         ce_loss, hyp.tri_2D_ce_coeff,
                                         summ_writer)

        # l2_loss_im = torch.mean(sql2_on_axis(emb_e-emb_g, 1, keepdim=True), dim=3)
        # if summ_writer is not None:
        #     summ_writer.summ_oned('emb3D/emb_3D_l2_loss', l2_loss_im)
        #     summ_writer.summ_feats('emb3D/embs_3D', [emb_e, emb_g], pca=True)
        return total_loss
コード例 #26
0
    def forward(self, feed, moc_init_done=False, debug=False):
        summ_writer = utils_improc.Summ_writer(
            writer = feed['writer'],
            global_step = feed['global_step'],
            set_name= feed['set_name'],
            fps=8)

        writer = feed['writer']
        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        ### ... All things sensor ... ###
        sensor_rgbs = feed['sensor_imgs']
        sensor_depths = feed['sensor_depths']
        center_sensor_H, center_sensor_W = sensor_depths[0][0].shape[-1] // 2, sensor_depths[0][0].shape[-2] // 2
        ### ... All things sensor end ... ###

        # 1. Form the memory tensor using the feat net and visual images.
        # check what all do you need for this and create only those things

        ##  .... Input images ....  ##
        rgb_camRs = feed['rgb_camRs']
        rgb_camXs = feed['rgb_camXs']
        ##  .... Input images end ....  ##

        ## ... Hyperparams ... ##
        B, H, W, V, S = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S
        __p = lambda x: pack_seqdim(x, B)
        __u = lambda x: unpack_seqdim(x, B)
        PH, PW = hyp.PH, hyp.PW
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z/2), int(Y/2), int(X/2)
        ## ... Hyperparams end ... ##

        ## .... VISUAL TRANSFORMS BEGIN .... ##
        pix_T_cams = feed['pix_T_cams']
        pix_T_cams_ = __p(pix_T_cams)
        origin_T_camRs = feed['origin_T_camRs']
        origin_T_camRs_ = __p(origin_T_camRs)
        origin_T_camXs = feed['origin_T_camXs']
        origin_T_camXs_ = __p(origin_T_camXs)
        camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse(
            origin_T_camRs_), origin_T_camXs_)
        camXs_T_camRs_ = utils_geom.safe_inverse(camRs_T_camXs_)
        camRs_T_camXs = __u(camRs_T_camXs_)
        camXs_T_camRs = __u(camXs_T_camRs_)
        pix_T_cams_ = utils_geom.pack_intrinsics(pix_T_cams_[:, 0, 0], pix_T_cams_[:, 1, 1], pix_T_cams_[:, 0, 2],
            pix_T_cams_[:, 1, 2])
        pix_T_camRs_ = torch.matmul(pix_T_cams_, camXs_T_camRs_)
        pix_T_camRs = __u(pix_T_camRs_)
        ## ... VISUAL TRANSFORMS END ... ##

        ## ... SENSOR TRANSFORMS BEGIN ... ##
        sensor_origin_T_camXs = feed['sensor_extrinsics']
        sensor_origin_T_camXs_ = __p(sensor_origin_T_camXs)
        sensor_origin_T_camRs = feed['sensor_origin_T_camRs']
        sensor_origin_T_camRs_ = __p(sensor_origin_T_camRs)
        sensor_camRs_T_origin_ = utils_geom.safe_inverse(sensor_origin_T_camRs_)

        sensor_camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse(
            sensor_origin_T_camRs_), sensor_origin_T_camXs_)
        sensor_camXs_T_camRs_ = utils_geom.safe_inverse(sensor_camRs_T_camXs_)

        sensor_camRs_T_camXs = __u(sensor_camRs_T_camXs_)
        sensor_camXs_T_camRs = __u(sensor_camXs_T_camRs_)

        sensor_pix_T_cams = feed['sensor_intrinsics']
        sensor_pix_T_cams_ = __p(sensor_pix_T_cams)
        sensor_pix_T_cams_ = utils_geom.pack_intrinsics(sensor_pix_T_cams_[:, 0, 0], sensor_pix_T_cams_[:, 1, 1],
            sensor_pix_T_cams_[:, 0, 2], sensor_pix_T_cams_[:, 1, 2])
        sensor_pix_T_camRs_ = torch.matmul(sensor_pix_T_cams_, sensor_camXs_T_camRs_)
        sensor_pix_T_camRs = __u(sensor_pix_T_camRs_)
        ## .... SENSOR TRANSFORMS END .... ##

        ## .... Visual Input point clouds .... ##
        xyz_camXs = feed['xyz_camXs']
        xyz_camXs_ = __p(xyz_camXs)
        xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, xyz_camXs_)  # (40, 4, 4) (B*S, N, 3)
        xyz_camRs = __u(xyz_camRs_)
        assert all([torch.allclose(xyz_camR, inp_xyz_camR) for xyz_camR, inp_xyz_camR in zip(
            xyz_camRs, feed['xyz_camRs']
        )]), "computation of xyz_camR here and those computed in input do not match"
        ## .... Visual Input point clouds end .... ##

        ## ... Sensor input point clouds ... ##
        sensor_xyz_camXs = feed['sensor_xyz_camXs']
        sensor_xyz_camXs_ = __p(sensor_xyz_camXs)
        sensor_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_xyz_camXs_)
        sensor_xyz_camRs = __u(sensor_xyz_camRs_)
        assert all([torch.allclose(sensor_xyz, inp_sensor_xyz) for sensor_xyz, inp_sensor_xyz in zip(
            sensor_xyz_camRs, feed['sensor_xyz_camRs']
        )]), "the sensor_xyz_camRs computed in forward do not match those computed in input"

        ## ... visual occupancy computation voxelize the pointcloud from above ... ##
        occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X)
        occXs_ = utils_vox.voxelize_xyz(xyz_camXs_, Z, Y, X)
        occRs_half_ = utils_vox.voxelize_xyz(xyz_camRs_, Z2, Y2, X2)
        occXs_half_ = utils_vox.voxelize_xyz(xyz_camXs_, Z2, Y2, X2)
        ## ... visual occupancy computation end ... NOTE: no unpacking ##

        ## .. visual occupancy computation for sensor inputs .. ##
        sensor_occRs_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z, Y, X)
        sensor_occXs_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z, Y, X)
        sensor_occRs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z2, Y2, X2)
        sensor_occXs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z2, Y2, X2)

        ## ... unproject rgb images ... ##
        unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_camRs_)
        unpXs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_cams_)
        ## ... unproject rgb finish ... NOTE: no unpacking ##

        ## ... Make depth images ... ##
        depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(pix_T_cams_, xyz_camXs_, H, W)
        dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, pix_T_cams_)
        dense_xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, dense_xyz_camXs_)
        inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float()
        inbound_camXs_ = torch.reshape(inbound_camXs_, [B*S, 1, H, W])
        valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_)
        ## ... Make depth images ... ##

        ## ... Make sensor depth images ... ##
        sensor_depth_camXs_, sensor_valid_camXs_ = utils_geom.create_depth_image(sensor_pix_T_cams_,
            sensor_xyz_camXs_, H, W)
        sensor_dense_xyz_camXs_ = utils_geom.depth2pointcloud(sensor_depth_camXs_, sensor_pix_T_cams_)
        sensor_dense_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_dense_xyz_camXs_)
        sensor_inbound_camXs_ = utils_vox.get_inbounds(sensor_dense_xyz_camRs_, Z, Y, X).float()
        sensor_inbound_camXs_ = torch.reshape(sensor_inbound_camXs_, [B*hyp.sensor_S, 1, H, W])
        sensor_valid_camXs = __u(sensor_valid_camXs_) * __u(sensor_inbound_camXs_)
        ### .. Done making sensor depth images .. ##

        ### ... Sanity check ... Write to tensorboard ... ###
        summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(__u(depth_camXs_), dim=1))
        summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camRs', torch.unbind(rgb_camRs, dim=1))
        summ_writer.summ_occs('3d_inputs/occXs', torch.unbind(__u(occXs_), dim=1), reduce_axes=[2])
        summ_writer.summ_unps('3d_inputs/unpXs', torch.unbind(__u(unpXs_), dim=1),\
            torch.unbind(__u(occXs_), dim=1))

        # A different approach for viewing occRs of sensors
        sensor_occRs = __u(sensor_occRs_)
        vis_sensor_occRs = torch.max(sensor_occRs, dim=1, keepdim=True)[0]
        # summ_writer.summ_occs('3d_inputs/sensor_occXs', torch.unbind(__u(sensor_occXs_), dim=1),
        #     reduce_axes=[2])
        summ_writer.summ_occs('3d_inputs/sensor_occRs', torch.unbind(vis_sensor_occRs, dim=1), reduce_axes=[2])

        ### ... code for visualizing sensor depths and sensor rgbs ... ###
        # summ_writer.summ_oneds('2D_inputs/depths_sensor', torch.unbind(sensor_depths, dim=1))
        # summ_writer.summ_rgbs('2D_inputs/rgbs_sensor', torch.unbind(sensor_rgbs, dim=1))
        # summ_writer.summ_oneds('2D_inputs/validXs_sensor', torch.unbind(sensor_valid_camXs, dim=1))

        if summ_writer.save_this:
            unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, matmul2(pix_T_cams_, camXs_T_camRs_))
            unpRs = __u(unpRs_)
            occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X)
            summ_writer.summ_occs('3d_inputs/occRs', torch.unbind(__u(occRs_), dim=1), reduce_axes=[2])
            summ_writer.summ_unps('3d_inputs/unpRs', torch.unbind(unpRs, dim=1),\
                torch.unbind(__u(occRs_), dim=1))
        ### ... Sanity check ... Writing to tensoboard complete ... ###
        results = list()

        mask_ = None
        ### ... Visual featnet part .... ###
        if hyp.do_feat:
            featXs_input = torch.cat([__u(occXs_), __u(occXs_)*__u(unpXs_)], dim=2)  # B, S, 4, H, W, D
            featXs_input_ = __p(featXs_input)

            freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), occXs_half_)
            freeXs = __u(freeXs_)
            visXs = torch.clamp(__u(occXs_half_) + freeXs, 0.0, 1.0)

            if type(mask_) != type(None):
                assert(list(mask_.shape)[2:5] == list(featXs_input.shape)[2:5])
            featXs_, validXs_, _ = self.featnet(featXs_input_, summ_writer, mask=occXs_)
            # total_loss += feat_loss  # Note no need of loss

            validXs, featXs = __u(validXs_), __u(featXs_) # unpacked into B, S, C, D, H, W
            # bring everything to ref_frame
            validRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, validXs)
            visRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, visXs)
            featRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, featXs)  # This is now in memory coordinates

            emb3D_e = torch.mean(featRs[:, 1:], dim=1)  # context, or the features of the scene
            emb3D_g = featRs[:, 0]  # this is to predict, basically I will pass emb3D_e as input and hope to predict emb3D_g
            vis3D_e = torch.max(validRs[:, 1:], dim=1)[0] * torch.max(visRs[:, 1:], dim=1)[0]
            vis3D_g = validRs[:, 0] * visRs[:, 0]

            #### ... I do not think I need this ... ####
            results = {}
        #     # if hyp.do_eval_recall:
        #     #     results['emb3D_e'] = emb3D_e
        #     #     results['emb3D_g'] = emb3D_g
        #     #### ... Check if you need the above

            summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featRs_output', torch.unbind(featRs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_e', vis3D_e, pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_g', vis3D_g, pca=False)

            # I need to aggregate the features and detach to prevent the backward pass on featnet
            featRs = torch.mean(featRs, dim=1)
            featRs = featRs.detach()
            #  ... HERE I HAVE THE VISUAL FEATURE TENSOR ... WHICH IS MADE USING 5 EVENLY SPACED VIEWS #

        # FOR THE TOUCH PART, I HAVE THE OCC and THE AIM IS TO PREDICT FEATURES FROM THEM #
        if hyp.do_touch_feat:
            # 1. Pass all the sensor depth images through the backbone network
            input_sensor_depths = __p(sensor_depths)
            sensor_features_ = self.backbone_2D(input_sensor_depths)

            # should normalize these feature tensors
            sensor_features_ = l2_normalize(sensor_features_, dim=1)

            sensor_features = __u(sensor_features_)
            assert torch.allclose(torch.norm(sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\
                "normalization has no effect on you huh."

            if hyp.do_eval_recall:
              results['sensor_features'] = sensor_features_
              results['sensor_depths'] = input_sensor_depths
              results['object_img'] = rgb_camRs
              results['sensor_imgs'] = __p(sensor_rgbs)

            # if moco is used do the same procedure as above but with a different network #
            if hyp.do_moc or hyp.do_eval_recall:
                # 1. Pass all the sensor depth images through the key network
                key_input_sensor_depths = copy.deepcopy(__p(sensor_depths)) # bx1024x1x16x16->(2048x1x16x16)
                self.key_touch_featnet.eval()
                with torch.no_grad():
                    key_sensor_features_ = self.key_touch_featnet(key_input_sensor_depths)

                key_sensor_features_ = l2_normalize(key_sensor_features_, dim=1)
                key_sensor_features = __u(key_sensor_features_)
                assert torch.allclose(torch.norm(key_sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\
                    "normalization has no effect on you huh."

        # doing the same procedure for moco but with a different network end #

        # do you want to do metric learning voxel point based using visual features and sensor features
        if hyp.do_touch_embML and not hyp.do_touch_forward:
            # trial 1: I do not pass the above obtained features through some encoder decoder in 3d
            # So compute the location is ref_frame which the center of these depth images will occupy
            # at all of these locations I will sample the from the visual tensor. It forms the positive pairs
            # negatives are simply everything except the positive
            sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_x = sensor_depths_centers_x.cuda()
            sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_y = sensor_depths_centers_y.cuda()
            sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W]

            # Next use Pixels2Camera to unproject all of these together.
            # merge the batch and the sequence dimension
            sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1)  # BxHxW as required by Pixels2Camera
            sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1)
            sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1)

            fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_)
            sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y,
                sensor_depths_centers_z, fx, fy, x0, y0)

            # finally use apply4x4 to get the locations in ref_cam
            sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_)

            # NOTE: convert them to memory coordinates, the name is xyz so I presume it returns xyz but talk to ADAM
            sensor_depths_centers_in_mem_ = utils_vox.Ref2Mem(sensor_depths_centers_in_ref_cam_, Z2, Y2, X2)
            sensor_depths_centers_in_mem = sensor_depths_centers_in_mem_.reshape(hyp.B, hyp.sensor_S, -1)

            if debug:
                print('assert that you are not entering here')
                from IPython import embed; embed()
                # form a (0, 1) volume here at these locations and see if it resembles a cup
                dim1 = X2 * Y2 * Z2
                dim2 = X2 * Y2
                dim3 = X2
                binary_voxel_grid = torch.zeros((hyp.B, X2, Y2, Z2))
                # NOTE: Z is the leading dimension
                rounded_idxs = torch.round(sensor_depths_centers_in_mem)
                flat_idxs = dim2 * rounded_idxs[0, :, 0] + dim3 * rounded_idxs[0, :, 1] + rounded_idxs[0, :, 2]
                flat_idxs1 = dim2 * rounded_idxs[1, :, 0] + dim3 * rounded_idxs[1, :, 1] + rounded_idxs[1, :, 2]
                flat_idxs1 = flat_idxs1 + dim1
                flat_idxs1 = flat_idxs1.long()
                flat_idxs = flat_idxs.long()

                flattened_grid = binary_voxel_grid.flatten()
                flattened_grid[flat_idxs] = 1.
                flattened_grid[flat_idxs1] = 1.

                binary_voxel_grid = flattened_grid.view(B, X2, Y2, Z2)

                assert binary_voxel_grid[0].sum() == len(torch.unique(flat_idxs)), "some indexes are missed here"
                assert binary_voxel_grid[1].sum() == len(torch.unique(flat_idxs1)), "some indexes are missed here"

                # o3d.io.write_voxel_grid("forward_pass_save/grid0.ply", binary_voxel_grid[0])
                # o3d.io.write_voxel_grid("forward_pass_save/grid1.ply", binary_voxel_grid[0])
                # need to save these voxels
                save_voxel(binary_voxel_grid[0].cpu().numpy(), "forward_pass_save/grid0.binvox")
                save_voxel(binary_voxel_grid[1].cpu().numpy(), "forward_pass_save/grid1.binvox")
                from IPython import embed; embed()

            # use grid sample to get the visual touch tensor at these locations, NOTE: visual tensor features shape is (B, C, N)
            visual_tensor_features = utils_samp.bilinear_sample3D(featRs, sensor_depths_centers_in_mem[:, :, 0],
                sensor_depths_centers_in_mem[:, :, 1], sensor_depths_centers_in_mem[:, :, 2])
            visual_feature_tensor = visual_tensor_features.permute(0, 2, 1)
            # pack it
            visual_feature_tensor_ = __p(visual_feature_tensor)
            C = list(visual_feature_tensor.shape)[-1]
            print('C=', C)

            # do the metric learning this is the same as before.
            # the code is basically copied from embnet3d.py but some changes are being made very minor
            emb_vec = torch.stack((sensor_features_, visual_feature_tensor_), dim=1).view(B*self.num_samples*self.batch_k, C)
            y = torch.stack([torch.range(0,self.num_samples*B-1), torch.range(0,self.num_samples*B-1)], dim=1).view(self.num_samples*B*self.batch_k)
            a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec)

            # I need to write my own version of margin loss since the negatives and anchors may not be same dim
            d_ap = torch.sqrt(torch.sum((positives - anchors)**2, dim=1) + 1e-8)
            pos_loss = torch.clamp(d_ap - beta + self._margin, min=0.0)

            # TODO: expand the dims of anchors and tile them and compute the negative loss

            # do the pair count where you average by contributors only

            # this is your total loss


            # Further idea is to check what volumetric locations do each of the depth images corresponds to
            # unproject the entire depth image and convert to ref. and then sample.

        if hyp.do_touch_forward:
            ## ... Begin code for getting crops from visual memory ... ##
            sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_x = sensor_depths_centers_x.cuda()
            sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_y = sensor_depths_centers_y.cuda()
            sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W]

            # Next use Pixels2Camera to unproject all of these together.
            # merge the batch and the sequence dimension
            sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1)
            sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1)
            sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1)

            fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_)
            sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y,
                sensor_depths_centers_z, fx, fy, x0, y0)
            sensor_depths_centers_in_world_ = utils_geom.apply_4x4(sensor_origin_T_camXs_, sensor_depths_centers_in_camXs_)  # not used by the algorithm
            ## this will be later used for visualization hence saving it here for now
            sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_)  # not used by the algorithm

            sensor_depths_centers_in_camXs = __u(sensor_depths_centers_in_camXs_).squeeze(2)

            # There has to be a better way to do this, for each of the cameras in the batch I want a box of size (ch, cw, cd)
            # TODO: rotation is the deviation of the box from the axis aligned do I want this
            tB, tN, _ = list(sensor_depths_centers_in_camXs.shape)  # 2, 512, _
            boxlist = torch.zeros(tB, tN, 9)  # 2, 512, 9
            boxlist[:, :, :3] = sensor_depths_centers_in_camXs  # this lies on the object
            boxlist[:, :, 3:6] = torch.FloatTensor([hyp.contextW, hyp.contextH, hyp.contextD])

            # convert the boxlist to lrtlist and to cuda
            # the rt here transforms the from box coordinates to camera coordinates
            box_lrtlist = utils_geom.convert_boxlist_to_lrtlist(boxlist)

            # Now I will use crop_zoom_from_mem functionality to get the features in each of the boxes
            # I will do it for each of the box separately as required by the api
            context_grid_list = list()
            for m in range(box_lrtlist.shape[1]):
                curr_box = box_lrtlist[:, m, :]
                context_grid = utils_vox.crop_zoom_from_mem(featRs, curr_box, 8, 8, 8,
                    sensor_camRs_T_camXs[:, m, :, :])
                context_grid_list.append(context_grid)
            context_grid_list = torch.stack(context_grid_list, dim=1)
            context_grid_list_ = __p(context_grid_list)
            ## ... till here I believe I have not introduced any randomness, so the points are still in
            ## ... End code for getting crops around this center of certain height, width and depth ... ##

            ## ... Begin code for passing the context grid through 3D CNN to obtain a vector ... ##
            sensor_cam_locs = feed['sensor_locs']  # these are in origin coordinates
            sensor_cam_quats = feed['sensor_quats'] # this too in in world_coordinates
            sensor_cam_locs_ = __p(sensor_cam_locs)
            sensor_cam_quats_ = __p(sensor_cam_quats)
            sensor_cam_locs_in_R_ = utils_geom.apply_4x4(sensor_camRs_T_origin_, sensor_cam_locs_.unsqueeze(1)).squeeze(1)
            # TODO TODO TODO confirm that this is right? TODO TODO TODO
            get_r_mat = lambda cam_quat: transformations.quaternion_matrix_py(cam_quat)
            rot_mat_Xs_ = torch.from_numpy(np.stack(list(map(get_r_mat, sensor_cam_quats_.cpu().numpy())))).to(sensor_cam_locs_.device).float()
            rot_mat_Rs_ = torch.bmm(sensor_camRs_T_origin_, rot_mat_Xs_)
            get_quat = lambda r_mat: transformations.quaternion_from_matrix_py(r_mat)
            sensor_quats_in_R_ = torch.from_numpy(np.stack(list(map(get_quat, rot_mat_Rs_.cpu().numpy())))).to(sensor_cam_locs_.device).float()

            pred_features_ = self.context_net(context_grid_list_,\
                sensor_cam_locs_in_R_, sensor_quats_in_R_)

            # normalize
            pred_features_ = l2_normalize(pred_features_, dim=1)
            pred_features = __u(pred_features_)

            # if doing moco I have to pass the inputs through the key(slow) network as well #
            if hyp.do_moc or hyp.do_eval_recall:
                key_context_grid_list_ = copy.deepcopy(context_grid_list_)
                key_sensor_cam_locs_in_R_ = copy.deepcopy(sensor_cam_locs_in_R_)
                key_sensor_quats_in_R_ = copy.deepcopy(sensor_quats_in_R_)
                self.key_context_net.eval()
                with torch.no_grad():
                    key_pred_features_ = self.key_context_net(key_context_grid_list_,\
                        key_sensor_cam_locs_in_R_, key_sensor_quats_in_R_)

                # normalize, normalization is very important why though
                key_pred_features_ = l2_normalize(key_pred_features_, dim=1)
                key_pred_features = __u(key_pred_features_)
            # end passing of the input through the slow network this is necessary for moco #
            ## ... End code for passing the context grid through 3D CNN to obtain a vector ... ##

        ## ... Begin code for doing metric learning between pred_features and sensor features ... ##
        # 1. Subsample both based on the number of positive samples
        if hyp.do_touch_embML:
            assert(hyp.do_touch_forward)
            assert(hyp.do_touch_feat)
            perm = torch.randperm(len(pred_features_))  ## 1024
            chosen_sensor_feats_ = sensor_features_[perm[:self.num_pos_samples*hyp.B]]
            chosen_pred_feats_ = pred_features_[perm[:self.num_pos_samples*B]]

            # 2. form the emb_vec and get pos and negative samples for the batch
            emb_vec = torch.stack((chosen_sensor_feats_, chosen_pred_feats_), dim=1).view(hyp.B*self.num_pos_samples*self.batch_k, -1)
            y = torch.stack([torch.range(0, self.num_pos_samples*B-1), torch.range(0, self.num_pos_samples*B-1)],\
                dim=1).view(B*self.num_pos_samples*self.batch_k) # (0, 0, 1, 1, ..., 255, 255)

            a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec)

            # 3. Compute the loss, ML loss and the l2 distance betwee the embeddings
            margin_loss, _ = self.criterion(anchors, positives, negatives, self.beta, y[a_indices])
            total_loss = utils_misc.add_loss('embtouch/emb_touch_ml_loss', total_loss, margin_loss,
                hyp.emb_3D_ml_coeff, summ_writer)

            # the l2 loss between the embeddings
            l2_loss = torch.nn.functional.mse_loss(chosen_sensor_feats_, chosen_pred_feats_)
            total_loss = utils_misc.add_loss('embtouch/emb_l2_loss', total_loss, l2_loss,
                hyp.emb_3D_l2_coeff, summ_writer)
        ## ... End code for doing metric learning between pred_features and sensor_features ... ##

        ## ... Begin code for doing moc inspired ML between pred_features and sensor_features ... ##
        if hyp.do_moc and moc_init_done:
            moc_loss = self.moc_ml_net(sensor_features_, key_sensor_features_,\
                pred_features_, key_pred_features_, summ_writer)
            total_loss += moc_loss
        ## ... End code for doing moc inspired ML between pred_features and sensor_feature ... ##

        ## ... add code for filling up results needed for eval recall ... ##
        if hyp.do_eval_recall and moc_init_done:
            results['context_features'] = pred_features_
            results['sensor_depth_centers_in_world'] = sensor_depths_centers_in_world_
            results['sensor_depths_centers_in_ref_cam'] = sensor_depths_centers_in_ref_cam_
            results['object_name'] = feed['object_name']

            # I will do precision recall here at different recall values and summarize it using tensorboard
            recalls = [1, 5, 10, 50, 100, 200]
            # also should not include any gradients because of this
            # fast_sensor_emb_e = sensor_features_
            # fast_context_emb_e = pred_features_
            # slow_sensor_emb_g = key_sensor_features_
            # slow_context_emb_g = key_context_features_
            fast_sensor_emb_e = sensor_features_.clone().detach()
            fast_context_emb_e = pred_features_.clone().detach()

            # I will do multiple eval recalls here
            slow_sensor_emb_g = key_sensor_features_.clone().detach()
            slow_context_emb_g = key_pred_features_.clone().detach()

            # assuming the above thing goes well
            fast_sensor_emb_e = fast_sensor_emb_e.cpu().numpy()
            fast_context_emb_e = fast_context_emb_e.cpu().numpy()
            slow_sensor_emb_g = slow_sensor_emb_g.cpu().numpy()
            slow_context_emb_g = slow_context_emb_g.cpu().numpy()

            # now also move the vis to numpy and plot it using matplotlib
            vis_e = __p(sensor_rgbs)
            vis_g = __p(sensor_rgbs)
            np_vis_e = vis_e.cpu().detach().numpy()
            np_vis_e = np.transpose(np_vis_e, [0, 2, 3, 1])
            np_vis_g = vis_g.cpu().detach().numpy()
            np_vis_g = np.transpose(np_vis_g, [0, 2, 3, 1])

            # bring it back to original color
            np_vis_g = ((np_vis_g+0.5) * 255).astype(np.uint8)
            np_vis_e = ((np_vis_e+0.5) * 255).astype(np.uint8)

            # now compare fast_sensor_emb_e with slow_context_emb_g
            # since I am doing positive against this
            fast_sensor_emb_e_list = [fast_sensor_emb_e, np_vis_e]
            slow_context_emb_g_list = [slow_context_emb_g, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_sensor_emb_e_list, slow_context_emb_g_list, recalls=recalls
            )

            # finally plot the nearest neighbour retrieval and move ahead
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_sensor_slow_context')

            # plot the precisions at different recalls
            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_sensor_slow_context/recall@{re}',\
                    prec[pr])

            # now compare fast_context_emb_e with slow_sensor_emb_g
            fast_context_emb_e_list = [fast_context_emb_e, np_vis_e]
            slow_sensor_emb_g_list = [slow_sensor_emb_g, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_context_emb_e_list, slow_sensor_emb_g_list, recalls=recalls
            )
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_context_slow_sensor')

            # plot the precisions at different recalls
            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_context_slow_sensor/recall@{re}',\
                    prec[pr])


            # now finally compare both the fast, I presume we want them to go closer too
            fast_sensor_list = [fast_sensor_emb_e, np_vis_e]
            fast_context_list = [fast_context_emb_e, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_sensor_list, fast_context_list, recalls=recalls
            )
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_sensor_fast_context')

            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_sensor_fast_context/recall@{re}',\
                    prec[pr])

        ## ... done code for filling up results needed for eval recall ... ##
        summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, [key_sensor_features_, key_pred_features_]
コード例 #27
0
def assemble(bkg_feat0, obj_feat0, origin_T_camRs, camRs_T_zoom):
    # let's first assemble the seq of background tensors
    # this should effectively CREATE egomotion
    # i fully expect we can do this all in one shot

    # note it makes sense to create egomotion here, because
    # we want to predict each view

    B, C, Z, Y, X = list(bkg_feat0.shape)
    B2, C2, Z2, Y2, X2 = list(obj_feat0.shape)
    assert (B == B2)
    assert (C == C2)

    B, S, _, _ = list(origin_T_camRs.shape)
    # ok, we have everything we need
    # for each timestep, we want to warp the bkg to this timestep

    # utils for packing/unpacking along seq dim
    __p = lambda x: pack_seqdim(x, B)
    __u = lambda x: unpack_seqdim(x, B)

    # we in fact have utils for this already
    cam0s_T_camRs = utils_geom.get_camM_T_camXs(origin_T_camRs, ind=0)
    camRs_T_cam0s = __u(utils_geom.safe_inverse(__p(cam0s_T_camRs)))

    bkg_feat0s = bkg_feat0.unsqueeze(1).repeat(1, S, 1, 1, 1, 1)
    bkg_featRs = apply_4x4s_to_voxs(camRs_T_cam0s, bkg_feat0s)

    # now for the objects

    # we want to sample for each location in the bird grid
    xyz_mems_ = utils_basic.gridcloud3D(B * S, Z, Y, X, norm=False)
    # this is B*S x Z*Y*X x 3
    xyz_camRs_ = Mem2Ref(xyz_mems_, Z, Y, X)
    camRs_T_zoom_ = __p(camRs_T_zoom)
    zoom_T_camRs_ = camRs_T_zoom_.inverse(
    )  # note this is not a rigid transform
    xyz_zooms_ = utils_geom.apply_4x4(zoom_T_camRs_, xyz_camRs_)

    # we will do the whole traj at once (per obj)
    # note we just have one feat for the whole traj, so we tile up
    obj_feats = obj_feat0.unsqueeze(1).repeat(1, S, 1, 1, 1, 1)
    obj_feats_ = __p(obj_feats)
    # this is B*S x Z x Y x X x C

    # to sample, we need feats_ in ZYX order
    obj_featRs_ = utils_samp.sample3D(obj_feats_, xyz_zooms_, Z, Y, X)
    obj_featRs = __u(obj_featRs_)

    # overweigh objects, so that we essentially overwrite
    # featRs = 0.05*bkg_featRs + 0.95*obj_featRs

    # overwrite the bkg at the object
    obj_mask = (bkg_featRs > 0).float()
    featRs = obj_featRs + (1.0 - obj_mask) * bkg_featRs

    # note the normalization (next) will restore magnitudes for the bkg

    # # featRs = bkg_featRs
    # featRs = obj_featRs

    # l2 normalize on chans
    featRs = l2_normalize(featRs, dim=2)

    validRs = 1.0 - (featRs == 0).all(dim=2, keepdim=True).float().cuda()

    return featRs, validRs, bkg_featRs, obj_featRs
コード例 #28
0
    def get_object_info(self, f, rgb_camX, pix_T_camX, origin_T_camX):
        # st()
        objs_info = f['objects_info']
        object_dict = {}
        if self.visualize:
            plt.imshow(rgb_camX[..., :3])
            plt.show(block=True)
        for obj_info in objs_info:
            classname = obj_info['category_name']
            if classname in self.ignore_classes:
                continue
            category = obj_info['category_id']
            instance_id = obj_info['instance_id']
            bbox_center = obj_info['bbox_center']
            bbox_size = obj_info['bbox_size']

            xmin, xmax = bbox_center[0] - bbox_size[0] / 2., bbox_center[
                0] + bbox_size[0] / 2.
            ymin, ymax = bbox_center[1] - bbox_size[1] / 2., bbox_center[
                1] + bbox_size[1] / 2.
            zmin, zmax = bbox_center[2] - bbox_size[2] / 2., bbox_center[
                2] + bbox_size[2] / 2.
            bbox_volume = (xmax - xmin) * (ymax - ymin) * (zmax - zmin)

            bbox_origin_ends = np.array([xmin, ymin, zmin, xmax, ymax, zmax])
            bbox_origin_ends = torch.tensor(bbox_origin_ends).reshape(
                1, 1, 2, 3).float()
            bbox_origin_theta = nlu.get_alignedboxes2thetaformat(
                bbox_origin_ends)
            bbox_origin_corners = utils_geom.transform_boxes_to_corners(
                bbox_origin_theta).float()
            camX_T_origin = utils_geom.safe_inverse(
                torch.tensor(origin_T_camX).unsqueeze(0)).float()
            bbox_corners_camX = utils_geom.apply_4x4(
                camX_T_origin.float(),
                bbox_origin_corners.squeeze(0).float())
            bbox_corners_pixX = utils_geom.apply_pix_T_cam(
                torch.tensor(pix_T_camX).unsqueeze(0).float(),
                bbox_corners_camX)
            bbox_ends_pixX = nlu.get_ends_of_corner(
                bbox_corners_pixX.permute(0, 2, 1)).permute(0, 2, 1)

            bbox_ends_pixX_np = torch.clamp(
                bbox_ends_pixX.squeeze(0), 0,
                rgb_camX.shape[1]).numpy().astype(int)
            bbox_area = (bbox_ends_pixX_np[1, 1] - bbox_ends_pixX_np[0, 1]) * (
                bbox_ends_pixX_np[1, 0] - bbox_ends_pixX_np[0, 0])
            print("Volume and area occupied by class {} is {} and {}".format(
                classname, bbox_volume, bbox_area))

            semantic = f['semantic_camX']
            instance_id_pixel_cnt = np.where(semantic == instance_id)[0].shape
            object_to_bbox_ratio = instance_id_pixel_cnt / bbox_area
            print(
                "Num pixels in semantic map {}. Ratio of pixels to bbox area{}. Ratio of pixels to bbox volume {}. "
                .format(instance_id_pixel_cnt, object_to_bbox_ratio,
                        instance_id_pixel_cnt / bbox_volume))
            if self.visualize:
                # print("bbox ends are: ", bbox_ends_pixX_np)
                cropped_rgb = rgb_camX[
                    bbox_ends_pixX_np[0, 1]:bbox_ends_pixX_np[1, 1],
                    bbox_ends_pixX_np[0, 0]:bbox_ends_pixX_np[1, 0], :3]
                plt.imshow(cropped_rgb)
                plt.show(block=True)
            if bbox_area < self.bbox_area_thresh:
                continue

            if object_to_bbox_ratio < self.occlusion_thresh:
                continue

            object_dict[instance_id] = (classname, category, instance_id,
                                        bbox_origin_ends)

        return object_dict
コード例 #29
0
    def process(self):
        fnames = []
        for scene_cnt, scene_dir in enumerate(self.scene_dirs):
            print("Processing scene {}. Scene number {}".format(
                scene_dir, scene_cnt))

            rgb_camXs = []
            depth_camXs = []
            pix_T_camXs = []
            origin_T_camXs = []
            xyz_camXs = []
            habitat_pix_T_camXs = []
            scene_bbox_ends = []
            scene_category_ids = []
            scene_category_names = []
            scene_instance_ids = []
            scene_object_dict = {}
            pickle_files = [
                os.path.join(scene_dir, f) for f in os.listdir(scene_dir)
                if f.endswith('.p')
            ]
            pickle_files = sorted(pickle_files)
            for cnt, pickle_file in enumerate(pickle_files):

                f = pickle.load(open(pickle_file, "rb"))
                if cnt == 0:
                    sample_f = f

                rgb_camXs.append(f['rgb_camX'])
                depth_camXs.append(f['depth_camX'])
                pix_T_camXs.append(self.get_pix_T_camX())
                habitat_pix_T_camXs.append(self.get_habitat_pix_T_camX())
                origin_T_camXs.append(
                    self.get_origin_T_camX(f['sensor_pos'], f['sensor_rot']))
                print("count of pickle file is: ", cnt)
                object_dict = self.get_object_info(f, rgb_camXs[-1],
                                                   pix_T_camXs[-1],
                                                   origin_T_camXs[-1])
                # st()
                for instance_id in object_dict:
                    if instance_id not in scene_object_dict:
                        scene_object_dict[instance_id] = []

                    scene_object_dict[instance_id].append(
                        object_dict[instance_id])

            scene_category_ids, scene_instance_ids, scene_category_names, scene_bbox_ends = self.select_frequently_occuring_objects(
                scene_object_dict)
            habitat_pix_T_camXs = np.stack(habitat_pix_T_camXs)
            rgb_camXs_to_save = np.stack(rgb_camXs)[:, :, :, :3]
            rgb_camXs = np.stack(rgb_camXs)[:, :, :, :3].astype(
                np.float32) / 255.
            origin_T_camXs = np.stack(origin_T_camXs)
            depth_camXs = np.stack(depth_camXs)
            pix_T_camXs = np.stack(pix_T_camXs)
            xyz_habitatCamXs = self.generate_xyz_habitatCamXs(
                depth_camXs, rgb_camXs, habitat_pix_T_camXs)

            if self.visualize:
                print(
                    "Showing pointclouds in habitat_camXs coordinate ref frame"
                )
                for xyz_habitatCamX in xyz_habitatCamXs:
                    pcd = nlu.make_pcd(xyz_habitatCamX)
                    o3d.visualization.draw_geometries([pcd, self.mesh_frame])

            # Get xyz_camXs in pydisco coordinate frame.
            # Since its 180 deg rotation, habitatCamX_T_camX and it's inverse will be same. Therefore, not taking inv.
            xyz_camXs = utils_geom.apply_4x4(
                torch.tensor(self.get_habitatCamX_T_camX()).repeat(
                    xyz_habitatCamXs.shape[0], 1, 1),
                torch.tensor(xyz_habitatCamXs)).numpy()

            if self.visualize:
                print(
                    "Showing pointclouds in pydisco_camXs coordinate ref frame"
                )
                for xyz_camX, rgb_camX in zip(xyz_camXs, rgb_camXs):
                    pcd = nlu.make_pcd(xyz_camX)
                    o3d.visualization.draw_geometries([pcd, self.mesh_frame])

                    pix_T_camX = pix_T_camXs[0]
                    depth, _ = utils_geom.create_depth_image(
                        torch.tensor(pix_T_camX).unsqueeze(0).float(),
                        torch.tensor(xyz_camX).unsqueeze(0).float(), self.H,
                        self.W)
                    depth[torch.where(depth > 10)] = 0
                    utils_pointcloud.visualize_colored_pcd(
                        depth.squeeze(0).squeeze(0).numpy(), rgb_camX,
                        pix_T_camX)

            xyz_camXs_origin = utils_geom.apply_4x4(
                torch.tensor(origin_T_camXs), torch.tensor(xyz_camXs))
            xyz_camXs_origin_agg = xyz_camXs_origin.reshape(-1, 3)

            # Visualize aggregated pointcloud
            if self.visualize:
                print("Showing aggregated pointclouds")
                pcd_list = [self.mesh_frame]
                for xyz_camX_origin in xyz_camXs_origin:
                    pcd_list.append(nlu.make_pcd(xyz_camX_origin))
                o3d.visualization.draw_geometries(pcd_list)

            if self.visualize:
                self.test_bbox_projection(xyz_camXs_origin_agg,
                                          origin_T_camXs[0], pix_T_camXs[0],
                                          rgb_camXs[0], xyz_camXs[0], sample_f)

            # object_data = self.get_objects_in_scene()
            # First num_camR_candidates views will be our camR candidates
            for num_save in range(self.num_camR_candidates):
                camX1_T_origin = utils_geom.safe_inverse(
                    torch.tensor(
                        origin_T_camXs[num_save]).unsqueeze(0)).float().repeat(
                            origin_T_camXs.shape[0], 1, 1).numpy()
                data_to_save = {
                    "camR_index": num_save,
                    "object_category_ids": scene_category_ids,
                    "object_category_names": scene_category_names,
                    "object_instance_ids": scene_instance_ids,
                    "bbox_origin": scene_bbox_ends,
                    "pix_T_cams_raw": pix_T_camXs,
                    "camR_T_origin_raw": camX1_T_origin,
                    "xyz_camXs_raw": xyz_camXs,
                    "origin_T_camXs_raw": origin_T_camXs,
                    'rgb_camXs_raw': rgb_camXs_to_save
                }
                cur_epoch = str(time()).replace(".", "")
                pickle_fname = cur_epoch + ".p"
                fnames.append(pickle_fname)
                with open(os.path.join(dump_dir, pickle_fname), 'wb') as f:
                    pickle.dump(data_to_save, f)

        return fnames
コード例 #30
0
    def prepare_common_tensors(self, feed, prep_summ=True):
        results = dict()

        if prep_summ:
            self.summ_writer = utils_improc.Summ_writer(
                writer=feed['writer'],
                global_step=feed['global_step'],
                log_freq=feed['set_log_freq'],
                fps=8,
                just_gif=feed['just_gif'],
            )
        else:
            self.summ_writer = None

        self.include_vis = hyp.do_include_vis

        self.B = feed["set_batch_size"]
        self.S = feed["set_seqlen"]

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        self.H, self.W, self.V, self.N = hyp.H, hyp.W, hyp.V, hyp.N
        self.PH, self.PW = hyp.PH, hyp.PW
        self.K = hyp.K

        self.set_name = feed['set_name']
        # print('set_name', self.set_name)
        if self.set_name == 'test':
            self.Z, self.Y, self.X = hyp.Z_test, hyp.Y_test, hyp.X_test
        else:
            self.Z, self.Y, self.X = hyp.Z, hyp.Y, hyp.X
        # print('Z, Y, X = %d, %d, %d' % (self.Z, self.Y, self.X))

        self.Z2, self.Y2, self.X2 = int(self.Z / 2), int(self.Y / 2), int(
            self.X / 2)
        self.Z4, self.Y4, self.X4 = int(self.Z / 4), int(self.Y / 4), int(
            self.X / 4)

        self.rgb_camXs = feed["rgb_camXs"]
        self.pix_T_cams = feed["pix_T_cams"]

        self.origin_T_camXs = feed["origin_T_camXs"]

        self.cams_T_velos = feed["cams_T_velos"]

        self.camX0s_T_camXs = utils_geom.get_camM_T_camXs(self.origin_T_camXs,
                                                          ind=0)
        self.camXs_T_camX0s = __u(
            utils_geom.safe_inverse(__p(self.camX0s_T_camXs)))

        self.xyz_veloXs = feed["xyz_veloXs"]
        self.xyz_camXs = __u(
            utils_geom.apply_4x4(__p(self.cams_T_velos), __p(self.xyz_veloXs)))
        self.xyz_camX0s = __u(
            utils_geom.apply_4x4(__p(self.camX0s_T_camXs),
                                 __p(self.xyz_camXs)))

        if self.set_name == 'test':
            self.boxlist_camXs = feed["boxlists"]
            self.scorelist_s = feed["scorelists"]
            self.tidlist_s = feed["tidlists"]

            boxlist_camXs_ = __p(self.boxlist_camXs)
            scorelist_s_ = __p(self.scorelist_s)
            tidlist_s_ = __p(self.tidlist_s)
            boxlist_camXs_, tidlist_s_, scorelist_s_ = utils_misc.shuffle_valid_and_sink_invalid_boxes(
                boxlist_camXs_, tidlist_s_, scorelist_s_)
            self.boxlist_camXs = __u(boxlist_camXs_)
            self.scorelist_s = __u(scorelist_s_)
            self.tidlist_s = __u(tidlist_s_)

            # self.boxlist_camXs[:,0], self.scorelist_s[:,0], self.tidlist_s[:,0] = utils_misc.shuffle_valid_and_sink_invalid_boxes(
            #     self.boxlist_camXs[:,0], self.tidlist_s[:,0], self.scorelist_s[:,0])

            # self.score_s = feed["scorelists"]
            # self.tid_s = torch.ones_like(self.score_s).long()
            # self.lrt_camRs = utils_geom.convert_boxlist_to_lrtlist(self.box_camRs)
            # self.lrt_camXs = utils_geom.apply_4x4s_to_lrts(self.camXs_T_camRs, self.lrt_camRs)
            # self.lrt_camX0s = utils_geom.apply_4x4s_to_lrts(self.camX0s_T_camXs, self.lrt_camXs)
            # self.lrt_camR0s = utils_geom.apply_4x4s_to_lrts(self.camR0s_T_camRs, self.lrt_camRs)

            # boxlist_camXs_ = __p(self.boxlist_camXs)
            # boxlist_camXs_ = __p(self.boxlist_camXs)

            # lrtlist_camXs = __u(utils_geom.convert_boxlist_to_lrtlist(__p(self.boxlist_camXs))).reshape(
            #     self.B, self.S, self.N, 19)

            self.lrtlist_camXs = __u(
                utils_geom.convert_boxlist_to_lrtlist(__p(self.boxlist_camXs)))

            # print('lrtlist_camXs', lrtlist_camXs.shape)
            # # self.B, self.S, self.N, 19)
            # # lrtlist_camXs = __u(utils_geom.apply_4x4_to_lrtlist(__p(camXs_T_camRs), __p(lrtlist_camRs)))
            # self.summ_writer.summ_lrtlist('2D_inputs/lrtlist_camX0', self.rgb_camXs[:,0], lrtlist_camXs[:,0],
            #                               self.scorelist_s[:,0], self.tidlist_s[:,0], self.pix_T_cams[:,0])
            # self.summ_writer.summ_lrtlist('2D_inputs/lrtlist_camX1', self.rgb_camXs[:,1], lrtlist_camXs[:,1],
            #                               self.scorelist_s[:,1], self.tidlist_s[:,1], self.pix_T_cams[:,1])

            (
                self.lrt_camXs,
                self.box_camXs,
                self.score_s,
            ) = utils_misc.collect_object_info(self.lrtlist_camXs,
                                               self.boxlist_camXs,
                                               self.tidlist_s,
                                               self.scorelist_s,
                                               1,
                                               mod='X',
                                               do_vis=False,
                                               summ_writer=None)
            self.lrt_camXs = self.lrt_camXs.squeeze(0)
            self.score_s = self.score_s.squeeze(0)
            self.tid_s = torch.ones_like(self.score_s).long()

            self.lrt_camX0s = utils_geom.apply_4x4s_to_lrts(
                self.camX0s_T_camXs, self.lrt_camXs)

            if prep_summ and self.include_vis:
                visX_g = []
                for s in list(range(self.S)):
                    visX_g.append(
                        self.summ_writer.summ_lrtlist('',
                                                      self.rgb_camXs[:, s],
                                                      self.lrtlist_camXs[:, s],
                                                      self.scorelist_s[:, s],
                                                      self.tidlist_s[:, s],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('2D_inputs/box_camXs', visX_g)
                # visX_g = []
                # for s in list(range(self.S)):
                #     visX_g.append(self.summ_writer.summ_lrtlist(
                #         'track/box_camX%d_g' % s, self.rgb_camXs[:,s], self.lrt_camXs[:,s:s+1],
                #         self.score_s[:,s:s+1], self.tid_s[:,s:s+1], self.pix_T_cams[:,0], only_return=True))
                # self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g)

        if self.set_name == 'test':
            # center on an object, so that it does not fall out of bounds
            self.scene_centroid = utils_geom.get_clist_from_lrtlist(
                self.lrt_camXs)[:, 0]
            self.vox_util = vox_util.Vox_util(
                self.Z,
                self.Y,
                self.X,
                self.set_name,
                scene_centroid=self.scene_centroid,
                assert_cube=True)
        else:
            # center randomly
            scene_centroid_x = np.random.uniform(-8.0, 8.0)
            scene_centroid_y = np.random.uniform(-1.5, 3.0)
            scene_centroid_z = np.random.uniform(10.0, 26.0)
            scene_centroid = np.array(
                [scene_centroid_x, scene_centroid_y,
                 scene_centroid_z]).reshape([1, 3])
            self.scene_centroid = torch.from_numpy(
                scene_centroid).float().cuda()
            # center on a random non-outlier point

            all_ok = False
            num_tries = 0
            while not all_ok:
                scene_centroid_x = np.random.uniform(-8.0, 8.0)
                scene_centroid_y = np.random.uniform(-1.5, 3.0)
                scene_centroid_z = np.random.uniform(10.0, 26.0)
                scene_centroid = np.array(
                    [scene_centroid_x, scene_centroid_y,
                     scene_centroid_z]).reshape([1, 3])
                self.scene_centroid = torch.from_numpy(
                    scene_centroid).float().cuda()
                num_tries += 1

                # try to vox
                self.vox_util = vox_util.Vox_util(
                    self.Z,
                    self.Y,
                    self.X,
                    self.set_name,
                    scene_centroid=self.scene_centroid,
                    assert_cube=True)
                all_ok = True

                # we want to ensure this gives us a few points inbound for each batch el
                inb = __u(
                    self.vox_util.get_inbounds(__p(self.xyz_camX0s),
                                               self.Z4,
                                               self.Y4,
                                               self.X,
                                               already_mem=False))
                num_inb = torch.sum(inb.float(), axis=2)
                if torch.min(num_inb) < 100:
                    all_ok = False

                if num_tries > 100:
                    return False
            self.summ_writer.summ_scalar('zoom_sampling/num_tries', num_tries)
            self.summ_writer.summ_scalar('zoom_sampling/num_inb',
                                         torch.mean(num_inb).cpu().item())

        self.occ_memXs = __u(
            self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y,
                                       self.X))
        self.occ_memX0s = __u(
            self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y,
                                       self.X))
        self.occ_memX0s_half = __u(
            self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z2, self.Y2,
                                       self.X2))

        self.unp_memXs = __u(
            self.vox_util.unproject_rgb_to_mem(__p(self.rgb_camXs), self.Z,
                                               self.Y, self.X,
                                               __p(self.pix_T_cams)))
        self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs(
            self.camX0s_T_camXs, self.unp_memXs)

        if prep_summ and self.include_vis:
            self.summ_writer.summ_rgbs('2D_inputs/rgb_camXs',
                                       torch.unbind(self.rgb_camXs, dim=1))
            self.summ_writer.summ_occs('3D_inputs/occ_memXs',
                                       torch.unbind(self.occ_memXs, dim=1))
            self.summ_writer.summ_occs('3D_inputs/occ_memX0s',
                                       torch.unbind(self.occ_memX0s, dim=1))
            self.summ_writer.summ_rgb('2D_inputs/rgb_camX0', self.rgb_camXs[:,
                                                                            0])
            # self.summ_writer.summ_oned('2D_inputs/depth_camX0', self.depth_camXs[:,0], maxval=20.0)
            # self.summ_writer.summ_oned('2D_inputs/valid_camX0', self.valid_camXs[:,0], norm=False)
        return True