def test_stack_intratask_episodes(self): feature_spec = TSpec() feature_spec.image = utils.ExtendedTensorSpec( shape=_DEFAULT_IN_IMAGE_SHAPE, dtype=tf.uint8, is_optional=False, data_format='jpeg', name='state/image') feature_spec.action = utils.ExtendedTensorSpec( shape=_DEFAULT_ACTION_SHAPE, dtype=tf.float32, is_optional=False, name='state/action') batch_size = 2 num_samples_in_task = 3 metaexample_spec = preprocessors.create_metaexample_spec( feature_spec, num_samples_in_task, 'condition') tensors = utils.make_random_numpy(metaexample_spec, batch_size) out_tensors = preprocessors.stack_intra_task_episodes( tensors, num_samples_in_task) self.assertEqual( out_tensors.image.shape, (batch_size, num_samples_in_task) + _DEFAULT_IN_IMAGE_SHAPE) self.assertEqual( out_tensors.action.shape, (batch_size, num_samples_in_task) + _DEFAULT_ACTION_SHAPE)
def test_create_metaexample_spec(self): feature_spec = TSpec() feature_spec.image = utils.ExtendedTensorSpec( shape=_DEFAULT_IN_IMAGE_SHAPE, dtype=tf.uint8, is_optional=False, data_format='jpeg', name='state/image') feature_spec.action = utils.ExtendedTensorSpec( shape=_DEFAULT_ACTION_SHAPE, dtype=tf.float32, is_optional=False, name='state/action') num_samples_in_task = 3 metaexample_spec = preprocessors.create_metaexample_spec( feature_spec, num_samples_in_task, 'condition') flat_feature_spec = utils.flatten_spec_structure(feature_spec) self.assertLen( list(metaexample_spec.keys()), num_samples_in_task * len(list(flat_feature_spec.keys()))) for key in flat_feature_spec: for i in range(num_samples_in_task): meta_example_key = six.ensure_str(key) + '/{:d}'.format(i) self.assertIn(meta_example_key, list(metaexample_spec.keys())) self.assertTrue( six.ensure_str(metaexample_spec[meta_example_key].name).startswith( 'condition_ep'))