def preprocessor(self): """See base class.""" base_preprocessor = vrgripper_env_models.DefaultVRGripperPreprocessor( model_feature_specification_fn=self._episode_feature_specification, model_label_specification_fn=self._episode_label_specification) self._preprocessor = preprocessors.FixedLenMetaExamplePreprocessor( base_preprocessor=base_preprocessor) return self._preprocessor
def test_meta_example_preprocess( self, num_condition_samples_per_task, num_inference_samples_per_task, outer_batch_size): base_preprocessor = MockBasePreprocessor() meta_example_preprocessor = preprocessors.FixedLenMetaExamplePreprocessor( base_preprocessor=base_preprocessor, num_condition_samples_per_task=num_condition_samples_per_task, num_inference_samples_per_task=num_inference_samples_per_task) mock_tensors = self._create_mock_tensors( meta_example_preprocessor, outer_batch_size) with self.session() as sess: dataset = tf.data.Dataset.from_tensor_slices(mock_tensors) dataset = dataset.batch(outer_batch_size, drop_remainder=True) preprocess_fn = functools.partial( meta_example_preprocessor.preprocess, mode=tf.estimator.ModeKeys.TRAIN) dataset = dataset.map(map_func=preprocess_fn, num_parallel_calls=1) raw_meta_features, raw_meta_labels = ( dataset.make_one_shot_iterator().get_next()) np_raw_meta_features, np_raw_meta_labels = sess.run( [raw_meta_features, raw_meta_labels]) ref_features, ref_labels = mock_tensors self.assertEqual( list(np_raw_meta_features.condition.features.keys()), list(np_raw_meta_features.inference.features.keys())) # The labels and the condition labels have to have the same keys. self.assertEqual( list(np_raw_meta_features.condition.labels.keys()), list(np_raw_meta_labels.keys())) # The image has been resized. Therefore, we ensure that its shape is # correct. Note, we have to strip the outer and inner batch dimensions. self.assertEqual( np_raw_meta_features.condition.features.image.shape[2:], _DEFAULT_OUT_IMAGE_SHAPE) self.assertEqual( np_raw_meta_features.inference.features.image.shape[2:], _DEFAULT_OUT_IMAGE_SHAPE) for i in range(num_condition_samples_per_task): np.testing.assert_array_almost_equal( np_raw_meta_features.condition.features['action'][:, i, Ellipsis], ref_features['condition/features/action/{:d}'.format(i)]) for label_name in np_raw_meta_features.condition.labels.keys(): np.testing.assert_array_almost_equal( np_raw_meta_features.condition.labels[label_name][:, i, Ellipsis], ref_features['condition/labels/{:s}/{:d}'.format( label_name, i)]) for i in range(num_inference_samples_per_task): np.testing.assert_array_almost_equal( np_raw_meta_features.inference.features['action'][:, i, Ellipsis], ref_features['inference/features/action/{:d}'.format(i)]) for label_name in np_raw_meta_features.condition.labels.keys(): np.testing.assert_array_almost_equal( np_raw_meta_labels[label_name][:, i, Ellipsis], ref_labels[six.ensure_str(label_name) + '/{:d}'.format(i)])