コード例 #1
0
    def test_tf_data_feature_label_keys(self):
        """Tests the ability of a get_tf_data_datasets to have extra labels/key.

    Test is done here because TAP is off in specific dataset tests.
    """
        features_data = data_provider.get_tf_data_dataset(
            dataset_name='scannet_scene',
            split_name='val',
            batch_size=1,
            preprocess_fn=None,
            is_training=True,
            num_readers=2,
            num_parallel_batches=2,
            shuffle_buffer_size=2)
        features = next(iter(features_data))

        self.assertEqual(
            features['mesh/vertices/positions'].get_shape().as_list()[2], 3)
        self.assertEqual(
            features['mesh/vertices/normals'].get_shape().as_list()[2], 3)
        self.assertEqual(
            features['mesh/vertices/colors'].get_shape().as_list()[2], 4)
        self.assertEqual(
            features['mesh/faces/polygons'].get_shape().as_list()[2], 3)
        self.assertEqual(
            features['mesh/vertices/semantic_labels'].get_shape().as_list()[2],
            1)
        self.assertEqual(
            features['mesh/vertices/instance_labels'].get_shape().as_list()[2],
            1)
コード例 #2
0
 def test_get_tf_data_dataset_tfrecord(self):
     dataset = data_provider.get_tf_data_dataset(
         dataset_name='waymo_object_per_frame',
         split_name='val',
         batch_size=1,
         is_training=True,
         preprocess_fn=None,
         feature_keys=None,
         label_keys=None,
         num_readers=1,
         filenames_shuffle_buffer_size=2,
         num_epochs=0,
         read_block_length=1,
         shuffle_buffer_size=2,
         num_parallel_batches=1,
         num_prefetch_batches=1,
         dataset_format='tfrecord',
     )
     tfrecord_features = next(iter(dataset))
     self.assertAllEqual(
         tfrecord_features['cameras/front/extrinsics/R'].shape, [1, 3, 3])
コード例 #3
0
  def test_tf_data_feature_label_keys(self):
    """Tests the ability of a get_tf_data_datasets to have extra labels/key.

    Test is done here because TAP is off in specific dataset tests.
    """
    features_data = data_provider.get_tf_data_dataset(
        dataset_name='waymo_object_per_frame',
        split_name='val',
        batch_size=1,
        preprocess_fn=None,
        is_training=True,
        num_readers=2,
        num_parallel_batches=2,
        shuffle_buffer_size=2)
    features = next(iter(features_data))

    cameras = ['front', 'front_left', 'front_right', 'side_left', 'side_right']
    lidars = ['top', 'front', 'side_left', 'side_right', 'rear']
    for camera in cameras:
      self.assertAllEqual(
          features[('cameras/%s/extrinsics/t' % camera)].get_shape().as_list(),
          np.array([1, 3]))
      self.assertAllEqual(
          features[('cameras/%s/extrinsics/R' % camera)].get_shape().as_list(),
          np.array([1, 3, 3]))
      self.assertAllEqual(
          features[('cameras/%s/intrinsics/distortion' %
                    camera)].get_shape().as_list(), np.array([1, 5]))
      self.assertAllEqual(
          features[('cameras/%s/intrinsics/K' % camera)].get_shape().as_list(),
          np.array([1, 3, 3]))
      self.assertAllEqual(
          features[('cameras/%s/image' % camera)].get_shape().as_list()[3], 3)
    for lidar in lidars:
      self.assertEqual(
          features[('lidars/%s/pointcloud/positions' %
                    lidar)].get_shape().as_list()[2], 3)
      self.assertEqual(
          features[('lidars/%s/pointcloud/intensity' %
                    lidar)].get_shape().as_list()[2], 1)
      self.assertEqual(
          features[('lidars/%s/pointcloud/elongation' %
                    lidar)].get_shape().as_list()[2], 1)
      self.assertAllEqual(
          features[('lidars/%s/extrinsics/R' % lidar)].get_shape().as_list(),
          np.array([1, 3, 3]))
      self.assertAllEqual(
          features[('lidars/%s/extrinsics/t' % lidar)].get_shape().as_list(),
          np.array([1, 3]))
      self.assertEqual(
          features['lidars/%s/camera_projections/positions' %
                   lidar].get_shape().as_list()[2], 2)
      self.assertEqual(
          features['lidars/%s/camera_projections/ids' %
                   lidar].get_shape().as_list()[2], 1)
    self.assertEqual(features['objects/pose/R'].get_shape().as_list()[2], 3)
    self.assertEqual(features['objects/pose/R'].get_shape().as_list()[3], 3)
    self.assertEqual(features['objects/pose/t'].get_shape().as_list()[2], 3)
    self.assertEqual(
        features['objects/shape/dimension'].get_shape().as_list()[2], 3)
    self.assertLen(features['objects/category/label'].get_shape().as_list(), 2)