Ejemplo n.º 1
0
 def testOverrideRunnerParams(self):
     """Test DQNAgent configuration using the default gin config."""
     tf.logging.info('####### Training the DQN agent #####')
     tf.logging.info('####### DQN base_dir: {}'.format(FLAGS.base_dir))
     FLAGS.agent_name = 'dqn'
     FLAGS.gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
     FLAGS.gin_bindings = [
         'TrainRunner.base_dir = "{}"'.format(FLAGS.base_dir),
         'Runner.log_every_n = 1729',
         'WrappedReplayBuffer.replay_capacity = 100'  # To prevent OOM.
     ]
     run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
     runner = run_experiment.TrainRunner(create_agent_fn=train.create_agent)
     self.assertEqual(runner._base_dir, FLAGS.base_dir)
     self.assertEqual(runner._log_every_n, 1729)
     shutil.rmtree(FLAGS.base_dir)
Ejemplo n.º 2
0
def create_runner(base_dir, create_agent_fn):
  """Creates an experiment Runner.

  Args:
    base_dir: str, base directory for hosting all subdirectories.
    create_agent_fn: A function that takes as args a Tensorflow session and an
     Atari 2600 Gym environment, and returns an agent.

  Returns:
    runner: A `run_experiment.Runner` like object.

  Raises:
    ValueError: When an unknown schedule is encountered.
  """
  assert base_dir is not None
  # Continuously runs training and evaluation until max num_iterations is hit.
  if FLAGS.schedule == 'continuous_train_and_eval':
    return run_experiment.Runner(base_dir, create_agent_fn, create_pacman_environment)
  # Continuously runs training until max num_iterations is hit.
  elif FLAGS.schedule == 'continuous_train':
    return run_experiment.TrainRunner(base_dir, create_agent_fn, create_pacman_environment)
  else:
    raise ValueError('Unknown schedule: {}'.format(FLAGS.schedule))