Ejemplo n.º 1
0
def main(_):

    def create_agent(sess, environment, summary_writer=None, memory=None):
        ag = agents[FLAGS.agent](num_actions=environment.action_space.n)
        if memory is not None:
            ag._replay = memory
            ag._replay.replay_capacity = (50000 * 0.2)
        return ag
    
    exp, value = FLAGS.experiment.split('=')
    #Generate the combination 
        
    agent_name = agents[FLAGS.agent].__name__

    gin_file = f'Configs/{FLAGS.agent}_{FLAGS.env}.gin'

    gin.clear_config()
    gin_bindings = get_gin_bindings(exp, agent_name, FLAGS.seed, value, FLAGS.test)
    gin.parse_config_files_and_bindings([gin_file], gin_bindings, skip_unknown=False)
    LOG_PATH = os.path.join(f'{FLAGS.base_path}/{FLAGS.agent}/{FLAGS.env}/{exp}_{value}_online', f'test{FLAGS.seed}')
    print(f"Saving data at {LOG_PATH}")
    agent_runner = run_experiment.TrainRunner(LOG_PATH, create_agent)

    print(f'Training agent {FLAGS.seed}, please be patient, may be a while...')
    agent_runner.run_experiment()
    print('Done training!')
    if FLAGS.test:
        from datetime import datetime
        dt = datetime.now().strftime("%H-%M-%d-%m")
        os.makedirs(f"test_logs/{dt}", exist_ok=True)
        open(f"test_logs/{dt}/{FLAGS.env}_{FLAGS.agent}_{exp}_{FLAGS.type}", 'x').close()
Ejemplo n.º 2
0
def create_exploration_runner(base_dir,
                              create_agent_fn,
                              schedule='continuous_train_and_eval'):
    """Creates an experiment Runner.

  Args:
    base_dir: Base directory for hosting all subdirectories.
    create_agent_fn: A function that takes as args a Tensorflow session and a
     Gym Atari 2600 environment, and returns an agent.
    schedule: string, which type of Runner to use.

  Returns:
    runner: A `run_experiment.Runner` like object.

  Raises:
    ValueError: When an unknown schedule is encountered.
  """
    assert base_dir is not None
    # Continuously runs training and eval till max num_iterations is hit.
    if schedule == 'continuous_train_and_eval':
        return run_experiment.Runner(base_dir, create_agent_fn)
    # Continuously runs training till maximum num_iterations is hit.
    elif schedule == 'continuous_train':
        return run_experiment.TrainRunner(base_dir, create_agent_fn)
    else:
        raise ValueError('Unknown schedule: {}'.format(schedule))
Ejemplo n.º 3
0
def create_runner(base_dir,
                  create_agent_fn,
                  schedule='continuous_train_and_eval'):
    """Creates an experiment Runner.

  TODO(b/): Figure out the right idiom to create a Runner. The current mechanism
  of using a number of flags will not scale and is not elegant.

  Args:
    base_dir: Base directory for hosting all subdirectories.
    create_agent_fn: A function that takes as args a Tensorflow session and a
      Gym Atari 2600 environment, and returns an agent.
    schedule: string, which type of Runner to use.

  Returns:
    runner: A `run_experiment.Runner` like object.

  Raises:
    ValueError: When an unknown schedule is encountered.
  """
    assert base_dir is not None
    # Continuously runs training and eval till max num_iterations is hit.
    if schedule == 'continuous_train_and_eval':
        return run_experiment.Runner(base_dir, create_agent_fn,
                                     atari_lib.create_atari_environment)
    # Continuously runs training till maximum num_iterations is hit.
    elif schedule == 'continuous_train':
        return run_experiment.TrainRunner(base_dir, create_agent_fn,
                                          atari_lib.create_atari_environment)
    else:
        raise ValueError('Unknown schedule: {}'.format(schedule))
Ejemplo n.º 4
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    logging.set_verbosity(logging.INFO)
    base_dir = FLAGS.base_dir
    gin_files = FLAGS.gin_files
    gin_bindings = FLAGS.gin_bindings
    run_experiment.load_gin_configs(gin_files, gin_bindings)
    runner = run_experiment.TrainRunner(base_dir, create_metric_agent)
    runner.run_experiment()
Ejemplo n.º 5
0
 def testOverrideRunnerParams(self):
     """Test DQNAgent configuration using the default gin config."""
     tf.logging.info('####### Training the DQN agent #####')
     tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir))
     gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
     gin_bindings = [
         'TrainRunner.base_dir = "{}"'.format(self._base_dir),
         'Runner.log_every_n = 1729',
         'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
         "create_agent.agent_name = 'dqn'"
     ]
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     runner = run_experiment.TrainRunner(
         create_agent_fn=run_experiment.create_agent,
         create_environment_fn=atari_lib.create_atari_environment)
     self.assertEqual(runner._base_dir, self._base_dir)
     self.assertEqual(runner._log_every_n, 1729)
     shutil.rmtree(self._base_dir)
Ejemplo n.º 6
0
def main(_):
    agent_name = utils.agents[FLAGS.agent].__name__

    def create_agent(sess, environment, summary_writer=None):
        return utils.agents[FLAGS.agent](
            num_actions=environment.action_space.n)

    LOG_PATH = os.path.join(FLAGS.base_path, "baselines",
                            f'{FLAGS.agent}/{FLAGS.env}')
    gin_file = f'./Configs/{FLAGS.agent}_{FLAGS.env}.gin'
    gin_bindings = [f"{agent_name}.seed=1729"]
    gin.clear_config()
    gin.parse_config_files_and_bindings([gin_file],
                                        gin_bindings,
                                        skip_unknown=False)

    agent_runner = run_experiment.TrainRunner(LOG_PATH, create_agent)
    print(f'Training agent, please be patient, may be a while...')
    agent_runner.run_experiment()
    print('Done training!')
Ejemplo n.º 7
0
def main(_):
    def create_agent(sess, environment, summary_writer=None, memory=None):
        ag = utils.agents[FLAGS.agent](num_actions=environment.action_space.n)
        return ag

    path = FLAGS.base_path
    grp = FLAGS.experiment
    values = utils.sample_group(FLAGS.category, grp, FLAGS.sample_seed)

    agent_name = utils.agents[FLAGS.agent].__name__

    if FLAGS.category == "atari_100k":
        gin_file = f'Configs/{FLAGS.agent}_atari_100k.gin'
        gin.clear_config()
        gin_bindings = [f'create_atari_environment.game_name="{FLAGS.env}"']
    else:
        gin_file = f'Configs/{FLAGS.agent}_{FLAGS.env}.gin'
        gin.clear_config()
        gin_bindings = []

    for exp, value in zip(utils.suites[FLAGS.category].groups[grp], values):
        gin_bindings.extend(
            utils.get_gin_bindings(exp, agent_name, FLAGS.rl_seed, value,
                                   False))
    gin.parse_config_files_and_bindings([gin_file],
                                        gin_bindings,
                                        skip_unknown=False)
    LOG_PATH = os.path.join(
        f'{path}/{FLAGS.agent}/{FLAGS.env}/{FLAGS.sample_seed}_{grp}_{utils.repr_values(values)}',
        f'{FLAGS.rl_seed}')
    logging.info(f"Saving data at {LOG_PATH}")
    agent_runner = run_experiment.TrainRunner(LOG_PATH, create_agent)

    logging.info(
        f'Training agent {FLAGS.rl_seed}, please be patient, may be a while...'
    )
    agent_runner.run_experiment()
    logging.info('Done training!')
Ejemplo n.º 8
0
def main(argv):
  del argv
  if FLAGS.seed is not None and FLAGS.set_seed:
    print('Seed set to %i.' % FLAGS.seed)
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

  (agent_type, log_dir, run_name, create_environment_fn,
   gin_file) = parse_flags()
  print('Flags parsed.')
  print('log_dir = {}'.format(log_dir))

  print(run_name)

  def create_agent_fn(sess, environment, summary_writer):
    """Creates the appropriate agent."""
    if agent_type == 'dqn':
      return dqn_agent.DQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'iqn':
      return implicit_quantile_agent.ImplicitQuantileAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'al_dqn':
      return al_dqn.ALDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'al_iqn':
      return al_iqn.ALImplicitQuantileAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'sail_dqn':
      return sail_dqn.SAILDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'sail_iqn':
      return sail_iqn.SAILImplicitQuantileAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    else:
      raise ValueError('Wrong agent %s' % agent_type)

  if gin_file:
    load_gin_configs([gin_file], FLAGS.gin_bindings)

  print('lets run!')
  runner = run_experiment.TrainRunner(log_dir, create_agent_fn,
                                      create_environment_fn)

  print('Agent of type %s created.' % agent_type)
  # pylint: disable=protected-access
  for k in sorted(runner._agent.__dict__):
    if not k.startswith('_'):
      print(k, runner._agent.__dict__[k])
  print()

  # pylint: enable=protected-access
  runner.run_experiment()
Ejemplo n.º 9
0
def main(argv):
  del argv
  if FLAGS.seed is not None and FLAGS.set_seed:
    print('Seed set to %i.' % FLAGS.seed)
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

  (agent_type, log_dir, run_name, create_environment_fn,
   gin_file) = parse_flags()
  print('Flags parsed.')
  print('log_dir = {}'.format(log_dir))

  print(run_name)

  def create_agent_fn(sess, environment, summary_writer):
    """Creates the appropriate agent."""
    if agent_type == 'dqn':
      return dqn_agent.DQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)

    if agent_type == 'al_dqn':
      return al_dqn.ALDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          alpha=FLAGS.shaping_scale,
          persistent=FLAGS.persistent,
          summary_writer=summary_writer)

    if agent_type == 'm_dqn':
      return m_dqn.MunchausenDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          tau=FLAGS.tau,
          alpha=FLAGS.alpha,
          clip_value_min=FLAGS.clip_value_min,
          interact=FLAGS.interact,
          summary_writer=summary_writer)

    if agent_type == 'm_iqn':
      return m_iqn.MunchausenIQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          tau=FLAGS.tau,
          alpha=FLAGS.alpha,
          interact=FLAGS.interact,
          clip_value_min=FLAGS.clip_value_min,
          summary_writer=summary_writer)

    raise ValueError('Wrong agent %s' % agent_type)

  if gin_file:
    load_gin_configs([gin_file], FLAGS.gin_bindings)

  print('lets run!')
  runner = run_experiment.TrainRunner(log_dir, create_agent_fn,
                                      create_environment_fn)

  print('Agent of type %s created.' % agent_type)
  # pylint: disable=protected-access
  for k in sorted(runner._agent.__dict__):
    if not k.startswith('_'):
      print(k, runner._agent.__dict__[k])
  print()

  # pylint: enable=protected-access
  runner.run_experiment()
Ejemplo n.º 10
0
                        nargs='?')
    parser.add_argument('--docker_training', action='store_true')
    parser.set_defaults(docker_training=False)
    args = parser.parse_args()
    env = ObstacleTowerEnv(args.environment_filename,
                           docker_training=args.docker_training,
                           retro=False,
                           timeout_wait=304)

    if env.is_grading():
        from obstacle_tower_env import ObstacleTowerEnv
        from dopamine.agents.rainbow import rainbow_agent
        from dopamine.discrete_domains import run_experiment
        from keepitpossible.common import unity_lib
        runner = run_experiment.TrainRunner(
            base_dir,
            create_agent,
            create_environment_fn=create_otc_environment(env))
        episode_reward = run_evaluation(runner, env)
        print(episode_reward)
    else:
        from obstacle_tower_env import ObstacleTowerEnv
        from dopamine.agents.rainbow import rainbow_agent
        from dopamine.discrete_domains import run_experiment
        from keepitpossible.common import unity_lib
        runner = run_experiment.TrainRunner(
            base_dir,
            create_agent,
            create_environment_fn=create_otc_environment(env))
        while True:
            episode_reward = run_episode(runner)
            print(episode_reward)
                            f"{agent_name}.seed=None"
                        ] if seed is False else [
                            f"{agent_name}.seed={i}",
                            f"{agent_name}.initzer = @{initializer}"
                        ]
                    else:
                        mode = '"' + inits[init]['mode'] + '"'
                        distribution = '"' + inits[init]['distribution'] + '"'
                        gin_bindings = [
                            f"{agent_name}.seed=None"
                        ] if seed is False else [
                            f"{agent_name}.seed={i}",
                            f"{agent_name}.initzer = @{initializer}()",
                            f"{initializer}.scale = 1",
                            f"{initializer}.mode = {mode}",
                            f"{initializer}.distribution = {distribution}"
                        ]

                    gin.clear_config()
                    gin.parse_config_files_and_bindings([gin_file],
                                                        gin_bindings,
                                                        skip_unknown=False)
                    agent_runner = run_experiment.TrainRunner(
                        LOG_PATH, create_agent)

                    print(
                        f'Will train agent {agent} in {env}, run {i}, please be patient, may be a while...'
                    )
                    agent_runner.run_experiment()
                    print('Done training!')
print('Finished!')