def main(_): seed = common.set_random_seed(FLAGS.random_seed) gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter( 'TrainerConfig.algorithm_ctor').scoped_configurable_fn env = create_environment(nonparallel=True, seed=seed) env.reset() common.set_global_env(env) config = policy_trainer.TrainerConfig(root_dir="") data_transformer = create_data_transformer(config.data_transformer_ctor, env.observation_spec()) config.data_transformer = data_transformer observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) algorithm = algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), config=config) try: policy_trainer.play( FLAGS.root_dir, env, algorithm, checkpoint_step=FLAGS.checkpoint_step or "latest", epsilon_greedy=FLAGS.epsilon_greedy, num_episodes=FLAGS.num_episodes, max_episode_length=FLAGS.max_episode_length, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split( ",") if FLAGS.ignored_parameter_prefixes else []) finally: env.close()
def train_eval(root_dir): """Train and evaluate algorithm Args: root_dir (str): directory for saving summary and checkpoints """ trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir) trainer = trainer_conf.create_trainer() trainer.initialize() trainer.train()
def train_eval(ml_type, root_dir): """Train and evaluate algorithm Args: ml_type (str): type of machine learning task, 'rl' or 'sl' root_dir (str): directory for saving summary and checkpoints """ trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir) if ml_type == 'rl': trainer = policy_trainer.RLTrainer(trainer_conf) elif ml_type == 'sl': trainer = policy_trainer.SLTrainer(trainer_conf) else: raise ValueError("Unsupported ml_type: %s" % ml_type) trainer.train()