Exemplo n.º 1
0
def main(_):
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    # Setup log dir.
    if FLAGS.sub_dir == 'auto':
        sub_dir = utils.get_datetime()
    else:
        sub_dir = FLAGS.sub_dir
    log_dir = os.path.join(
        FLAGS.root_dir,
        FLAGS.env_name,
        FLAGS.identifier,
        FLAGS.agent_name,
        sub_dir,
        str(FLAGS.seed),
    )
    utils.maybe_makedirs(log_dir)
    eval_results = train_eval_offline.train_eval_offline(
        log_dir=log_dir,
        data_file=None,
        agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name],
        env_name=FLAGS.env_name,
        n_train=FLAGS.n_train,
        total_train_steps=FLAGS.total_train_steps,
        n_eval_episodes=FLAGS.n_eval_episodes,
    )

    results_file = os.path.join(log_dir, 'results.npy')
    with tf.io.gfile.GFile(results_file, 'w') as f:
        np.save(f, eval_results)
Exemplo n.º 2
0
def main(_):
  # Setup log dir.
  if FLAGS.sub_dir == 'auto':
    sub_dir = utils.get_datetime()
  else:
    sub_dir = FLAGS.sub_dir
  log_dir = os.path.join(
      FLAGS.root_dir,
      FLAGS.env_name,
      'bc',
      sub_dir,
      str(FLAGS.seed),
  )

  model_arch = ((200,200),)
  opt_params = (('adam', 5e-4),)
  utils.maybe_makedirs(log_dir)
  train_eval_offline.train_eval_offline(
      log_dir=log_dir,
      data_file=None,
      agent_module=agents.AGENT_MODULES_DICT['bc'],
      env_name=FLAGS.env_name,
      n_train=FLAGS.n_train,
      total_train_steps=FLAGS.total_train_steps,
      n_eval_episodes=1,
      model_params=model_arch,
      optimizers=opt_params,
  )
Exemplo n.º 3
0
def main(_):
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    # Setup log dir.
    if FLAGS.sub_dir == 'auto':
        sub_dir = utils.get_datetime()
    else:
        sub_dir = FLAGS.sub_dir
    log_dir = os.path.join(
        FLAGS.root_dir,
        FLAGS.env_name,
        FLAGS.identifier,
        FLAGS.agent_name,
        sub_dir,
        str(FLAGS.seed),
    )
    utils.maybe_makedirs(log_dir)

    model_arch = None
    if FLAGS.model_arch == 0:
        model_arch = ((200, 200), )
    elif FLAGS.model_arch == 1:
        model_arch = ((
            (300, 300),
            (200, 200),
        ), 2)
    else:
        raise ValueError()

    if FLAGS.opt_params == 0:
        opt_params = (('adam', 1e-5), )
    elif FLAGS.opt_params == 1:
        opt_params = (('adam', 1e-3), ('adam', 3e-5), ('adam', 1e-5))
    elif FLAGS.opt_params == 2:
        opt_params = (('adam', 1e-3), ('adam', 3e-4), ('adam', 1e-5))
    elif FLAGS.opt_params == 3:
        opt_params = (('adam', 0e-3), ('adam', 0e-4), ('adam', 0e-5))
    else:
        raise ValueError()

    eval_results = train_eval_offline.train_eval_offline(
        log_dir=log_dir,
        data_file=None,
        agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name],
        env_name=FLAGS.env_name,
        n_train=FLAGS.n_train,
        total_train_steps=FLAGS.total_train_steps,
        n_eval_episodes=FLAGS.n_eval_episodes,
        model_params=model_arch,
        optimizers=opt_params,
        value_penalty=bool(FLAGS.value_penalty),
        behavior_ckpt_file=FLAGS.b_ckpt,
        save_freq=FLAGS.save_freq)

    results_file = os.path.join(log_dir, 'results.npy')
    with tf.io.gfile.GFile(results_file, 'w') as f:
        np.save(f, eval_results)
Exemplo n.º 4
0
def main(_):
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
    sub_dir = FLAGS.sub_dir
    log_dir = os.path.join(
        FLAGS.root_dir,
        FLAGS.env_name,
        FLAGS.data_name,
        sub_dir,
    )
    utils.maybe_makedirs(log_dir)
    config_module = importlib.import_module('{}.{}'.format(
        FLAGS.config_dir, FLAGS.config_file))
    collect_data(log_dir=log_dir,
                 data_config=config_module.get_data_config(
                     FLAGS.env_name, FLAGS.policy_root_dir),
                 n_samples=FLAGS.n_samples,
                 env_name=FLAGS.env_name,
                 n_eval_episodes=FLAGS.n_eval_episodes)
Exemplo n.º 5
0
def main(_):
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
    if FLAGS.sub_dir == "auto":
        sub_dir = utils.get_datetime()
    else:
        sub_dir = FLAGS.sub_dir
    log_dir = os.path.join(
        FLAGS.root_dir,
        FLAGS.env_name,
        FLAGS.agent_name,
        sub_dir,
    )
    utils.maybe_makedirs(log_dir)
    train_eval_online.train_eval_online(
        log_dir=log_dir,
        agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name],
        env_name=FLAGS.env_name,
        total_train_steps=FLAGS.total_train_steps,
        n_eval_episodes=FLAGS.n_eval_episodes,
        eval_target=FLAGS.eval_target,
    )
def main(_):
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
  # Setup data file path.
  data_dir = os.path.join(
      FLAGS.data_root_dir,
      FLAGS.env_name,
      FLAGS.data_name,
      FLAGS.data_sub_dir,
      )
  data_file = os.path.join(
      data_dir, FLAGS.data_file_name)

  # Setup log dir.
  if FLAGS.sub_dir == 'auto':
    sub_dir = utils.get_datetime()
  else:
    sub_dir = FLAGS.sub_dir
  log_dir = os.path.join(
      FLAGS.root_dir,
      FLAGS.env_name,
      FLAGS.data_name,
      'n'+str(FLAGS.n_train),
      FLAGS.agent_name,
      sub_dir,
      str(FLAGS.seed),
      )
  utils.maybe_makedirs(log_dir)
  train_eval_offline.train_eval_offline(
      log_dir=log_dir,
      data_file=data_file,
      agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name],
      env_name=FLAGS.env_name,
      n_train=FLAGS.n_train,
      total_train_steps=FLAGS.total_train_steps,
      n_eval_episodes=FLAGS.n_eval_episodes,
      )
Exemplo n.º 7
0
def main(_):
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    bc_log_dir, sub_dir = train_bc()
    behavior_ckpt_file = os.path.join(bc_log_dir, 'agent_behavior')

    log_dir = os.path.join(
        FLAGS.root_dir,
        FLAGS.env_name,
        FLAGS.agent_name,
        sub_dir,
        str(FLAGS.seed),
    )

    model_arch = ((
        (300, 300),
        (200, 200),
    ), 2)
    opt_params = (('adam', 1e-3), ('adam', 3e-5), ('adam', 1e-5))

    utils.maybe_makedirs(log_dir)
    train_eval_offline.train_eval_offline(
        log_dir=log_dir,
        data_file=None,
        agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name],
        env_name=FLAGS.env_name,
        n_train=FLAGS.n_train,
        total_train_steps=FLAGS.total_train_steps,
        n_eval_episodes=FLAGS.n_eval_episodes,
        model_params=model_arch,
        optimizers=opt_params,
        behavior_ckpt_file=behavior_ckpt_file,
        value_penalty=bool(FLAGS.value_penalty),
        save_freq=FLAGS.save_freq,
        alpha=FLAGS.alpha)