def build_inputs( self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Builds classification input.""" parser = video_input.Parser(input_params=params, image_key=params.image_field_key, label_key=params.label_field_key) postprocess_fn = video_input.PostBatchProcessor(params) if params.mixup_and_cutmix is not None: def mixup_and_cutmix(features, labels): augmenter = augment.MixupAndCutmix( mixup_alpha=params.mixup_and_cutmix.mixup_alpha, cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha, prob=params.mixup_and_cutmix.prob, label_smoothing=params.mixup_and_cutmix.label_smoothing, num_classes=self._get_num_classes()) features['image'], labels = augmenter(features['image'], labels) return features, labels postprocess_fn = mixup_and_cutmix reader = input_reader_factory.input_reader_generator( params, dataset_fn=self._get_dataset_fn(params), decoder_fn=self._get_decoder_fn(params), parser_fn=parser.parse_fn(params.is_training), postprocess_fn=postprocess_fn) dataset = reader.read(input_context=input_context) return dataset
def test_video_audio_input(self): params = exp_cfg.kinetics600(is_training=True) params.feature_shape = (2, 224, 224, 3) params.min_image_size = 224 params.output_audio = True params.audio_feature = AUDIO_KEY params.audio_feature_shape = (15, 256) decoder = video_input.Decoder() decoder.add_feature(params.audio_feature, tf.io.VarLenFeature(dtype=tf.float32)) parser = video_input.Parser(params).parse_fn(params.is_training) seq_example, label = fake_seq_example() input_tensor = tf.constant(seq_example.SerializeToString()) decoded_tensors = decoder.decode(input_tensor) output_tensor = parser(decoded_tensors) features, label = output_tensor image = features['image'] audio = features['audio'] self.assertAllEqual(image.shape, (2, 224, 224, 3)) self.assertAllEqual(label.shape, (600, )) self.assertEqual(audio.shape, (15, 256))
def test_video_input(self): params = exp_cfg.kinetics600(is_training=True) params.feature_shape = (2, 224, 224, 3) params.min_image_size = 224 decoder = video_input.Decoder() parser = video_input.Parser(params).parse_fn(params.is_training) seq_example, label = fake_seq_example() input_tensor = tf.constant(seq_example.SerializeToString()) decoded_tensors = decoder.decode(input_tensor) output_tensor = parser(decoded_tensors) image_features, label = output_tensor image = image_features['image'] self.assertAllEqual(image.shape, (2, 224, 224, 3)) self.assertAllEqual(label.shape, (600, ))
def test_video_input_image_shape_label_type(self): params = exp_cfg.kinetics600(is_training=True) params.feature_shape = (2, 168, 224, 1) params.min_image_size = 168 params.label_dtype = 'float32' params.one_hot = False decoder = video_input.Decoder() parser = video_input.Parser(params).parse_fn(params.is_training) seq_example, label = fake_seq_example() input_tensor = tf.constant(seq_example.SerializeToString()) decoded_tensors = decoder.decode(input_tensor) output_tensor = parser(decoded_tensors) image_features, label = output_tensor image = image_features['image'] self.assertAllEqual(image.shape, (2, 168, 224, 1)) self.assertAllEqual(label.shape, (1,)) self.assertDTypeEqual(label, tf.float32)
def build_inputs(self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Builds classification input.""" parser = video_input.Parser( input_params=params, image_key=params.image_field_key, label_key=params.label_field_key) postprocess_fn = video_input.PostBatchProcessor(params) reader = input_reader_factory.input_reader_generator( params, dataset_fn=self._get_dataset_fn(params), decoder_fn=self._get_decoder_fn(params), parser_fn=parser.parse_fn(params.is_training), postprocess_fn=postprocess_fn) dataset = reader.read(input_context=input_context) return dataset
def test_video_input_augmentation_returns_shape(self): params = exp_cfg.kinetics600(is_training=True) params.feature_shape = (2, 224, 224, 3) params.min_image_size = 224 params.temporal_stride = 2 params.aug_type = common.Augmentation(type='autoaug', autoaug=common.AutoAugment()) decoder = video_input.Decoder() parser = video_input.Parser(params).parse_fn(params.is_training) seq_example, label = fake_seq_example() input_tensor = tf.constant(seq_example.SerializeToString()) decoded_tensors = decoder.decode(input_tensor) output_tensor = parser(decoded_tensors) image_features, label = output_tensor image = image_features['image'] self.assertAllEqual(image.shape, (2, 224, 224, 3)) self.assertAllEqual(label.shape, (600, ))