示例#1
0
文件: train.py 项目: slin70/batch_rl
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs')
    create_agent_fn = functools.partial(create_agent,
                                        replay_data_dir=replay_data_dir)
    runner = run_experiment.FixedReplayRunner(FLAGS.base_dir, create_agent_fn)
    runner.run_experiment()
示例#2
0
def main(unused_argv):
    path, split = osp.split(FLAGS.exp_dir)
    path, game = osp.split(path)
    gin.bind_parameter('atari_lib.create_atari_environment.game_name', game)
    if FLAGS.use_preference_rewards:
        training_log_path = create_logs_for_training(FLAGS)
        agent_name = "_".join([
            FLAGS.agent_name, FLAGS.preference_model_type,
            FLAGS.reward_model_type
        ])
        FLAGS.replay_dir = training_log_path
    else:
        raise NotImplementedError
    tf.logging.set_verbosity(tf.logging.INFO)
    base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs')
    create_agent_fn = functools.partial(create_agent,
                                        replay_data_dir=replay_data_dir)
    runner = run_experiment.FixedReplayRunner(
        osp.join(FLAGS.exp_dir, agent_name), create_agent_fn)
    runner.run_experiment()
    pack_agents(FLAGS)
示例#3
0
def main(configs):
    tf.logging.set_verbosity(tf.logging.INFO)
    base_run_experiment.load_gin_configs(configs.gin_files, configs.gin_bindings)
    replay_data_dir = os.path.join(configs.replay_dir, "replay_logs")
    create_agent_fn = functools.partial(
        create_agent,
        replay_data_dir=replay_data_dir,
        agent_name=configs.agent_name,
        init_checkpoint_dir=configs.init_checkpoint_dir,
    )
    create_environment_fn = functools.partial(create_environment)
    runner = run_experiment.FixedReplayRunner(
        configs.base_dir, create_agent_fn, create_environment_fn=create_environment_fn
    )

    dataset_path = os.path.join(
        os.path.realpath("."), "data/processed/v5_dataset/test_dataset_users/"
    )
    chkpt_path = os.path.join(
        os.path.realpath("."), "models/reward_pred_v0_model/release/80_input"
    )
    runner.set_offline_evaluation(dataset_path, chkpt_path)

    runner.run_experiment()