Esempio n. 1
0
def pose_estimation_loss(ori_vertices, y_true, y_pred):
    """Pose estimation loss used for training.

  This loss measures the average of squared distance between some vertices
  of the mesh in 'rest pose' and the transformed mesh to which the predicted
  inverse pose is applied. Comparing this loss with a regular L2 loss on the
  quaternion and translation values is left as exercise to the interested
  reader.

  Args:
    y_true: The ground-truth value. [n,c]
    y_pred: The prediction we want to evaluate the loss for. [b,4]

  Returns:
    A scalar value containing the loss described in the description above.
  """

    # vertices.shape: (num_vertices, 3)
    # corners.shape:(num_vertices, 1, 3)
    corners = tf.expand_dims(ori_vertices, axis=1)
    # corners = ori_vertices

    # transformed_corners.shape: (num_vertices, batch, 3)
    # q and t shapes get pre-pre-padded with 1's following standard broadcast rules.
    transformed_corners = quaternion.rotate(corners, y_pred)

    # recovered_corners.shape: (num_vertices, batch, 3)
    recovered_corners = quaternion.rotate(transformed_corners,
                                          quaternion.inverse(y_true))

    # vertex_error.shape: (num_vertices, batch)
    vertex_error = tf.reduce_sum((recovered_corners - corners)**2, axis=-1)

    return tf.reduce_mean(vertex_error)
Esempio n. 2
0
def loss_xy_z_q_separate(keypoints_3D_model, labels, init_p_q, net_out,
                        im_center, image_scaling, cam_mat):

    pos_init = init_p_q['pos']

    xy_image_delta = net_out['xy']

    xy_image = im_center + xy_image_delta / tf.expand_dims(image_scaling, axis=-1)

    z_model = pos_init[:, 2]

    x_model_cam_frame = (xy_image[:, 0] - cam_mat[:, 0, 2]) / cam_mat[:, 0, 0]
    x_model_cam_frame = tf.expand_dims(tf.multiply(x_model_cam_frame, z_model), axis=-1)
    y_model_cam_frame = (xy_image[:, 1] - cam_mat[:, 1, 2]) / cam_mat[:, 1, 1]
    y_model_cam_frame = tf.expand_dims(tf.multiply(y_model_cam_frame, z_model), axis=-1)

    z_model = tf.expand_dims(z_model, axis=-1)
    pose_model_lateral = tf.concat((x_model_cam_frame, y_model_cam_frame, z_model), axis=-1)

    z_out = net_out['z'] / 3000.0
    pos_out = tf.multiply(tf.exp(z_out), pose_model_lateral)

    keypoints_3D = quaternion.rotate(keypoints_3D_model, tf.expand_dims(labels['quat'], axis=-2))

    quat_init = tf.expand_dims(init_p_q['q'], axis=1)

    # normalize quaternion to unit length
    quat = net_out['q'] / 10.0
    quat_len = tf.expand_dims(tf.norm(quat, axis=-1), axis=-1)
    quat = tf.div(quat, quat_len)

    keypoints_3D_predicted = quaternion.rotate(keypoints_3D_model, quat_init)
    keypoints_3D_predicted = quaternion.rotate(keypoints_3D_predicted, tf.expand_dims(quat, axis=1))

    # Adding pos to keypoints
    keypoints_3D = keypoints_3D + tf.expand_dims(labels["pos"], axis=1)
    keypoints_3D_predicted = keypoints_3D_predicted + tf.expand_dims(pos_out, axis=1)

    keypoint_distance = keypoints_3D_predicted - keypoints_3D
    keypoint_distance = tf.norm(keypoint_distance, axis=-1)
    loss_keypoint = tf.reduce_mean(0.5 * tf.square(keypoint_distance), axis=-1)

    loss = 0.5 * loss_keypoint

    loss = tf.reduce_mean(loss)

    outputs = {
        'xy': net_out['xy'],
        'z': z_out,
        'q': quat,
    }

    return loss, outputs
Esempio n. 3
0
def stage_loss_keypoints(pos_out, quat_out, pos_gt, quat_gt,
                         keypoints_3D_model):

    keypoints_3D = quaternion.rotate(keypoints_3D_model,
                                     tf.expand_dims(quat_gt, axis=-2))
    keypoints_3D_predicted = quaternion.rotate(
        keypoints_3D_model, tf.expand_dims(quat_out, axis=1))

    keypoints_3D = keypoints_3D + tf.expand_dims(pos_gt, axis=1)
    keypoints_3D_predicted = keypoints_3D_predicted + tf.expand_dims(pos_out,
                                                                     axis=1)

    keypoint_distance = keypoints_3D_predicted - keypoints_3D
    keypoint_distance = tf.norm(keypoint_distance, axis=-1)
    loss_keypoint = tf.reduce_mean(0.5 * tf.square(keypoint_distance), axis=-1)

    return loss_keypoint
Esempio n. 4
0
def rotate(vertices, features, num_samples, rot_range, angle_fixed):
    '''
    
    :param vertices: [num_vertices, 3]
    :param features: [FEAT_CAP, 4]
    :param num_samples:
    :return:
    '''
    vertices = vertices.astype(np.float32)
    # [FEAT_CAP]
    mask = features[:, 0].astype(np.bool)
    # [FEAT_CAP, 3]
    features = features[:, 1:].astype(np.float32)
    if angle_fixed:
        vertices = vertices[np.newaxis, :, :]
        features = features[np.newaxis, :, :]
        return vertices, features, mask

    random_angles_x = np.random.uniform(-rot_range[0], rot_range[0],
                                        (num_samples)).astype(np.float32)
    random_angles_y = np.random.uniform(-rot_range[1], rot_range[1],
                                        (num_samples)).astype(np.float32)
    random_angles_z = np.random.uniform(-rot_range[2], rot_range[2],
                                        (num_samples)).astype(np.float32)

    # random_angles.shape: (num_samples, 3)
    random_angles = np.stack(
        [random_angles_x, random_angles_y, random_angles_z], axis=1)
    ## debug
    regular_angles = np.concatenate([
        np.linspace(-np.pi, np.pi, num_samples)[:, np.newaxis],
        np.zeros((num_samples, 2))
    ],
                                    axis=-1).astype(np.float32)

    # random_quaternion.shape: (num_samples, 4)
    random_quaternion = quaternion.from_euler(random_angles)

    # vertices.shape : (num_samples, num_vertices, 3)
    vertices = quaternion.rotate(vertices[tf.newaxis, :, :],
                                 random_quaternion[:, tf.newaxis, :])
    # features.shape : (num_samples, FEAT_CAP, 3)
    features = quaternion.rotate(features[tf.newaxis, :, :],
                                 random_quaternion[:, tf.newaxis, :])

    return np.array(vertices), np.array(features), mask
Esempio n. 5
0
  def test_rotate_random(self):
    """Tests that the rotate provide the same results as quaternion.rotate."""
    random_axis, random_angle = test_helpers.generate_random_test_axis_angle()
    tensor_shape = random_angle.shape[:-1]
    random_point = np.random.normal(size=tensor_shape + (3,))

    random_quaternion = quaternion.from_axis_angle(random_axis, random_angle)
    ground_truth = quaternion.rotate(random_point, random_quaternion)
    prediction = axis_angle.rotate(random_point, random_axis, random_angle)

    self.assertAllClose(ground_truth, prediction, rtol=1e-6)
    def test_rotate_vs_rotate_quaternion_random(self):
        """Tests that the rotate provide the same results as quaternion.rotate."""
        random_euler_angle = test_helpers.generate_random_test_euler_angles()
        tensor_tile = random_euler_angle.shape[:-1]

        random_matrix = rotation_matrix_3d.from_euler(random_euler_angle)
        random_quaternion = quaternion.from_rotation_matrix(random_matrix)
        random_point = np.random.normal(size=tensor_tile + (3, ))
        ground_truth = quaternion.rotate(random_point, random_quaternion)
        prediction = rotation_matrix_3d.rotate(random_point, random_matrix)

        self.assertAllClose(ground_truth, prediction, rtol=1e-6)
Esempio n. 7
0
  def test_from_euler_random(self):
    """Tests that quaternions can be constructed from Euler angles."""
    random_euler_angles = test_helpers.generate_random_test_euler_angles()
    tensor_shape = random_euler_angles.shape[:-1]

    random_matrix = rotation_matrix_3d.from_euler(random_euler_angles)
    random_quaternion = quaternion.from_euler(random_euler_angles)
    random_point = np.random.normal(size=tensor_shape + (3,))
    rotated_with_matrix = rotation_matrix_3d.rotate(random_point, random_matrix)
    rotated_with_quaternion = quaternion.rotate(random_point, random_quaternion)

    self.assertAllClose(rotated_with_matrix, rotated_with_quaternion)
Esempio n. 8
0
  def test_rotate_jacobian_random(self):
    """Test the Jacobian of the rotate function."""
    x_matrix_init = test_helpers.generate_random_test_quaternions()
    x_matrix = tf.convert_to_tensor(value=x_matrix_init)
    tensor_shape = x_matrix_init.shape[:-1] + (3,)
    x_point_init = np.random.uniform(size=tensor_shape)
    x_point = tf.convert_to_tensor(value=x_point_init)

    y = quaternion.rotate(x_point, x_matrix)

    self.assert_jacobian_is_correct(x_matrix, x_matrix_init, y)
    self.assert_jacobian_is_correct(x_point, x_point_init, y)
Esempio n. 9
0
  def test_rotate_random(self):
    """Tests the rotation using a quaternion vs a rotation matrix."""
    random_quaternion = test_helpers.generate_random_test_quaternions()
    tensor_shape = random_quaternion.shape[:-1]
    random_point = np.random.normal(size=tensor_shape + (3,))

    rotated_point_quaternion = quaternion.rotate(random_point,
                                                 random_quaternion)
    matrix = rotation_matrix_3d.from_quaternion(random_quaternion)
    rotated_point_matrix = rotation_matrix_3d.rotate(random_point, matrix)

    self.assertAllClose(
        rotated_point_matrix, rotated_point_quaternion, rtol=1e-3)
Esempio n. 10
0
    def test_between_two_vectors_3d_random(self):
        """Checks the extracted rotation between two 3d vectors."""
        tensor_size = np.random.randint(3)
        tensor_shape = np.random.randint(1, 10, size=(tensor_size)).tolist()
        source = np.random.random(tensor_shape + [3]).astype(np.float32)
        target = np.random.random(tensor_shape + [3]).astype(np.float32)

        rotation = quaternion.between_two_vectors_3d(source, target)
        rec_target = quaternion.rotate(source, rotation)

        self.assertAllClose(tf.nn.l2_normalize(target, axis=-1),
                            tf.nn.l2_normalize(rec_target, axis=-1))
        # Checks that resulting quaternions are normalized.
        self.assertAllEqual(quaternion.is_normalized(rotation),
                            np.full(tensor_shape + [1], True))
def read_meshes(mesh_folder):

    mesh_vertices_list = []

    for it_mesh_file in linemod_cls_names:

        it_mesh_file_full = os.path.join(mesh_folder, it_mesh_file + '_color_4000.ply')
        vertices_model = read_ply_file(it_mesh_file_full)

        vertices_model = tf.convert_to_tensor(vertices_model, dtype=tf.float32)
        init_rot_quat = tf.cast(quaternion.from_rotation_matrix(R_init), tf.float32)
        vertices_model = quaternion.rotate(vertices_model, init_rot_quat)

        vertices_model = tf.expand_dims(vertices_model, axis=0)

        mesh_vertices_list.append(vertices_model)

    return  tf.concat(mesh_vertices_list, axis=0)
Esempio n. 12
0
def rotate(vertices, num_samples):
    vertices = vertices.astype(np.float32)
    # random_angles.shape: (num_samples, 3)
    random_angles = np.random.uniform(-np.pi, np.pi,
                                      (num_samples, 3)).astype(np.float32)

    ## debug
    regular_angles = np.concatenate([
        np.linspace(-np.pi, np.pi, num_samples)[:, np.newaxis],
        np.zeros((num_samples, 2))
    ],
                                    axis=-1).astype(np.float32)

    ##

    # random_quaternion.shape: (num_samples, 4)
    random_quaternion = quaternion.from_euler(random_angles)

    # data.shape : (num_samples, num_vertices, 3)
    data = quaternion.rotate(vertices[tf.newaxis, :, :],
                             random_quaternion[:, tf.newaxis, :])

    return np.array(data), np.array(random_quaternion)
Esempio n. 13
0
def energy(vertices_rest_pose,
           vertices_deformed_pose,
           quaternions,
           edges,
           vertex_weight=None,
           edge_weight=None,
           conformal_energy=True,
           aggregate_loss=True,
           name=None):
    """Estimates an As Conformal As Possible (ACAP) fitting energy.

  For a given mesh in rest pose, this function evaluates a variant of the ACAP
  [1] fitting energy for a batch of deformed meshes. The vertex weights and edge
  weights are defined on the rest pose.

  The method implemented here is similar to [2], but with an added free variable
    capturing a scale factor per vertex.

  [1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro.
  "As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol.
  33. No. 5. 2014.</br>
  [2]: Olga Sorkine, and Marc Alexa.
  "As-rigid-as-possible surface modeling". Symposium on Geometry Processing.
  Vol. 4. 2007.

  Note:
    In the description of the arguments, V corresponds to
      the number of vertices in the mesh, and E to the number of edges in this
      mesh.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of
      all the vertices of the mesh in rest pose.
    vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing
      the position of all the vertices of the mesh in deformed pose.
    quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid
      transformation to apply to each vertex of the rest pose. See Section 2
      from [1] for further details.
    edges: A tensor of shape `[E, 2]` defining indices of vertices that are
      connected by an edge.
    vertex_weight: An optional tensor of shape `[V]` defining the weight
      associated with each vertex. Defaults to a tensor of ones.
    edge_weight: A tensor of shape `[E]` defining the weight of edges. Common
      choices for these weights include uniform weighting, and cotangent
      weights. Defaults to a tensor of ones.
    conformal_energy: A `bool` indicating whether each vertex is associated with
      a scale factor or not. If this parameter is True, scaling information must
      be encoded in the norm of `quaternions`. If this parameter is False, this
      function implements the energy described in [2].
    aggregate_loss: A `bool` defining whether the returned loss should be an
      aggregate measure. When True, the mean squared error is returned. When
      False, returns two losses for every edge of the mesh.
    name: A name for this op. Defaults to "as_conformal_as_possible_energy".

  Returns:
    When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]`
    containing the ACAP energies. When aggregate_loss is `False`, returns a
    tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation
    described in the equation 7 of [2].

  Raises:
    ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`,
    `quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported.
  """
    with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [
            vertices_rest_pose, vertices_deformed_pose, quaternions, edges,
            conformal_energy, vertex_weight, edge_weight
    ]):
        vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose)
        vertices_deformed_pose = tf.convert_to_tensor(
            value=vertices_deformed_pose)
        quaternions = tf.convert_to_tensor(value=quaternions)
        edges = tf.convert_to_tensor(value=edges)
        if vertex_weight is not None:
            vertex_weight = tf.convert_to_tensor(value=vertex_weight)
        if edge_weight is not None:
            edge_weight = tf.convert_to_tensor(value=edge_weight)

        shape.check_static(tensor=vertices_rest_pose,
                           tensor_name="vertices_rest_pose",
                           has_rank=2,
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=vertices_deformed_pose,
                           tensor_name="vertices_deformed_pose",
                           has_rank_greater_than=1,
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=quaternions,
                           tensor_name="quaternions",
                           has_rank_greater_than=1,
                           has_dim_equals=(-1, 4))
        shape.compare_batch_dimensions(tensors=(vertices_deformed_pose,
                                                quaternions),
                                       last_axes=(-3, -3),
                                       broadcast_compatible=False)
        shape.check_static(tensor=edges,
                           tensor_name="edges",
                           has_rank=2,
                           has_dim_equals=(-1, 2))
        tensors_with_vertices = [
            vertices_rest_pose, vertices_deformed_pose, quaternions
        ]
        names_with_vertices = [
            "vertices_rest_pose", "vertices_deformed_pose", "quaternions"
        ]
        axes_with_vertices = [-2, -2, -2]
        if vertex_weight is not None:
            shape.check_static(tensor=vertex_weight,
                               tensor_name="vertex_weight",
                               has_rank=1)
            tensors_with_vertices.append(vertex_weight)
            names_with_vertices.append("vertex_weight")
            axes_with_vertices.append(0)
        shape.compare_dimensions(tensors=tensors_with_vertices,
                                 axes=axes_with_vertices,
                                 tensor_names=names_with_vertices)
        if edge_weight is not None:
            shape.check_static(tensor=edge_weight,
                               tensor_name="edge_weight",
                               has_rank=1)
            shape.compare_dimensions(tensors=(edges, edge_weight),
                                     axes=(0, 0),
                                     tensor_names=("edges", "edge_weight"))

        if not conformal_energy:
            quaternions = quaternion.normalize(quaternions)
        # Extracts the indices of vertices.
        indices_i, indices_j = tf.unstack(edges, axis=-1)
        # Extracts the vertices we need per term.
        vertices_i_rest = tf.gather(vertices_rest_pose, indices_i, axis=-2)
        vertices_j_rest = tf.gather(vertices_rest_pose, indices_j, axis=-2)
        vertices_i_deformed = tf.gather(vertices_deformed_pose,
                                        indices_i,
                                        axis=-2)
        vertices_j_deformed = tf.gather(vertices_deformed_pose,
                                        indices_j,
                                        axis=-2)
        # Extracts the weights we need per term.
        weights_shape = vertices_i_rest.shape.as_list()[-2]
        if vertex_weight is not None:
            weight_i = tf.gather(vertex_weight, indices_i)
            weight_j = tf.gather(vertex_weight, indices_j)
        else:
            weight_i = weight_j = tf.ones(weights_shape,
                                          dtype=vertices_rest_pose.dtype)
        weight_i = tf.expand_dims(weight_i, axis=-1)
        weight_j = tf.expand_dims(weight_j, axis=-1)
        if edge_weight is not None:
            weight_ij = edge_weight
        else:
            weight_ij = tf.ones(weights_shape, dtype=vertices_rest_pose.dtype)
        weight_ij = tf.expand_dims(weight_ij, axis=-1)
        # Extracts the rotation we need per term.
        quaternion_i = tf.gather(quaternions, indices_i, axis=-2)
        quaternion_j = tf.gather(quaternions, indices_j, axis=-2)
        # Computes the energy.
        deformed_ij = vertices_i_deformed - vertices_j_deformed
        rotated_rest_ij = quaternion.rotate(
            (vertices_i_rest - vertices_j_rest), quaternion_i)
        energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij)
        deformed_ji = vertices_j_deformed - vertices_i_deformed
        rotated_rest_ji = quaternion.rotate(
            (vertices_j_rest - vertices_i_rest), quaternion_j)
        energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji)
        energy_ij_squared = vector.dot(energy_ij, energy_ij, keepdims=False)
        energy_ji_squared = vector.dot(energy_ji, energy_ji, keepdims=False)
        if aggregate_loss:
            average_energy_ij = tf.reduce_mean(input_tensor=energy_ij_squared,
                                               axis=-1)
            average_energy_ji = tf.reduce_mean(input_tensor=energy_ji_squared,
                                               axis=-1)
            return (average_energy_ij + average_energy_ji) / 2.0
        return tf.concat((energy_ij_squared, energy_ji_squared), axis=-1)
        def parser(tfrecord_file):

            features = {
                "img": tf.FixedLenFeature([], tf.string),
                "cls_indexes": tf.VarLenFeature(dtype=tf.int64),
                "obj_num": tf.FixedLenFeature([1], tf.int64),
                "pos": tf.VarLenFeature(dtype=tf.float32),
                "quat": tf.VarLenFeature(dtype=tf.float32),
                "init_pose": tf.FixedLenFeature([3], tf.float32),
                "init_quat": tf.FixedLenFeature([4], tf.float32),
                "K_init_all": tf.FixedLenFeature([13 * 3 * 3], tf.float32),
            }

            fs = tf.parse_single_example(tfrecord_file, features=features)

            cls_indexes = tf.sparse.to_dense(fs["cls_indexes"])
            obj_num = fs["obj_num"]
            pos_all = tf.reshape(tf.sparse.to_dense(fs["pos"]), [-1, 3])
            quat_all = tf.reshape(tf.sparse.to_dense(fs["quat"]), [-1, 4])

            rand_obj_ind = tf.random_uniform([1], minval=[0.0], maxval=tf.cast(obj_num, dtype=tf.float32))
            rand_obj_ind = tf.cast(tf.floor(rand_obj_ind), dtype=tf.int32)
            rand_obj_ind = cls_indexes[rand_obj_ind[0]]

            cls_indexes_one_hot = tf.one_hot(cls_indexes, len(linemod_cls_names))
            cls_indexes_one_hot_obj = cls_indexes_one_hot[:, rand_obj_ind]

            cls_indexes_ind = tf.argmax(cls_indexes_one_hot_obj)

            quat = quat_all[cls_indexes_ind, :]
            pos = pos_all[cls_indexes_ind, :] * scaling_2_mm

            image_decoded = tf.image.decode_png(fs["img"], channels=3)

            vertices_model = meshes_list[rand_obj_ind]

            # init pose
            pos_init = fs["init_pose"] * scaling_2_mm

            quat_init = fs["init_quat"]

            K_init_all = tf.reshape(fs["K_init_all"], [-1, 3, 3])
            camera_intrinsics_tensor = K_init_all[cls_indexes_ind, :]

            # computing image center, by computing projection of the center
            im_center = tf.matmul(tf.expand_dims(pos_init, axis=0), tf.transpose(camera_intrinsics_tensor))
            im_center = tf.squeeze(tf.div(im_center, im_center[:, -1]))

            # computing crop area
            transformed_corners = quaternion.rotate(vertices_model, quat_init) + pos_init

            # computing scaling and crop size based on extreme projection points
            corners_projected_cv = tf.matmul(transformed_corners, tf.transpose(camera_intrinsics_tensor))
            corners_projected_cv = tf.div(corners_projected_cv, tf.expand_dims(corners_projected_cv[:, -1], axis=-1))

            y_min = tf.reduce_min(corners_projected_cv[:, 1], axis=0)
            y_max = tf.reduce_max(corners_projected_cv[:, 1], axis=0)
            x_min = tf.reduce_min(corners_projected_cv[:, 0], axis=0)
            x_max = tf.reduce_max(corners_projected_cv[:, 0], axis=0)

            distance_to_center = tf.stack([
                tf.abs(y_min - im_center[1]),
                tf.abs(y_max - im_center[1]),
                tf.abs(x_min - im_center[0]),
                tf.abs(x_max - im_center[0]),
            ], axis=-1)

            crop_width_half = lambda_mask_scaling * tf.reduce_max(distance_to_center, axis=0)

            bb_xy_bounds = tf.convert_to_tensor([im_center[1] - crop_width_half,
                                                 im_center[0] - crop_width_half,
                                                 im_center[1] + crop_width_half,
                                                 im_center[0] + crop_width_half,
                                                ])

            bb_box_scaling = tf.convert_to_tensor([image_params["img_size"][1], image_params["img_size"][0],
                                                   image_params["img_size"][1], image_params["img_size"][0]],
                                                  dtype=tf.float32)

            bb_box_normalized = tf.div(bb_xy_bounds, bb_box_scaling)

            img_cropped = tf.image.crop_and_resize(tf.expand_dims(image_decoded, axis=0),
                                                   tf.expand_dims(bb_box_normalized, axis=0),
                                                   box_ind=[0],
                                                   crop_size=[-crop_dimensions["y_crop_dim"][0] + crop_dimensions["y_crop_dim"][1],
                                                              -crop_dimensions["x_crop_dim"][0] + crop_dimensions["x_crop_dim"][1]])

            img_cropped = tf.squeeze(img_cropped)

            labels = {
                "pos": tf.cast(pos, dtype=tf.float32),
                "quat": tf.cast(quat, dtype=tf.float32)
            }

            fs = {
                "img": img_cropped,
                "pos": tf.cast(pos, dtype=tf.float32),
                "quat": tf.cast(quat, dtype=tf.float32),
                "pos_init": tf.cast(pos_init, dtype=tf.float32),
                "quat_init": tf.cast(quat_init, dtype=tf.float32),
                "object_ind": [rand_obj_ind],
                "cam_mat": camera_intrinsics_tensor,
            }

            return fs, labels