def testCreateRainbowAgent(self): FLAGS.agent_name = 'rainbow' with mock.patch.object(train, 'rainbow_agent') as mock_rainbow_agent: def mock_fn(unused_sess, num_actions): return num_actions * 10 mock_rainbow_agent.RainbowAgent.side_effect = mock_fn environment = mock.Mock() environment.action_space.n = 7 self.assertEqual(70, train.create_agent(self.test_session(), environment))
def testCreateRainbowAgent(self): FLAGS.agent_name = 'rainbow' with mock.patch.object(train, 'rainbow_agent') as mock_rainbow_agent: def mock_fn(unused_sess, num_actions): return num_actions * 10 mock_rainbow_agent.RainbowAgent.side_effect = mock_fn environment = mock.Mock() environment.action_space.n = 7 self.assertEqual( 70, train.create_agent(self.test_session(), environment))
def testCreateDQNAgent(self): FLAGS.agent_name = 'dqn' with mock.patch.object(train, 'dqn_agent') as mock_dqn_agent: def mock_fn(unused_sess, num_actions, summary_writer): del summary_writer return num_actions * 10 mock_dqn_agent.DQNAgent.side_effect = mock_fn environment = mock.Mock() environment.action_space.n = 7 self.assertEqual(70, train.create_agent(self.test_session(), environment))
def testCreateDQNAgent(self): FLAGS.agent_name = 'dqn' with mock.patch.object(train, 'dqn_agent') as mock_dqn_agent: def mock_fn(unused_sess, num_actions, summary_writer): del summary_writer return num_actions * 10 mock_dqn_agent.DQNAgent.side_effect = mock_fn environment = mock.Mock() environment.action_space.n = 7 self.assertEqual( 70, train.create_agent(self.test_session(), environment))
def testDefaultDQNConfig(self): """Verify the default DQN configuration.""" FLAGS.agent_name = 'dqn' run_experiment.load_gin_configs( ['dopamine/agents/dqn/configs/dqn.gin'], []) agent = train.create_agent(self.test_session(), run_experiment.create_atari_environment('Pong')) self.assertEqual(agent.gamma, 0.99) self.assertEqual(agent.update_horizon, 1) self.assertEqual(agent.min_replay_history, 20000) self.assertEqual(agent.update_period, 4) self.assertEqual(agent.target_update_period, 8000) self.assertEqual(agent.epsilon_train, 0.01) self.assertEqual(agent.epsilon_eval, 0.001) self.assertEqual(agent.epsilon_decay_period, 250000) self.assertEqual(agent._replay.memory._replay_capacity, 1000000) self.assertEqual(agent._replay.memory._batch_size, 32)
def testDefaultRainbowConfig(self): """Verify the default Rainbow configuration.""" FLAGS.agent_name = 'rainbow' run_experiment.load_gin_configs( ['dopamine/agents/rainbow/configs/rainbow.gin'], []) agent = train.create_agent(self.test_session(), run_experiment.create_atari_environment('Pong')) self.assertEqual(agent._num_atoms, 51) support = self.evaluate(agent._support) self.assertEqual(min(support), -10.) self.assertEqual(max(support), 10.) self.assertEqual(len(support), 51) self.assertEqual(agent.gamma, 0.99) self.assertEqual(agent.update_horizon, 3) self.assertEqual(agent.min_replay_history, 20000) self.assertEqual(agent.update_period, 4) self.assertEqual(agent.target_update_period, 8000) self.assertEqual(agent.epsilon_train, 0.01) self.assertEqual(agent.epsilon_eval, 0.001) self.assertEqual(agent.epsilon_decay_period, 250000) self.assertEqual(agent._replay.memory._replay_capacity, 1000000) self.assertEqual(agent._replay.memory._batch_size, 32)