def _coalesce_post_step(self, losses: Dict[str, float], count_steps_delta: int) -> Dict[str, float]: stats_ordering = sorted(self.running_episode_stats.keys()) stats = torch.stack( [self.running_episode_stats[k] for k in stats_ordering], 0) stats = self._all_reduce(stats) for i, k in enumerate(stats_ordering): self.window_episode_stats[k].append(stats[i]) if self._is_distributed: loss_name_ordering = sorted(losses.keys()) stats = torch.tensor( [losses[k] for k in loss_name_ordering] + [count_steps_delta], device="cpu", dtype=torch.float32, ) stats = self._all_reduce(stats) count_steps_delta = int(stats[-1].item()) stats /= torch.distributed.get_world_size() losses = { k: stats[i].item() for i, k in enumerate(loss_name_ordering) } if self._is_distributed and rank0_only(): self.num_rollouts_done_store.set("num_done", "0") self.num_steps_done += count_steps_delta return losses
def train(self) -> None: r"""Main method for training DD/PPO. Returns: None """ self._init_train() count_checkpoints = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: 1 - self.percent_done(), ) resume_state = load_resume_state(self.config) if resume_state is not None: self.agent.load_state_dict(resume_state["state_dict"]) self.agent.optimizer.load_state_dict(resume_state["optim_state"]) lr_scheduler.load_state_dict(resume_state["lr_sched_state"]) requeue_stats = resume_state["requeue_stats"] self.env_time = requeue_stats["env_time"] self.pth_time = requeue_stats["pth_time"] self.num_steps_done = requeue_stats["num_steps_done"] self.num_updates_done = requeue_stats["num_updates_done"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent"] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent"] self.running_episode_stats = requeue_stats["running_episode_stats"] self.window_episode_stats.update( requeue_stats["window_episode_stats"]) ppo_cfg = self.config.RL.PPO with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if rank0_only() else contextlib.suppress()) as writer: while not self.is_done(): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * ( 1 - self.percent_done()) if rank0_only() and self._should_save_resume_state(): requeue_stats = dict( env_time=self.env_time, pth_time=self.pth_time, count_checkpoints=count_checkpoints, num_steps_done=self.num_steps_done, num_updates_done=self.num_updates_done, _last_checkpoint_percent=self._last_checkpoint_percent, prev_time=(time.time() - self.t_start) + prev_time, running_episode_stats=self.running_episode_stats, window_episode_stats=dict(self.window_episode_stats), ) save_resume_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), self.config, ) if EXIT.is_set(): profiling_wrapper.range_pop() # train update self.envs.close() requeue_job() return self.agent.eval() count_steps_delta = 0 profiling_wrapper.range_push("rollouts loop") profiling_wrapper.range_push("_collect_rollout_step") for buffer_index in range(self._nbuffers): self._compute_actions_and_step_envs(buffer_index) for step in range(ppo_cfg.num_steps): is_last_step = (self.should_end_early(step + 1) or (step + 1) == ppo_cfg.num_steps) for buffer_index in range(self._nbuffers): count_steps_delta += self._collect_environment_result( buffer_index) if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_pop( ) # _collect_rollout_step if not is_last_step: if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_push( "_collect_rollout_step") self._compute_actions_and_step_envs(buffer_index) if is_last_step: break profiling_wrapper.range_pop() # rollouts loop if self._is_distributed: self.num_rollouts_done_store.add("num_done", 1) ( value_loss, action_loss, dist_entropy, ) = self._update_agent() if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore self.num_updates_done += 1 losses = self._coalesce_post_step( dict(value_loss=value_loss, action_loss=action_loss), count_steps_delta, ) self._training_log(writer, losses, prev_time) # checkpoint model if rank0_only() and self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.num_steps_done, wall_time=(time.time() - self.t_start) + prev_time, ), ) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def _init_train(self): if self.config.RL.DDPPO.force_distributed: self._is_distributed = True if is_slurm_batch_job(): add_signal_handlers() if self._is_distributed: local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) if rank0_only(): logger.info("Initialized DD-PPO with {} workers".format( torch.distributed.get_world_size())) self.config.defrost() self.config.TORCH_GPU_ID = local_rank self.config.SIMULATOR_GPU_ID = local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (torch.distributed.get_rank() * self.config.NUM_ENVIRONMENTS) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) self.num_rollouts_done_store = torch.distributed.PrefixStore( "rollout_tracker", tcp_store) self.num_rollouts_done_store.set("num_done", "0") if rank0_only() and self.config.VERBOSE: logger.info(f"config: {self.config}") profiling_wrapper.configure( capture_start_step=self.config.PROFILING.CAPTURE_START_STEP, num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE, ) self._init_envs() ppo_cfg = self.config.RL.PPO if torch.cuda.is_available(): self.device = torch.device("cuda", self.config.TORCH_GPU_ID) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") if rank0_only() and not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) if self._is_distributed: self.agent.init_distributed(find_unused_params=True) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) obs_space = self.obs_space if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = spaces.Dict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) self._nbuffers = 2 if ppo_cfg.use_double_buffered_sampler else 1 self.rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, is_double_buffered=ppo_cfg.use_double_buffered_sampler, ) self.rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device, cache=self._obs_batching_cache) batch = apply_obs_transforms_batch(batch, self.obs_transforms) if self._static_encoder: with torch.no_grad(): batch["visual_features"] = self._encoder(batch) self.rollouts.buffers["observations"][0] = batch self.current_episode_reward = torch.zeros(self.envs.num_envs, 1) self.running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) self.window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) self.env_time = 0.0 self.pth_time = 0.0 self.t_start = time.time()