示例#1
0
  def testCreateRainbowAgent(self):
    FLAGS.agent_name = 'rainbow'
    with mock.patch.object(train, 'rainbow_agent') as mock_rainbow_agent:

      def mock_fn(unused_sess, num_actions):
        return num_actions * 10

      mock_rainbow_agent.RainbowAgent.side_effect = mock_fn
      environment = mock.Mock()
      environment.action_space.n = 7
      self.assertEqual(70, train.create_agent(self.test_session(), environment))
示例#2
0
    def testCreateRainbowAgent(self):
        FLAGS.agent_name = 'rainbow'
        with mock.patch.object(train, 'rainbow_agent') as mock_rainbow_agent:

            def mock_fn(unused_sess, num_actions):
                return num_actions * 10

            mock_rainbow_agent.RainbowAgent.side_effect = mock_fn
            environment = mock.Mock()
            environment.action_space.n = 7
            self.assertEqual(
                70, train.create_agent(self.test_session(), environment))
示例#3
0
  def testCreateDQNAgent(self):
    FLAGS.agent_name = 'dqn'
    with mock.patch.object(train, 'dqn_agent') as mock_dqn_agent:

      def mock_fn(unused_sess, num_actions, summary_writer):
        del summary_writer
        return num_actions * 10

      mock_dqn_agent.DQNAgent.side_effect = mock_fn
      environment = mock.Mock()
      environment.action_space.n = 7
      self.assertEqual(70, train.create_agent(self.test_session(), environment))
示例#4
0
    def testCreateDQNAgent(self):
        FLAGS.agent_name = 'dqn'
        with mock.patch.object(train, 'dqn_agent') as mock_dqn_agent:

            def mock_fn(unused_sess, num_actions, summary_writer):
                del summary_writer
                return num_actions * 10

            mock_dqn_agent.DQNAgent.side_effect = mock_fn
            environment = mock.Mock()
            environment.action_space.n = 7
            self.assertEqual(
                70, train.create_agent(self.test_session(), environment))
 def testDefaultDQNConfig(self):
   """Verify the default DQN configuration."""
   FLAGS.agent_name = 'dqn'
   run_experiment.load_gin_configs(
       ['dopamine/agents/dqn/configs/dqn.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 1)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)
示例#6
0
 def testDefaultDQNConfig(self):
   """Verify the default DQN configuration."""
   FLAGS.agent_name = 'dqn'
   run_experiment.load_gin_configs(
       ['dopamine/agents/dqn/configs/dqn.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 1)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)
 def testDefaultRainbowConfig(self):
   """Verify the default Rainbow configuration."""
   FLAGS.agent_name = 'rainbow'
   run_experiment.load_gin_configs(
       ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent._num_atoms, 51)
   support = self.evaluate(agent._support)
   self.assertEqual(min(support), -10.)
   self.assertEqual(max(support), 10.)
   self.assertEqual(len(support), 51)
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 3)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)
示例#8
0
 def testDefaultRainbowConfig(self):
   """Verify the default Rainbow configuration."""
   FLAGS.agent_name = 'rainbow'
   run_experiment.load_gin_configs(
       ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent._num_atoms, 51)
   support = self.evaluate(agent._support)
   self.assertEqual(min(support), -10.)
   self.assertEqual(max(support), 10.)
   self.assertEqual(len(support), 51)
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 3)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)