Ejemplo n.º 1
0
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    config = policy_trainer.TrainerConfig(root_dir="")
    data_transformer = create_data_transformer(config.data_transformer_ctor,
                                               env.observation_spec())
    config.data_transformer = data_transformer
    observation_spec = data_transformer.transformed_observation_spec
    common.set_transformed_observation_spec(observation_spec)
    algorithm = algorithm_ctor(
        observation_spec=observation_spec,
        action_spec=env.action_spec(),
        config=config)
    try:
        policy_trainer.play(
            FLAGS.root_dir,
            env,
            algorithm,
            checkpoint_step=FLAGS.checkpoint_step or "latest",
            epsilon_greedy=FLAGS.epsilon_greedy,
            num_episodes=FLAGS.num_episodes,
            max_episode_length=FLAGS.max_episode_length,
            sleep_time_per_step=FLAGS.sleep_time_per_step,
            record_file=FLAGS.record_file,
            ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split(
                ",") if FLAGS.ignored_parameter_prefixes else [])
    finally:
        env.close()
Ejemplo n.º 2
0
    def __init__(self, config: TrainerConfig):
        """

        Args:
            config (TrainerConfig): configuration used to construct this trainer
        """
        Trainer._trainer_progress = _TrainerProgress()
        root_dir = os.path.expanduser(config.root_dir)
        os.makedirs(root_dir, exist_ok=True)
        logging.get_absl_handler().use_absl_log_file(log_dir=root_dir)
        self._root_dir = root_dir
        self._train_dir = os.path.join(root_dir, 'train')
        self._eval_dir = os.path.join(root_dir, 'eval')

        self._algorithm_ctor = config.algorithm_ctor
        self._algorithm = None

        self._num_checkpoints = config.num_checkpoints
        self._checkpointer = None

        self._evaluate = config.evaluate
        self._eval_interval = config.eval_interval

        self._summary_interval = config.summary_interval
        self._summaries_flush_secs = config.summaries_flush_secs
        self._summary_max_queue = config.summary_max_queue
        self._debug_summaries = config.debug_summaries
        self._summarize_grads_and_vars = config.summarize_grads_and_vars
        self._config = config

        self._random_seed = common.set_random_seed(config.random_seed)
Ejemplo n.º 3
0
    def initialize(self):
        """Initializes the Trainer."""
        self._random_seed = common.set_random_seed(self._random_seed,
                                                   not self._use_tf_functions)

        tf.config.experimental_run_functions_eagerly(
            not self._use_tf_functions)
        env = self._create_environment(random_seed=self._random_seed)
        common.set_global_env(env)

        self._algorithm = self._algorithm_ctor(
            observation_spec=env.observation_spec(),
            action_spec=env.action_spec(),
            debug_summaries=self._debug_summaries)
        self._algorithm.set_summary_settings(
            summarize_grads_and_vars=self._summarize_grads_and_vars,
            summarize_action_distributions=self._config.
            summarize_action_distributions)
        self._algorithm.use_rollout_state = self._config.use_rollout_state

        self._driver = self._init_driver()

        # Create an unwrapped env to expose subprocess gin confs which otherwise
        # will be marked as "inoperative". This env should be created last.
        # DO NOT register this env in self._envs because AsyncOffPolicyTrainer
        # will use all self._envs to init AsyncOffPolicyDriver!
        self._unwrapped_env = self._create_environment(
            nonparallel=True, random_seed=self._random_seed, register=False)
        if self._evaluate:
            self._eval_env = self._unwrapped_env
Ejemplo n.º 4
0
Archivo: play.py Proyecto: runjerry/alf
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed,
                                  not FLAGS.use_tf_functions)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    algorithm = algorithm_ctor(observation_spec=env.observation_spec(),
                               action_spec=env.action_spec())
    policy_trainer.play(FLAGS.root_dir,
                        env,
                        algorithm,
                        checkpoint_name=FLAGS.checkpoint_name,
                        epsilon_greedy=FLAGS.epsilon_greedy,
                        num_episodes=FLAGS.num_episodes,
                        sleep_time_per_step=FLAGS.sleep_time_per_step,
                        record_file=FLAGS.record_file,
                        use_tf_functions=FLAGS.use_tf_functions)
    env.pyenv.close()
Ejemplo n.º 5
0
 def setUp(self):
     common.set_random_seed(1)
     alf.summary.reset_global_counter()