Ejemplo n.º 1
0
  def test_stack_intratask_episodes(self):
    feature_spec = TSpec()
    feature_spec.image = utils.ExtendedTensorSpec(
        shape=_DEFAULT_IN_IMAGE_SHAPE,
        dtype=tf.uint8,
        is_optional=False,
        data_format='jpeg',
        name='state/image')
    feature_spec.action = utils.ExtendedTensorSpec(
        shape=_DEFAULT_ACTION_SHAPE,
        dtype=tf.float32,
        is_optional=False,
        name='state/action')

    batch_size = 2
    num_samples_in_task = 3
    metaexample_spec = preprocessors.create_metaexample_spec(
        feature_spec, num_samples_in_task, 'condition')
    tensors = utils.make_random_numpy(metaexample_spec, batch_size)
    out_tensors = preprocessors.stack_intra_task_episodes(
        tensors, num_samples_in_task)

    self.assertEqual(
        out_tensors.image.shape,
        (batch_size, num_samples_in_task) + _DEFAULT_IN_IMAGE_SHAPE)
    self.assertEqual(
        out_tensors.action.shape,
        (batch_size, num_samples_in_task) + _DEFAULT_ACTION_SHAPE)
Ejemplo n.º 2
0
  def test_create_metaexample_spec(self):
    feature_spec = TSpec()
    feature_spec.image = utils.ExtendedTensorSpec(
        shape=_DEFAULT_IN_IMAGE_SHAPE,
        dtype=tf.uint8,
        is_optional=False,
        data_format='jpeg',
        name='state/image')
    feature_spec.action = utils.ExtendedTensorSpec(
        shape=_DEFAULT_ACTION_SHAPE,
        dtype=tf.float32,
        is_optional=False,
        name='state/action')

    num_samples_in_task = 3
    metaexample_spec = preprocessors.create_metaexample_spec(
        feature_spec, num_samples_in_task, 'condition')

    flat_feature_spec = utils.flatten_spec_structure(feature_spec)
    self.assertLen(
        list(metaexample_spec.keys()),
        num_samples_in_task * len(list(flat_feature_spec.keys())))

    for key in flat_feature_spec:
      for i in range(num_samples_in_task):
        meta_example_key = six.ensure_str(key) + '/{:d}'.format(i)
        self.assertIn(meta_example_key, list(metaexample_spec.keys()))
        self.assertTrue(
            six.ensure_str(metaexample_spec[meta_example_key].name).startswith(
                'condition_ep'))