Beispiel #1
0
  def _tensor_specs(method_name, unused_kwargs, constructor_kwargs):
    """Returns a nest of `TensorSpec` with the method's output specification."""
    observation_spec = [
        contrib_framework.TensorSpec([FLAGS.height, FLAGS.width, 3], tf.uint8),
        contrib_framework.TensorSpec([FLAGS.graph_height, FLAGS.graph_width, 3],
                                     tf.uint8),
        contrib_framework.TensorSpec([
            2,
        ], tf.float64),
        contrib_framework.TensorSpec([
            2,
        ], tf.float64),
        contrib_framework.TensorSpec([], tf.float64),
        contrib_framework.TensorSpec([], tf.uint8),
        contrib_framework.TensorSpec([], tf.int32),
        contrib_framework.TensorSpec([], tf.int32),
    ]

    if method_name == 'initial':
      return observation_spec
    elif method_name == 'step':
      return (
          contrib_framework.TensorSpec([], tf.float32),
          contrib_framework.TensorSpec([], tf.bool),
          observation_spec,
      )
Beispiel #2
0
 def test_extended_from_spec(self):
     desc = contrib_framework.TensorSpec(shape=[1], dtype=np.float32)
     extended_desc = utils.ExtendedTensorSpec.from_spec(desc)
     self.assertEqual(desc, extended_desc)