def _compute_actions_and_step_envs(self, buffer_index: int = 0): num_envs = self.envs.num_envs env_slice = slice( int(buffer_index * num_envs / self._nbuffers), int((buffer_index + 1) * num_envs / self._nbuffers), ) t_sample_action = time.time() # sample actions with torch.no_grad(): step_batch = self.rollouts.buffers[ self.rollouts.current_rollout_step_idxs[buffer_index], env_slice, ] profiling_wrapper.range_push("compute actions") ( values, actions, actions_log_probs, recurrent_hidden_states, ) = self.actor_critic.act( step_batch["observations"], step_batch["recurrent_hidden_states"], step_batch["prev_actions"], step_batch["masks"], ) # NB: Move actions to CPU. If CUDA tensors are # sent in to env.step(), that will create CUDA contexts # in the subprocesses. # For backwards compatibility, we also call .item() to convert to # an int actions = actions.to(device="cpu") self.pth_time += time.time() - t_sample_action profiling_wrapper.range_pop() # compute actions t_step_env = time.time() for index_env, act in zip(range(env_slice.start, env_slice.stop), actions.unbind(0)): if self.using_velocity_ctrl: step_action = action_to_velocity_control(act) else: step_action = act.item() self.envs.async_step_at(index_env, step_action) self.env_time += time.time() - t_step_env self.rollouts.insert( next_recurrent_hidden_states=recurrent_hidden_states, actions=actions, action_log_probs=actions_log_probs, value_preds=values, buffer_index=buffer_index, )
def train(self) -> None: r"""Main method for training DD/PPO. Returns: None """ self._init_train() count_checkpoints = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: 1 - self.percent_done(), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] self.env_time = requeue_stats["env_time"] self.pth_time = requeue_stats["pth_time"] self.num_steps_done = requeue_stats["num_steps_done"] self.num_updates_done = requeue_stats["num_updates_done"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent" ] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent" ] ppo_cfg = self.config.RL.PPO with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if rank0_only() else contextlib.suppress() ) as writer: while not self.is_done(): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * ( 1 - self.percent_done() ) if EXIT.is_set(): profiling_wrapper.range_pop() # train update self.envs.close() if REQUEUE.is_set() and rank0_only(): requeue_stats = dict( env_time=self.env_time, pth_time=self.pth_time, count_checkpoints=count_checkpoints, num_steps_done=self.num_steps_done, num_updates_done=self.num_updates_done, _last_checkpoint_percent=self._last_checkpoint_percent, prev_time=(time.time() - self.t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) ) requeue_job() return self.agent.eval() count_steps_delta = 0 profiling_wrapper.range_push("rollouts loop") profiling_wrapper.range_push("_collect_rollout_step") for buffer_index in range(self._nbuffers): self._compute_actions_and_step_envs(buffer_index) for step in range(ppo_cfg.num_steps): is_last_step = ( self.should_end_early(step + 1) or (step + 1) == ppo_cfg.num_steps ) for buffer_index in range(self._nbuffers): count_steps_delta += self._collect_environment_result( buffer_index ) if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_pop() # _collect_rollout_step if not is_last_step: if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_push( "_collect_rollout_step" ) self._compute_actions_and_step_envs(buffer_index) if is_last_step: break profiling_wrapper.range_pop() # rollouts loop if self._is_distributed: self.num_rollouts_done_store.add("num_done", 1) ( value_loss, action_loss, dist_entropy, ) = self._update_agent() if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore self.num_updates_done += 1 losses = self._coalesce_post_step( dict(value_loss=value_loss, action_loss=action_loss), count_steps_delta, ) self._training_log(writer, losses, prev_time) # checkpoint model if rank0_only() and self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.num_steps_done, wall_time=(time.time() - self.t_start) + prev_time, ), ) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]: advantages = self.get_advantages(rollouts) value_loss_epoch = 0.0 action_loss_epoch = 0.0 dist_entropy_epoch = 0.0 for _e in range(self.ppo_epoch): profiling_wrapper.range_push("PPO.update epoch") data_generator = rollouts.recurrent_generator( advantages, self.num_mini_batch) for batch in data_generator: ( values, action_log_probs, dist_entropy, _, ) = self.actor_critic.evaluate_actions( batch["observations"], batch["recurrent_hidden_states"], batch["prev_actions"], batch["masks"], batch["actions"], ) ratio = torch.exp(action_log_probs - batch["action_log_probs"]) surr1 = ratio * batch["advantages"] surr2 = (torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch["advantages"]) action_loss = -(torch.min(surr1, surr2).mean()) if self.use_clipped_value_loss: value_pred_clipped = batch["value_preds"] + ( values - batch["value_preds"]).clamp( -self.clip_param, self.clip_param) value_losses = (values - batch["returns"]).pow(2) value_losses_clipped = (value_pred_clipped - batch["returns"]).pow(2) value_loss = 0.5 * torch.max(value_losses, value_losses_clipped) else: value_loss = 0.5 * (batch["returns"] - values).pow(2) value_loss = value_loss.mean() dist_entropy = dist_entropy.mean() self.optimizer.zero_grad() total_loss = (value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef) self.before_backward(total_loss) total_loss.backward() self.after_backward(total_loss) self.before_step() self.optimizer.step() self.after_step() value_loss_epoch += value_loss.item() action_loss_epoch += action_loss.item() dist_entropy_epoch += dist_entropy.item() profiling_wrapper.range_pop() # PPO.update epoch num_updates = self.ppo_epoch * self.num_mini_batch value_loss_epoch /= num_updates action_loss_epoch /= num_updates dist_entropy_epoch /= num_updates return value_loss_epoch, action_loss_epoch, dist_entropy_epoch
def train(self) -> None: r"""Main method for training PPO. Returns: None """ profiling_wrapper.configure( capture_start_step=self.config.PROFILING.CAPTURE_START_STEP, num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE, ) self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) batch = apply_obs_transforms_batch(batch, self.obs_transforms) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats: DefaultDict[str, deque] = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES ), # type: ignore ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) profiling_wrapper.range_push("rollouts loop") for _step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps profiling_wrapper.range_pop() # rollouts loop ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def _collect_rollout_step(self, rollouts, current_episode_reward, running_episode_stats): pth_time = 0.0 env_time = 0.0 t_sample_action = time.time() # sample actions with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } profiling_wrapper.range_push("compute actions") ( values, actions, actions_log_probs, recurrent_hidden_states, ) = self.actor_critic.act( step_observation, rollouts.recurrent_hidden_states[rollouts.step], rollouts.prev_actions[rollouts.step], rollouts.masks[rollouts.step], ) pth_time += time.time() - t_sample_action t_step_env = time.time() # NB: Move actions to CPU. If CUDA tensors are # sent in to env.step(), that will create CUDA contexts # in the subprocesses. # For backwards compatibility, we also call .item() to convert to # an int step_data = [a.item() for a in actions.to(device="cpu")] profiling_wrapper.range_pop() # compute actions outputs = self.envs.step(step_data) observations, rewards_l, dones, infos = [ list(x) for x in zip(*outputs) ] env_time += time.time() - t_step_env t_update_stats = time.time() batch = batch_obs(observations, device=self.device) batch = apply_obs_transforms_batch(batch, self.obs_transforms) rewards = torch.tensor(rewards_l, dtype=torch.float, device=current_episode_reward.device) rewards = rewards.unsqueeze(1) masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=current_episode_reward.device, ) current_episode_reward += rewards running_episode_stats["reward"] += ( 1 - masks) * current_episode_reward # type: ignore running_episode_stats["count"] += 1 - masks # type: ignore for k, v_k in self._extract_scalars_from_infos(infos).items(): v = torch.tensor(v_k, dtype=torch.float, device=current_episode_reward.device).unsqueeze(1) if k not in running_episode_stats: running_episode_stats[k] = torch.zeros_like( running_episode_stats["count"]) running_episode_stats[k] += (1 - masks) * v # type: ignore current_episode_reward *= masks if self._static_encoder: with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts.insert( batch, recurrent_hidden_states, actions, actions_log_probs, values, rewards, masks, ) pth_time += time.time() - t_update_stats return pth_time, env_time, self.envs.num_envs
def update(self, rollouts): advantages = self.get_advantages(rollouts) value_loss_epoch = 0 action_loss_epoch = 0 dist_entropy_epoch = 0 aux_losses_epoch = [0] * len(self.aux_tasks) aux_entropy_epoch = 0 aux_weights_epoch = [0] * len(self.aux_tasks) for e in range(self.ppo_epoch): # This data generator steps through the rollout (gathering n=batch_size processes rollouts) data_generator = rollouts.recurrent_generator( advantages, self.num_mini_batch, ) for sample in data_generator: ( obs_batch, recurrent_hidden_states_batch, actions_batch, prev_actions_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ, ) = sample ( values, action_log_probs, dist_entropy, final_rnn_state, # Used to encourage trajectory memory (it's the same as final rnn feature due to GRU) rnn_features, individual_rnn_features, aux_dist_entropy, aux_weights ) = self.actor_critic.evaluate_actions( obs_batch, recurrent_hidden_states_batch, prev_actions_batch, masks_batch, actions_batch, ) ratio = torch.exp( action_log_probs - old_action_log_probs_batch ) surr1 = ratio * adv_targ surr2 = ( torch.clamp( ratio, 1.0 - self.clip_param, 1.0 + self.clip_param ) * adv_targ ) action_loss = -torch.min(surr1, surr2).mean() # Value loss is MSE with actual(TM) value/rewards if self.use_clipped_value_loss: value_pred_clipped = value_preds_batch + ( values - value_preds_batch ).clamp(-self.clip_param, self.clip_param) value_losses = (values - return_batch).pow(2) value_losses_clipped = ( value_pred_clipped - return_batch ).pow(2) value_loss = ( 0.5 * torch.max(value_losses, value_losses_clipped).mean() ) else: value_loss = 0.5 * (return_batch - values).pow(2).mean() total_aux_loss = 0 aux_losses = [] if len(self.aux_tasks) > 0: # Only nonempty in training raw_losses = self.actor_critic.evaluate_aux_losses(sample, final_rnn_state, rnn_features, individual_rnn_features) aux_losses = torch.stack(raw_losses) total_aux_loss = torch.sum(aux_losses, dim=0) self.optimizer.zero_grad() total_loss = ( value_loss * self.value_loss_coef + action_loss + total_aux_loss * self.aux_loss_coef - dist_entropy * self.entropy_coef ) if aux_dist_entropy is not None: total_loss -= aux_dist_entropy * self.aux_cfg.entropy_coef self.before_backward(total_loss) total_loss.backward() self.after_backward(total_loss) self.before_step() self.optimizer.step() self.after_step() value_loss_epoch += value_loss.item() action_loss_epoch += action_loss.item() dist_entropy_epoch += dist_entropy.item() if aux_dist_entropy is not None: aux_entropy_epoch += aux_dist_entropy.item() for i, aux_loss in enumerate(aux_losses): aux_losses_epoch[i] += aux_loss.item() if aux_weights is not None: for i, aux_weight in enumerate(aux_weights): aux_weights_epoch[i] += aux_weight.item() profiling_wrapper.range_pop() # PPO.update epoch num_updates = self.ppo_epoch * self.num_mini_batch value_loss_epoch /= num_updates action_loss_epoch /= num_updates dist_entropy_epoch /= num_updates for i, aux_loss in enumerate(aux_losses): aux_losses_epoch[i] /= num_updates if aux_weights is not None: for i, aux_weight in enumerate(aux_weights): aux_weights_epoch[i] /= num_updates return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, aux_losses_epoch, aux_entropy_epoch, aux_weights_epoch
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) add_signal_handlers() profiling_wrapper.configure( capture_start_step=self.config.PROFILING.CAPTURE_START_STEP, num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE, ) # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore( "rollout_tracker", tcp_store ) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( self.world_rank * self.config.NUM_PROCESSES ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME), workers_ignore_signals=True, ) ppo_cfg = self.config.RL.PPO if ( not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0 ): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) batch = apply_obs_transforms_batch(batch, self.obs_transforms) obs_space = self.obs_space if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = spaces.Dict( { "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, } ) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device ) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats: DefaultDict[str, deque] = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) t_start = time.time() env_time = 0 pth_time = 0 count_steps: float = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), # type: ignore ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if self.world_rank == 0 else contextlib.suppress() ) as writer: for update in range(start_update, self.config.NUM_UPDATES): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) if EXIT.is_set(): profiling_wrapper.range_pop() # train update self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) ) requeue_job() return count_steps_delta = 0 self.agent.eval() profiling_wrapper.range_push("rollouts loop") for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if ( step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size ): break profiling_wrapper.range_pop() # rollouts loop num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = sorted(running_episode_stats.keys()) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0 ) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]: advantages = self.get_advantages(rollouts) value_loss_epoch = 0.0 action_loss_epoch = 0.0 dist_entropy_epoch = 0.0 for _e in range(self.ppo_epoch): profiling_wrapper.range_push("PPO.update epoch") data_generator = rollouts.recurrent_generator( advantages, self.num_mini_batch) for sample in data_generator: ( obs_batch, recurrent_hidden_states_batch, actions_batch, prev_actions_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ, ) = sample # Reshape to do in a single forward pass for all steps ( values, action_log_probs, dist_entropy, _, ) = self.actor_critic.evaluate_actions( obs_batch, recurrent_hidden_states_batch, prev_actions_batch, masks_batch, actions_batch, ) ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = (torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ) action_loss = -torch.min(surr1, surr2).mean() if self.use_clipped_value_loss: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp( -self.clip_param, self.clip_param) value_losses = (values - return_batch).pow(2) value_losses_clipped = (value_pred_clipped - return_batch).pow(2) value_loss = ( 0.5 * torch.max(value_losses, value_losses_clipped).mean()) else: value_loss = 0.5 * (return_batch - values).pow(2).mean() self.optimizer.zero_grad() total_loss = (value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef) self.before_backward(total_loss) total_loss.backward() self.after_backward(total_loss) self.before_step() self.optimizer.step() self.after_step() value_loss_epoch += value_loss.item() action_loss_epoch += action_loss.item() dist_entropy_epoch += dist_entropy.item() profiling_wrapper.range_pop() # PPO.update epoch num_updates = self.ppo_epoch * self.num_mini_batch value_loss_epoch /= num_updates action_loss_epoch /= num_updates dist_entropy_epoch /= num_updates return value_loss_epoch, action_loss_epoch, dist_entropy_epoch