def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs') create_agent_fn = functools.partial(create_agent, replay_data_dir=replay_data_dir) runner = run_experiment.FixedReplayRunner(FLAGS.base_dir, create_agent_fn) runner.run_experiment()
def main(unused_argv): path, split = osp.split(FLAGS.exp_dir) path, game = osp.split(path) gin.bind_parameter('atari_lib.create_atari_environment.game_name', game) if FLAGS.use_preference_rewards: training_log_path = create_logs_for_training(FLAGS) agent_name = "_".join([ FLAGS.agent_name, FLAGS.preference_model_type, FLAGS.reward_model_type ]) FLAGS.replay_dir = training_log_path else: raise NotImplementedError tf.logging.set_verbosity(tf.logging.INFO) base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs') create_agent_fn = functools.partial(create_agent, replay_data_dir=replay_data_dir) runner = run_experiment.FixedReplayRunner( osp.join(FLAGS.exp_dir, agent_name), create_agent_fn) runner.run_experiment() pack_agents(FLAGS)
def main(configs): tf.logging.set_verbosity(tf.logging.INFO) base_run_experiment.load_gin_configs(configs.gin_files, configs.gin_bindings) replay_data_dir = os.path.join(configs.replay_dir, "replay_logs") create_agent_fn = functools.partial( create_agent, replay_data_dir=replay_data_dir, agent_name=configs.agent_name, init_checkpoint_dir=configs.init_checkpoint_dir, ) create_environment_fn = functools.partial(create_environment) runner = run_experiment.FixedReplayRunner( configs.base_dir, create_agent_fn, create_environment_fn=create_environment_fn ) dataset_path = os.path.join( os.path.realpath("."), "data/processed/v5_dataset/test_dataset_users/" ) chkpt_path = os.path.join( os.path.realpath("."), "models/reward_pred_v0_model/release/80_input" ) runner.set_offline_evaluation(dataset_path, chkpt_path) runner.run_experiment()