def main(_): # Setup log dir. if FLAGS.sub_dir == 'auto': sub_dir = utils.get_datetime() else: sub_dir = FLAGS.sub_dir log_dir = os.path.join( FLAGS.root_dir, FLAGS.env_name, 'bc', sub_dir, str(FLAGS.seed), ) model_arch = ((200,200),) opt_params = (('adam', 5e-4),) utils.maybe_makedirs(log_dir) train_eval_offline.train_eval_offline( log_dir=log_dir, data_file=None, agent_module=agents.AGENT_MODULES_DICT['bc'], env_name=FLAGS.env_name, n_train=FLAGS.n_train, total_train_steps=FLAGS.total_train_steps, n_eval_episodes=1, model_params=model_arch, optimizers=opt_params, )
def main(_): logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) # Setup log dir. if FLAGS.sub_dir == 'auto': sub_dir = utils.get_datetime() else: sub_dir = FLAGS.sub_dir log_dir = os.path.join( FLAGS.root_dir, FLAGS.env_name, FLAGS.identifier, FLAGS.agent_name, sub_dir, str(FLAGS.seed), ) utils.maybe_makedirs(log_dir) eval_results = train_eval_offline.train_eval_offline( log_dir=log_dir, data_file=None, agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name], env_name=FLAGS.env_name, n_train=FLAGS.n_train, total_train_steps=FLAGS.total_train_steps, n_eval_episodes=FLAGS.n_eval_episodes, ) results_file = os.path.join(log_dir, 'results.npy') with tf.io.gfile.GFile(results_file, 'w') as f: np.save(f, eval_results)
def main(_): logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) # Setup log dir. if FLAGS.sub_dir == 'auto': sub_dir = utils.get_datetime() else: sub_dir = FLAGS.sub_dir log_dir = os.path.join( FLAGS.root_dir, FLAGS.env_name, FLAGS.identifier, FLAGS.agent_name, sub_dir, str(FLAGS.seed), ) utils.maybe_makedirs(log_dir) model_arch = None if FLAGS.model_arch == 0: model_arch = ((200, 200), ) elif FLAGS.model_arch == 1: model_arch = (( (300, 300), (200, 200), ), 2) else: raise ValueError() if FLAGS.opt_params == 0: opt_params = (('adam', 1e-5), ) elif FLAGS.opt_params == 1: opt_params = (('adam', 1e-3), ('adam', 3e-5), ('adam', 1e-5)) elif FLAGS.opt_params == 2: opt_params = (('adam', 1e-3), ('adam', 3e-4), ('adam', 1e-5)) elif FLAGS.opt_params == 3: opt_params = (('adam', 0e-3), ('adam', 0e-4), ('adam', 0e-5)) else: raise ValueError() eval_results = train_eval_offline.train_eval_offline( log_dir=log_dir, data_file=None, agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name], env_name=FLAGS.env_name, n_train=FLAGS.n_train, total_train_steps=FLAGS.total_train_steps, n_eval_episodes=FLAGS.n_eval_episodes, model_params=model_arch, optimizers=opt_params, value_penalty=bool(FLAGS.value_penalty), behavior_ckpt_file=FLAGS.b_ckpt, save_freq=FLAGS.save_freq) results_file = os.path.join(log_dir, 'results.npy') with tf.io.gfile.GFile(results_file, 'w') as f: np.save(f, eval_results)
def main(_): logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) if FLAGS.sub_dir == "auto": sub_dir = utils.get_datetime() else: sub_dir = FLAGS.sub_dir log_dir = os.path.join( FLAGS.root_dir, FLAGS.env_name, FLAGS.agent_name, sub_dir, ) utils.maybe_makedirs(log_dir) train_eval_online.train_eval_online( log_dir=log_dir, agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name], env_name=FLAGS.env_name, total_train_steps=FLAGS.total_train_steps, n_eval_episodes=FLAGS.n_eval_episodes, eval_target=FLAGS.eval_target, )
def main(_): logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) # Setup data file path. data_dir = os.path.join( FLAGS.data_root_dir, FLAGS.env_name, FLAGS.data_name, FLAGS.data_sub_dir, ) data_file = os.path.join( data_dir, FLAGS.data_file_name) # Setup log dir. if FLAGS.sub_dir == 'auto': sub_dir = utils.get_datetime() else: sub_dir = FLAGS.sub_dir log_dir = os.path.join( FLAGS.root_dir, FLAGS.env_name, FLAGS.data_name, 'n'+str(FLAGS.n_train), FLAGS.agent_name, sub_dir, str(FLAGS.seed), ) utils.maybe_makedirs(log_dir) train_eval_offline.train_eval_offline( log_dir=log_dir, data_file=data_file, agent_module=agents.AGENT_MODULES_DICT[FLAGS.agent_name], env_name=FLAGS.env_name, n_train=FLAGS.n_train, total_train_steps=FLAGS.total_train_steps, n_eval_episodes=FLAGS.n_eval_episodes, )