Beispiel #1
0
  def _get_dummy_input(self, input_type, module=None):
    """Get dummy input for the given input type."""

    if input_type == 'image_tensor':
      images = np.random.randint(
          low=0, high=255, size=(1, 8, 64, 64, 3), dtype=np.uint8)
      # images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
      return images, images
    elif input_type == 'tf_example':
      example = tfexample_utils.make_video_test_example(
          image_shape=(64, 64, 3),
          audio_shape=(20, 128),
          label=random.randint(0, 100)).SerializeToString()
      images = tf.nest.map_structure(
          tf.stop_gradient,
          tf.map_fn(
              module._decode_tf_example,
              elems=tf.constant([example]),
              fn_output_signature={
                  video_classification.video_input.IMAGE_KEY: tf.string,
              }))
      images = images[video_classification.video_input.IMAGE_KEY]
      return [example], images
    else:
      raise ValueError(f'{input_type}')
Beispiel #2
0
 def setUp(self):
     super(VideoClassificationTaskTest, self).setUp()
     data_dir = os.path.join(self.get_temp_dir(), 'data')
     tf.io.gfile.makedirs(data_dir)
     self._data_path = os.path.join(data_dir, 'data.tfrecord')
     # pylint: disable=g-complex-comprehension
     examples = [
         tfexample_utils.make_video_test_example(image_shape=(36, 36, 3),
                                                 audio_shape=(20, 128),
                                                 label=random.randint(
                                                     0, 100))
         for _ in range(2)
     ]
     # pylint: enable=g-complex-comprehension
     tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)