コード例 #1
0
  def test_video_ssl_input_pretrain(self):
    params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data

    decoder = video_ssl_input.Decoder()
    parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
    seq_example, _ = fake_seq_example()

    input_tensor = tf.constant(seq_example.SerializeToString())
    decoded_tensors = decoder.decode(input_tensor)
    output_tensor = parser(decoded_tensors)
    image_features, _ = output_tensor
    image = image_features['image']

    self.assertAllEqual(image.shape, (32, 224, 224, 3))
コード例 #2
0
  def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
    """Builds classification input."""

    parser = video_ssl_input.Parser(input_params=params)
    postprocess_fn = video_ssl_input.PostBatchProcessor(params)

    reader = input_reader.InputReader(
        params,
        dataset_fn=self._get_dataset_fn(params),
        decoder_fn=self._get_decoder_fn(params),
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset
コード例 #3
0
  def test_video_ssl_input_linear_eval(self):
    params = exp_cfg.video_ssl_linear_eval_kinetics600().task.validation_data
    print('!!!', params)

    decoder = video_ssl_input.Decoder()
    parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
    seq_example, label = fake_seq_example()

    input_tensor = tf.constant(seq_example.SerializeToString())
    decoded_tensors = decoder.decode(input_tensor)
    output_tensor = parser(decoded_tensors)
    image_features, label = output_tensor
    image = image_features['image']

    self.assertAllEqual(image.shape, (960, 256, 256, 3))
    self.assertAllEqual(label.shape, (600,))