def main(_): seed = common.set_random_seed(FLAGS.random_seed) gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter( 'TrainerConfig.algorithm_ctor').scoped_configurable_fn env = create_environment(nonparallel=True, seed=seed) env.reset() common.set_global_env(env) config = policy_trainer.TrainerConfig(root_dir="") data_transformer = create_data_transformer(config.data_transformer_ctor, env.observation_spec()) config.data_transformer = data_transformer observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) algorithm = algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), config=config) try: policy_trainer.play( FLAGS.root_dir, env, algorithm, checkpoint_step=FLAGS.checkpoint_step or "latest", epsilon_greedy=FLAGS.epsilon_greedy, num_episodes=FLAGS.num_episodes, max_episode_length=FLAGS.max_episode_length, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split( ",") if FLAGS.ignored_parameter_prefixes else []) finally: env.close()
def __init__(self, config: TrainerConfig): """ Args: config (TrainerConfig): configuration used to construct this trainer """ super().__init__(config) self._envs = [] self._num_env_steps = config.num_env_steps self._num_iterations = config.num_iterations assert (self._num_iterations + self._num_env_steps > 0 and self._num_iterations * self._num_env_steps == 0), \ "Must provide #iterations or #env_steps exclusively for training!" self._trainer_progress.set_termination_criterion( self._num_iterations, self._num_env_steps) self._num_eval_episodes = config.num_eval_episodes alf.summary.should_summarize_output(config.summarize_output) env = self._create_environment(random_seed=self._random_seed) logging.info("observation_spec=%s" % pprint.pformat(env.observation_spec())) logging.info("action_spec=%s" % pprint.pformat(env.action_spec())) common.set_global_env(env) data_transformer = create_data_transformer( config.data_transformer_ctor, env.observation_spec()) self._config.data_transformer = data_transformer observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) self._algorithm = self._algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), env=env, config=self._config, debug_summaries=self._debug_summaries) # Create an unwrapped env to expose subprocess gin confs which otherwise # will be marked as "inoperative". This env should be created last. # DO NOT register this env in self._envs because AsyncOffPolicyTrainer # will use all self._envs to init AsyncOffPolicyDriver! if self._evaluate or isinstance( env, alf.environments.parallel_environment.ParallelAlfEnvironment): self._unwrapped_env = self._create_environment( nonparallel=True, random_seed=self._random_seed, register=False) else: self._unwrapped_env = None self._eval_env = None self._eval_metrics = None self._eval_summary_writer = None if self._evaluate: self._eval_env = self._unwrapped_env self._eval_metrics = [ alf.metrics.AverageReturnMetric( buffer_size=self._num_eval_episodes, reward_shape=self._eval_env.reward_spec().shape), alf.metrics.AverageEpisodeLengthMetric( buffer_size=self._num_eval_episodes), alf.metrics.AverageEnvInfoMetric( example_env_info=self._eval_env.reset().env_info, batch_size=self._eval_env.batch_size, buffer_size=self._num_eval_episodes) ] self._eval_summary_writer = alf.summary.create_summary_writer( self._eval_dir, flush_secs=config.summaries_flush_secs)