Exemple #1
0
    def test_video_ssl_input_pretrain(self):
        params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data

        decoder = video_ssl_input.Decoder()
        parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
        seq_example, _ = fake_seq_example()

        input_tensor = tf.constant(seq_example.SerializeToString())
        decoded_tensors = decoder.decode(input_tensor)
        output_tensor = parser(decoded_tensors)
        image_features, _ = output_tensor
        image = image_features['image']

        self.assertAllEqual(image.shape, (32, 224, 224, 3))
Exemple #2
0
    def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
        """Builds classification input."""

        parser = video_ssl_input.Parser(input_params=params)
        postprocess_fn = video_ssl_input.PostBatchProcessor(params)

        reader = input_reader.InputReader(
            params,
            dataset_fn=self._get_dataset_fn(params),
            decoder_fn=self._get_decoder_fn(params),
            parser_fn=parser.parse_fn(params.is_training),
            postprocess_fn=postprocess_fn)

        dataset = reader.read(input_context=input_context)

        return dataset
Exemple #3
0
    def test_video_ssl_input_linear_eval(self):
        params = exp_cfg.video_ssl_linear_eval_kinetics600(
        ).task.validation_data
        print('!!!', params)

        decoder = video_ssl_input.Decoder()
        parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
        seq_example, label = fake_seq_example()

        input_tensor = tf.constant(seq_example.SerializeToString())
        decoded_tensors = decoder.decode(input_tensor)
        output_tensor = parser(decoded_tensors)
        image_features, label = output_tensor
        image = image_features['image']

        self.assertAllEqual(image.shape, (960, 256, 256, 3))
        self.assertAllEqual(label.shape, (600, ))