Ejemplo n.º 1
0
def main(argv):
  del argv
  if FLAGS.seed is not None and FLAGS.set_seed:
    print('Seed set to %i.' % FLAGS.seed)
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

  (agent_type, log_dir, run_name, create_environment_fn,
   gin_file) = parse_flags()
  print('Flags parsed.')
  print('log_dir = {}'.format(log_dir))

  print(run_name)

  def create_agent_fn(sess, environment, summary_writer):
    """Creates the appropriate agent."""
    if agent_type == 'dqn':
      return dqn_agent.DQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)

    if agent_type == 'al_dqn':
      return al_dqn.ALDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          alpha=FLAGS.shaping_scale,
          persistent=FLAGS.persistent,
          summary_writer=summary_writer)

    if agent_type == 'm_dqn':
      return m_dqn.MunchausenDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          tau=FLAGS.tau,
          alpha=FLAGS.alpha,
          clip_value_min=FLAGS.clip_value_min,
          interact=FLAGS.interact,
          summary_writer=summary_writer)

    if agent_type == 'm_iqn':
      return m_iqn.MunchausenIQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          tau=FLAGS.tau,
          alpha=FLAGS.alpha,
          interact=FLAGS.interact,
          clip_value_min=FLAGS.clip_value_min,
          summary_writer=summary_writer)

    raise ValueError('Wrong agent %s' % agent_type)

  if gin_file:
    load_gin_configs([gin_file], FLAGS.gin_bindings)

  print('lets run!')
  runner = run_experiment.TrainRunner(log_dir, create_agent_fn,
                                      create_environment_fn)

  print('Agent of type %s created.' % agent_type)
  # pylint: disable=protected-access
  for k in sorted(runner._agent.__dict__):
    if not k.startswith('_'):
      print(k, runner._agent.__dict__[k])
  print()

  # pylint: enable=protected-access
  runner.run_experiment()
Ejemplo n.º 2
0
def main(argv):
  del argv
  if FLAGS.seed is not None and FLAGS.set_seed:
    print('Seed set to %i.' % FLAGS.seed)
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

  (agent_type, log_dir, run_name, create_environment_fn,
   gin_file) = parse_flags()
  print('Flags parsed.')
  print('log_dir = {}'.format(log_dir))

  print(run_name)

  def create_agent_fn(sess, environment, summary_writer):
    """Creates the appropriate agent."""
    if agent_type == 'dqn':
      return dqn_agent.DQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'iqn':
      return implicit_quantile_agent.ImplicitQuantileAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'al_dqn':
      return al_dqn.ALDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'al_iqn':
      return al_iqn.ALImplicitQuantileAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'sail_dqn':
      return sail_dqn.SAILDQNAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    elif agent_type == 'sail_iqn':
      return sail_iqn.SAILImplicitQuantileAgent(
          sess=sess,
          num_actions=environment.action_space.n,
          summary_writer=summary_writer)
    else:
      raise ValueError('Wrong agent %s' % agent_type)

  if gin_file:
    load_gin_configs([gin_file], FLAGS.gin_bindings)

  print('lets run!')
  runner = run_experiment.TrainRunner(log_dir, create_agent_fn,
                                      create_environment_fn)

  print('Agent of type %s created.' % agent_type)
  # pylint: disable=protected-access
  for k in sorted(runner._agent.__dict__):
    if not k.startswith('_'):
      print(k, runner._agent.__dict__[k])
  print()

  # pylint: enable=protected-access
  runner.run_experiment()
Ejemplo n.º 3
0
def set_seed(seed=0):
    tf.set_random_seed(seed)