Пример #1
0
 def testLoadGinConfigs(self, mock_parse_config_files_and_bindings):
     gin_files = ['file1', 'file2', 'file3']
     gin_bindings = ['binding1', 'binding2']
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     self.assertEqual(1, mock_parse_config_files_and_bindings.call_count)
     mock_args, mock_kwargs = mock_parse_config_files_and_bindings.call_args
     self.assertEqual(gin_files, mock_args[0])
     self.assertEqual(gin_bindings, mock_kwargs['bindings'])
     self.assertFalse(mock_kwargs['skip_unknown'])
Пример #2
0
 def testLoadGinConfigs(self, mock_parse_config_files_and_bindings):
   gin_files = ['file1', 'file2', 'file3']
   gin_bindings = ['binding1', 'binding2']
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   self.assertEqual(1, mock_parse_config_files_and_bindings.call_count)
   mock_args, mock_kwargs = mock_parse_config_files_and_bindings.call_args
   self.assertEqual(gin_files, mock_args[0])
   self.assertEqual(gin_bindings, mock_kwargs['bindings'])
   self.assertFalse(mock_kwargs['skip_unknown'])
Пример #3
0
def launch_experiment(create_runner_fn, create_agent_fn):
  """Launches the experiment.

  Args:
    create_runner_fn: A function that takes as args a base directory and a
      function for creating an agent and returns a `Runner`-like object.
    create_agent_fn: A function that takes as args a Tensorflow session and an  environment, and returns an agent.
  """
  run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
  runner = create_runner_fn(FLAGS.base_dir, create_agent_fn)
  runner.run_experiment()
Пример #4
0
def launch_experiment(create_runner_fn, create_agent_fn):
  """Launches the experiment.

  Args:
    create_runner_fn: A function that takes as args a base directory and a
      function for creating an agent and returns a `Runner`-like object.
    create_agent_fn: A function that takes as args a Tensorflow session and an
     Atari 2600 Gym environment, and returns an agent.
  """
  run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
  runner = create_runner_fn(FLAGS.base_dir, create_agent_fn)
  runner.run_experiment()
Пример #5
0
 def testDefaultGinRainbow(self):
     """Test RainbowAgent default configuration using default gin."""
     tf.logging.info('####### Training the RAINBOW agent #####')
     tf.logging.info('####### RAINBOW base_dir: {}'.format(FLAGS.base_dir))
     FLAGS.agent_name = 'rainbow'
     FLAGS.gin_files = ['dopamine/agents/rainbow/configs/rainbow.gin']
     FLAGS.gin_bindings = [
         'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
     ]
     run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
     runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
     self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
     self.assertNear(0.0000625, runner._agent.optimizer._lr, 0.0001)
     shutil.rmtree(FLAGS.base_dir)
 def testDefaultGinDqn(self):
   """Test DQNAgent configuration using the default gin config."""
   tf.logging.info('####### Training the DQN agent #####')
   tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'dqn'
   FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
   FLAGS.gin_bindings = [
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = dopamine.atari.runner.Runner(FLAGS.base_dir, train.create_agent)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertNear(0.00025, runner._agent.optimizer._learning_rate, 0.0001)
   shutil.rmtree(FLAGS.base_dir)
Пример #7
0
 def testDefaultGinDqn(self):
   """Test DQNAgent configuration using the default gin config."""
   tf.logging.info('####### Training the DQN agent #####')
   tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'dqn'
   FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
   FLAGS.gin_bindings = [
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertNear(0.00025, runner._agent.optimizer._learning_rate, 0.0001)
   shutil.rmtree(FLAGS.base_dir)
 def testOverrideRunnerParams(self):
   """Test DQNAgent configuration using the default gin config."""
   tf.logging.info('####### Training the DQN agent #####')
   tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'dqn'
   FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
   FLAGS.gin_bindings = [
       'TrainRunner.base_dir = "{}"'.format(FLAGS.base_dir),
       'Runner.log_every_n = 1729',
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = dopamine.atari.train_runner.TrainRunner(create_agent_fn=train.create_agent)
   self.assertEqual(runner._base_dir, FLAGS.base_dir)
   self.assertEqual(runner._log_every_n, 1729)
   shutil.rmtree(FLAGS.base_dir)
Пример #9
0
 def testDefaultGinRainbow(self):
   """Test RainbowAgent default configuration using default gin."""
   tf.logging.info('####### Training the RAINBOW agent #####')
   tf.logging.info('####### RAINBOW base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'rainbow'
   FLAGS.gin_files = [
       'dopamine/agents/rainbow/configs/rainbow.gin'
   ]
   FLAGS.gin_bindings = [
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
   self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
   self.assertNear(0.0000625, runner._agent.optimizer._lr, 0.0001)
   shutil.rmtree(FLAGS.base_dir)
Пример #10
0
 def testOverrideRunnerParams(self):
   """Test DQNAgent configuration using the default gin config."""
   tf.logging.info('####### Training the DQN agent #####')
   tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'dqn'
   FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
   FLAGS.gin_bindings = [
       'TrainRunner.base_dir = "{}"'.format(FLAGS.base_dir),
       'Runner.log_every_n = 1729',
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = run_experiment.TrainRunner(create_agent_fn=train.create_agent)
   self.assertEqual(runner._base_dir, FLAGS.base_dir)
   self.assertEqual(runner._log_every_n, 1729)
   shutil.rmtree(FLAGS.base_dir)
  def testOverrideGinDqn(self):
    """Test DQNAgent configuration overridden with AdamOptimizer."""
    tf.logging.info('####### Training the DQN agent #####')
    tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
    FLAGS.agent_name = 'dqn'
    FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
    FLAGS.gin_bindings = [
        'DQNAgent.optimizer = @tf.train.AdamOptimizer()',
        'tf.train.AdamOptimizer.learning_rate = 100',
        'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
    ]

    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    runner = dopamine.atari.runner.Runner(FLAGS.base_dir, train.create_agent)
    self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
    self.assertEqual(100, runner._agent.optimizer._lr)
    shutil.rmtree(FLAGS.base_dir)
 def testDefaultGinImplicitQuantileIcml(self):
   """Test default ImplicitQuantile configuration using ICML gin."""
   tf.logging.info('###### Training the Implicit Quantile agent #####')
   FLAGS.agent_name = 'implicit_quantile'
   FLAGS.base_dir = os.path.join(
       '/tmp/dopamine_tests',
       datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
   tf.logging.info('###### IQN base dir: {}'.format(FLAGS.base_dir))
   FLAGS.gin_files = ['dopamine/agents/'
                      'implicit_quantile/configs/implicit_quantile_icml.gin']
   FLAGS.gin_bindings = [
       'Runner.num_iterations=0',
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = dopamine.atari.runner.Runner(FLAGS.base_dir, train.create_agent)
   self.assertEqual(1000000, runner._agent._replay.memory._replay_capacity)
   shutil.rmtree(FLAGS.base_dir)
 def testDefaultDQNConfig(self):
   """Verify the default DQN configuration."""
   FLAGS.agent_name = 'dqn'
   run_experiment.load_gin_configs(
       ['dopamine/agents/dqn/configs/dqn.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 1)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)
Пример #14
0
 def testDefaultDQNConfig(self):
   """Verify the default DQN configuration."""
   FLAGS.agent_name = 'dqn'
   run_experiment.load_gin_configs(
       ['dopamine/agents/dqn/configs/dqn.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 1)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)
Пример #15
0
  def testOverrideGinDqn(self):
    """Test DQNAgent configuration overridden with AdamOptimizer."""
    tf.logging.info('####### Training the DQN agent #####')
    tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
    FLAGS.agent_name = 'dqn'
    FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
    FLAGS.gin_bindings = [
        'DQNAgent.optimizer = @tf.train.AdamOptimizer()',
        'tf.train.AdamOptimizer.learning_rate = 100',
        'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
    ]

    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
    self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
    self.assertEqual(100, runner._agent.optimizer._lr)
    shutil.rmtree(FLAGS.base_dir)
Пример #16
0
 def testDefaultGinImplicitQuantileIcml(self):
   """Test default ImplicitQuantile configuration using ICML gin."""
   tf.logging.info('###### Training the Implicit Quantile agent #####')
   FLAGS.agent_name = 'implicit_quantile'
   FLAGS.base_dir = os.path.join(
       '/tmp/dopamine_tests',
       datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
   tf.logging.info('###### IQN base dir: {}'.format(FLAGS.base_dir))
   FLAGS.gin_files = ['dopamine/agents/'
                      'implicit_quantile/configs/implicit_quantile_icml.gin']
   FLAGS.gin_bindings = [
       'Runner.num_iterations=0',
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
   self.assertEqual(1000000, runner._agent._replay.memory._replay_capacity)
   shutil.rmtree(FLAGS.base_dir)
 def testOverrideGinRainbow(self):
   """Test RainbowAgent configuration overridden with RMSPropOptimizer."""
   tf.logging.info('####### Training the RAINBOW agent #####')
   tf.logging.info('####### RAINBOW base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'rainbow'
   FLAGS.gin_files = [
       'dopamine/agents/rainbow/configs/rainbow.gin',
   ]
   FLAGS.gin_bindings = [
       'RainbowAgent.optimizer = @tf.train.RMSPropOptimizer()',
       'tf.train.RMSPropOptimizer.learning_rate = 100',
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = dopamine.atari.runner.Runner(FLAGS.base_dir, train.create_agent)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertEqual(100, runner._agent.optimizer._learning_rate)
   shutil.rmtree(FLAGS.base_dir)
Пример #18
0
 def testOverrideGinRainbow(self):
   """Test RainbowAgent configuration overridden with RMSPropOptimizer."""
   tf.logging.info('####### Training the RAINBOW agent #####')
   tf.logging.info('####### RAINBOW base_dir: {}'.format(FLAGS.base_dir))
   FLAGS.agent_name = 'rainbow'
   FLAGS.gin_files = [
       'dopamine/agents/rainbow/configs/rainbow.gin',
   ]
   FLAGS.gin_bindings = [
       'RainbowAgent.optimizer = @tf.train.RMSPropOptimizer()',
       'tf.train.RMSPropOptimizer.learning_rate = 100',
       'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
   ]
   run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
   runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertEqual(100, runner._agent.optimizer._learning_rate)
   shutil.rmtree(FLAGS.base_dir)
Пример #19
0
 def testOverrideGinImplicitQuantile(self):
     """Test ImplicitQuantile configuration overriding using IQN gin."""
     tf.logging.info('###### Training the Implicit Quantile agent #####')
     FLAGS.agent_name = 'implicit_quantile'
     FLAGS.base_dir = os.path.join(
         '/tmp/dopamine_tests',
         datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
     tf.logging.info('###### IQN base dir: {}'.format(FLAGS.base_dir))
     FLAGS.gin_files = [
         'dopamine/agents/'
         'implicit_quantile/configs/implicit_quantile.gin'
     ]
     FLAGS.gin_bindings = [
         'Runner.num_iterations=0',
         'WrappedPrioritizedReplayBuffer.replay_capacity = 1000',
     ]
     run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
     runner = run_experiment.Runner(FLAGS.base_dir, train.create_agent)
     self.assertEqual(1000, runner._agent._replay.memory._replay_capacity)
     shutil.rmtree(FLAGS.base_dir)
 def testDefaultRainbowConfig(self):
   """Verify the default Rainbow configuration."""
   FLAGS.agent_name = 'rainbow'
   run_experiment.load_gin_configs(
       ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent._num_atoms, 51)
   support = self.evaluate(agent._support)
   self.assertEqual(min(support), -10.)
   self.assertEqual(max(support), 10.)
   self.assertEqual(len(support), 51)
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 3)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)
Пример #21
0
 def testDefaultRainbowConfig(self):
   """Verify the default Rainbow configuration."""
   FLAGS.agent_name = 'rainbow'
   run_experiment.load_gin_configs(
       ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
   agent = train.create_agent(self.test_session(),
                              run_experiment.create_atari_environment('Pong'))
   self.assertEqual(agent._num_atoms, 51)
   support = self.evaluate(agent._support)
   self.assertEqual(min(support), -10.)
   self.assertEqual(max(support), 10.)
   self.assertEqual(len(support), 51)
   self.assertEqual(agent.gamma, 0.99)
   self.assertEqual(agent.update_horizon, 3)
   self.assertEqual(agent.min_replay_history, 20000)
   self.assertEqual(agent.update_period, 4)
   self.assertEqual(agent.target_update_period, 8000)
   self.assertEqual(agent.epsilon_train, 0.01)
   self.assertEqual(agent.epsilon_eval, 0.001)
   self.assertEqual(agent.epsilon_decay_period, 250000)
   self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
   self.assertEqual(agent._replay.memory._batch_size, 32)