コード例 #1
0
 def create_agent_fn(sess, environment, summary_writer):
   """Creates the appropriate agent."""
   if agent_type == 'dqn':
     return dqn_agent.DQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'iqn':
     return implicit_quantile_agent.ImplicitQuantileAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'al_dqn':
     return al_dqn.ALDQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'al_iqn':
     return al_iqn.ALImplicitQuantileAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'sail_dqn':
     return sail_dqn.SAILDQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'sail_iqn':
     return sail_iqn.SAILImplicitQuantileAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   else:
     raise ValueError('Wrong agent %s' % agent_type)
コード例 #2
0
ファイル: train_rgb.py プロジェクト: kbehouse/dopamine
def create_agent(sess, environment, summary_writer=None):
    """Creates a DQN agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    environment: An Atari 2600 Gym environment.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    if not FLAGS.debug_mode:
        summary_writer = None
    if FLAGS.agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=5,
                                  summary_writer=summary_writer)
    elif FLAGS.agent_name == 'rainbow':
        return rainbow_rgb_agent.RainbowRGBAgent(sess,
                                                 num_actions=5,
                                                 summary_writer=summary_writer)
    elif FLAGS.agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess, num_actions=5, summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
コード例 #3
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
コード例 #4
0
 def testCreateAgentWithDefaults(self):
   # Verifies that we can create and train an agent with the default values.
   with self.test_session(use_gpu=False) as sess:
     agent = implicit_quantile_agent.ImplicitQuantileAgent(sess, num_actions=4)
     sess.run(tf.compat.v1.global_variables_initializer())
     observation = np.ones([84, 84, 1])
     agent.begin_episode(observation)
     agent.step(reward=1, observation=observation)
     agent.end_episode(reward=1)
コード例 #5
0
def create_agent(sess, agent_name, num_actions,
                 observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE,
                 observation_dtype=atari_lib.NATURE_DQN_DTYPE,
                 stack_size=atari_lib.NATURE_DQN_STACK_SIZE,
                 summary_writer=None):
  """Creates an agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    agent_name: str, name of the agent to create.
    num_actions: int, number of actions the agent can take at any state.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    observation_shape: tuple of ints describing the observation shape.
    observation_dtype: tf.DType, specifies the type of the observations. Note
      that if your inputs are continuous, you should set this to tf.float32.
    stack_size: int, number of frames to use in state stack.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list or one of the
      GAIRL submodules is not in supported list when the chosen agent is GAIRL.
  """
  if agent_name == 'dqn':
    return dqn_agent.DQNAgent(
      sess, num_actions, observation_shape=observation_shape,
      observation_dtype=observation_dtype, stack_size=stack_size,
      summary_writer=summary_writer
    )
  elif agent_name == 'rainbow':
    return rainbow_agent.RainbowAgent(
      sess, num_actions, observation_shape=observation_shape,
      observation_dtype=observation_dtype, stack_size=stack_size,
      summary_writer=summary_writer
    )
  elif agent_name == 'implicit_quantile':
    return implicit_quantile_agent.ImplicitQuantileAgent(
      sess, num_actions, summary_writer=summary_writer
    )
  else:
    raise ValueError('Unknown agent: {}'.format(agent_name))
コード例 #6
0
    def _create_test_agent(self, sess):
        # This dummy network allows us to deterministically anticipate that the
        # state-action-quantile outputs will be equal to sum of the
        # corresponding quantile inputs.
        # State/Quantile shapes will be k x 1, (N x batch_size) x 1,
        # or (N' x batch_size) x 1.

        class MockImplicitQuantileNetwork(tf.keras.Model):
            """Custom tf.keras.Model used in tests."""
            def __init__(self, num_actions, quantile_embedding_dim, **kwargs):
                # This weights_initializer gives action 0 a higher weight, ensuring
                # that it gets picked by the argmax.
                super(MockImplicitQuantileNetwork, self).__init__(**kwargs)
                self.num_actions = num_actions
                self.layer = tf.keras.layers.Dense(
                    self.num_actions,
                    kernel_initializer=tf.ones_initializer(),
                    bias_initializer=tf.zeros_initializer())

            def call(self, state, num_quantiles):
                batch_size = state.get_shape().as_list()[0]
                inputs = tf.constant(np.ones(
                    (batch_size * num_quantiles, self.num_actions)),
                                     dtype=tf.float32)
                quantiles_shape = [num_quantiles * batch_size, 1]
                quantiles = tf.ones(quantiles_shape)
                return atari_lib.ImplicitQuantileNetworkType(
                    self.layer(inputs), quantiles)

        agent = implicit_quantile_agent.ImplicitQuantileAgent(
            sess=sess,
            network=MockImplicitQuantileNetwork,
            num_actions=self._num_actions,
            kappa=1.0,
            num_tau_samples=2,
            num_tau_prime_samples=3,
            num_quantile_samples=4)
        # This ensures non-random action choices (since epsilon_eval = 0.0) and
        # skips the train_step.
        agent.eval_mode = True
        sess.run(tf.global_variables_initializer())
        return agent
コード例 #7
0
def create_agent(sess, environment):
    """Creates a DQN agent.

    Args:
      sess: A `tf.Session` object for running associated ops.
      environment: An Atari 2600 Gym environment.

    Returns:
      agent: An RL agent.

    Raises:
      ValueError: If `agent_name` is not in supported list.
    """
    if FLAGS.agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n)
    elif FLAGS.agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess, num_actions=environment.action_space.n)
    elif FLAGS.agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess, num_actions=environment.action_space.n)
    else:
        raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))