示例#1
0
    def test_video_ssl_input_pretrain(self):
        params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data

        decoder = video_ssl_input.Decoder()
        parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
        seq_example, _ = fake_seq_example()

        input_tensor = tf.constant(seq_example.SerializeToString())
        decoded_tensors = decoder.decode(input_tensor)
        output_tensor = parser(decoded_tensors)
        image_features, _ = output_tensor
        image = image_features['image']

        self.assertAllEqual(image.shape, (32, 224, 224, 3))
示例#2
0
    def test_video_ssl_input_linear_eval(self):
        params = exp_cfg.video_ssl_linear_eval_kinetics600(
        ).task.validation_data
        print('!!!', params)

        decoder = video_ssl_input.Decoder()
        parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
        seq_example, label = fake_seq_example()

        input_tensor = tf.constant(seq_example.SerializeToString())
        decoded_tensors = decoder.decode(input_tensor)
        output_tensor = parser(decoded_tensors)
        image_features, label = output_tensor
        image = image_features['image']

        self.assertAllEqual(image.shape, (960, 256, 256, 3))
        self.assertAllEqual(label.shape, (600, ))
示例#3
0
 def _get_decoder_fn(self, params):
     decoder = video_ssl_input.Decoder()
     return decoder.decode