def transform_boxes_to_corners(boxes):
    # returns corners, shaped B x N x 8 x 3
    B, N, D = list(boxes.shape)
    assert (D == 9)

    __p = lambda x: utils_basic.pack_seqdim(x, B)
    __u = lambda x: utils_basic.unpack_seqdim(x, B)

    boxes_ = __p(boxes)
    corners_ = transform_boxes_to_corners_single(boxes_)
    corners = __u(corners_)
    return corners
def transform_corners_to_boxes(corners):
    # corners is B x N x 8 x 3
    B, N, C, D = corners.shape
    assert (C == 8)
    assert (D == 3)
    # do them all at once
    __p = lambda x: utils_basic.pack_seqdim(x, B)
    __u = lambda x: utils_basic.unpack_seqdim(x, B)
    corners_ = __p(corners)
    boxes_ = transform_corners_to_boxes_single(corners_)
    boxes_ = boxes_.cuda()
    boxes = __u(boxes_)
    return boxes
def rotate_tensor_along_y_axis(tensor, gamma):
    B = tensor.shape[0]
    tensor = tensor.to("cpu")
    assert tensor.ndim == 6, "Tensors should have 6 dimensions."
    tensor = tensor.float()
    # B,S,C,D,H,W
    __p = lambda x: utils_basic.pack_seqdim(x, B)
    __u = lambda x: utils_basic.unpack_seqdim(x, B)
    tensor_ = __p(tensor)
    tensor_ = tensor_.permute(
        0, 1, 3, 2, 4)  # Make it BS, C, H, D, W  (i.e. BS, C, y, z, x)
    BS, C, H, D, W = tensor_.shape

    # merge y dimension with channel dimension and rotate with gamma_
    tensor_y_reduced = tensor_.reshape(BS, C * H, D, W)
    # # gammas will be rotation angles along y axis.
    # gammas = torch.arange(10, 360, 10)

    # define the rotation center
    center = torch.ones(1, 2)
    center[..., 0] = tensor_y_reduced.shape[3] / 2  # x
    center[..., 1] = tensor_y_reduced.shape[2] / 2  # z

    # define the scale factor
    scale = torch.ones(1)

    gamma_ = torch.ones(1) * gamma

    # compute the transformation matrix
    M = kornia.get_rotation_matrix2d(center, gamma_, scale)
    M = M.repeat(BS, 1, 1)
    # apply the transformation to original image
    # st()
    tensor_y_reduced_warped = kornia.warp_affine(tensor_y_reduced,
                                                 M,
                                                 dsize=(D, W))
    tensor_y_reduced_warped = tensor_y_reduced_warped.reshape(BS, C, H, D, W)
    tensor_y_reduced_warped = tensor_y_reduced_warped.permute(0, 1, 3, 2, 4)
    tensor_y_reduced_warped = __u(tensor_y_reduced_warped)
    return tensor_y_reduced_warped.cuda()
Exemple #4
0
    def __getitem__(self, index):
        if hyp.dataset_name == 'kitti' or hyp.dataset_name == 'clevr' or hyp.dataset_name == 'real' or hyp.dataset_name == "bigbird" or hyp.dataset_name == "carla" or hyp.dataset_name == "carla_mix" or hyp.dataset_name == "replica" or hyp.dataset_name == "clevr_vqa" or hyp.dataset_name == "carla_det":
            # print(index)
            # st()
            filename = self.records[index]
            d = pickle.load(open(filename, "rb"))
            d = dict(d)

            d_empty = pickle.load(open(self.empty_scene, "rb"))
            d_empty = dict(d_empty)
            # st()
        # elif hyp.dataset_name=="carla":
        #     filename = self.records[index]
        #     d = np.load(filename)
        #     d = dict(d)

        #     d['rgb_camXs_raw'] = d['rgb_camXs']
        #     d['pix_T_cams_raw'] = d['pix_T_cams']
        #     d['tree_seq_filename'] = "dummy_tree_filename"
        #     d['origin_T_camXs_raw'] = d['origin_T_camXs']
        #     d['camR_T_origin_raw'] = utils_geom.safe_inverse(torch.from_numpy(d['origin_T_camRs'])).numpy()
        #     d['xyz_camXs_raw'] = d['xyz_camXs']

        else:
            assert (False)  # reader not ready yet

        if hyp.do_empty:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
                'empty_rgb_camXs_raw',
                'empty_xyz_camXs_raw',
            ]
        else:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
            ]

        if hyp.use_gt_occs:
            __p = lambda x: utils_basic.pack_seqdim(x, 1)
            __u = lambda x: utils_basic.unpack_seqdim(x, 1)

            B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
            PH, PW = hyp.PH, hyp.PW
            K = hyp.K
            BOX_SIZE = hyp.BOX_SIZE
            Z, Y, X = hyp.Z, hyp.Y, hyp.X
            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
            D = 9
            pix_T_cams = torch.from_numpy(
                d["pix_T_cams_raw"]).unsqueeze(0).cuda().to(torch.float)
            camRs_T_origin = torch.from_numpy(
                d["camR_T_origin_raw"]).unsqueeze(0).cuda().to(torch.float)
            origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
            origin_T_camXs = torch.from_numpy(
                d["origin_T_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
            camRs_T_camXs = __u(
                torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                             __p(origin_T_camXs)))
            camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
            camX0_T_camRs = camXs_T_camRs[:, 0]
            camX1_T_camRs = camXs_T_camRs[:, 1]
            camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)
            xyz_camXs = torch.from_numpy(
                d["xyz_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            xyz_camRs = __u(
                utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
            depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
                __p(pix_T_cams), __p(xyz_camXs), H, W)
            dense_xyz_camXs_ = utils_geom.depth2pointcloud(
                depth_camXs_, __p(pix_T_cams))
            occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))
            occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2,
                                                    X2))
            occRs_half = torch.max(occRs_half, dim=1).values.squeeze(0)
            occ_complete = occRs_half.cpu().numpy()

        # if hyp.do_time_flip:
        #     d = random_time_flip_single(d,item_names)
        # if the sequence length > 2, select S frames
        # filename = d['raw_seq_filename']
        original_filename = filename
        original_filename_empty = self.empty_scene

        # st()
        if hyp.dataset_name == "clevr_vqa":
            d['tree_seq_filename'] = "temp"
            pix_T_cams = d['pix_T_cams_raw']
            num_cams = pix_T_cams.shape[0]
            # padding_1 = torch.zeros([num_cams,1,3])
            # padding_2 = torch.zeros([num_cams,4,1])
            # padding_2[:,3] = 1.0
            # st()
            # pix_T_cams = torch.cat([pix_T_cams,padding_1],dim=1)
            # pix_T_cams = torch.cat([pix_T_cams,padding_2],dim=2)
            # st()
            shape_name = d['shape_list']
            color_name = d['color_list']
            material_name = d['material_list']
            all_name = []
            all_style = []
            for index in range(len(shape_name)):
                name = shape_name[index] + "/" + color_name[
                    index] + "_" + material_name[index]
                style_name = color_name[index] + "_" + material_name[index]
                all_name.append(name)
                all_style.append(style_name)

            # st()

            if hyp.do_shape:
                class_name = shape_name
            elif hyp.do_color:
                class_name = color_name
            elif hyp.do_material:
                class_name = material_name
            elif hyp.do_style:
                class_name = all_style
            else:
                class_name = all_name

            object_category = class_name
            bbox_origin = d['bbox_origin']
            # bbox_origin = torch.cat([bbox_origin],dim=0)
            # object_category = object_category
            bbox_origin_empty = np.zeros_like(bbox_origin)
            object_category_empty = ['0']
        # st()
        if not hyp.dataset_name == "clevr_vqa":
            filename = d['tree_seq_filename']
            filename_empty = d_empty['tree_seq_filename']
        if hyp.fixed_view:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)
            d_empty, indexes_empty = specific_select_single_empty(
                d_empty,
                item_names,
                d['origin_T_camXs_raw'],
                num_samples=hyp.S)

        filename_g = "/".join([original_filename, str(indexes[0])])
        filename_e = "/".join([original_filename, str(indexes[1])])

        filename_g_empty = "/".join([original_filename_empty, str(indexes[0])])
        filename_e_empty = "/".join([original_filename_empty, str(indexes[1])])

        rgb_camXs = d['rgb_camXs_raw']
        rgb_camXs_empty = d_empty['rgb_camXs_raw']
        # move channel dim inward, like pytorch wants
        # rgb_camRs = np.transpose(rgb_camRs, axes=[0, 3, 1, 2])
        rgb_camXs = np.transpose(rgb_camXs, axes=[0, 3, 1, 2])
        rgb_camXs = rgb_camXs[:, :3]
        rgb_camXs = utils_improc.preprocess_color(rgb_camXs)

        rgb_camXs_empty = np.transpose(rgb_camXs_empty, axes=[0, 3, 1, 2])
        rgb_camXs_empty = rgb_camXs_empty[:, :3]
        rgb_camXs_empty = utils_improc.preprocess_color(rgb_camXs_empty)

        if hyp.dataset_name == "clevr_vqa":
            num_boxes = bbox_origin.shape[0]
            bbox_origin = np.array(bbox_origin)
            score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
            bbox_origin = np.pad(bbox_origin,
                                 [[0, hyp.N - num_boxes], [0, 0], [0, 0]])
            object_category = np.pad(object_category, [[0, hyp.N - num_boxes]],
                                     lambda x, y, z, m: "0")
            object_category_empty = np.pad(object_category_empty,
                                           [[0, hyp.N - 1]],
                                           lambda x, y, z, m: "0")

            # st()
            score_empty = np.zeros_like(score)
            bbox_origin_empty = np.zeros_like(bbox_origin)
            d['gt_box'] = np.stack(
                [bbox_origin.astype(np.float32), bbox_origin_empty])
            d['gt_scores'] = np.stack([score.astype(np.float32), score_empty])
            try:
                d['classes'] = np.stack(
                    [object_category, object_category_empty]).tolist()
            except Exception as e:
                st()

        d['rgb_camXs_raw'] = np.stack([rgb_camXs, rgb_camXs_empty])
        d['pix_T_cams_raw'] = np.stack(
            [d["pix_T_cams_raw"], d_empty["pix_T_cams_raw"]])
        d['origin_T_camXs_raw'] = np.stack(
            [d["origin_T_camXs_raw"], d_empty["origin_T_camXs_raw"]])
        d['camR_T_origin_raw'] = np.stack(
            [d["camR_T_origin_raw"], d_empty["camR_T_origin_raw"]])
        d['xyz_camXs_raw'] = np.stack(
            [d["xyz_camXs_raw"], d_empty["xyz_camXs_raw"]])
        # d['rgb_camXs_raw'] = rgb_camXs
        # d['tree_seq_filename'] = filename
        if not hyp.dataset_name == "clevr_vqa":
            d['tree_seq_filename'] = [filename, "invalid_tree"]
        else:
            d['tree_seq_filename'] = ["temp"]
        # st()
        d['filename_e'] = ["temp"]
        d['filename_g'] = ["temp"]
        if hyp.use_gt_occs:
            d['occR_complete'] = np.expand_dims(occ_complete, axis=0)
        return d
Exemple #5
0
    def __getitem__(self, index):
        if hyp.dataset_name == 'kitti' or hyp.dataset_name == 'clevr' or hyp.dataset_name == 'real' or hyp.dataset_name == "bigbird" or hyp.dataset_name == "carla" or hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det" or hyp.dataset_name == "replica" or hyp.dataset_name == "clevr_vqa":
            # print(index)
            filename = self.records[index]
            d = pickle.load(open(filename, "rb"))
            d = dict(d)
        # elif hyp.dataset_name=="carla":
        #     filename = self.records[index]
        #     d = np.load(filename)
        #     d = dict(d)

        #     d['rgb_camXs_raw'] = d['rgb_camXs']
        #     d['pix_T_cams_raw'] = d['pix_T_cams']
        #     d['tree_seq_filename'] = "dummy_tree_filename"
        #     d['origin_T_camXs_raw'] = d['origin_T_camXs']
        #     d['camR_T_origin_raw'] = utils_geom.safe_inverse(torch.from_numpy(d['origin_T_camRs'])).numpy()
        #     d['xyz_camXs_raw'] = d['xyz_camXs']

        else:
            assert (False)  # reader not ready yet

        # st()
        # if hyp.save_gt_occs:
        # pickle.dump(d,open(filename, "wb"))
        # st()
        # st()
        if hyp.use_gt_occs:
            __p = lambda x: utils_basic.pack_seqdim(x, 1)
            __u = lambda x: utils_basic.unpack_seqdim(x, 1)

            B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
            PH, PW = hyp.PH, hyp.PW
            K = hyp.K
            BOX_SIZE = hyp.BOX_SIZE
            Z, Y, X = hyp.Z, hyp.Y, hyp.X
            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
            D = 9
            pix_T_cams = torch.from_numpy(
                d["pix_T_cams_raw"]).unsqueeze(0).cuda().to(torch.float)
            camRs_T_origin = torch.from_numpy(
                d["camR_T_origin_raw"]).unsqueeze(0).cuda().to(torch.float)
            origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
            origin_T_camXs = torch.from_numpy(
                d["origin_T_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
            camRs_T_camXs = __u(
                torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                             __p(origin_T_camXs)))
            camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
            camX0_T_camRs = camXs_T_camRs[:, 0]
            camX1_T_camRs = camXs_T_camRs[:, 1]
            camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)
            xyz_camXs = torch.from_numpy(
                d["xyz_camXs_raw"]).unsqueeze(0).cuda().to(torch.float)
            xyz_camRs = __u(
                utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
            depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
                __p(pix_T_cams), __p(xyz_camXs), H, W)
            dense_xyz_camXs_ = utils_geom.depth2pointcloud(
                depth_camXs_, __p(pix_T_cams))
            occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))
            occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2,
                                                    X2))
            occRs_half = torch.max(occRs_half, dim=1).values.squeeze(0)
            occ_complete = occRs_half.cpu().numpy()

            # st()

        if hyp.do_empty:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
                'empty_rgb_camXs_raw',
                'empty_xyz_camXs_raw',
            ]
        else:
            item_names = [
                'pix_T_cams_raw',
                'origin_T_camXs_raw',
                'camR_T_origin_raw',
                'rgb_camXs_raw',
                'xyz_camXs_raw',
            ]

        # if hyp.do_time_flip:
        #     d = random_time_flip_single(d,item_names)
        # if the sequence length > 2, select S frames
        # filename = d['raw_seq_filename']
        original_filename = filename
        if hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det":
            bbox_origin_gt = d['bbox_origin']
            if 'bbox_origin_predicted' in d:
                bbox_origin_predicted = d['bbox_origin_predicted']
            else:
                bbox_origin_predicted = []
            classes = d['obj_name']

            if isinstance(classes, str):
                classes = [classes]
            # st()

            d['tree_seq_filename'] = "temp"
        if hyp.dataset_name == "replica":
            d['tree_seq_filename'] = "temp"
            object_category = d['object_category_names']
            bbox_origin = d['bbox_origin']

        if hyp.dataset_name == "clevr_vqa":
            d['tree_seq_filename'] = "temp"
            pix_T_cams = d['pix_T_cams_raw']
            num_cams = pix_T_cams.shape[0]
            # padding_1 = torch.zeros([num_cams,1,3])
            # padding_2 = torch.zeros([num_cams,4,1])
            # padding_2[:,3] = 1.0
            # st()
            # pix_T_cams = torch.cat([pix_T_cams,padding_1],dim=1)
            # pix_T_cams = torch.cat([pix_T_cams,padding_2],dim=2)
            # st()
            shape_name = d['shape_list']
            color_name = d['color_list']
            material_name = d['material_list']
            all_name = []
            all_style = []
            for index in range(len(shape_name)):
                name = shape_name[index] + "/" + color_name[
                    index] + "_" + material_name[index]
                style_name = color_name[index] + "_" + material_name[index]
                all_name.append(name)
                all_style.append(style_name)

            # st()

            if hyp.do_shape:
                class_name = shape_name
            elif hyp.do_color:
                class_name = color_name
            elif hyp.do_material:
                class_name = material_name
            elif hyp.do_style:
                class_name = all_style
            else:
                class_name = all_name

            object_category = class_name
            bbox_origin = d['bbox_origin']
            # st()

        if hyp.dataset_name == "carla":
            camR_index = d['camR_index']
            rgb_camtop = d['rgb_camXs_raw'][camR_index:camR_index + 1]
            origin_T_camXs_top = d['origin_T_camXs_raw'][
                camR_index:camR_index + 1]
            # predicted_box  = d['bbox_origin_predicted']
            predicted_box = []
        filename = d['tree_seq_filename']
        if hyp.do_2d_style_munit:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)

        # st()
        if hyp.fixed_view:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)
        elif self.shuffle or hyp.randomly_select_views:
            d, indexes = random_select_single(d, item_names, num_samples=hyp.S)
        else:
            d, indexes = non_random_select_single(d,
                                                  item_names,
                                                  num_samples=hyp.S)

        filename_g = "/".join([original_filename, str(indexes[0])])
        filename_e = "/".join([original_filename, str(indexes[1])])

        rgb_camXs = d['rgb_camXs_raw']
        # move channel dim inward, like pytorch wants
        # rgb_camRs = np.transpose(rgb_camRs, axes=[0, 3, 1, 2])

        rgb_camXs = np.transpose(rgb_camXs, axes=[0, 3, 1, 2])
        rgb_camXs = rgb_camXs[:, :3]
        rgb_camXs = utils_improc.preprocess_color(rgb_camXs)

        if hyp.dataset_name == "carla":
            rgb_camtop = np.transpose(rgb_camtop, axes=[0, 3, 1, 2])
            rgb_camtop = rgb_camtop[:, :3]
            rgb_camtop = utils_improc.preprocess_color(rgb_camtop)
            d['rgb_camtop'] = rgb_camtop
            d['origin_T_camXs_top'] = origin_T_camXs_top
            if len(predicted_box) == 0:
                predicted_box = np.zeros([hyp.N, 6])
                score = np.zeros([hyp.N]).astype(np.float32)
            else:
                num_boxes = predicted_box.shape[0]
                score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
                predicted_box = np.pad(predicted_box,
                                       [[0, hyp.N - num_boxes], [0, 0]])
            d['predicted_box'] = predicted_box.astype(np.float32)
            d['predicted_scores'] = score.astype(np.float32)
        if hyp.dataset_name == "clevr_vqa":
            num_boxes = bbox_origin.shape[0]
            bbox_origin = np.array(bbox_origin)
            score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
            bbox_origin = np.pad(bbox_origin,
                                 [[0, hyp.N - num_boxes], [0, 0], [0, 0]])
            object_category = np.pad(object_category, [[0, hyp.N - num_boxes]],
                                     lambda x, y, z, m: "0")

            d['gt_box'] = bbox_origin.astype(np.float32)
            d['gt_scores'] = score.astype(np.float32)
            d['classes'] = list(object_category)

        if hyp.dataset_name == "replica":
            if len(bbox_origin) == 0:
                score = np.zeros([hyp.N])
                bbox_origin = np.zeros([hyp.N, 6])
                object_category = ["0"] * hyp.N
                object_category = np.array(object_category)
            else:
                num_boxes = len(bbox_origin)
                bbox_origin = torch.stack(bbox_origin).numpy().squeeze(
                    1).squeeze(1).reshape([num_boxes, 6])
                bbox_origin = np.array(bbox_origin)
                score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
                bbox_origin = np.pad(bbox_origin,
                                     [[0, hyp.N - num_boxes], [0, 0]])
                object_category = np.pad(object_category,
                                         [[0, hyp.N - num_boxes]],
                                         lambda x, y, z, m: "0")
            d['gt_box'] = bbox_origin.astype(np.float32)
            d['gt_scores'] = score.astype(np.float32)
            d['classes'] = list(object_category)
            # st()

        if hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det":
            bbox_origin_predicted = bbox_origin_predicted[:3]
            if len(bbox_origin_gt.shape) == 1:
                bbox_origin_gt = np.expand_dims(bbox_origin_gt, 0)
            num_boxes = bbox_origin_gt.shape[0]
            # st()
            score_gt = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes])
            bbox_origin_gt = np.pad(bbox_origin_gt,
                                    [[0, hyp.N - num_boxes], [0, 0]])
            # st()
            classes = np.pad(classes, [[0, hyp.N - num_boxes]],
                             lambda x, y, z, m: "0")

            if len(bbox_origin_predicted) == 0:
                bbox_origin_predicted = np.zeros([hyp.N, 6])
                score_pred = np.zeros([hyp.N]).astype(np.float32)
            else:
                num_boxes = bbox_origin_predicted.shape[0]
                score_pred = np.pad(np.ones([num_boxes]),
                                    [0, hyp.N - num_boxes])
                bbox_origin_predicted = np.pad(
                    bbox_origin_predicted, [[0, hyp.N - num_boxes], [0, 0]])

            d['predicted_box'] = bbox_origin_predicted.astype(np.float32)
            d['predicted_scores'] = score_pred.astype(np.float32)
            d['gt_box'] = bbox_origin_gt.astype(np.float32)
            d['gt_scores'] = score_gt.astype(np.float32)
            d['classes'] = list(classes)

        d['rgb_camXs_raw'] = rgb_camXs

        if hyp.dataset_name != "carla" and hyp.do_empty:
            empty_rgb_camXs = d['empty_rgb_camXs_raw']
            # move channel dim inward, like pytorch wants
            empty_rgb_camXs = np.transpose(empty_rgb_camXs, axes=[0, 3, 1, 2])
            empty_rgb_camXs = empty_rgb_camXs[:, :3]
            empty_rgb_camXs = utils_improc.preprocess_color(empty_rgb_camXs)
            d['empty_rgb_camXs_raw'] = empty_rgb_camXs
        # st()
        if hyp.use_gt_occs:
            d['occR_complete'] = occ_complete
        d['tree_seq_filename'] = filename
        d['filename_e'] = filename_e
        d['filename_g'] = filename_g
        return d
Exemple #6
0
    def forward(self, clist_cam, energy_map, occ_mems, summ_writer):

        total_loss = torch.tensor(0.0).cuda()

        B, S, C, Z, Y, X = list(occ_mems.shape)
        B2, S, D = list(clist_cam.shape)
        assert (B == B2)

        traj_past = clist_cam[:, :self.T_past]
        traj_futu = clist_cam[:, self.T_past:]

        # just xz
        traj_past = torch.stack([traj_past[:, :, 0], traj_past[:, :, 2]],
                                dim=2)  # xz
        traj_futu = torch.stack([traj_futu[:, :, 0], traj_futu[:, :, 2]],
                                dim=2)  # xz

        feat = occ_mems[:, 0].permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X)
        mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float().cuda()
        halfgrid = utils_basic.meshgrid2D(B,
                                          int(Z / 2),
                                          int(X / 2),
                                          stack=True,
                                          norm=True).permute(0, 3, 1, 2)
        feat_map, _ = self.compressor(feat, mask, halfgrid)
        pred_map = self.conv2d(feat_map)
        # these are B x C x Z x X

        K = 12  # number of samples
        traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1)
        feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1)
        pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1)
        # to sample the K trajectories in parallel, we'll pack K onto the batch dim
        __p = lambda x: utils_basic.pack_seqdim(x, K)
        __u = lambda x: utils_basic.unpack_seqdim(x, K)
        traj_past_ = __p(traj_past)
        feat_map_ = __p(feat_map)
        pred_map_ = __p(pred_map)
        base_sample_ = torch.randn(K * B, self.T_futu, 2).cuda()
        traj_futu_e_ = self.compute_forward_mapping(feat_map_, pred_map_,
                                                    base_sample_, traj_past_)
        traj_futu_e = __u(traj_futu_e_)
        # this is K x B x T x 2

        # print('traj_futu_e', traj_futu_e.shape, traj_futu_e[0,0])
        if summ_writer.save_this:
            o = []
            for k in list(range(K)):
                o.append(
                    utils_improc.preprocess_color(
                        summ_writer.summ_traj_on_occ(
                            '',
                            utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]),
                                              Z, Y, X),
                            occ_mems[:, 0],
                            already_mem=True,
                            only_return=True)))
                summ_writer.summ_traj_on_occ(
                    'rponet/traj_futu_sample_%d' % k,
                    utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y,
                                      X),
                    occ_mems[:, 0],
                    already_mem=True)

            mean_vis = torch.max(torch.stack(o, dim=0), dim=0)[0]
            summ_writer.summ_rgb('rponet/traj_futu_e_mean', mean_vis)

            summ_writer.summ_traj_on_occ('rponet/traj_futu_g',
                                         utils_vox.Ref2Mem(
                                             self.add_fake_y(traj_futu), Z, Y,
                                             X),
                                         occ_mems[:, 0],
                                         already_mem=True)

        # forward loss: neg logprob of GT samples under the model
        # reverse loss: neg logprob of estim samples under the (approx) GT (i.e., spatial prior)
        forward_loss, reverse_loss = self.compute_loss(feat_map[0],
                                                       pred_map[0],
                                                       traj_past[0], traj_futu,
                                                       traj_futu_e, energy_map)
        total_loss = utils_misc.add_loss('rpo/forward_loss', total_loss,
                                         forward_loss, hyp.rpo2D_forward_coeff,
                                         summ_writer)
        total_loss = utils_misc.add_loss('rpo/reverse_loss', total_loss,
                                         reverse_loss, hyp.rpo2D_reverse_coeff,
                                         summ_writer)

        return total_loss
Exemple #7
0
    def forward(self, clist_cam, occs, summ_writer, vox_util, suffix=''):
        total_loss = torch.tensor(0.0).cuda()
        B, S, C, Z, Y, X = list(occs.shape)
        B2, S2, D = list(clist_cam.shape)
        assert (B == B2, S == S2)
        assert (D == 3)

        if summ_writer.save_this:
            summ_writer.summ_traj_on_occ('motioncost/actual_traj',
                                         clist_cam,
                                         occs[:, self.T_past],
                                         vox_util,
                                         sigma=2)

        __p = lambda x: utils_basic.pack_seqdim(x, B)
        __u = lambda x: utils_basic.unpack_seqdim(x, B)

        # occs_ = occs.reshape(B*S, C, Z, Y, X)
        occs_ = __p(occs)
        feats_ = occs_.permute(0, 1, 3, 2, 4).reshape(B * S, C * Y, Z, X)
        masks_ = 1.0 - (feats_ == 0).all(dim=1, keepdim=True).float().cuda()
        halfgrids_ = utils_basic.meshgrid2D(B * S,
                                            int(Z / 2),
                                            int(X / 2),
                                            stack=True,
                                            norm=True).permute(0, 3, 1, 2)
        # feats_ = torch.cat([feats_, grids_], dim=1)
        feats = __u(feats_)
        masks = __u(masks_)
        halfgrids = __u(halfgrids_)
        input_feats = feats[:, :self.T_past]
        input_masks = masks[:, :self.T_past]
        input_halfgrids = halfgrids[:, :self.T_past]
        dense_feats_, _ = self.densifier(__p(input_feats), __p(input_masks),
                                         __p(input_halfgrids))
        dense_feats = __u(dense_feats_)
        super_feat = dense_feats.reshape(B, self.T_past * self.dense_dim,
                                         int(Z / 2), int(X / 2))
        cost_maps = self.motioncoster(super_feat)
        cost_maps = F.interpolate(cost_maps, scale_factor=4, mode='bilinear')
        # this is B x T_futu x Z x X
        cost_maps = cost_maps.clamp(-1000,
                                    1000)  # raquel says this adds stability
        summ_writer.summ_histogram('motioncost/cost_maps_hist', cost_maps)
        summ_writer.summ_oneds('motioncost/cost_maps',
                               torch.unbind(cost_maps.unsqueeze(2), dim=1))

        # next i need to sample some trajectories

        N = hyp.motioncost_num_negs
        sampled_trajs_cam = self.sample_trajs(N, clist_cam)
        # this is B x N x S x 3

        if summ_writer.save_this:
            # for n in list(range(np.min([N, 3]))):
            #     # this is 1 x S x 3
            #     summ_writer.summ_traj_on_occ('motioncost/sample%d_clist' % n,
            #                                  sampled_trajs_cam[0, n].unsqueeze(0),
            #                                  occs[:,self.T_past],
            #                                  # torch.max(occs, dim=1)[0],
            #                                  # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
            #                                  already_mem=False)
            o = []
            for n in list(range(N)):
                o.append(
                    utils_improc.preprocess_color(
                        summ_writer.summ_traj_on_occ(
                            '',
                            sampled_trajs_cam[0, n].unsqueeze(0),
                            occs[0:1, self.T_past],
                            vox_util,
                            only_return=True,
                            sigma=0.5)))
            summ_vis = torch.max(torch.stack(o, dim=0), dim=0)[0]
            summ_writer.summ_rgb('motioncost/all_sampled_trajs', summ_vis)

        # smooth loss
        cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X)
        dz, dx = gradient2D(cost_maps_, absolute=True)
        dt = torch.abs(cost_maps[:, 1:] - cost_maps[:, 0:-1])
        smooth_spatial = torch.mean(dx + dz, dim=1, keepdims=True)
        smooth_time = torch.mean(dt, dim=1, keepdims=True)
        summ_writer.summ_oned('motioncost/smooth_loss_spatial', smooth_spatial)
        summ_writer.summ_oned('motioncost/smooth_loss_time', smooth_time)
        smooth_loss = torch.mean(smooth_spatial) + torch.mean(smooth_time)
        total_loss = utils_misc.add_loss('motioncost/smooth_loss', total_loss,
                                         smooth_loss,
                                         hyp.motioncost_smooth_coeff,
                                         summ_writer)

        # def clamp_xyz(xyz, X, Y, Z):
        #     x, y, z = torch.unbind(xyz, dim=-1)
        #     x = x.clamp(0, X)
        #     y = x.clamp(0, Y)
        #     z = x.clamp(0, Z)
        #     # if zero_y:
        #     #     y = torch.zeros_like(y)
        #     xyz = torch.stack([x,y,z], dim=-1)
        #     return xyz
        def clamp_xz(xz, X, Z):
            x, z = torch.unbind(xz, dim=-1)
            x = x.clamp(0, X)
            z = x.clamp(0, Z)
            xz = torch.stack([x, z], dim=-1)
            return xz

        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is B x S x 3

        # sampled_trajs_cam is B x N x S x 3
        sampled_trajs_cam_ = sampled_trajs_cam.reshape(B, N * S, 3)
        sampled_trajs_mem_ = utils_vox.Ref2Mem(sampled_trajs_cam_, Z, Y, X)
        sampled_trajs_mem = sampled_trajs_mem_.reshape(B, N, S, 3)
        # this is B x N x S x 3

        xyz_pos_ = clist_mem[:, self.T_past:].reshape(B * self.T_futu, 1, 3)
        xyz_neg_ = sampled_trajs_mem[:, :,
                                     self.T_past:].permute(0, 2, 1, 3).reshape(
                                         B * self.T_futu, N, 3)
        # get rid of y
        xz_pos_ = torch.stack([xyz_pos_[:, :, 0], xyz_pos_[:, :, 2]], dim=2)
        xz_neg_ = torch.stack([xyz_neg_[:, :, 0], xyz_neg_[:, :, 2]], dim=2)
        xz_ = torch.cat([xz_pos_, xz_neg_], dim=1)
        xz_ = clamp_xz(xz_, X, Z)
        cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X)
        cost_ = utils_samp.bilinear_sample2D(cost_maps_, xz_[:, :, 0],
                                             xz_[:, :, 1]).squeeze(1)
        # cost is B*T_futu x 1+N
        cost_pos = cost_[:, 0:1]  # B*T_futu x 1
        cost_neg = cost_[:, 1:]  # B*T_futu x N

        cost_pos = cost_pos.unsqueeze(2)  # B*T_futu x 1 x 1
        cost_neg = cost_neg.unsqueeze(1)  # B*T_futu x 1 x N

        utils_misc.add_loss('motioncost/mean_cost_pos', 0,
                            torch.mean(cost_pos), 0, summ_writer)
        utils_misc.add_loss('motioncost/mean_cost_neg', 0,
                            torch.mean(cost_neg), 0, summ_writer)
        utils_misc.add_loss('motioncost/mean_margin', 0,
                            torch.mean(cost_neg - cost_pos), 0, summ_writer)

        xz_pos = xz_pos_.unsqueeze(2)  # B*T_futu x 1 x 1 x 3
        xz_neg = xz_neg_.unsqueeze(1)  # B*T_futu x 1 x N x 3
        dist = torch.norm(xz_pos - xz_neg, dim=3)  # B*T_futu x 1 x N
        dist = dist / float(
            Z) * 5.0  # normalize for resolution, but upweight it a bit
        margin = F.relu(cost_pos - cost_neg + dist)
        margin = margin.reshape(B, self.T_futu, N)
        # mean over time (in the paper this is a sum)
        margin = torch.mean(margin, dim=1)
        # max over the negatives
        maxmargin = torch.max(margin, dim=1)[0]  # B
        maxmargin_loss = torch.mean(maxmargin)
        total_loss = utils_misc.add_loss('motioncost/maxmargin_loss',
                                         total_loss, maxmargin_loss,
                                         hyp.motioncost_maxmargin_coeff,
                                         summ_writer)

        # now let's see some top k
        # we'll do this for the first el of the batch
        cost_neg = cost_neg.reshape(B, self.T_futu,
                                    N)[0].detach().cpu().numpy()
        futu_mem = sampled_trajs_mem[:, :, self.T_past:].reshape(
            B, N, self.T_futu, 3)[0:1]
        cost_neg = np.reshape(cost_neg, [self.T_futu, N])
        cost_neg = np.sum(cost_neg, axis=0)
        inds = np.argsort(cost_neg, axis=0)

        for n in list(range(2)):
            xyzlist_e_mem = futu_mem[0:1, inds[n]]
            xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X)
            # this is B x S x 3

            if summ_writer.save_this and n == 0:
                print('xyzlist_e_cam', xyzlist_e_cam[0:1])
                print('xyzlist_g_cam', clist_cam[0:1, self.T_past:])

            dist = torch.norm(clist_cam[0:1, self.T_past:] -
                              xyzlist_e_cam[0:1],
                              dim=2)
            # this is B x T_futu
            meandist = torch.mean(dist)
            utils_misc.add_loss('motioncost/xyz_dist_%d' % n, 0, meandist, 0,
                                summ_writer)

        if summ_writer.save_this:
            # plot the best and worst trajs
            # print('sorted costs:', cost_neg[inds])
            for n in list(range(2)):
                ind = inds[n]
                print('plotting good traj with cost %.2f' % (cost_neg[ind]))
                xyzlist_e_mem = sampled_trajs_mem[:, ind]
                # this is 1 x S x 3
                summ_writer.summ_traj_on_occ('motioncost/best_sampled_traj%d' %
                                             n,
                                             xyzlist_e_mem[0:1],
                                             occs[0:1, self.T_past],
                                             vox_util,
                                             already_mem=True,
                                             sigma=2)

            for n in list(range(2)):
                ind = inds[-(n + 1)]
                print('plotting bad traj with cost %.2f' % (cost_neg[ind]))
                xyzlist_e_mem = sampled_trajs_mem[:, ind]
                # this is 1 x S x 3
                summ_writer.summ_traj_on_occ(
                    'motioncost/worst_sampled_traj%d' % n,
                    xyzlist_e_mem[0:1],
                    occs[0:1, self.T_past],
                    vox_util,
                    already_mem=True,
                    sigma=2)

        # xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X)
        # xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X)
        # summ_writer.summ_traj_on_occ('motioncost/traj_e',
        #                              xyzlist_e_mem,
        #                              torch.max(occs, dim=1)[0],
        #                              already_mem=True,
        #                              sigma=2)
        # summ_writer.summ_traj_on_occ('motioncost/traj_g',
        #                              xyzlist_g_mem,
        #                              torch.max(occs, dim=1)[0],
        #                              already_mem=True,
        #                              sigma=2)

        # scorelist_here = scorelist[:,self.num_given:,0]
        # sql2 = torch.sum((vel_g-vel_e)**2, dim=2)

        # ## yes weightmask
        # weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda'))
        # weightmask = torch.exp(-weightmask**(1./4))
        # # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860,
        # #         0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353
        # weightmask = weightmask.reshape(1, self.num_need)
        # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask)

        # utils_misc.add_loss('motioncost/l2_loss', 0, l2_loss, 0, summ_writer)

        # # # no weightmask:
        # # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here)

        # # total_loss = utils_misc.add_loss('motioncost/l2_loss', total_loss, l2_loss, hyp.motioncost_l2_coeff, summ_writer)

        # dist = torch.norm(xyzlist_e - xyzlist_g, dim=2)
        # meandist = utils_basic.reduce_masked_mean(dist, scorelist[:,:,0])
        # utils_misc.add_loss('motioncost/xyz_dist_0', 0, meandist, 0, summ_writer)

        # l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here)
        # # utils_misc.add_loss('motioncost/vel_dist_noexp', 0, l2_loss, 0, summ_writer)
        # total_loss = utils_misc.add_loss('motioncost/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.motioncost_l2_coeff, summ_writer)

        return total_loss
    def run_test(self, feed):
        results = dict()

        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist(
            self.lrt_camX0s)

        self.original_centroid = self.scene_centroid.clone()

        obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s)
        obj_length = obj_lengths[:, 0]
        for b in list(range(self.B)):
            if self.score_s[b, 0] < 1.0:
                # we need the template to exist
                print('returning early, since score_s[%d,0] = %.1f' %
                      (b, self.score_s[b, 0].cpu().numpy()))
                return total_loss, results, True
            # if torch.sum(self.score_s[b]) < (self.S/2):
            if not (torch.sum(self.score_s[b]) == self.S):
                # the full traj should be valid
                print(
                    'returning early, since sum(score_s) = %d, while S = %d' %
                    (torch.sum(self.score_s).cpu().numpy(), self.S))
                return total_loss, results, True

        if hyp.do_feat3D:

            feat_memX0_input = torch.cat([
                self.occ_memX0s[:, 0],
                self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0],
            ],
                                         dim=1)
            _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input)
            B, C, Z, Y, X = list(feat_memX0.shape)
            S = self.S

            obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist(
                self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1)
            # only take the occupied voxels
            occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y,
                                                   X)
            # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0
            obj_mask_memX0 = obj_mask_memX0s[:, 0]

            # discard the known freespace
            _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs[:, 0:1],
                self.xyz_camXs[:, 0:1],
                Z,
                Y,
                X,
                agg=True)
            free_memX0 = free_memX0_.squeeze(1)
            obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0)

            for b in list(range(self.B)):
                if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8:
                    print(
                        'returning early, since there are not enough valid object points'
                    )
                    return total_loss, results, True

            # for b in list(range(self.B)):
            #     sum_b = torch.sum(obj_mask_memX0[b])
            #     print('sum_b', sum_b.detach().cpu().numpy())
            #     if sum_b > 1000:
            #         obj_mask_memX0[b] *= occ_memX0[b]
            #         sum_b = torch.sum(obj_mask_memX0[b])
            #         print('reducing this to', sum_b.detach().cpu().numpy())

            feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1)
            # this is B x C x huge
            feat0_vec = feat0_vec.permute(0, 2, 1)
            # this is B x huge x C

            obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round()
            occ_mask0_vec = occ_memX0.reshape(B, -1).round()
            free_mask0_vec = free_memX0.reshape(B, -1).round()
            # these are B x huge

            orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X)
            # this is B x huge x 3

            obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(
                self.lrt_camX0s)
            obj_length = obj_lengths[:, 0]
            cam0_T_obj = cams_T_obj0[:, 0]
            # this is B x S x 4 x 4

            mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X)
            cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X)

            lrt_camIs_g = self.lrt_camX0s.clone()
            lrt_camIs_e = torch.zeros_like(self.lrt_camX0s)
            # we will fill this up

            ious = torch.zeros([B, S]).float().cuda()
            point_counts = np.zeros([B, S])
            inb_counts = np.zeros([B, S])

            feat_vis = []
            occ_vis = []

            for s in range(self.S):
                if not (s == 0):
                    # remake the vox util and all the mem data
                    self.scene_centroid = utils_geom.get_clist_from_lrtlist(
                        lrt_camIs_e[:, s - 1:s])[:, 0]
                    delta = self.scene_centroid - self.original_centroid
                    self.vox_util = vox_util.Vox_util(
                        self.Z,
                        self.Y,
                        self.X,
                        self.set_name,
                        scene_centroid=self.scene_centroid,
                        assert_cube=True)
                    self.occ_memXs = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z,
                                                   self.Y, self.X))
                    self.occ_memX0s = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camX0s),
                                                   self.Z, self.Y, self.X))

                    self.unp_memXs = __u(
                        self.vox_util.unproject_rgb_to_mem(
                            __p(self.rgb_camXs), self.Z, self.Y, self.X,
                            __p(self.pix_T_cams)))
                    self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs(
                        self.camX0s_T_camXs, self.unp_memXs)
                    self.summ_writer.summ_occ('track/reloc_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                else:
                    self.summ_writer.summ_occ('track/init_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                    delta = torch.zeros([B, 3]).float().cuda()
                # print('scene centroid:', self.scene_centroid.detach().cpu().numpy())
                occ_vis.append(
                    self.summ_writer.summ_occ('',
                                              self.occ_memX0s[:, s],
                                              only_return=True))

                # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False))
                inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s],
                                                 self.Z4,
                                                 self.Y4,
                                                 self.X,
                                                 already_mem=False)
                num_inb = torch.sum(inb.float(), axis=1)
                # print('num_inb', num_inb, num_inb.shape)
                inb_counts[:, s] = num_inb.cpu().numpy()

                feat_memI_input = torch.cat([
                    self.occ_memX0s[:, s],
                    self.unp_memX0s[:, s] * self.occ_memX0s[:, s],
                ],
                                            dim=1)
                _, feat_memI, valid_memI = self.featnet3D(feat_memI_input)

                self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s,
                                           feat_memI_input,
                                           pca=True)
                self.summ_writer.summ_feat('3D_feats/feat_%d' % s,
                                           feat_memI,
                                           pca=True)
                feat_vis.append(
                    self.summ_writer.summ_feat('',
                                               feat_memI,
                                               pca=True,
                                               only_return=True))

                # collect freespace here, to discard bad matches
                _, free_memI_, _, _ = self.vox_util.prep_occs_supervision(
                    self.camX0s_T_camXs[:, s:s + 1],
                    self.xyz_camXs[:, s:s + 1],
                    Z,
                    Y,
                    X,
                    agg=True)
                free_memI = free_memI_.squeeze(1)

                feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1)
                # this is B x C x huge
                feat_vec = feat_vec.permute(0, 2, 1)
                # this is B x huge x C

                memI_T_mem0 = utils_geom.eye_4x4(B)
                # we will fill this up

                # # put these on cpu, to save mem
                # feat0_vec = feat0_vec.detach().cpu()
                # feat_vec = feat_vec.detach().cpu()

                # to simplify the impl, we will iterate over the batch dim
                for b in list(range(B)):
                    feat_vec_b = feat_vec[b]
                    feat0_vec_b = feat0_vec[b]
                    obj_mask0_vec_b = obj_mask0_vec[b]
                    occ_mask0_vec_b = occ_mask0_vec[b]
                    free_mask0_vec_b = free_mask0_vec[b]
                    orig_xyz_b = orig_xyz[b]
                    # these are huge x C

                    careful = False
                    if careful:
                        # start with occ points, since these are definitely observed
                        obj_inds_b = torch.where(
                            (occ_mask0_vec_b * obj_mask0_vec_b) > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # also take random non-free non-occ points in the mask
                        ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * (
                            1.0 - free_mask0_vec_b)
                        alt_inds_b = torch.where(ok_mask > 0)
                        alt_vec_b = feat0_vec_b[alt_inds_b]
                        alt_xyz0 = orig_xyz_b[alt_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        num = len(alt_xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts))
                            perm = np.random.permutation(num)
                            alt_vec_b = alt_vec_b[perm[:max_pts]]
                            alt_xyz0 = alt_xyz0[perm[:max_pts]]

                        obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0)
                        xyz0 = torch.cat([xyz0, alt_xyz0], dim=0)
                        if s == 0:
                            print('have %d pts in total' % (len(xyz0)))
                    else:
                        # take any points within the mask
                        obj_inds_b = torch.where(obj_mask0_vec_b > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        # trim down to max_pts
                        num = len(xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            print(
                                'have %d pts; taking a random set of %d pts inside'
                                % (num, max_pts))
                            perm = np.random.permutation(num)
                            obj_vec_b = obj_vec_b[perm[:max_pts]]
                            xyz0 = xyz0[perm[:max_pts]]

                    obj_vec_b = obj_vec_b.permute(1, 0)
                    # this is is C x med

                    corr_b = torch.matmul(feat_vec_b, obj_vec_b)
                    # this is huge x med

                    heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X)
                    # this is med x 1 x Z4 x Y4 x X4

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b_min = (torch.min(heat_b_).values)
                    # free_b = free_memI[b:b+1]
                    # print('free_b', free_b.shape)
                    # print('heat_b', heat_b.shape)
                    # heat_b[free_b > 0.0] = heat_b_min

                    # make the min zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_min
                    # zero out the freespace
                    heat_b = heat_b * (1.0 - free_memI[b:b + 1])
                    # make the max zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_max
                    # scale up, for numerical stability
                    heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True)
                    # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True)
                    # this is med x 3

                    xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y,
                                                     X)
                    xyzI_cam += delta
                    xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1)

                    memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI)

                    # record #points, since ransac depends on this
                    point_counts[b, s] = len(xyz0)
                # done stepping through batch

                mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0)
                cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI,
                                                  mem_T_cam)

                # eval
                camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0,
                                                 mem_T_cam, cam0_T_obj)
                # this is B x 4 x 4
                lrt_camIs_e[:,
                            s] = utils_geom.merge_lrt(obj_length, camI_T_obj)
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:,
                                                         s:s + 1]).squeeze(1)
            results['ious'] = ious
            # if ious[0,-1] > 0.5:
            #     print('returning early, since acc is too high')
            #     return total_loss, results, True

            self.summ_writer.summ_rgbs('track/feats', feat_vis)
            self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False)

            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())

            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())
            self.summ_writer.summ_scalar('track/point_counts',
                                         np.mean(point_counts))
            # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item())
            self.summ_writer.summ_scalar('track/inb_counts',
                                         np.mean(inb_counts))

            lrt_camX0s_e = lrt_camIs_e.clone()
            lrt_camXs_e = utils_geom.apply_4x4s_to_lrts(
                self.camXs_T_camX0s, lrt_camX0s_e)

            if self.include_vis:
                visX_e = []
                for s in list(range(self.S)):
                    visX_e.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s,
                                                      self.rgb_camXs[:, s],
                                                      lrt_camXs_e[:, s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e)
                visX_g = []
                for s in list(range(self.S)):
                    visX_g.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s,
                                                      self.rgb_camXs[:, s],
                                                      self.lrt_camXs[:,
                                                                     s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g)

            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)

            dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2)
            # this is B x S
            mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s)
            median_dist = utils_basic.reduce_masked_median(dists, self.score_s)
            # this is []
            self.summ_writer.summ_scalar('track/centroid_dist_mean',
                                         mean_dist.cpu().item())
            self.summ_writer.summ_scalar('track/centroid_dist_median',
                                         median_dist.cpu().item())

            # if self.include_vis:
            if (True):
                self.summ_writer.summ_traj_on_occ('track/traj_e',
                                                  obj_clist_camX0_e,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
                self.summ_writer.summ_traj_on_occ('track/traj_g',
                                                  self.obj_clist_camX0,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
            total_loss += mean_dist  # we won't backprop, but it's nice to plot and print this anyway

        else:
            ious = torch.zeros([self.B, self.S]).float().cuda()
            for s in list(range(self.S)):
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    self.lrt_camX0s[:, 0:1],
                    self.lrt_camX0s[:, s:s + 1]).squeeze(1)
            results['ious'] = ious
            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())
            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())

            lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1)
            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)
            self.summ_writer.summ_traj_on_occ('track/traj_e',
                                              obj_clist_camX0_e,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)
            self.summ_writer.summ_traj_on_occ('track/traj_g',
                                              self.obj_clist_camX0,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)

        self.summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, False
    def run_train(self, feed):
        results = dict()

        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        if hyp.do_feat3D:
            feat_memX0s_input = torch.cat([
                self.occ_memX0s,
                self.unp_memX0s * self.occ_memX0s,
            ],
                                          dim=2)
            feat3D_loss, feat_memX0s_, valid_memX0s_ = self.featnet3D(
                __p(feat_memX0s_input[:, 1:]),
                self.summ_writer,
            )
            feat_memX0s = __u(feat_memX0s_)
            valid_memX0s = __u(valid_memX0s_)
            total_loss += feat3D_loss

            feat_memX0 = utils_basic.reduce_masked_mean(
                feat_memX0s,
                valid_memX0s.repeat(1, 1, hyp.feat3D_dim, 1, 1, 1),
                dim=1)
            valid_memX0 = torch.sum(valid_memX0s, dim=1).clamp(0, 1)
            self.summ_writer.summ_feat('3D_feats/feat_memX0',
                                       feat_memX0,
                                       valid=valid_memX0,
                                       pca=True)
            self.summ_writer.summ_feat('3D_feats/valid_memX0',
                                       valid_memX0,
                                       pca=False)

            if hyp.do_emb3D:
                _, altfeat_memX0, altvalid_memX0 = self.featnet3D_slow(
                    feat_memX0s_input[:, 0])
                self.summ_writer.summ_feat('3D_feats/altfeat_memX0',
                                           altfeat_memX0,
                                           valid=altvalid_memX0,
                                           pca=True)
                self.summ_writer.summ_feat('3D_feats/altvalid_memX0',
                                           altvalid_memX0,
                                           pca=False)

        if hyp.do_emb3D:
            if hyp.do_feat3D:
                _, _, Z_, Y_, X_ = list(feat_memX0.shape)
            else:
                Z_, Y_, X_ = self.Z2, self.Y2, self.X2
            # Z_, Y_, X_ = self.Z, self.Y, self.X

            occ_memX0s, free_memX0s, _, _ = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=False)

            not_ok = torch.zeros_like(occ_memX0s[:, 0])
            # it's not ok for a voxel to be marked occ only once
            not_ok += (torch.sum(occ_memX0s, dim=1) == 1.0).float()
            # it's not ok for a voxel to be marked occ AND free
            occ_agg = torch.sum(occ_memX0s, dim=1).clamp(0, 1)
            free_agg = torch.sum(free_memX0s, dim=1).clamp(0, 1)
            have_either = (occ_agg + free_agg).clamp(0, 1)
            have_both = occ_agg * free_agg
            not_ok += have_either * have_both
            # it's not ok for a voxel to be totally unobserved
            not_ok += (have_either == 0.0).float()
            not_ok = not_ok.clamp(0, 1)
            self.summ_writer.summ_occ('rely/not_ok', not_ok)
            self.summ_writer.summ_occ(
                'rely/not_ok_occ',
                not_ok * torch.max(self.occ_memX0s_half, dim=1)[0])
            self.summ_writer.summ_occ(
                'rely/ok_occ',
                (1.0 - not_ok) * torch.max(self.occ_memX0s_half, dim=1)[0])
            self.summ_writer.summ_occ(
                'rely/aggressive_occ',
                torch.max(self.occ_memX0s_half, dim=1)[0])

            be_safe = False
            if hyp.do_feat3D and be_safe:
                # update the valid masks
                valid_memX0 = valid_memX0 * (1.0 - not_ok)
                altvalid_memX0 = altvalid_memX0 * (1.0 - not_ok)

        if hyp.do_occ:
            _, _, Z_, Y_, X_ = list(feat_memX0.shape)
            occ_memX0_sup, free_memX0_sup, _, free_memX0s = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=True)

            self.summ_writer.summ_occ('occ_sup/occ_sup', occ_memX0_sup)
            self.summ_writer.summ_occ('occ_sup/free_sup', free_memX0_sup)
            self.summ_writer.summ_occs('occ_sup/freeX0s_sup',
                                       torch.unbind(free_memX0s, dim=1))
            self.summ_writer.summ_occs(
                'occ_sup/occX0s_sup', torch.unbind(self.occ_memX0s_half,
                                                   dim=1))
            occ_loss, occ_memX0_pred = self.occnet(altfeat_memX0,
                                                   occ_memX0_sup,
                                                   free_memX0_sup,
                                                   altvalid_memX0,
                                                   self.summ_writer)
            total_loss += occ_loss

        if hyp.do_emb3D:
            # compute 3D ML
            emb_loss_3D = self.embnet3D(feat_memX0, altfeat_memX0,
                                        valid_memX0.round(),
                                        altvalid_memX0.round(),
                                        self.summ_writer)
            total_loss += emb_loss_3D

        self.summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, False
    def prepare_common_tensors(self, feed, prep_summ=True):
        results = dict()

        if prep_summ:
            self.summ_writer = utils_improc.Summ_writer(
                writer=feed['writer'],
                global_step=feed['global_step'],
                log_freq=feed['set_log_freq'],
                fps=8,
                just_gif=feed['just_gif'],
            )
        else:
            self.summ_writer = None

        self.include_vis = hyp.do_include_vis

        self.B = feed["set_batch_size"]
        self.S = feed["set_seqlen"]

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        self.H, self.W, self.V, self.N = hyp.H, hyp.W, hyp.V, hyp.N
        self.PH, self.PW = hyp.PH, hyp.PW
        self.K = hyp.K

        self.set_name = feed['set_name']
        # print('set_name', self.set_name)
        if self.set_name == 'test':
            self.Z, self.Y, self.X = hyp.Z_test, hyp.Y_test, hyp.X_test
        else:
            self.Z, self.Y, self.X = hyp.Z, hyp.Y, hyp.X
        # print('Z, Y, X = %d, %d, %d' % (self.Z, self.Y, self.X))

        self.Z2, self.Y2, self.X2 = int(self.Z / 2), int(self.Y / 2), int(
            self.X / 2)
        self.Z4, self.Y4, self.X4 = int(self.Z / 4), int(self.Y / 4), int(
            self.X / 4)

        self.rgb_camXs = feed["rgb_camXs"]
        self.pix_T_cams = feed["pix_T_cams"]

        self.origin_T_camXs = feed["origin_T_camXs"]

        self.cams_T_velos = feed["cams_T_velos"]

        self.camX0s_T_camXs = utils_geom.get_camM_T_camXs(self.origin_T_camXs,
                                                          ind=0)
        self.camXs_T_camX0s = __u(
            utils_geom.safe_inverse(__p(self.camX0s_T_camXs)))

        self.xyz_veloXs = feed["xyz_veloXs"]
        self.xyz_camXs = __u(
            utils_geom.apply_4x4(__p(self.cams_T_velos), __p(self.xyz_veloXs)))
        self.xyz_camX0s = __u(
            utils_geom.apply_4x4(__p(self.camX0s_T_camXs),
                                 __p(self.xyz_camXs)))

        if self.set_name == 'test':
            self.boxlist_camXs = feed["boxlists"]
            self.scorelist_s = feed["scorelists"]
            self.tidlist_s = feed["tidlists"]

            boxlist_camXs_ = __p(self.boxlist_camXs)
            scorelist_s_ = __p(self.scorelist_s)
            tidlist_s_ = __p(self.tidlist_s)
            boxlist_camXs_, tidlist_s_, scorelist_s_ = utils_misc.shuffle_valid_and_sink_invalid_boxes(
                boxlist_camXs_, tidlist_s_, scorelist_s_)
            self.boxlist_camXs = __u(boxlist_camXs_)
            self.scorelist_s = __u(scorelist_s_)
            self.tidlist_s = __u(tidlist_s_)

            # self.boxlist_camXs[:,0], self.scorelist_s[:,0], self.tidlist_s[:,0] = utils_misc.shuffle_valid_and_sink_invalid_boxes(
            #     self.boxlist_camXs[:,0], self.tidlist_s[:,0], self.scorelist_s[:,0])

            # self.score_s = feed["scorelists"]
            # self.tid_s = torch.ones_like(self.score_s).long()
            # self.lrt_camRs = utils_geom.convert_boxlist_to_lrtlist(self.box_camRs)
            # self.lrt_camXs = utils_geom.apply_4x4s_to_lrts(self.camXs_T_camRs, self.lrt_camRs)
            # self.lrt_camX0s = utils_geom.apply_4x4s_to_lrts(self.camX0s_T_camXs, self.lrt_camXs)
            # self.lrt_camR0s = utils_geom.apply_4x4s_to_lrts(self.camR0s_T_camRs, self.lrt_camRs)

            # boxlist_camXs_ = __p(self.boxlist_camXs)
            # boxlist_camXs_ = __p(self.boxlist_camXs)

            # lrtlist_camXs = __u(utils_geom.convert_boxlist_to_lrtlist(__p(self.boxlist_camXs))).reshape(
            #     self.B, self.S, self.N, 19)

            self.lrtlist_camXs = __u(
                utils_geom.convert_boxlist_to_lrtlist(__p(self.boxlist_camXs)))

            # print('lrtlist_camXs', lrtlist_camXs.shape)
            # # self.B, self.S, self.N, 19)
            # # lrtlist_camXs = __u(utils_geom.apply_4x4_to_lrtlist(__p(camXs_T_camRs), __p(lrtlist_camRs)))
            # self.summ_writer.summ_lrtlist('2D_inputs/lrtlist_camX0', self.rgb_camXs[:,0], lrtlist_camXs[:,0],
            #                               self.scorelist_s[:,0], self.tidlist_s[:,0], self.pix_T_cams[:,0])
            # self.summ_writer.summ_lrtlist('2D_inputs/lrtlist_camX1', self.rgb_camXs[:,1], lrtlist_camXs[:,1],
            #                               self.scorelist_s[:,1], self.tidlist_s[:,1], self.pix_T_cams[:,1])

            (
                self.lrt_camXs,
                self.box_camXs,
                self.score_s,
            ) = utils_misc.collect_object_info(self.lrtlist_camXs,
                                               self.boxlist_camXs,
                                               self.tidlist_s,
                                               self.scorelist_s,
                                               1,
                                               mod='X',
                                               do_vis=False,
                                               summ_writer=None)
            self.lrt_camXs = self.lrt_camXs.squeeze(0)
            self.score_s = self.score_s.squeeze(0)
            self.tid_s = torch.ones_like(self.score_s).long()

            self.lrt_camX0s = utils_geom.apply_4x4s_to_lrts(
                self.camX0s_T_camXs, self.lrt_camXs)

            if prep_summ and self.include_vis:
                visX_g = []
                for s in list(range(self.S)):
                    visX_g.append(
                        self.summ_writer.summ_lrtlist('',
                                                      self.rgb_camXs[:, s],
                                                      self.lrtlist_camXs[:, s],
                                                      self.scorelist_s[:, s],
                                                      self.tidlist_s[:, s],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('2D_inputs/box_camXs', visX_g)
                # visX_g = []
                # for s in list(range(self.S)):
                #     visX_g.append(self.summ_writer.summ_lrtlist(
                #         'track/box_camX%d_g' % s, self.rgb_camXs[:,s], self.lrt_camXs[:,s:s+1],
                #         self.score_s[:,s:s+1], self.tid_s[:,s:s+1], self.pix_T_cams[:,0], only_return=True))
                # self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g)

        if self.set_name == 'test':
            # center on an object, so that it does not fall out of bounds
            self.scene_centroid = utils_geom.get_clist_from_lrtlist(
                self.lrt_camXs)[:, 0]
            self.vox_util = vox_util.Vox_util(
                self.Z,
                self.Y,
                self.X,
                self.set_name,
                scene_centroid=self.scene_centroid,
                assert_cube=True)
        else:
            # center randomly
            scene_centroid_x = np.random.uniform(-8.0, 8.0)
            scene_centroid_y = np.random.uniform(-1.5, 3.0)
            scene_centroid_z = np.random.uniform(10.0, 26.0)
            scene_centroid = np.array(
                [scene_centroid_x, scene_centroid_y,
                 scene_centroid_z]).reshape([1, 3])
            self.scene_centroid = torch.from_numpy(
                scene_centroid).float().cuda()
            # center on a random non-outlier point

            all_ok = False
            num_tries = 0
            while not all_ok:
                scene_centroid_x = np.random.uniform(-8.0, 8.0)
                scene_centroid_y = np.random.uniform(-1.5, 3.0)
                scene_centroid_z = np.random.uniform(10.0, 26.0)
                scene_centroid = np.array(
                    [scene_centroid_x, scene_centroid_y,
                     scene_centroid_z]).reshape([1, 3])
                self.scene_centroid = torch.from_numpy(
                    scene_centroid).float().cuda()
                num_tries += 1

                # try to vox
                self.vox_util = vox_util.Vox_util(
                    self.Z,
                    self.Y,
                    self.X,
                    self.set_name,
                    scene_centroid=self.scene_centroid,
                    assert_cube=True)
                all_ok = True

                # we want to ensure this gives us a few points inbound for each batch el
                inb = __u(
                    self.vox_util.get_inbounds(__p(self.xyz_camX0s),
                                               self.Z4,
                                               self.Y4,
                                               self.X,
                                               already_mem=False))
                num_inb = torch.sum(inb.float(), axis=2)
                if torch.min(num_inb) < 100:
                    all_ok = False

                if num_tries > 100:
                    return False
            self.summ_writer.summ_scalar('zoom_sampling/num_tries', num_tries)
            self.summ_writer.summ_scalar('zoom_sampling/num_inb',
                                         torch.mean(num_inb).cpu().item())

        self.occ_memXs = __u(
            self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y,
                                       self.X))
        self.occ_memX0s = __u(
            self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y,
                                       self.X))
        self.occ_memX0s_half = __u(
            self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z2, self.Y2,
                                       self.X2))

        self.unp_memXs = __u(
            self.vox_util.unproject_rgb_to_mem(__p(self.rgb_camXs), self.Z,
                                               self.Y, self.X,
                                               __p(self.pix_T_cams)))
        self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs(
            self.camX0s_T_camXs, self.unp_memXs)

        if prep_summ and self.include_vis:
            self.summ_writer.summ_rgbs('2D_inputs/rgb_camXs',
                                       torch.unbind(self.rgb_camXs, dim=1))
            self.summ_writer.summ_occs('3D_inputs/occ_memXs',
                                       torch.unbind(self.occ_memXs, dim=1))
            self.summ_writer.summ_occs('3D_inputs/occ_memX0s',
                                       torch.unbind(self.occ_memX0s, dim=1))
            self.summ_writer.summ_rgb('2D_inputs/rgb_camX0', self.rgb_camXs[:,
                                                                            0])
            # self.summ_writer.summ_oned('2D_inputs/depth_camX0', self.depth_camXs[:,0], maxval=20.0)
            # self.summ_writer.summ_oned('2D_inputs/valid_camX0', self.valid_camXs[:,0], norm=False)
        return True
Exemple #11
0
    def forward(self, feat, obj_lrtlist_cams, obj_scorelist_s, summ_writer, suffix=''):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(feat.shape)
        N, B2, S, D = list(obj_lrtlist_cams.shape)
        assert(B==B2)
        # obj_scorelist_s is N x B x S

        obj_lrtlist_cams_ = obj_lrtlist_cams.reshape(N*B, S, 19)
        obj_clist_cam_ = utils_geom.get_clist_from_lrtlist(obj_lrtlist_cams_)
        obj_clist_cam = obj_clist_cam_.reshape(N, B, S, 1, 3)
        # obj_clist_cam is N x B x S x 1 x 3
        obj_clist_cam = obj_clist_cam.squeeze(3)
        
        # # obj_clist_cam is N x B x S x 3
        # clist_cam = obj_clist_cam.reshape(N*B, S, 3)
        # clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # # this is N*B x S x 3
        # clist_mem = clist_mem.reshape(N, B, S, 3)

        # as with prinet, let's do this for a single object first
        traj_past = obj_clist_cam[0,:,:,:2]
        traj_futu = obj_clist_cam[0,:,:,2:]
        T_past = 2
        T_futu = S-2
        # traj_past is B x T_past x 3
        # traj_futu is B x T_futu x 3

        print('traj_past', traj_past.shape)
        print('traj_futu', traj_futu.shape)

        feat_map = self.compressor(feat)
        pred_map = self.conv3d(feat)
        # these are B x C x Z x Y x X

        # each component of the noise is IID Normal

        ## get K samples
        
        K = 5 # number of samples
        traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1)
        feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1, 1)
        pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1, 1)
        # to sample the K trajectories in parallel, we'll pack K onto the batch dim

        __p = lambda x: utils_basic.pack_seqdim(x, K)
        __u = lambda x: utils_basic.unpack_seqdim(x, K)
        traj_past_ = __p(traj_past)
        feat_map_ = __p(feat_map)
        pred_map_ = __p(pred_map)
        base_sample_ = torch.randn(K*B, T_futu, 3)
        traj_futu_e_ = self.sample_forward(feat_map_, pred_map_, base_sample_, traj_past_)
        # traj_futu_e = __u(traj_futu_e_)
        # # this is K x B x T x 3

        # print('traj_futu_e', traj_futu_e)
        # print(traj_futu_e.shape)
        
        # energy_vol = self.conv3d(feat)
        # # energy_vol is B x 1 x Z x Y x X
        # summ_writer.summ_oned('rpo/energy_vol', torch.mean(energy_vol, dim=3))
        # summ_writer.summ_histogram('rpo/energy_vol_hist', energy_vol)

        # summ_writer.summ_traj_on_occ('traj/obj%d_clist' % k,
        #                              obj_clist, occ_memXs[:,0], already_mem=False)
        

        return total_loss
Exemple #12
0
    def forward(self, feed):
        results = dict()

        if 'log_freq' not in feed.keys():
            feed['log_freq'] = None
        start_time = time.time()

        summ_writer = utils_improc.Summ_writer(writer=feed['writer'],
                                               global_step=feed['global_step'],
                                               set_name=feed['set_name'],
                                               log_freq=feed['log_freq'],
                                               fps=8)
        writer = feed['writer']
        global_step = feed['global_step']

        total_loss = torch.tensor(0.0).cuda()
        __p = lambda x: utils_basic.pack_seqdim(x, B)
        __u = lambda x: utils_basic.unpack_seqdim(x, B)

        __pb = lambda x: utils_basic.pack_boxdim(x, hyp.N)
        __ub = lambda x: utils_basic.unpack_boxdim(x, hyp.N)
        if hyp.aug_object_ent_dis:
            __pb_a = lambda x: utils_basic.pack_boxdim(
                x, hyp.max_obj_aug + hyp.max_obj_aug_dis)
            __ub_a = lambda x: utils_basic.unpack_boxdim(
                x, hyp.max_obj_aug + hyp.max_obj_aug_dis)
        else:
            __pb_a = lambda x: utils_basic.pack_boxdim(x, hyp.max_obj_aug)
            __ub_a = lambda x: utils_basic.unpack_boxdim(x, hyp.max_obj_aug)

        B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
        PH, PW = hyp.PH, hyp.PW
        K = hyp.K
        BOX_SIZE = hyp.BOX_SIZE
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
        Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
        D = 9

        tids = torch.from_numpy(np.reshape(np.arange(B * N), [B, N]))

        rgb_camXs = feed["rgb_camXs_raw"]
        pix_T_cams = feed["pix_T_cams_raw"]
        camRs_T_origin = feed["camR_T_origin_raw"]
        origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
        origin_T_camXs = feed["origin_T_camXs_raw"]
        camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
        camRs_T_camXs = __u(
            torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                         __p(origin_T_camXs)))
        camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
        camX0_T_camRs = camXs_T_camRs[:, 0]
        camX1_T_camRs = camXs_T_camRs[:, 1]

        camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)

        xyz_camXs = feed["xyz_camXs_raw"]
        depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
            __p(pix_T_cams), __p(xyz_camXs), H, W)
        dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_,
                                                       __p(pix_T_cams))

        xyz_camRs = __u(
            utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
        xyz_camX0s = __u(
            utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs)))

        occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))

        occXs_to_Rs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, occXs)
        occXs_to_Rs_45 = cross_corr.rotate_tensor_along_y_axis(occXs_to_Rs, 45)
        occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2))
        occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2))
        occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2))

        unpXs = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X,
                                           __p(pix_T_cams)))

        unpXs_half = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z2, Y2, X2,
                                           __p(pix_T_cams)))

        unpX0s_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camX0_T_camXs)))))

        unpRs = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z, Y, X,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camRs_T_camXs)))))

        unpRs_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camRs_T_camXs)))))

        dense_xyz_camRs_ = utils_geom.apply_4x4(__p(camRs_T_camXs),
                                                dense_xyz_camXs_)
        inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y,
                                                X).float()
        inbound_camXs_ = torch.reshape(inbound_camXs_, [B * S, 1, H, W])

        depth_camXs = __u(depth_camXs_)
        valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_)

        summ_writer.summ_oneds('2D_inputs/depth_camXs',
                               torch.unbind(depth_camXs, dim=1),
                               maxdepth=21.0)
        summ_writer.summ_oneds('2D_inputs/valid_camXs',
                               torch.unbind(valid_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs',
                              torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1))
        summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1),
                              torch.unbind(occXs, dim=1))

        occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X))

        if hyp.do_eval_boxes:
            if hyp.dataset_name == "clevr_vqa":
                gt_boxes_origin_corners = feed['gt_box']
                gt_scores_origin = feed['gt_scores'].detach().cpu().numpy()
                classes = feed['classes']
                scores = gt_scores_origin
                tree_seq_filename = feed['tree_seq_filename']
                gt_boxes_origin = nlu.get_ends_of_corner(
                    gt_boxes_origin_corners)
                gt_boxes_origin_end = torch.reshape(gt_boxes_origin,
                                                    [hyp.B, hyp.N, 2, 3])
                gt_boxes_origin_theta = nlu.get_alignedboxes2thetaformat(
                    gt_boxes_origin_end)
                gt_boxes_origin_corners = utils_geom.transform_boxes_to_corners(
                    gt_boxes_origin_theta)
                gt_boxesR_corners = __ub(
                    utils_geom.apply_4x4(camRs_T_origin[:, 0],
                                         __pb(gt_boxes_origin_corners)))
                gt_boxesR_theta = utils_geom.transform_corners_to_boxes(
                    gt_boxesR_corners)
                gt_boxesR_end = nlu.get_ends_of_corner(gt_boxesR_corners)

            else:
                tree_seq_filename = feed['tree_seq_filename']
                tree_filenames = [
                    join(hyp.root_dataset, i) for i in tree_seq_filename
                    if i != "invalid_tree"
                ]
                invalid_tree_filenames = [
                    join(hyp.root_dataset, i) for i in tree_seq_filename
                    if i == "invalid_tree"
                ]
                num_empty = len(invalid_tree_filenames)
                trees = [pickle.load(open(i, "rb")) for i in tree_filenames]

                len_valid = len(trees)
                if len_valid > 0:
                    gt_boxesR, scores, classes = nlu.trees_rearrange(trees)

                if num_empty > 0:
                    gt_boxesR = np.concatenate([
                        gt_boxesR, empty_gt_boxesR
                    ]) if len_valid > 0 else empty_gt_boxesR
                    scores = np.concatenate([
                        scores, empty_scores
                    ]) if len_valid > 0 else empty_scores
                    classes = np.concatenate([
                        classes, empty_classes
                    ]) if len_valid > 0 else empty_classes

                gt_boxesR = torch.from_numpy(
                    gt_boxesR).cuda().float()  # torch.Size([2, 3, 6])
                gt_boxesR_end = torch.reshape(gt_boxesR, [hyp.B, hyp.N, 2, 3])
                gt_boxesR_theta = nlu.get_alignedboxes2thetaformat(
                    gt_boxesR_end)  #torch.Size([2, 3, 9])
                gt_boxesR_corners = utils_geom.transform_boxes_to_corners(
                    gt_boxesR_theta)

            class_names_ex_1 = "_".join(classes[0])
            summ_writer.summ_text('eval_boxes/class_names', class_names_ex_1)

            gt_boxesRMem_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z2, Y2, X2))
            gt_boxesRMem_end = nlu.get_ends_of_corner(gt_boxesRMem_corners)

            gt_boxesRMem_theta = utils_geom.transform_corners_to_boxes(
                gt_boxesRMem_corners)
            gt_boxesRUnp_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z, Y, X))
            gt_boxesRUnp_end = nlu.get_ends_of_corner(gt_boxesRUnp_corners)

            gt_boxesX0_corners = __ub(
                utils_geom.apply_4x4(camX0_T_camRs, __pb(gt_boxesR_corners)))
            gt_boxesX0Mem_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesX0_corners), Z2, Y2, X2))

            gt_boxesX0Mem_theta = utils_geom.transform_corners_to_boxes(
                gt_boxesX0Mem_corners)

            gt_boxesX0Mem_end = nlu.get_ends_of_corner(gt_boxesX0Mem_corners)
            gt_boxesX0_end = nlu.get_ends_of_corner(gt_boxesX0_corners)

            gt_cornersX0_pix = __ub(
                utils_geom.apply_pix_T_cam(pix_T_cams[:, 0],
                                           __pb(gt_boxesX0_corners)))

            rgb_camX0 = rgb_camXs[:, 0]
            rgb_camX1 = rgb_camXs[:, 1]

            summ_writer.summ_box_by_corners('eval_boxes/gt_boxescamX0',
                                            rgb_camX0, gt_boxesX0_corners,
                                            torch.from_numpy(scores), tids,
                                            pix_T_cams[:, 0])
            unps_vis = utils_improc.get_unps_vis(unpX0s_half, occX0s_half)
            unp_vis = torch.mean(unps_vis, dim=1)
            unps_visRs = utils_improc.get_unps_vis(unpRs_half, occRs_half)
            unp_visRs = torch.mean(unps_visRs, dim=1)
            unps_visRs_full = utils_improc.get_unps_vis(unpRs, occRs)
            unp_visRs_full = torch.mean(unps_visRs_full, dim=1)
            summ_writer.summ_box_mem_on_unp('eval_boxes/gt_boxesR_mem',
                                            unp_visRs, gt_boxesRMem_end,
                                            scores, tids)

            unpX0s_half = torch.mean(unpX0s_half, dim=1)
            unpX0s_half = nlu.zero_out(unpX0s_half, gt_boxesX0Mem_end, scores)

            occX0s_half = torch.mean(occX0s_half, dim=1)
            occX0s_half = nlu.zero_out(occX0s_half, gt_boxesX0Mem_end, scores)

            summ_writer.summ_unp('3D_inputs/unpX0s', unpX0s_half, occX0s_half)

        if hyp.do_feat:
            featXs_input = torch.cat([occXs, occXs * unpXs], dim=2)
            featXs_input_ = __p(featXs_input)

            freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), __p(occXs_half))
            freeXs = __u(freeXs_)
            visXs = torch.clamp(occXs_half + freeXs, 0.0, 1.0)
            mask_ = None

            if (type(mask_) != type(None)):
                assert (list(mask_.shape)[2:5] == list(
                    featXs_input_.shape)[2:5])

            featXs_, feat_loss = self.featnet(featXs_input_,
                                              summ_writer,
                                              mask=__p(occXs))  #mask_)
            total_loss += feat_loss

            validXs = torch.ones_like(visXs)
            _validX00 = validXs[:, 0:1]
            _validX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:],
                                                     validXs[:, 1:])
            validX0s = torch.cat([_validX00, _validX01], dim=1)
            validRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, validXs)
            visRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, visXs)

            featXs = __u(featXs_)
            _featX00 = featXs[:, 0:1]
            _featX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:],
                                                    featXs[:, 1:])
            featX0s = torch.cat([_featX00, _featX01], dim=1)

            emb3D_e = torch.mean(featX0s[:, 1:], dim=1)
            vis3D_e_R = torch.max(visRs[:, 1:], dim=1)[0]
            emb3D_g = featX0s[:, 0]
            vis3D_g_R = visRs[:, 0]
            validR_combo = torch.min(validRs, dim=1).values

            summ_writer.summ_feats('3D_feats/featXs_input',
                                   torch.unbind(featXs_input, dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/featXs_output',
                                   torch.unbind(featXs, dim=1),
                                   valids=torch.unbind(validXs, dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/featX0s_output',
                                   torch.unbind(featX0s, dim=1),
                                   valids=torch.unbind(
                                       torch.ones_like(validRs), dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/validRs',
                                   torch.unbind(validRs, dim=1),
                                   pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_e_R', vis3D_e_R, pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_g_R', vis3D_g_R, pca=False)

        if hyp.do_munit:
            object_classes, filenames = nlu.create_object_classes(
                classes, [tree_seq_filename, tree_seq_filename], scores)
            if hyp.do_munit_fewshot:
                emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e)
                emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g)
                emb3D_R = emb3D_e_R
                emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors(
                    [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end,
                    scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE])
                emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2
                content, style = self.munitnet.net.gen_a.encode(emb3D_R_object)
                objects_taken, _ = self.munitnet.net.gen_a.decode(
                    content, style)
                styles = style
                contents = content
            elif hyp.do_3d_style_munit:
                emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e)
                emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g)
                emb3D_R = emb3D_e_R
                # st()
                emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors(
                    [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end,
                    scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE])
                emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2

                camX1_T_R = camXs_T_camRs[:, 1]
                camX0_T_R = camXs_T_camRs[:, 0]
                assert hyp.B == 2
                assert emb3D_e_R_object.shape[0] == 2
                munit_loss, sudo_input_0, sudo_input_1, recon_input_0, recon_input_1, sudo_input_0_cycle, sudo_input_1_cycle, styles, contents, adin = self.munitnet(
                    emb3D_R_object[0:1], emb3D_R_object[1:2])

                if hyp.store_content_style_range:
                    if self.max_content == None:
                        self.max_content = torch.zeros_like(
                            contents[0][0]).cuda() - 100000000
                    if self.min_content == None:
                        self.min_content = torch.zeros_like(
                            contents[0][0]).cuda() + 100000000
                    if self.max_style == None:
                        self.max_style = torch.zeros_like(
                            styles[0][0]).cuda() - 100000000
                    if self.min_style == None:
                        self.min_style = torch.zeros_like(
                            styles[0][0]).cuda() + 100000000
                    self.max_content = torch.max(
                        torch.max(self.max_content, contents[0][0]),
                        contents[1][0])
                    self.min_content = torch.min(
                        torch.min(self.min_content, contents[0][0]),
                        contents[1][0])
                    self.max_style = torch.max(
                        torch.max(self.max_style, styles[0][0]), styles[1][0])
                    self.min_style = torch.min(
                        torch.min(self.min_style, styles[0][0]), styles[1][0])

                    data_to_save = {
                        'max_content': self.max_content.cpu().numpy(),
                        'min_content': self.min_content.cpu().numpy(),
                        'max_style': self.max_style.cpu().numpy(),
                        'min_style': self.min_style.cpu().numpy()
                    }
                    with open('content_style_range.p', 'wb') as f:
                        pickle.dump(data_to_save, f)
                elif hyp.is_contrastive_examples:
                    if hyp.normalize_contrast:
                        content0 = (contents[0] - self.min_content) / (
                            self.max_content - self.min_content + 1e-5)
                        content1 = (contents[1] - self.min_content) / (
                            self.max_content - self.min_content + 1e-5)
                        style0 = (styles[0] - self.min_style) / (
                            self.max_style - self.min_style + 1e-5)
                        style1 = (styles[1] - self.min_style) / (
                            self.max_style - self.min_style + 1e-5)
                    else:
                        content0 = contents[0]
                        content1 = contents[1]
                        style0 = styles[0]
                        style1 = styles[1]

                    # euclid_dist_content = torch.sum(torch.sqrt((content0 - content1)**2))/torch.prod(torch.tensor(content0.shape))
                    # euclid_dist_style = torch.sum(torch.sqrt((style0-style1)**2))/torch.prod(torch.tensor(style0.shape))
                    euclid_dist_content = (content0 - content1).norm(2) / (
                        content0.numel())
                    euclid_dist_style = (style0 -
                                         style1).norm(2) / (style0.numel())

                    content_0_pooled = torch.mean(
                        content0.reshape(list(content0.shape[:2]) + [-1]),
                        dim=-1)
                    content_1_pooled = torch.mean(
                        content1.reshape(list(content1.shape[:2]) + [-1]),
                        dim=-1)

                    euclid_dist_content_pooled = (content_0_pooled -
                                                  content_1_pooled).norm(2) / (
                                                      content_0_pooled.numel())

                    content_0_normalized = content0 / content0.norm()
                    content_1_normalized = content1 / content1.norm()

                    style_0_normalized = style0 / style0.norm()
                    style_1_normalized = style1 / style1.norm()

                    content_0_pooled_normalized = content_0_pooled / content_0_pooled.norm(
                    )
                    content_1_pooled_normalized = content_1_pooled / content_1_pooled.norm(
                    )

                    cosine_dist_content = torch.sum(content_0_normalized *
                                                    content_1_normalized)
                    cosine_dist_style = torch.sum(style_0_normalized *
                                                  style_1_normalized)
                    cosine_dist_content_pooled = torch.sum(
                        content_0_pooled_normalized *
                        content_1_pooled_normalized)

                    print("euclid dist [content, pooled-content, style]: ",
                          euclid_dist_content, euclid_dist_content_pooled,
                          euclid_dist_style)
                    print("cosine sim [content, pooled-content, style]: ",
                          cosine_dist_content, cosine_dist_content_pooled,
                          cosine_dist_style)

            if hyp.run_few_shot_on_munit:
                if (global_step % 300) == 1 or (global_step % 300) == 0:
                    wrong = False
                    try:
                        precision_style = float(self.tp_style) / self.all_style
                        precision_content = float(
                            self.tp_content) / self.all_content
                    except ZeroDivisionError:
                        wrong = True

                    if not wrong:
                        summ_writer.summ_scalar(
                            'precision/unsupervised_precision_style',
                            precision_style)
                        summ_writer.summ_scalar(
                            'precision/unsupervised_precision_content',
                            precision_content)
                        # st()
                    self.embed_list_style = defaultdict(lambda: [])
                    self.embed_list_content = defaultdict(lambda: [])
                    self.tp_style = 0
                    self.all_style = 0
                    self.tp_content = 0
                    self.all_content = 0
                    self.check = False
                elif not self.check and not nlu.check_fill_dict(
                        self.embed_list_content, self.embed_list_style):
                    print("Filling \n")
                    for index, class_val in enumerate(object_classes):

                        if hyp.dataset_name == "clevr_vqa":
                            class_val_content, class_val_style = class_val.split(
                                "/")
                        else:
                            class_val_content, class_val_style = [
                                class_val.split("/")[0],
                                class_val.split("/")[0]
                            ]

                        print(len(self.embed_list_style.keys()), "style class",
                              len(self.embed_list_content), "content class",
                              self.embed_list_content.keys())
                        if len(self.embed_list_style[class_val_style]
                               ) < hyp.few_shot_nums:
                            self.embed_list_style[class_val_style].append(
                                styles[index].squeeze())
                        if len(self.embed_list_content[class_val_content]
                               ) < hyp.few_shot_nums:
                            if hyp.avg_3d:
                                content_val = contents[index]
                                content_val = torch.mean(content_val.reshape(
                                    [content_val.shape[1], -1]),
                                                         dim=-1)
                                # st()
                                self.embed_list_content[
                                    class_val_content].append(content_val)
                            else:
                                self.embed_list_content[
                                    class_val_content].append(
                                        contents[index].reshape([-1]))
                else:
                    self.check = True
                    try:
                        print(float(self.tp_content) / self.all_content)
                        print(float(self.tp_style) / self.all_style)
                    except Exception as e:
                        pass
                    average = True
                    if average:
                        for key, val in self.embed_list_style.items():
                            if isinstance(val, type([])):
                                self.embed_list_style[key] = torch.mean(
                                    torch.stack(val, dim=0), dim=0)

                        for key, val in self.embed_list_content.items():
                            if isinstance(val, type([])):
                                self.embed_list_content[key] = torch.mean(
                                    torch.stack(val, dim=0), dim=0)
                    else:
                        for key, val in self.embed_list_style.items():
                            if isinstance(val, type([])):
                                self.embed_list_style[key] = torch.stack(val,
                                                                         dim=0)

                        for key, val in self.embed_list_content.items():
                            if isinstance(val, type([])):
                                self.embed_list_content[key] = torch.stack(
                                    val, dim=0)
                    for index, class_val in enumerate(object_classes):
                        class_val = class_val
                        if hyp.dataset_name == "clevr_vqa":
                            class_val_content, class_val_style = class_val.split(
                                "/")
                        else:
                            class_val_content, class_val_style = [
                                class_val.split("/")[0],
                                class_val.split("/")[0]
                            ]

                        style_val = styles[index].squeeze().unsqueeze(0)
                        if not average:
                            embed_list_val_style = torch.cat(list(
                                self.embed_list_style.values()),
                                                             dim=0)
                            embed_list_key_style = list(
                                np.repeat(
                                    np.expand_dims(
                                        list(self.embed_list_style.keys()), 1),
                                    hyp.few_shot_nums, 1).reshape([-1]))
                        else:
                            embed_list_val_style = torch.stack(list(
                                self.embed_list_style.values()),
                                                               dim=0)
                            embed_list_key_style = list(
                                self.embed_list_style.keys())
                        embed_list_val_style = utils_basic.l2_normalize(
                            embed_list_val_style, dim=1).permute(1, 0)
                        style_val = utils_basic.l2_normalize(style_val, dim=1)
                        scores_styles = torch.matmul(style_val,
                                                     embed_list_val_style)
                        index_key = torch.argmax(scores_styles,
                                                 dim=1).squeeze()
                        selected_class_style = embed_list_key_style[index_key]
                        self.styles_prediction[class_val_style].append(
                            selected_class_style)
                        if class_val_style == selected_class_style:
                            self.tp_style += 1
                        self.all_style += 1

                        if hyp.avg_3d:
                            content_val = contents[index]
                            content_val = torch.mean(content_val.reshape(
                                [content_val.shape[1], -1]),
                                                     dim=-1).unsqueeze(0)
                        else:
                            content_val = contents[index].reshape(
                                [-1]).unsqueeze(0)
                        if not average:
                            embed_list_val_content = torch.cat(list(
                                self.embed_list_content.values()),
                                                               dim=0)
                            embed_list_key_content = list(
                                np.repeat(
                                    np.expand_dims(
                                        list(self.embed_list_content.keys()),
                                        1), hyp.few_shot_nums,
                                    1).reshape([-1]))
                        else:
                            embed_list_val_content = torch.stack(list(
                                self.embed_list_content.values()),
                                                                 dim=0)
                            embed_list_key_content = list(
                                self.embed_list_content.keys())
                        embed_list_val_content = utils_basic.l2_normalize(
                            embed_list_val_content, dim=1).permute(1, 0)
                        content_val = utils_basic.l2_normalize(content_val,
                                                               dim=1)
                        scores_content = torch.matmul(content_val,
                                                      embed_list_val_content)
                        index_key = torch.argmax(scores_content,
                                                 dim=1).squeeze()
                        selected_class_content = embed_list_key_content[
                            index_key]
                        self.content_prediction[class_val_content].append(
                            selected_class_content)
                        if class_val_content == selected_class_content:
                            self.tp_content += 1

                        self.all_content += 1
            # st()
            munit_loss = hyp.munit_loss_weight * munit_loss

            recon_input_obj = torch.cat([recon_input_0, recon_input_1], dim=0)
            recon_emb3D_R = nlu.update_scene_with_objects(
                emb3D_R, recon_input_obj, gt_boxesRMem_end, scores)

            sudo_input_obj = torch.cat([sudo_input_0, sudo_input_1], dim=0)
            styled_emb3D_R = nlu.update_scene_with_objects(
                emb3D_R, sudo_input_obj, gt_boxesRMem_end, scores)

            styled_emb3D_e_X1 = utils_vox.apply_4x4_to_vox(
                camX1_T_R, styled_emb3D_R)
            styled_emb3D_e_X0 = utils_vox.apply_4x4_to_vox(
                camX0_T_R, styled_emb3D_R)

            emb3D_e_X1 = utils_vox.apply_4x4_to_vox(camX1_T_R, recon_emb3D_R)
            emb3D_e_X0 = utils_vox.apply_4x4_to_vox(camX0_T_R, recon_emb3D_R)

            emb3D_e_X1_og = utils_vox.apply_4x4_to_vox(camX1_T_R, emb3D_R)
            emb3D_e_X0_og = utils_vox.apply_4x4_to_vox(camX0_T_R, emb3D_R)

            emb3D_R_aug_diff = torch.abs(emb3D_R - recon_emb3D_R)

            summ_writer.summ_feat(f'aug_feat/og', emb3D_R)
            summ_writer.summ_feat(f'aug_feat/og_gen', recon_emb3D_R)
            summ_writer.summ_feat(f'aug_feat/og_aug_diff', emb3D_R_aug_diff)

            if hyp.cycle_style_view_loss:
                sudo_input_obj_cycle = torch.cat(
                    [sudo_input_0_cycle, sudo_input_1_cycle], dim=0)
                styled_emb3D_R_cycle = nlu.update_scene_with_objects(
                    emb3D_R, sudo_input_obj_cycle, gt_boxesRMem_end, scores)

                styled_emb3D_e_X0_cycle = utils_vox.apply_4x4_to_vox(
                    camX0_T_R, styled_emb3D_R_cycle)
                styled_emb3D_e_X1_cycle = utils_vox.apply_4x4_to_vox(
                    camX1_T_R, styled_emb3D_R_cycle)
            summ_writer.summ_scalar('munit_loss', munit_loss.cpu().item())
            total_loss += munit_loss

        if hyp.do_occ and hyp.occ_do_cheap:
            occX0_sup, freeX0_sup, _, freeXs = utils_vox.prep_occs_supervision(
                camX0_T_camXs, xyz_camXs, Z2, Y2, X2, agg=True)

            summ_writer.summ_occ('occ_sup/occ_sup', occX0_sup)
            summ_writer.summ_occ('occ_sup/free_sup', freeX0_sup)
            summ_writer.summ_occs('occ_sup/freeXs_sup',
                                  torch.unbind(freeXs, dim=1))
            summ_writer.summ_occs('occ_sup/occXs_sup',
                                  torch.unbind(occXs_half, dim=1))

            occ_loss, occX0s_pred_ = self.occnet(
                torch.mean(featX0s[:, 1:], dim=1), occX0_sup, freeX0_sup,
                torch.max(validX0s[:, 1:], dim=1)[0], summ_writer)
            occX0s_pred = __u(occX0s_pred_)
            total_loss += occ_loss

        if hyp.do_view:
            assert (hyp.do_feat)
            PH, PW = hyp.PH, hyp.PW
            sy = float(PH) / float(hyp.H)
            sx = float(PW) / float(hyp.W)
            assert (sx == 0.5)  # else we need a fancier downsampler
            assert (sy == 0.5)
            projpix_T_cams = __u(
                utils_geom.scale_intrinsics(__p(pix_T_cams), sx, sy))
            # st()

            if hyp.do_munit:
                feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    emb3D_e_X1,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                feat_projX00_og = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    emb3D_e_X1_og,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                # only for checking the style
                styled_feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    styled_emb3D_e_X1,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                if hyp.cycle_style_view_loss:
                    styled_feat_projX00_cycle = utils_vox.apply_pixX_T_memR_to_voxR(
                        projpix_T_cams[:, 0],
                        camX0_T_camXs[:, 1],
                        styled_emb3D_e_X1_cycle,  # use feat1 to predict rgb0
                        hyp.view_depth,
                        PH,
                        PW)

            else:
                feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    featXs[:, 1],  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)
            rgb_X00 = utils_basic.downsample(rgb_camXs[:, 0], 2)
            rgb_X01 = utils_basic.downsample(rgb_camXs[:, 1], 2)
            valid_X00 = utils_basic.downsample(valid_camXs[:, 0], 2)

            view_loss, rgb_e, emb2D_e = self.viewnet(feat_projX00, rgb_X00,
                                                     valid_X00, summ_writer,
                                                     "rgb")

            if hyp.do_munit:
                _, rgb_e, emb2D_e = self.viewnet(feat_projX00_og, rgb_X00,
                                                 valid_X00, summ_writer,
                                                 "rgb_og")
            if hyp.do_munit:
                styled_view_loss, styled_rgb_e, styled_emb2D_e = self.viewnet(
                    styled_feat_projX00, rgb_X00, valid_X00, summ_writer,
                    "recon_style")
                if hyp.cycle_style_view_loss:
                    styled_view_loss_cycle, styled_rgb_e_cycle, styled_emb2D_e_cycle = self.viewnet(
                        styled_feat_projX00_cycle, rgb_X00, valid_X00,
                        summ_writer, "recon_style_cycle")

                rgb_input_1 = torch.cat(
                    [rgb_X01[1], rgb_X01[0], styled_rgb_e[0]], dim=2)
                rgb_input_2 = torch.cat(
                    [rgb_X01[0], rgb_X01[1], styled_rgb_e[1]], dim=2)
                complete_vis = torch.cat([rgb_input_1, rgb_input_2], dim=1)
                summ_writer.summ_rgb('munit/munit_recons_vis',
                                     complete_vis.unsqueeze(0))

            if not hyp.do_munit:
                total_loss += view_loss
            else:
                if hyp.basic_view_loss:
                    total_loss += view_loss
                if hyp.style_view_loss:
                    total_loss += styled_view_loss
                if hyp.cycle_style_view_loss:
                    total_loss += styled_view_loss_cycle

        summ_writer.summ_scalar('loss', total_loss.cpu().item())

        if hyp.save_embed_tsne:
            for index, class_val in enumerate(object_classes):
                class_val_content, class_val_style = class_val.split("/")
                style_val = styles[index].squeeze().unsqueeze(0)
                self.cluster_pool.update(style_val, [class_val_style])
                print(self.cluster_pool.num)

            if self.cluster_pool.is_full():
                embeds, classes = self.cluster_pool.fetch()
                with open("offline_cluster" + '/%st.txt' % 'classes',
                          'w') as f:
                    for index, embed in enumerate(classes):
                        class_val = classes[index]
                        f.write("%s\n" % class_val)
                f.close()
                with open("offline_cluster" + '/%st.txt' % 'embeddings',
                          'w') as f:
                    for index, embed in enumerate(embeds):
                        # embed = utils_basic.l2_normalize(embed,dim=0)
                        print("writing {} embed".format(index))
                        embed_l_s = [str(i) for i in embed.tolist()]
                        embed_str = '\t'.join(embed_l_s)
                        f.write("%s\n" % embed_str)
                f.close()
                st()

        return total_loss, results
def get_synth_flow_v2(xyz_cam0,
                      occ0,
                      unp0,
                      summ_writer,
                      sometimes_zero=False,
                      do_vis=False):
    # this version re-voxlizes occ1, rather than warp
    B, C, Z, Y, X = list(unp0.shape)
    assert (C == 3)

    __p = lambda x: utils_basic.pack_seqdim(x, B)
    __u = lambda x: utils_basic.unpack_seqdim(x, B)

    # we do not sample any rotations here, to keep the distribution purely
    # uniform across all translations
    # (rotation ruins this, since the pivot point is at the camera)
    cam1_T_cam0 = [
        utils_geom.get_random_rt(B, r_amount=0.0,
                                 t_amount=3.0),  # large motion
        utils_geom.get_random_rt(
            B,
            r_amount=0.0,
            t_amount=0.1,  # small motion
            sometimes_zero=sometimes_zero)
    ]
    cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0]

    xyz_cam1 = utils_geom.apply_4x4(cam1_T_cam0, xyz_cam0)
    occ1 = utils_vox.voxelize_xyz(xyz_cam1, Z, Y, X)
    unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0)
    occs = [occ0, occ1]
    unps = [unp0, unp1]

    if do_vis:
        summ_writer.summ_occs('synth/occs', occs)
        summ_writer.summ_unps('synth/unps', unps, occs)

    mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X)
    cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem)
    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)
    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)
    flow = xyz_mem1 - xyz_mem0
    # this is B x Z x Y x X x 3
    flow = flow.permute(0, 4, 1, 2, 3)
    # this is B x 3 x Z x Y x X
    if do_vis:
        summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0)

    if do_vis:
        occ0_e = utils_samp.backwarp_using_3D_flow(occ1,
                                                   flow,
                                                   binary_feat=True)
        unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow)
        summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e])
        summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e],
                              [occ0, occ0_e])

    occs = torch.stack(occs, dim=1)
    unps = torch.stack(unps, dim=1)

    return occs, unps, flow, cam1_T_cam0