def _get_dummy_input(self, input_type, module=None): """Get dummy input for the given input type.""" if input_type == 'image_tensor': images = np.random.randint( low=0, high=255, size=(1, 8, 64, 64, 3), dtype=np.uint8) # images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8) return images, images elif input_type == 'tf_example': example = tfexample_utils.make_video_test_example( image_shape=(64, 64, 3), audio_shape=(20, 128), label=random.randint(0, 100)).SerializeToString() images = tf.nest.map_structure( tf.stop_gradient, tf.map_fn( module._decode_tf_example, elems=tf.constant([example]), fn_output_signature={ video_classification.video_input.IMAGE_KEY: tf.string, })) images = images[video_classification.video_input.IMAGE_KEY] return [example], images else: raise ValueError(f'{input_type}')
def setUp(self): super(VideoClassificationTaskTest, self).setUp() data_dir = os.path.join(self.get_temp_dir(), 'data') tf.io.gfile.makedirs(data_dir) self._data_path = os.path.join(data_dir, 'data.tfrecord') # pylint: disable=g-complex-comprehension examples = [ tfexample_utils.make_video_test_example(image_shape=(36, 36, 3), audio_shape=(20, 128), label=random.randint( 0, 100)) for _ in range(2) ] # pylint: enable=g-complex-comprehension tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)