def create_exploration_agent(sess, environment, agent_name=None, summary_writer=None, debug_mode=False): """Creates an exploration agent. Args: sess: A `tf.Session` object for running associated ops. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. Agent supported are dqn_cts and rainbow_cts. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, whether to output Tensorboard summaries. If set to true, the agent will output in-episode statistics to Tensorboard. Disabled by default as this results in slower training. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ assert agent_name is not None if not debug_mode: summary_writer = None if agent_name == 'dqn_cts': return intrinsic_dqn_agent.CTSDQNAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow_cts': return intrinsic_rainbow_agent.CTSRainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) if agent_name == 'dqn_pixelcnn': return intrinsic_dqn_agent.PixelCNNDQNAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow_pixelcnn': return intrinsic_rainbow_agent.PixelCNNRainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'dqn_rnd': return intrinsic_dqn_agent.RNDDQNAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow_rnd': return intrinsic_rainbow_agent.RNDRainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'noisy_dqn': return noisy_dqn_agent.NoisyDQNAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'noisy_rainbow': return noisy_rainbow_agent.NoisyRainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) else: return run_experiment.create_agent(sess, environment, agent_name, summary_writer, debug_mode)
def testCreateImplicitQuantileAgent(self, mock_implicit_quantile_agent): def mock_fn(unused_sess, num_actions, summary_writer): del summary_writer return num_actions * 10 mock_implicit_quantile_agent.side_effect = mock_fn environment = mock.Mock() environment.action_space.n = 7 self.assertEqual( 70, run_experiment.create_agent(self.test_session(), environment, agent_name='implicit_quantile'))
def create_parametric_agent(sess, environment, agent_name=None, summary_writer=None, debug_mode=False): """Creates an agent. Args: sess: A `tf.Session` object for running associated ops. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, whether to output Tensorboard summaries. If set to true, the agent will output in-episode statistics to Tensorboard. Disabled by default as this results in slower training. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ assert agent_name is not None if not debug_mode: summary_writer = None if agent_name == 'parametric_dqn': return parametric_agents.ParametricDQNAgent( sess, num_actions=environment.action_space.n, environment=environment.environment, summary_writer=summary_writer) elif agent_name == 'parametric_rainbow': return parametric_agents.ParametricRainbowAgent( sess, num_actions=environment.action_space.n, environment=environment.environment, summary_writer=summary_writer) elif agent_name == 'parametric_implicit_quantile': return parametric_agents.ParametricImplicitQuantileAgent( sess, num_actions=environment.action_space.n, environment=environment.environment, summary_writer=summary_writer) else: return run_experiment.create_agent(sess, environment, agent_name, summary_writer, debug_mode)
def testDefaultDQNConfig(self): """Verify the default DQN configuration.""" run_experiment.load_gin_configs( ['dopamine/agents/dqn/configs/dqn.gin'], []) agent = run_experiment.create_agent( self.test_session(), atari_lib.create_atari_environment(game_name='Pong')) self.assertEqual(agent.gamma, 0.99) self.assertEqual(agent.update_horizon, 1) self.assertEqual(agent.min_replay_history, 20000) self.assertEqual(agent.update_period, 4) self.assertEqual(agent.target_update_period, 8000) self.assertEqual(agent.epsilon_train, 0.01) self.assertEqual(agent.epsilon_eval, 0.001) self.assertEqual(agent.epsilon_decay_period, 250000) self.assertEqual(agent._replay.memory._replay_capacity, 1000000) self.assertEqual(agent._replay.memory._batch_size, 32)
def testDefaultRainbowConfig(self): """Verify the default Rainbow configuration.""" run_experiment.load_gin_configs( ['dopamine/agents/rainbow/configs/rainbow.gin'], []) agent = run_experiment.create_agent( self.test_session(), atari_lib.create_atari_environment(game_name='Pong')) self.assertEqual(agent._num_atoms, 51) support = self.evaluate(agent._support) self.assertEqual(min(support), -10.) self.assertEqual(max(support), 10.) self.assertEqual(len(support), 51) self.assertEqual(agent.gamma, 0.99) self.assertEqual(agent.update_horizon, 3) self.assertEqual(agent.min_replay_history, 20000) self.assertEqual(agent.update_period, 4) self.assertEqual(agent.target_update_period, 8000) self.assertEqual(agent.epsilon_train, 0.01) self.assertEqual(agent.epsilon_eval, 0.001) self.assertEqual(agent.epsilon_decay_period, 250000) self.assertEqual(agent._replay.memory._replay_capacity, 1000000) self.assertEqual(agent._replay.memory._batch_size, 32)
def testNoAgentName(self): with self.assertRaises(AssertionError): _ = run_experiment.create_agent(self.test_session(), mock.Mock())