Пример #1
0
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    config = policy_trainer.TrainerConfig(root_dir="")
    data_transformer = create_data_transformer(config.data_transformer_ctor,
                                               env.observation_spec())
    config.data_transformer = data_transformer
    observation_spec = data_transformer.transformed_observation_spec
    common.set_transformed_observation_spec(observation_spec)
    algorithm = algorithm_ctor(
        observation_spec=observation_spec,
        action_spec=env.action_spec(),
        config=config)
    try:
        policy_trainer.play(
            FLAGS.root_dir,
            env,
            algorithm,
            checkpoint_step=FLAGS.checkpoint_step or "latest",
            epsilon_greedy=FLAGS.epsilon_greedy,
            num_episodes=FLAGS.num_episodes,
            max_episode_length=FLAGS.max_episode_length,
            sleep_time_per_step=FLAGS.sleep_time_per_step,
            record_file=FLAGS.record_file,
            ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split(
                ",") if FLAGS.ignored_parameter_prefixes else [])
    finally:
        env.close()
Пример #2
0
    def __init__(self, config: TrainerConfig):
        """

        Args:
            config (TrainerConfig): configuration used to construct this trainer
        """
        super().__init__(config)

        self._envs = []
        self._num_env_steps = config.num_env_steps
        self._num_iterations = config.num_iterations
        assert (self._num_iterations + self._num_env_steps > 0
                and self._num_iterations * self._num_env_steps == 0), \
            "Must provide #iterations or #env_steps exclusively for training!"
        self._trainer_progress.set_termination_criterion(
            self._num_iterations, self._num_env_steps)

        self._num_eval_episodes = config.num_eval_episodes
        alf.summary.should_summarize_output(config.summarize_output)

        env = self._create_environment(random_seed=self._random_seed)
        logging.info("observation_spec=%s" %
                     pprint.pformat(env.observation_spec()))
        logging.info("action_spec=%s" % pprint.pformat(env.action_spec()))
        common.set_global_env(env)

        data_transformer = create_data_transformer(
            config.data_transformer_ctor, env.observation_spec())
        self._config.data_transformer = data_transformer
        observation_spec = data_transformer.transformed_observation_spec
        common.set_transformed_observation_spec(observation_spec)

        self._algorithm = self._algorithm_ctor(
            observation_spec=observation_spec,
            action_spec=env.action_spec(),
            env=env,
            config=self._config,
            debug_summaries=self._debug_summaries)

        # Create an unwrapped env to expose subprocess gin confs which otherwise
        # will be marked as "inoperative". This env should be created last.
        # DO NOT register this env in self._envs because AsyncOffPolicyTrainer
        # will use all self._envs to init AsyncOffPolicyDriver!
        if self._evaluate or isinstance(
                env,
                alf.environments.parallel_environment.ParallelAlfEnvironment):
            self._unwrapped_env = self._create_environment(
                nonparallel=True,
                random_seed=self._random_seed,
                register=False)
        else:
            self._unwrapped_env = None
        self._eval_env = None
        self._eval_metrics = None
        self._eval_summary_writer = None
        if self._evaluate:
            self._eval_env = self._unwrapped_env
            self._eval_metrics = [
                alf.metrics.AverageReturnMetric(
                    buffer_size=self._num_eval_episodes,
                    reward_shape=self._eval_env.reward_spec().shape),
                alf.metrics.AverageEpisodeLengthMetric(
                    buffer_size=self._num_eval_episodes),
                alf.metrics.AverageEnvInfoMetric(
                    example_env_info=self._eval_env.reset().env_info,
                    batch_size=self._eval_env.batch_size,
                    buffer_size=self._num_eval_episodes)
            ]
            self._eval_summary_writer = alf.summary.create_summary_writer(
                self._eval_dir, flush_secs=config.summaries_flush_secs)