예제 #1
0
  def __init__(self, batch_env):
    """Batch of environments inside the TensorFlow graph.

    Args:
      batch_env: Batch environment.
    """
    self._batch_env = batch_env
    observ_shape = utils.parse_shape(self._batch_env.observation_space)
    observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
    self.action_shape = list(utils.parse_shape(self._batch_env.action_space))
    self.action_dtype = utils.parse_dtype(self._batch_env.action_space)
    with tf.variable_scope('env_temporary'):
      self._observ = tf.Variable(
          tf.zeros((len(self._batch_env),) + observ_shape, observ_dtype),
          name='observ', trainable=False)
예제 #2
0
    def simulate(self, action):
        """Step the batch of environments.

    The results of the step can be accessed from the variables defined below.

    Args:
      action: Tensor holding the batch of actions to apply.

    Returns:
      Operation.
    """
        with tf.name_scope('environment/simulate'):
            if action.dtype in (tf.float16, tf.float32, tf.float64):
                action = tf.check_numerics(action, 'action')
            observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
            observ, reward, done = tf.py_func(
                lambda a: self._batch_env.step(a)[:3], [action],
                [observ_dtype, tf.float32, tf.bool],
                name='step')
            observ = tf.check_numerics(observ, 'observ')
            reward = tf.check_numerics(reward, 'reward')
            reward.set_shape((len(self), ))
            done.set_shape((len(self), ))
            with tf.control_dependencies([self._observ.assign(observ)]):
                return tf.identity(reward), tf.identity(done)
예제 #3
0
    def __init__(self, batch_env):
        """Batch of environments inside the TensorFlow graph.

    Args:
      batch_env: Batch environment.
    """
        self._batch_env = batch_env
        observ_shape = utils.parse_shape(self._batch_env.observation_space)
        observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
        self.action_shape = list(
            utils.parse_shape(self._batch_env.action_space))
        self.action_dtype = utils.parse_dtype(self._batch_env.action_space)
        with tf.variable_scope('env_temporary'):
            self._observ = tf.Variable(tf.zeros(
                (len(self._batch_env), ) + observ_shape, observ_dtype),
                                       name='observ',
                                       trainable=False)
예제 #4
0
  def _reset_non_empty(self, indices):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    observ = tf.check_numerics(observ, 'observ')
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ)]):
      return tf.identity(observ)
예제 #5
0
    def _reset_non_empty(self, indices):
        """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
        observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
        observ = tf.py_func(self._batch_env.reset, [indices],
                            observ_dtype,
                            name='reset')
        observ = tf.check_numerics(observ, 'observ')
        with tf.control_dependencies(
            [tf.scatter_update(self._observ, indices, observ)]):
            return tf.identity(observ)
예제 #6
0
  def simulate(self, action):
    """Step the batch of environments.

    The results of the step can be accessed from the variables defined below.

    Args:
      action: Tensor holding the batch of actions to apply.

    Returns:
      Operation.
    """
    with tf.name_scope('environment/simulate'):
      if action.dtype in (tf.float16, tf.float32, tf.float64):
        action = tf.check_numerics(action, 'action')
      observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
      observ, reward, done = tf.py_func(
          lambda a: self._batch_env.step(a)[:3], [action],
          [observ_dtype, tf.float32, tf.bool], name='step')
      observ = tf.check_numerics(observ, 'observ')
      reward = tf.check_numerics(reward, 'reward')
      reward.set_shape((len(self),))
      done.set_shape((len(self),))
      with tf.control_dependencies([self._observ.assign(observ)]):
        return tf.identity(reward), tf.identity(done)
예제 #7
0
 def action_dtype(self):
   return utils.parse_dtype(self.action_space)
예제 #8
0
 def observ_dtype(self):
   return utils.parse_dtype(self.observ_space)