def _prepare_networks(self, hparams, sess):
    self.action = tf.placeholder(shape=(1,), dtype=tf.int32)
    batch_env = SimulatedBatchEnv(hparams.environment_spec, hparams.num_agents)
    self.reward, self.done = batch_env.simulate(self.action)
    self.observation = batch_env.observ
    self.reset_op = batch_env.reset(tf.constant([0], dtype=tf.int32))

    environment_wrappers = hparams.environment_spec.wrappers
    wrappers = copy.copy(environment_wrappers) if environment_wrappers else []

    to_initialize = [batch_env]
    for w in wrappers:
      batch_env = w[0](batch_env, **w[1])
      to_initialize.append(batch_env)

    def initialization_lambda():
      for batch_env in to_initialize:
        batch_env.initialize(sess)

    self.initialize = initialization_lambda

    obs_copy = batch_env.observ + 0

    actor_critic = get_policy(tf.expand_dims(obs_copy, 0), hparams)
    self.policy_probs = actor_critic.policy.probs[0, 0, :]
    self.value = actor_critic.value[0, :]
Ejemplo n.º 2
0
class SimulatedBatchGymEnv(Env):
    """SimulatedBatchEnv in a Gym-like interface, environments are  batched."""
    def __init__(self, *args, **kwargs):
        with tf.Graph().as_default():
            self._batch_env = SimulatedBatchEnv(*args, **kwargs)

            self._actions_t = tf.placeholder(shape=(self.batch_size, ),
                                             dtype=tf.int32)
            self._rewards_t, self._dones_t = self._batch_env.simulate(
                self._actions_t)
            with tf.control_dependencies([self._rewards_t]):
                self._obs_t = self._batch_env.observ
            self._indices_t = tf.placeholder(shape=(self.batch_size, ),
                                             dtype=tf.int32)
            self._reset_op = self._batch_env.reset(
                tf.range(self.batch_size, dtype=tf.int32))

            self._sess = tf.Session()
            self._sess.run(tf.global_variables_initializer())
            self._batch_env.initialize(self._sess)

    @property
    def batch_size(self):
        return self._batch_env.batch_size

    @property
    def observation_space(self):
        return self._batch_env.observ_space

    @property
    def action_space(self):
        return self._batch_env.action_space

    def render(self, mode="human"):
        raise NotImplementedError()

    def reset(self, indices=None):
        if indices is None:
            indices = np.array(range(self.batch_size))
        obs = self._sess.run(self._reset_op,
                             feed_dict={self._indices_t: indices})
        # TODO(pmilos): remove if possible
        # obs[:, 0, 0, 0] = 0
        # obs[:, 0, 0, 1] = 255
        return obs

    def step(self, actions):
        obs, rewards, dones = self._sess.run(
            [self._obs_t, self._rewards_t, self._dones_t],
            feed_dict={self._actions_t: actions})
        return obs, rewards, dones

    def close(self):
        self._sess.close()
class SimulatedBatchGymEnv(Env):
  """SimulatedBatchEnv in a Gym-like interface, environments are  batched."""

  def __init__(self, *args, **kwargs):
    with tf.Graph().as_default():
      self._batch_env = SimulatedBatchEnv(*args, **kwargs)

      self._actions_t = tf.placeholder(shape=(self.batch_size,), dtype=tf.int32)
      self._rewards_t, self._dones_t = self._batch_env.simulate(self._actions_t)
      self._obs_t = self._batch_env.observ
      self._reset_op = self._batch_env.reset(
          tf.range(self.batch_size, dtype=tf.int32)
      )

      self._sess = tf.Session()
      self._sess.run(tf.global_variables_initializer())
      self._batch_env.initialize(self._sess)

  @property
  def batch_size(self):
    return self._batch_env.batch_size

  @property
  def observation_space(self):
    return self._batch_env.observ_space

  @property
  def action_space(self):
    return self._batch_env.action_space

  def render(self, mode="human"):
    raise NotImplementedError()

  def reset(self, indices=None):
    if indices:
      raise NotImplementedError()
    obs = self._sess.run(self._reset_op)
    # TODO(pmilos): remove if possible
    # obs[:, 0, 0, 0] = 0
    # obs[:, 0, 0, 1] = 255
    return obs

  def step(self, actions):
    obs, rewards, dones = self._sess.run(
        [self._obs_t, self._rewards_t, self._dones_t],
        feed_dict={self._actions_t: actions})
    return obs, rewards, dones

  def close(self):
    self._sess.close()
Ejemplo n.º 4
0
class SimulatedBatchGymEnv(Env):
    """SimulatedBatchEnv in a Gym-like interface, environments are  batched."""
    def __init__(self,
                 environment_spec,
                 batch_size,
                 model_dir=None,
                 sess=None):
        self.batch_size = batch_size

        with tf.Graph().as_default():
            self._batch_env = SimulatedBatchEnv(environment_spec,
                                                self.batch_size)

            self.action_space = self._batch_env.action_space
            # TODO(kc): check for the stack wrapper and correct number of channels in
            # observation_space
            self.observation_space = self._batch_env.observ_space
            self._sess = sess if sess is not None else tf.Session()
            self._to_initialize = [self._batch_env]

            environment_wrappers = environment_spec.wrappers
            wrappers = copy.copy(
                environment_wrappers) if environment_wrappers else []

            for w in wrappers:
                self._batch_env = w[0](self._batch_env, **w[1])
                self._to_initialize.append(self._batch_env)

            self._sess.run(tf.global_variables_initializer())
            for wrapped_env in self._to_initialize:
                wrapped_env.initialize(self._sess)

            self._actions_t = tf.placeholder(shape=(batch_size, ),
                                             dtype=tf.int32)
            self._rewards_t, self._dones_t = self._batch_env.simulate(
                self._actions_t)
            self._obs_t = self._batch_env.observ
            self._reset_op = self._batch_env.reset(
                tf.range(batch_size, dtype=tf.int32))

            env_model_loader = tf.train.Saver(
                var_list=tf.global_variables(scope="next_frame*"))  # pylint:disable=unexpected-keyword-arg
            trainer_lib.restore_checkpoint(model_dir,
                                           saver=env_model_loader,
                                           sess=self._sess,
                                           must_restore=True)

    def render(self, mode="human"):
        raise NotImplementedError()

    def reset(self, indices=None):
        if indices:
            raise NotImplementedError()
        obs = self._sess.run(self._reset_op)
        # TODO(pmilos): remove if possible
        # obs[:, 0, 0, 0] = 0
        # obs[:, 0, 0, 1] = 255
        return obs

    def step(self, actions):
        obs, rewards, dones = self._sess.run(
            [self._obs_t, self._rewards_t, self._dones_t],
            feed_dict={self._actions_t: actions})
        return obs, rewards, dones