def test_video_ssl_input_pretrain(self): params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data decoder = video_ssl_input.Decoder() parser = video_ssl_input.Parser(params).parse_fn(params.is_training) seq_example, _ = fake_seq_example() input_tensor = tf.constant(seq_example.SerializeToString()) decoded_tensors = decoder.decode(input_tensor) output_tensor = parser(decoded_tensors) image_features, _ = output_tensor image = image_features['image'] self.assertAllEqual(image.shape, (32, 224, 224, 3))
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None): """Builds classification input.""" parser = video_ssl_input.Parser(input_params=params) postprocess_fn = video_ssl_input.PostBatchProcessor(params) reader = input_reader.InputReader( params, dataset_fn=self._get_dataset_fn(params), decoder_fn=self._get_decoder_fn(params), parser_fn=parser.parse_fn(params.is_training), postprocess_fn=postprocess_fn) dataset = reader.read(input_context=input_context) return dataset
def test_video_ssl_input_linear_eval(self): params = exp_cfg.video_ssl_linear_eval_kinetics600( ).task.validation_data print('!!!', params) decoder = video_ssl_input.Decoder() parser = video_ssl_input.Parser(params).parse_fn(params.is_training) seq_example, label = fake_seq_example() input_tensor = tf.constant(seq_example.SerializeToString()) decoded_tensors = decoder.decode(input_tensor) output_tensor = parser(decoded_tensors) image_features, label = output_tensor image = image_features['image'] self.assertAllEqual(image.shape, (960, 256, 256, 3)) self.assertAllEqual(label.shape, (600, ))