コード例 #1
0
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    config = policy_trainer.TrainerConfig(root_dir="")
    data_transformer = create_data_transformer(config.data_transformer_ctor,
                                               env.observation_spec())
    config.data_transformer = data_transformer
    observation_spec = data_transformer.transformed_observation_spec
    common.set_transformed_observation_spec(observation_spec)
    algorithm = algorithm_ctor(
        observation_spec=observation_spec,
        action_spec=env.action_spec(),
        config=config)
    try:
        policy_trainer.play(
            FLAGS.root_dir,
            env,
            algorithm,
            checkpoint_step=FLAGS.checkpoint_step or "latest",
            epsilon_greedy=FLAGS.epsilon_greedy,
            num_episodes=FLAGS.num_episodes,
            max_episode_length=FLAGS.max_episode_length,
            sleep_time_per_step=FLAGS.sleep_time_per_step,
            record_file=FLAGS.record_file,
            ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split(
                ",") if FLAGS.ignored_parameter_prefixes else [])
    finally:
        env.close()
コード例 #2
0
ファイル: main.py プロジェクト: LiuQiangOpenMind/alf
def train_eval(root_dir):
    """Train and evaluate algorithm

    Args:
        root_dir (str): directory for saving summary and checkpoints
    """

    trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir)
    trainer = trainer_conf.create_trainer()
    trainer.initialize()
    trainer.train()
コード例 #3
0
def train_eval(ml_type, root_dir):
    """Train and evaluate algorithm

    Args:
        ml_type (str): type of machine learning task, 'rl' or 'sl'
        root_dir (str): directory for saving summary and checkpoints
    """
    trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir)
    if ml_type == 'rl':
        trainer = policy_trainer.RLTrainer(trainer_conf)
    elif ml_type == 'sl':
        trainer = policy_trainer.SLTrainer(trainer_conf)
    else:
        raise ValueError("Unsupported ml_type: %s" % ml_type)

    trainer.train()