示例#1
0
 def preprocessor(self):
     """See base class."""
     base_preprocessor = vrgripper_env_models.DefaultVRGripperPreprocessor(
         model_feature_specification_fn=self._episode_feature_specification,
         model_label_specification_fn=self._episode_label_specification)
     self._preprocessor = preprocessors.FixedLenMetaExamplePreprocessor(
         base_preprocessor=base_preprocessor)
     return self._preprocessor
示例#2
0
  def test_meta_example_preprocess(
      self,
      num_condition_samples_per_task,
      num_inference_samples_per_task,
      outer_batch_size):
    base_preprocessor = MockBasePreprocessor()
    meta_example_preprocessor = preprocessors.FixedLenMetaExamplePreprocessor(
        base_preprocessor=base_preprocessor,
        num_condition_samples_per_task=num_condition_samples_per_task,
        num_inference_samples_per_task=num_inference_samples_per_task)
    mock_tensors = self._create_mock_tensors(
        meta_example_preprocessor, outer_batch_size)

    with self.session() as sess:
      dataset = tf.data.Dataset.from_tensor_slices(mock_tensors)

      dataset = dataset.batch(outer_batch_size, drop_remainder=True)

      preprocess_fn = functools.partial(
          meta_example_preprocessor.preprocess,
          mode=tf.estimator.ModeKeys.TRAIN)
      dataset = dataset.map(map_func=preprocess_fn, num_parallel_calls=1)

      raw_meta_features, raw_meta_labels = (
          dataset.make_one_shot_iterator().get_next())
      np_raw_meta_features, np_raw_meta_labels = sess.run(
          [raw_meta_features, raw_meta_labels])
      ref_features, ref_labels = mock_tensors

      self.assertEqual(
          list(np_raw_meta_features.condition.features.keys()),
          list(np_raw_meta_features.inference.features.keys()))

      # The labels and the condition labels have to have the same keys.
      self.assertEqual(
          list(np_raw_meta_features.condition.labels.keys()),
          list(np_raw_meta_labels.keys()))

      # The image has been resized. Therefore, we ensure that its shape is
      # correct. Note, we have to strip the outer and inner batch dimensions.
      self.assertEqual(
          np_raw_meta_features.condition.features.image.shape[2:],
          _DEFAULT_OUT_IMAGE_SHAPE)
      self.assertEqual(
          np_raw_meta_features.inference.features.image.shape[2:],
          _DEFAULT_OUT_IMAGE_SHAPE)

      for i in range(num_condition_samples_per_task):
        np.testing.assert_array_almost_equal(
            np_raw_meta_features.condition.features['action'][:, i, Ellipsis],
            ref_features['condition/features/action/{:d}'.format(i)])
        for label_name in np_raw_meta_features.condition.labels.keys():
          np.testing.assert_array_almost_equal(
              np_raw_meta_features.condition.labels[label_name][:, i, Ellipsis],
              ref_features['condition/labels/{:s}/{:d}'.format(
                  label_name, i)])

      for i in range(num_inference_samples_per_task):
        np.testing.assert_array_almost_equal(
            np_raw_meta_features.inference.features['action'][:, i, Ellipsis],
            ref_features['inference/features/action/{:d}'.format(i)])
        for label_name in np_raw_meta_features.condition.labels.keys():
          np.testing.assert_array_almost_equal(
              np_raw_meta_labels[label_name][:, i, Ellipsis],
              ref_labels[six.ensure_str(label_name) + '/{:d}'.format(i)])