コード例 #1
0
  def test_video_ssl_input_pretrain(self):
    params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data

    decoder = video_ssl_input.Decoder()
    parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
    seq_example, _ = fake_seq_example()

    input_tensor = tf.constant(seq_example.SerializeToString())
    decoded_tensors = decoder.decode(input_tensor)
    output_tensor = parser(decoded_tensors)
    image_features, _ = output_tensor
    image = image_features['image']

    self.assertAllEqual(image.shape, (32, 224, 224, 3))
コード例 #2
0
  def test_video_ssl_input_linear_eval(self):
    params = exp_cfg.video_ssl_linear_eval_kinetics600().task.validation_data
    print('!!!', params)

    decoder = video_ssl_input.Decoder()
    parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
    seq_example, label = fake_seq_example()

    input_tensor = tf.constant(seq_example.SerializeToString())
    decoded_tensors = decoder.decode(input_tensor)
    output_tensor = parser(decoded_tensors)
    image_features, label = output_tensor
    image = image_features['image']

    self.assertAllEqual(image.shape, (960, 256, 256, 3))
    self.assertAllEqual(label.shape, (600,))
コード例 #3
0
 def _get_decoder_fn(self, params):
   decoder = video_ssl_input.Decoder()
   return decoder.decode