def _train_iter(self, iter_num, policy_state, time_step): if not self._config.update_counter_every_mini_batch: common.get_global_counter().assign_add(1) unroll_steps = self._unroll_length * self._envs[0].batch_size max_num_steps = unroll_steps if iter_num == 0 and self._initial_collect_steps != 0: max_num_steps = self._initial_collect_steps with record_time("time/driver_run"): for _ in range((max_num_steps + unroll_steps - 1) // unroll_steps): time_step, policy_state = self._driver.run( max_num_steps=unroll_steps, time_step=time_step, policy_state=policy_state) with record_time("time/train"): # `train_steps` might be different from `max_num_steps`! train_steps = self._algorithm.train( num_updates=self._num_updates_per_train_step, mini_batch_size=self._mini_batch_size, mini_batch_length=self._mini_batch_length, whole_replay_buffer_training=self. _whole_replay_buffer_training, clear_replay_buffer=self._clear_replay_buffer, update_counter_every_mini_batch=self._config. update_counter_every_mini_batch) return time_step, policy_state, train_steps
def _train_iter(self, iter_num, policy_state, time_step): if not self._driver_started: self._driver.start() self._driver_started = True if not self._config.update_counter_every_mini_batch: common.get_global_counter().assign_add(1) with record_time("time/driver_run"): if iter_num == 0 and self._initial_collect_steps != 0: steps = 0 while steps < self._initial_collect_steps: steps += self._driver.run_async() else: self._driver.run_async() with record_time("time/train"): # `train_steps` might be different from `steps`! train_steps = self._algorithm.train( num_updates=self._num_updates_per_train_step, mini_batch_size=self._mini_batch_size, mini_batch_length=self._mini_batch_length, whole_replay_buffer_training=self. _whole_replay_buffer_training, clear_replay_buffer=self._clear_replay_buffer, update_counter_every_mini_batch=self._config. update_counter_every_mini_batch) return time_step, policy_state, train_steps
def _train(self, experience, num_updates, mini_batch_size, mini_batch_length): """Train using experience.""" experience = self.transform_timestep(experience) experience = self.preprocess_experience(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten( self.train_state_spec)) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: common.warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast(tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) common.get_global_counter().assign_add(1) self.training_summary(training_info, loss_info, grads_and_vars) self.metric_summary() train_steps = batch_size * mini_batch_length * num_updates return train_steps
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure( create_ta, self._training_info_spec._replace( info=nest_utils.to_distribution_param_spec( self._training_info_spec.info))) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, next_time_step, next_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) training_info = nest_utils.params_to_distributions( training_info, self._training_info_spec) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._algorithm.summarize_train(training_info, loss_info, grads_and_vars) self._algorithm.summarize_metrics() common.get_global_counter().assign_add(1) return [next_time_step, next_state]
def __init__(self, env, algorithm, observation_transformer: Callable = None, observers=[], metrics=[], training=True, greedy_predict=False, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None): """Create a PolicyDriver. Args: env (TFEnvironment): A TFEnvoronmnet algorithm (OnPolicyAlgorith): The algorithm for training observers (list[Callable]): An optional list of observers that are updated after every step in the environment. Each observer is a callable(time_step.Trajectory). metrics (list[TFStepMetric]): An optiotional list of metrics. training (bool): True for training, false for evaluating greedy_predict (bool): use greedy action for evaluation (i.e. training==False). debug_summaries (bool): A bool to gather debug summaries. summarize_grads_and_vars (bool): If True, gradient and network variable summaries will be written during training. train_step_counter (tf.Variable): An optional counter to increment every time the a new iteration is started. If None, it will use tf.summary.experimental.get_step(). If this is still None, a counter will be created. """ metric_buf_size = max(10, env.batch_size) standard_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=metric_buf_size), tf_metrics.AverageEpisodeLengthMetric(buffer_size=metric_buf_size), ] self._metrics = standard_metrics + metrics self._exp_observers = [] super(PolicyDriver, self).__init__(env, None, observers + self._metrics) self._algorithm = algorithm self._training = training self._greedy_predict = greedy_predict self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._observation_transformer = observation_transformer self._train_step_counter = common.get_global_counter( train_step_counter) self._proc = psutil.Process(os.getpid()) if training: self._policy_state_spec = algorithm.train_state_spec else: self._policy_state_spec = algorithm.predict_state_spec self._initial_state = self.get_initial_policy_state()
def _restore_checkpoint(self): global_step = get_global_counter() checkpointer = tfa_common.Checkpointer( ckpt_dir=os.path.join(self._train_dir, 'algorithm'), algorithm=self._algorithm, metrics=metric_utils.MetricsGroup(self._driver.get_metrics(), 'metrics'), global_step=global_step) checkpointer.initialize_or_restore() self._checkpointer = checkpointer
def summarize_metrics(self): """Generate summaries for metrics `AverageEpisodeLength`, `AverageReturn`...""" if self._metrics: for metric in self._metrics: metric.tf_summaries(train_step=common.get_global_counter(), step_metrics=self._metrics[:2]) mem = tf.py_function(lambda: self._proc.memory_info().rss // 1e6, [], tf.float32, name='memory_usage') if not tf.executing_eagerly(): mem.set_shape(()) tf.summary.scalar(name='memory_usage', data=mem)
def _eval(self): global_step = get_global_counter() with tf.summary.record_if(True): eager_compute( metrics=self._eval_metrics, environment=self._eval_env, state_spec=self._algorithm.predict_state_spec, action_fn=lambda time_step, state: common.algorithm_step( algorithm_step_func=self._algorithm.greedy_predict, time_step=self._algorithm.transform_timestep(time_step), state=state), num_episodes=self._num_eval_episodes, step_metrics=self._driver.get_step_metrics(), train_step=global_step, summary_writer=self._eval_summary_writer, summary_prefix="Metrics") metric_utils.log_metrics(self._eval_metrics)
def _train(self, experience, num_updates, mini_batch_size, mini_batch_length, update_counter_every_mini_batch, should_summarize): """Train using experience.""" experience = nest_utils.params_to_distributions( experience, self.experience_spec) experience = self.transform_timestep(experience) experience = self.preprocess_experience(experience) experience = nest_utils.distributions_to_params(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten( self.train_state_spec)) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: common.warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) scope = get_current_scope() for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): if update_counter_every_mini_batch: common.get_global_counter().assign_add(1) is_last_mini_batch = tf.logical_and( tf.equal(u, num_updates - 1), tf.greater_equal(b + mini_batch_size, batch_size)) do_summary = tf.logical_or(is_last_mini_batch, update_counter_every_mini_batch) common.enable_summary(do_summary) batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) # Tensorflow graph mode loses the original name scope here. We # need to restore the original name scope with tf.name_scope(scope): training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast( tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) if should_summarize: if do_summary: # Putting `if do_summary` under the above `with` statement # does not help. Somehow `if` statement will also lose # the original name scope. with tf.name_scope(scope): self.summarize_train(training_info, loss_info, grads_and_vars) train_steps = batch_size * mini_batch_length * num_updates return train_steps
def play(root_dir, env, algorithm, checkpoint_name=None, greedy_predict=True, random_seed=None, num_episodes=10, sleep_time_per_step=0.01, record_file=None, use_tf_functions=True): """Play using the latest checkpoint under `train_dir`. The following example record the play of a trained model to a mp4 video: ```bash python -m alf.bin.play \ --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \ --num_episodes=1 \ --record_file=ppo_bullet_humanoid.mp4 ``` Args: root_dir (str): same as the root_dir used for `train()` env (TFEnvironment): the environment algorithm (OnPolicyAlgorithm): the training algorithm checkpoint_name (str): name of the checkpoint (e.g. 'ckpt-12800`). If None, the latest checkpoint under train_dir will be used. greedy_predict (bool): use greedy action for evaluation. random_seed (None|int): random seed, a random seed is used if None num_episodes (int): number of episodes to play sleep_time_per_step (float): sleep so many seconds for each step record_file (str): if provided, video will be recorded to a file instead of shown on the screen. use_tf_functions (bool): whether to use tf.function """ root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') if random_seed is not None: random.seed(random_seed) np.random.seed(random_seed) tf.random.set_seed(random_seed) global_step = get_global_counter() driver = OnPolicyDriver(env=env, algorithm=algorithm, training=False, greedy_predict=greedy_predict) ckpt_dir = os.path.join(train_dir, 'algorithm') checkpoint = tf.train.Checkpoint(algorithm=algorithm, metrics=metric_utils.MetricsGroup( driver.get_metrics(), 'metrics'), global_step=global_step) if checkpoint_name is not None: ckpt_path = os.path.join(ckpt_dir, checkpoint_name) else: ckpt_path = tf.train.latest_checkpoint(ckpt_dir) if ckpt_path is not None: logging.info("Restore from checkpoint %s" % ckpt_path) checkpoint.restore(ckpt_path) else: logging.info("Checkpoint is not found at %s" % ckpt_dir) if not use_tf_functions: tf.config.experimental_run_functions_eagerly(True) recorder = None if record_file is not None: recorder = VideoRecorder(env.pyenv.envs[0], path=record_file) else: # pybullet_envs need to render() before reset() to enable mode='human' env.pyenv.envs[0].render(mode='human') env.reset() if recorder: recorder.capture_frame() time_step = driver.get_initial_time_step() policy_state = driver.get_initial_policy_state() episode_reward = 0. episode_length = 0 episodes = 0 while episodes < num_episodes: time_step, policy_state = driver.run(max_num_steps=1, time_step=time_step, policy_state=policy_state) if recorder: recorder.capture_frame() else: env.pyenv.envs[0].render(mode='human') time.sleep(sleep_time_per_step) episode_reward += float(time_step.reward) if time_step.is_last(): logging.info("episode_length=%s episode_reward=%s" % (episode_length, episode_reward)) episode_reward = 0. episode_length = 0. episodes += 1 else: episode_length += 1 if recorder: recorder.close() env.reset()
def _save_checkpoint(self): global_step = get_global_counter() self._checkpointer.save(global_step=global_step.numpy())
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval + 1, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure(create_ta, self._training_info_spec) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, time_step, policy_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP: next_time_step, policy_step, action, transformed_time_step = self._step( time_step, policy_state) next_state = policy_step.state else: transformed_time_step = self._algorithm.transform_timestep( time_step) policy_step = common.algorithm_step(self._algorithm.rollout, transformed_time_step, policy_state) action = common.sample_action_distribution(policy_step.action) next_time_step = time_step next_state = policy_state action_distribution_param = common.get_distribution_params( policy_step.action) final_training_info = TrainingInfo( action_distribution=action_distribution_param, action=action, reward=transformed_time_step.reward, discount=transformed_time_step.discount, step_type=transformed_time_step.step_type, info=policy_step.info) with tape: training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, final_training_info) training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) action_distribution = nested_distributions_from_specs( self._algorithm.action_distribution_spec, training_info.action_distribution) training_info = training_info._replace( action_distribution=action_distribution) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._algorithm.training_summary(training_info, loss_info, grads_and_vars) self._algorithm.metric_summary() common.get_global_counter().assign_add(1) return [next_time_step, next_state]