def main(_): seed = common.set_random_seed(FLAGS.random_seed) gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter( 'TrainerConfig.algorithm_ctor').scoped_configurable_fn env = create_environment(nonparallel=True, seed=seed) env.reset() common.set_global_env(env) config = policy_trainer.TrainerConfig(root_dir="") data_transformer = create_data_transformer(config.data_transformer_ctor, env.observation_spec()) config.data_transformer = data_transformer observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) algorithm = algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), config=config) try: policy_trainer.play( FLAGS.root_dir, env, algorithm, checkpoint_step=FLAGS.checkpoint_step or "latest", epsilon_greedy=FLAGS.epsilon_greedy, num_episodes=FLAGS.num_episodes, max_episode_length=FLAGS.max_episode_length, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split( ",") if FLAGS.ignored_parameter_prefixes else []) finally: env.close()
def play(root_dir, algorithm_ctor): """Play using the latest checkpoint under `train_dir`. Args: root_dir (str): directory where checkpoints stores algorithm_ctor (Callable): callable that create an algorithm parameter value is bind with `Trainer.algorithm_ctor`, just config `Trainer.algorithm_ctor` when using with gin configuration """ env = create_environment(num_parallel_environments=1) algorithm = algorithm_ctor(env) policy_trainer.play(root_dir, env, algorithm)
def main(_): gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter('TrainerConfig.algorithm_ctor') env = create_environment(num_parallel_environments=1) algorithm = algorithm_ctor(env) policy_trainer.play(FLAGS.root_dir, env, algorithm, checkpoint_name=FLAGS.checkpoint_name, greedy_predict=FLAGS.greedy_predict, random_seed=FLAGS.random_seed, num_episodes=FLAGS.num_episodes, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file)
def main(_): seed = common.set_random_seed(FLAGS.random_seed, not FLAGS.use_tf_functions) gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter( 'TrainerConfig.algorithm_ctor').scoped_configurable_fn env = create_environment(nonparallel=True, seed=seed) env.reset() common.set_global_env(env) algorithm = algorithm_ctor(observation_spec=env.observation_spec(), action_spec=env.action_spec()) policy_trainer.play(FLAGS.root_dir, env, algorithm, checkpoint_name=FLAGS.checkpoint_name, epsilon_greedy=FLAGS.epsilon_greedy, num_episodes=FLAGS.num_episodes, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, use_tf_functions=FLAGS.use_tf_functions) env.pyenv.close()