def create_exploration_agent(sess, environment, agent_name=None,
                             summary_writer=None, debug_mode=False):
  """Creates an exploration agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create. Agent supported are dqn_cts
      and rainbow_cts.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
  assert agent_name is not None
  if not debug_mode:
    summary_writer = None
  if agent_name == 'dqn_cts':
    return intrinsic_dqn_agent.CTSDQNAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'rainbow_cts':
    return intrinsic_rainbow_agent.CTSRainbowAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  if agent_name == 'dqn_pixelcnn':
    return intrinsic_dqn_agent.PixelCNNDQNAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'rainbow_pixelcnn':
    return intrinsic_rainbow_agent.PixelCNNRainbowAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'dqn_rnd':
    return intrinsic_dqn_agent.RNDDQNAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'rainbow_rnd':
    return intrinsic_rainbow_agent.RNDRainbowAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'noisy_dqn':
    return noisy_dqn_agent.NoisyDQNAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'noisy_rainbow':
    return noisy_rainbow_agent.NoisyRainbowAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  else:
    return run_experiment.create_agent(sess, environment, agent_name,
                                       summary_writer, debug_mode)
Beispiel #2
0
    def testCreateImplicitQuantileAgent(self, mock_implicit_quantile_agent):
        def mock_fn(unused_sess, num_actions, summary_writer):
            del summary_writer
            return num_actions * 10

        mock_implicit_quantile_agent.side_effect = mock_fn
        environment = mock.Mock()
        environment.action_space.n = 7
        self.assertEqual(
            70,
            run_experiment.create_agent(self.test_session(),
                                        environment,
                                        agent_name='implicit_quantile'))
Beispiel #3
0
def create_parametric_agent(sess,
                            environment,
                            agent_name=None,
                            summary_writer=None,
                            debug_mode=False):
    """Creates an agent.

    Args:
        sess: A `tf.Session` object for running associated ops.
        environment: A gym environment (e.g. Atari 2600).
        agent_name: str, name of the agent to create.
        summary_writer: A Tensorflow summary writer to pass to the agent
        for in-agent training statistics in Tensorboard.
        debug_mode: bool, whether to output Tensorboard summaries. If set to true,
        the agent will output in-episode statistics to Tensorboard. Disabled by
        default as this results in slower training.

    Returns:
        agent: An RL agent.

    Raises:
        ValueError: If `agent_name` is not in supported list.
    """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'parametric_dqn':
        return parametric_agents.ParametricDQNAgent(
            sess,
            num_actions=environment.action_space.n,
            environment=environment.environment,
            summary_writer=summary_writer)
    elif agent_name == 'parametric_rainbow':
        return parametric_agents.ParametricRainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            environment=environment.environment,
            summary_writer=summary_writer)
    elif agent_name == 'parametric_implicit_quantile':
        return parametric_agents.ParametricImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            environment=environment.environment,
            summary_writer=summary_writer)
    else:
        return run_experiment.create_agent(sess, environment, agent_name,
                                           summary_writer, debug_mode)
Beispiel #4
0
 def testDefaultDQNConfig(self):
     """Verify the default DQN configuration."""
     run_experiment.load_gin_configs(
         ['dopamine/agents/dqn/configs/dqn.gin'], [])
     agent = run_experiment.create_agent(
         self.test_session(),
         atari_lib.create_atari_environment(game_name='Pong'))
     self.assertEqual(agent.gamma, 0.99)
     self.assertEqual(agent.update_horizon, 1)
     self.assertEqual(agent.min_replay_history, 20000)
     self.assertEqual(agent.update_period, 4)
     self.assertEqual(agent.target_update_period, 8000)
     self.assertEqual(agent.epsilon_train, 0.01)
     self.assertEqual(agent.epsilon_eval, 0.001)
     self.assertEqual(agent.epsilon_decay_period, 250000)
     self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
     self.assertEqual(agent._replay.memory._batch_size, 32)
Beispiel #5
0
 def testDefaultRainbowConfig(self):
     """Verify the default Rainbow configuration."""
     run_experiment.load_gin_configs(
         ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
     agent = run_experiment.create_agent(
         self.test_session(),
         atari_lib.create_atari_environment(game_name='Pong'))
     self.assertEqual(agent._num_atoms, 51)
     support = self.evaluate(agent._support)
     self.assertEqual(min(support), -10.)
     self.assertEqual(max(support), 10.)
     self.assertEqual(len(support), 51)
     self.assertEqual(agent.gamma, 0.99)
     self.assertEqual(agent.update_horizon, 3)
     self.assertEqual(agent.min_replay_history, 20000)
     self.assertEqual(agent.update_period, 4)
     self.assertEqual(agent.target_update_period, 8000)
     self.assertEqual(agent.epsilon_train, 0.01)
     self.assertEqual(agent.epsilon_eval, 0.001)
     self.assertEqual(agent.epsilon_decay_period, 250000)
     self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
     self.assertEqual(agent._replay.memory._batch_size, 32)
Beispiel #6
0
 def testNoAgentName(self):
     with self.assertRaises(AssertionError):
         _ = run_experiment.create_agent(self.test_session(), mock.Mock())