def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def _worker_fn(world_rank: int, world_size: int, port: int, unused_params: bool): device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) tcp_store = distrib.TCPStore( # type: ignore "127.0.0.1", port, world_size, world_rank == 0) distrib.init_process_group("gloo", store=tcp_store, rank=world_rank, world_size=world_size) config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml") obs_space = gym.spaces.Dict({ IntegratedPointGoalGPSAndCompassSensor.cls_uuid: gym.spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) }) action_space = ActionSpace({"move": EmptySpace()}) actor_critic = PointNavBaselinePolicy.from_config(config, obs_space, action_space) # This use adds some arbitrary parameters that aren't part of the computation # graph, so they will mess up DDP if they aren't correctly ignored by it if unused_params: actor_critic.unused = nn.Linear(64, 64) actor_critic.to(device=device) ppo_cfg = config.RL.PPO agent = DDPPO( actor_critic=actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, ) agent.init_distributed() rollouts = RolloutStorage( ppo_cfg.num_steps, 2, obs_space, action_space, ppo_cfg.hidden_size, num_recurrent_layers=actor_critic.net.num_recurrent_layers, is_double_buffered=False, ) rollouts.to(device) for k, v in rollouts.buffers["observations"].items(): rollouts.buffers["observations"][k] = torch.randn_like(v) # Add two steps so batching works rollouts.advance_rollout() rollouts.advance_rollout() # Get a single batch batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1)) # Call eval actions through the internal wrapper that is used in # agent.update value, action_log_probs, dist_entropy, _ = agent._evaluate_actions( batch["observations"], batch["recurrent_hidden_states"], batch["prev_actions"], batch["masks"], batch["actions"], ) # Backprop on things (value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward() # Make sure all ranks have very similar parameters for param in actor_critic.parameters(): if param.grad is not None: grads = [param.grad.detach().clone() for _ in range(world_size)] distrib.all_gather(grads, grads[world_rank]) for i in range(world_size): assert torch.isclose(grads[i], grads[world_rank]).all()
class PPOTrainer(BaseRLTrainer): r"""Trainer class for PPO algorithm Paper: https://arxiv.org/abs/1707.06347. """ supported_tasks = ["Nav-v0"] SHORT_ROLLOUT_THRESHOLD: float = 0.25 _is_distributed: bool _obs_batching_cache: ObservationBatchingCache envs: VectorEnv agent: PPO actor_critic: Policy def __init__(self, config=None): interrupted_state = load_interrupted_state() if interrupted_state is not None: config = interrupted_state["config"] super().__init__(config) self.actor_critic = None self.agent = None self.envs = None self.obs_transforms = [] self._static_encoder = False self._encoder = None self._obs_space = None # Distirbuted if the world size would be # greater than 1 self._is_distributed = get_distrib_size()[2] > 1 self._obs_batching_cache = ObservationBatchingCache() @property def obs_space(self): if self._obs_space is None and self.envs is not None: self._obs_space = self.envs.observation_spaces[0] return self._obs_space @obs_space.setter def obs_space(self, new_obs_space): self._obs_space = new_obs_space def _all_reduce(self, t: torch.Tensor) -> torch.Tensor: r"""All reduce helper method that moves things to the correct device and only runs if distributed """ if not self._is_distributed: return t orig_device = t.device t = t.to(device=self.device) torch.distributed.all_reduce(t) return t.to(device=orig_device) def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) policy = baseline_registry.get_policy(self.config.RL.POLICY.name) observation_space = self.obs_space self.obs_transforms = get_active_obs_transforms(self.config) observation_space = apply_obs_transforms_obs_space( observation_space, self.obs_transforms ) self.actor_critic = policy.from_config( self.config, observation_space, self.envs.action_spaces[0] ) self.obs_space = observation_space self.actor_critic.to(self.device) if ( self.config.RL.DDPPO.pretrained_encoder or self.config.RL.DDPPO.pretrained ): pretrained_state = torch.load( self.config.RL.DDPPO.pretrained_weights, map_location="cpu" ) if self.config.RL.DDPPO.pretrained: self.actor_critic.load_state_dict( { k[len("actor_critic.") :]: v for k, v in pretrained_state["state_dict"].items() } ) elif self.config.RL.DDPPO.pretrained_encoder: prefix = "actor_critic.net.visual_encoder." self.actor_critic.net.visual_encoder.load_state_dict( { k[len(prefix) :]: v for k, v in pretrained_state["state_dict"].items() if k.startswith(prefix) } ) if not self.config.RL.DDPPO.train_encoder: self._static_encoder = True for param in self.actor_critic.net.visual_encoder.parameters(): param.requires_grad_(False) if self.config.RL.DDPPO.reset_critic: nn.init.orthogonal_(self.actor_critic.critic.fc.weight) nn.init.constant_(self.actor_critic.critic.fc.bias, 0) self.agent = (DDPPO if self._is_distributed else PPO)( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, ) def _init_envs(self, config=None): if config is None: config = self.config self.envs = construct_envs( config, get_env_class(config.ENV_NAME), workers_ignore_signals=is_slurm_batch_job(), ) def _init_train(self): if self.config.RL.DDPPO.force_distributed: self._is_distributed = True if is_slurm_batch_job(): add_signal_handlers() if self._is_distributed: local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) if rank0_only(): logger.info( "Initialized DD-PPO with {} workers".format( torch.distributed.get_world_size() ) ) self.config.defrost() self.config.TORCH_GPU_ID = local_rank self.config.SIMULATOR_GPU_ID = local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( torch.distributed.get_rank() * self.config.NUM_ENVIRONMENTS ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) self.num_rollouts_done_store = torch.distributed.PrefixStore( "rollout_tracker", tcp_store ) self.num_rollouts_done_store.set("num_done", "0") if rank0_only() and self.config.VERBOSE: logger.info(f"config: {self.config}") profiling_wrapper.configure( capture_start_step=self.config.PROFILING.CAPTURE_START_STEP, num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE, ) self._init_envs() ppo_cfg = self.config.RL.PPO if torch.cuda.is_available(): self.device = torch.device("cuda", self.config.TORCH_GPU_ID) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") if rank0_only() and not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) if self._is_distributed: self.agent.init_distributed(find_unused_params=True) logger.info( "agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()) ) ) obs_space = self.obs_space if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = spaces.Dict( { "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, } ) self._nbuffers = 2 if ppo_cfg.use_double_buffered_sampler else 1 self.rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, is_double_buffered=ppo_cfg.use_double_buffered_sampler, ) self.rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs( observations, device=self.device, cache=self._obs_batching_cache ) batch = apply_obs_transforms_batch(batch, self.obs_transforms) if self._static_encoder: with torch.no_grad(): batch["visual_features"] = self._encoder(batch) self.rollouts.buffers["observations"][0] = batch self.current_episode_reward = torch.zeros(self.envs.num_envs, 1) self.running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) self.window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) self.env_time = 0.0 self.pth_time = 0.0 self.t_start = time.time() @rank0_only @profiling_wrapper.RangeContext("save_checkpoint") def save_checkpoint( self, file_name: str, extra_state: Optional[Dict] = None ) -> None: r"""Save checkpoint with specified name. Args: file_name: file name for checkpoint Returns: None """ checkpoint = { "state_dict": self.agent.state_dict(), "config": self.config, } if extra_state is not None: checkpoint["extra_state"] = extra_state torch.save( checkpoint, os.path.join(self.config.CHECKPOINT_FOLDER, file_name) ) def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict: r"""Load checkpoint of specified path as a dict. Args: checkpoint_path: path of target checkpoint *args: additional positional args **kwargs: additional keyword args Returns: dict containing checkpoint info """ return torch.load(checkpoint_path, *args, **kwargs) METRICS_BLACKLIST = {"top_down_map", "collisions.is_collision"} @classmethod def _extract_scalars_from_info( cls, info: Dict[str, Any] ) -> Dict[str, float]: result = {} for k, v in info.items(): if k in cls.METRICS_BLACKLIST: continue if isinstance(v, dict): result.update( { k + "." + subk: subv for subk, subv in cls._extract_scalars_from_info( v ).items() if (k + "." + subk) not in cls.METRICS_BLACKLIST } ) # Things that are scalar-like will have an np.size of 1. # Strings also have an np.size of 1, so explicitly ban those elif np.size(v) == 1 and not isinstance(v, str): result[k] = float(v) return result @classmethod def _extract_scalars_from_infos( cls, infos: List[Dict[str, Any]] ) -> Dict[str, List[float]]: results = defaultdict(list) for i in range(len(infos)): for k, v in cls._extract_scalars_from_info(infos[i]).items(): results[k].append(v) return results def _compute_actions_and_step_envs(self, buffer_index: int = 0): num_envs = self.envs.num_envs env_slice = slice( int(buffer_index * num_envs / self._nbuffers), int((buffer_index + 1) * num_envs / self._nbuffers), ) t_sample_action = time.time() # sample actions with torch.no_grad(): step_batch = self.rollouts.buffers[ self.rollouts.current_rollout_step_idxs[buffer_index], env_slice, ] profiling_wrapper.range_push("compute actions") ( values, actions, actions_log_probs, recurrent_hidden_states, ) = self.actor_critic.act( step_batch["observations"], step_batch["recurrent_hidden_states"], step_batch["prev_actions"], step_batch["masks"], ) # NB: Move actions to CPU. If CUDA tensors are # sent in to env.step(), that will create CUDA contexts # in the subprocesses. # For backwards compatibility, we also call .item() to convert to # an int actions = actions.to(device="cpu") self.pth_time += time.time() - t_sample_action profiling_wrapper.range_pop() # compute actions t_step_env = time.time() for index_env, act in zip( range(env_slice.start, env_slice.stop), actions.unbind(0) ): self.envs.async_step_at(index_env, act.item()) self.env_time += time.time() - t_step_env self.rollouts.insert( next_recurrent_hidden_states=recurrent_hidden_states, actions=actions, action_log_probs=actions_log_probs, value_preds=values, buffer_index=buffer_index, ) def _collect_environment_result(self, buffer_index: int = 0): num_envs = self.envs.num_envs env_slice = slice( int(buffer_index * num_envs / self._nbuffers), int((buffer_index + 1) * num_envs / self._nbuffers), ) t_step_env = time.time() outputs = [ self.envs.wait_step_at(index_env) for index_env in range(env_slice.start, env_slice.stop) ] observations, rewards_l, dones, infos = [ list(x) for x in zip(*outputs) ] self.env_time += time.time() - t_step_env t_update_stats = time.time() batch = batch_obs( observations, device=self.device, cache=self._obs_batching_cache ) batch = apply_obs_transforms_batch(batch, self.obs_transforms) rewards = torch.tensor( rewards_l, dtype=torch.float, device=self.current_episode_reward.device, ) rewards = rewards.unsqueeze(1) not_done_masks = torch.tensor( [[not done] for done in dones], dtype=torch.bool, device=self.current_episode_reward.device, ) done_masks = torch.logical_not(not_done_masks) self.current_episode_reward[env_slice] += rewards current_ep_reward = self.current_episode_reward[env_slice] self.running_episode_stats["reward"][env_slice] += current_ep_reward.where(done_masks, current_ep_reward.new_zeros(())) # type: ignore self.running_episode_stats["count"][env_slice] += done_masks.float() # type: ignore for k, v_k in self._extract_scalars_from_infos(infos).items(): v = torch.tensor( v_k, dtype=torch.float, device=self.current_episode_reward.device, ).unsqueeze(1) if k not in self.running_episode_stats: self.running_episode_stats[k] = torch.zeros_like( self.running_episode_stats["count"] ) self.running_episode_stats[k][env_slice] += v.where(done_masks, v.new_zeros(())) # type: ignore self.current_episode_reward[env_slice].masked_fill_(done_masks, 0.0) if self._static_encoder: with torch.no_grad(): batch["visual_features"] = self._encoder(batch) self.rollouts.insert( next_observations=batch, rewards=rewards, next_masks=not_done_masks, buffer_index=buffer_index, ) self.rollouts.advance_rollout(buffer_index) self.pth_time += time.time() - t_update_stats return env_slice.stop - env_slice.start @profiling_wrapper.RangeContext("_collect_rollout_step") def _collect_rollout_step(self): self._compute_actions_and_step_envs() return self._collect_environment_result() @profiling_wrapper.RangeContext("_update_agent") def _update_agent(self): ppo_cfg = self.config.RL.PPO t_update_model = time.time() with torch.no_grad(): step_batch = self.rollouts.buffers[ self.rollouts.current_rollout_step_idx ] next_value = self.actor_critic.get_value( step_batch["observations"], step_batch["recurrent_hidden_states"], step_batch["prev_actions"], step_batch["masks"], ) self.rollouts.compute_returns( next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau ) self.agent.train() value_loss, action_loss, dist_entropy = self.agent.update( self.rollouts ) self.rollouts.after_update() self.pth_time += time.time() - t_update_model return ( value_loss, action_loss, dist_entropy, ) def _coalesce_post_step( self, losses: Dict[str, float], count_steps_delta: int ) -> Dict[str, float]: stats_ordering = sorted(self.running_episode_stats.keys()) stats = torch.stack( [self.running_episode_stats[k] for k in stats_ordering], 0 ) stats = self._all_reduce(stats) for i, k in enumerate(stats_ordering): self.window_episode_stats[k].append(stats[i]) if self._is_distributed: loss_name_ordering = sorted(losses.keys()) stats = torch.tensor( [losses[k] for k in loss_name_ordering] + [count_steps_delta], device="cpu", dtype=torch.float32, ) stats = self._all_reduce(stats) count_steps_delta = int(stats[-1].item()) stats /= torch.distributed.get_world_size() losses = { k: stats[i].item() for i, k in enumerate(loss_name_ordering) } if self._is_distributed and rank0_only(): self.num_rollouts_done_store.set("num_done", "0") self.num_steps_done += count_steps_delta return losses @rank0_only def _training_log( self, writer, losses: Dict[str, float], prev_time: int = 0 ): deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in self.window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], self.num_steps_done, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, self.num_steps_done) writer.add_scalars( "losses", losses, self.num_steps_done, ) # log stats if self.num_updates_done % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( self.num_updates_done, self.num_steps_done / ((time.time() - self.t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( self.num_updates_done, self.env_time, self.pth_time, self.num_steps_done, ) ) logger.info( "Average window size: {} {}".format( len(self.window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) def should_end_early(self, rollout_step) -> bool: if not self._is_distributed: return False # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! return ( rollout_step >= self.config.RL.PPO.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(self.num_rollouts_done_store.get("num_done")) >= ( self.config.RL.DDPPO.sync_frac * torch.distributed.get_world_size() ) @profiling_wrapper.RangeContext("train") def train(self) -> None: r"""Main method for training DD/PPO. Returns: None """ self._init_train() count_checkpoints = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: 1 - self.percent_done(), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] self.env_time = requeue_stats["env_time"] self.pth_time = requeue_stats["pth_time"] self.num_steps_done = requeue_stats["num_steps_done"] self.num_updates_done = requeue_stats["num_updates_done"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent" ] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent" ] ppo_cfg = self.config.RL.PPO with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if rank0_only() else contextlib.suppress() ) as writer: while not self.is_done(): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * ( 1 - self.percent_done() ) if EXIT.is_set(): profiling_wrapper.range_pop() # train update self.envs.close() if REQUEUE.is_set() and rank0_only(): requeue_stats = dict( env_time=self.env_time, pth_time=self.pth_time, count_checkpoints=count_checkpoints, num_steps_done=self.num_steps_done, num_updates_done=self.num_updates_done, _last_checkpoint_percent=self._last_checkpoint_percent, prev_time=(time.time() - self.t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) ) requeue_job() return self.agent.eval() count_steps_delta = 0 profiling_wrapper.range_push("rollouts loop") profiling_wrapper.range_push("_collect_rollout_step") for buffer_index in range(self._nbuffers): self._compute_actions_and_step_envs(buffer_index) for step in range(ppo_cfg.num_steps): is_last_step = ( self.should_end_early(step + 1) or (step + 1) == ppo_cfg.num_steps ) for buffer_index in range(self._nbuffers): count_steps_delta += self._collect_environment_result( buffer_index ) if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_pop() # _collect_rollout_step if not is_last_step: if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_push( "_collect_rollout_step" ) self._compute_actions_and_step_envs(buffer_index) if is_last_step: break profiling_wrapper.range_pop() # rollouts loop if self._is_distributed: self.num_rollouts_done_store.add("num_done", 1) ( value_loss, action_loss, dist_entropy, ) = self._update_agent() if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore self.num_updates_done += 1 losses = self._coalesce_post_step( dict(value_loss=value_loss, action_loss=action_loss), count_steps_delta, ) self._training_log(writer, losses, prev_time) # checkpoint model if rank0_only() and self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.num_steps_done, wall_time=(time.time() - self.t_start) + prev_time, ), ) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close() def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ if self._is_distributed: raise RuntimeError("Evaluation does not support distributed mode") # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() if config.VERBOSE: logger.info(f"env config: {config}") self._init_envs(config) self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic observations = self.envs.reset() batch = batch_obs( observations, device=self.device, cache=self._obs_batching_cache ) batch = apply_obs_transforms_batch(batch, self.obs_transforms) current_episode_reward = torch.zeros( self.envs.num_envs, 1, device="cpu" ) test_recurrent_hidden_states = torch.zeros( self.config.NUM_ENVIRONMENTS, self.actor_critic.net.num_recurrent_layers, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros( self.config.NUM_ENVIRONMENTS, 1, device=self.device, dtype=torch.long, ) not_done_masks = torch.zeros( self.config.NUM_ENVIRONMENTS, 1, device=self.device, dtype=torch.bool, ) stats_episodes: Dict[ Any, Any ] = {} # dict of dicts that stores stats per episode rgb_frames = [ [] for _ in range(self.config.NUM_ENVIRONMENTS) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) number_of_eval_episodes = self.config.TEST_EPISODE_COUNT if number_of_eval_episodes == -1: number_of_eval_episodes = sum(self.envs.number_of_episodes) else: total_num_eps = sum(self.envs.number_of_episodes) if total_num_eps < number_of_eval_episodes: logger.warn( f"Config specified {number_of_eval_episodes} eval episodes" ", dataset only has {total_num_eps}." ) logger.warn(f"Evaluating with {total_num_eps} instead.") number_of_eval_episodes = total_num_eps pbar = tqdm.tqdm(total=number_of_eval_episodes) self.actor_critic.eval() while ( len(stats_episodes) < number_of_eval_episodes and self.envs.num_envs > 0 ): current_episodes = self.envs.current_episodes() with torch.no_grad(): ( _, actions, _, test_recurrent_hidden_states, ) = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) prev_actions.copy_(actions) # type: ignore # NB: Move actions to CPU. If CUDA tensors are # sent in to env.step(), that will create CUDA contexts # in the subprocesses. # For backwards compatibility, we also call .item() to convert to # an int step_data = [a.item() for a in actions.to(device="cpu")] outputs = self.envs.step(step_data) observations, rewards_l, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs( observations, device=self.device, cache=self._obs_batching_cache, ) batch = apply_obs_transforms_batch(batch, self.obs_transforms) not_done_masks = torch.tensor( [[not done] for done in dones], dtype=torch.bool, device="cpu", ) rewards = torch.tensor( rewards_l, dtype=torch.float, device="cpu" ).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) # episode ended if not not_done_masks[i].item(): pbar.update() episode_stats = {} episode_stats["reward"] = current_episode_reward[i].item() episode_stats.update( self._extract_scalars_from_info(infos[i]) ) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats stats_episodes[ ( current_episodes[i].scene_id, current_episodes[i].episode_id, ) ] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics=self._extract_scalars_from_info(infos[i]), tb_writer=writer, ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: # TODO move normalization / channel changing out of the policy and undo it here frame = observations_to_image( {k: v[i] for k, v in batch.items()}, infos[i] ) rgb_frames[i].append(frame) not_done_masks = not_done_masks.to(device=self.device) ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) num_episodes = len(stats_episodes) aggregated_stats = {} for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = ( sum(v[stat_key] for v in stats_episodes.values()) / num_episodes ) for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.4f}") step_id = checkpoint_index if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] writer.add_scalars( "eval_reward", {"average reward": aggregated_stats["reward"]}, step_id, ) metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} if len(metrics) > 0: writer.add_scalars("eval_metrics", metrics, step_id) self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = torch.device("cuda", self.config.TORCH_GPU_ID) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) observations = self.envs.reset() batch = batch_obs(observations) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) rollouts.to(self.device) episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, episode_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) losses = [value_loss, action_loss] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) # Initialize auxiliary tasks observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = \ self._setup_auxiliary_tasks(aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] + aux_task_losses writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) writer.add_scalar( "success", deltas["success"] / deltas["count"], count_steps, ) # Log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {}\t action_loss: {}\taux_task_loss: {} \t aux_entropy {}" .format(update, value_loss, action_loss, aux_task_losses, aux_dist_entropy)) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self): # Get environments for training self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) #logger.info( # "agent number of parameters: {}".format( # sum(param.numel() for param in self.agent.parameters()) # ) #) # Change for the actual value cfg = self.config.RL.PPO rollouts = RolloutStorage( cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) for sensor in rollouts.observations: print(batch[sensor].shape) # Copy the information to the wrapper for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) #current_episode_reward = torch.zeros(self.envs.num_envs, 1) #window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) #window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) '''
def train(self) -> None: r""" Main method for training PPO Returns: None """ assert ( self.config is not None ), "trainer is not properly initialized, need to specify config file" self.envs = construct_envs(self.config, NavRLEnv) ppo_cfg = self.config.TRAINER.RL.PPO self.device = torch.device("cuda", ppo_cfg.pth_gpu_id) if not os.path.isdir(ppo_cfg.checkpoint_folder): os.makedirs(ppo_cfg.checkpoint_folder) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) observations = self.envs.reset() batch = batch_obs(observations) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) rollouts.to(self.device) episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 with (get_tensorboard_writer( log_dir=ppo_cfg.tensorboard_dir, purge_step=count_steps, flush_secs=30, )) as writer: for update in range(ppo_cfg.num_updates): if ppo_cfg.use_linear_lr_decay: update_linear_schedule( self.agent.optimizer, update, ppo_cfg.num_updates, ppo_cfg.lr, ) if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * ( 1 - update / ppo_cfg.num_updates) for step in range(ppo_cfg.num_steps): t_sample_action = time.time() # sample actions with torch.no_grad(): step_observation = { k: v[step] for k, v in rollouts.observations.items() } ( values, actions, actions_log_probs, recurrent_hidden_states, ) = self.actor_critic.act( step_observation, rollouts.recurrent_hidden_states[step], rollouts.masks[step], ) pth_time += time.time() - t_sample_action t_step_env = time.time() outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] env_time += time.time() - t_step_env t_update_stats = time.time() batch = batch_obs(observations) rewards = torch.tensor(rewards, dtype=torch.float) rewards = rewards.unsqueeze(1) masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, ) current_episode_reward += rewards episode_rewards += (1 - masks) * current_episode_reward episode_counts += 1 - masks current_episode_reward *= masks rollouts.insert( batch, recurrent_hidden_states, actions, actions_log_probs, values, rewards, masks, ) count_steps += self.envs.num_envs pth_time += time.time() - t_update_stats window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) t_update_model = time.time() with torch.no_grad(): last_observation = { k: v[-1] for k, v in rollouts.observations.items() } next_value = self.actor_critic.get_value( last_observation, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1], ).detach() rollouts.compute_returns(next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau) value_loss, action_loss, dist_entropy = self.agent.update( rollouts) rollouts.after_update() pth_time += time.time() - t_update_model losses = [value_loss, action_loss] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % ppo_cfg.log_interval == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % ppo_cfg.checkpoint_interval == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.freeze() if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks( aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), # including bonus ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 prev_time = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_updates = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() ( delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0).to(self.device) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ dist_entropy, aux_dist_entropy, ] + [value_loss, action_loss] + aux_task_losses + [count_steps_delta], device=self.device, ) distrib.all_reduce(stats) if aux_weights is not None and len(aux_weights) > 0: distrib.all_reduce( torch.tensor(aux_weights, device=self.device)) count_steps += stats[-1].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") avg_stats = [ stats[i].item() / self.world_size for i in range(len(stats) - 1) ] losses = avg_stats[2:] dist_entropy, aux_dist_entropy = avg_stats[:2] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) # Log stats formatted_aux_losses = [ "{:.3g}".format(l) for l in aux_task_losses ] if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t" .format( update, value_loss, action_loss, formatted_aux_losses, aux_dist_entropy, )) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def benchmark(self) -> None: if TIME_DEBUG: s = time.time() #self.config.defrost() #self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_EPISODES = 10 #self.config.freeze() if torch.cuda.device_count() <= 1: self.config.defrost() self.config.TORCH_GPU_ID = 0 self.config.SIMULATOR_GPU_ID = 0 self.config.freeze() self.envs = construct_envs(self.config, eval(self.config.ENV_NAME)) if ADD_IL: self.il_envs = construct_envs(self.config, eval(self.config.ENV_NAME), no_val=True) self.collect_mode = 'RL' if TIME_DEBUG: s = log_time(s, 'construct envs') ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) # if 'SMT' in self.config.POLICY: # sd = torch.load('visual_embedding18.pth') # self.actor_critic.net.visual_encoder.load_state_dict(sd['visual_encoder']) # self.actor_critic.net.prev_action_embedding.load_state_dict(sd['prev_action_embedding']) # self.actor_critic.net.visual_encoder.cuda() # self.actor_critic.net.prev_action_embedding.cuda() # self.envs.setup_embedding_network(self.actor_critic.net.visual_encoder, self.actor_critic.net.prev_action_embedding) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) num_train_processes, num_val_processes = self.config.NUM_PROCESSES, self.config.NUM_VAL_PROCESSES total_processes = num_train_processes + num_val_processes OBS_LIST = self.config.OBS_TO_SAVE self.num_processes = num_train_processes rollouts = RolloutStorage(ppo_cfg.num_steps, num_train_processes, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, self.actor_critic.net.num_recurrent_layers, OBS_LIST=OBS_LIST) rollouts.to(self.device) batch = self.envs.reset() for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_( batch[sensor][:num_train_processes]) self.last_observations = batch self.last_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device) self.last_prev_actions = torch.zeros( total_processes, rollouts.prev_actions.shape[-1]).to(self.device) self.last_masks = torch.zeros(total_processes, 1).to(self.device) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None if ADD_IL: rollouts2 = RolloutStorage( ppo_cfg.num_steps, num_train_processes, self.il_envs.observation_spaces[0], self.il_envs.action_spaces[0], ppo_cfg.hidden_size, self.actor_critic.net.num_recurrent_layers, OBS_LIST=OBS_LIST) rollouts2.to(self.device) batch2 = self.il_envs.reset() for sensor in rollouts2.observations: rollouts2.observations[sensor][0].copy_( batch2[sensor][:num_train_processes]) self.saved_last_obs = batch2 self.saved_last_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device) self.saved_last_prev_actions = torch.zeros( total_processes, rollouts2.prev_actions.shape[-1]).to(self.device) self.saved_last_masks = torch.zeros(total_processes, 1).to(self.device) batch2 = None else: rollouts2 = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) if TIME_DEBUG: s = log_time(s, 'setup all') for update in range(100): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if TIME_DEBUG: s = log_time(s, 'collect rollout start') modes = ['RL'] if ADD_IL: modes += ['IL'] for collect_mode in modes: self.collect_mode = collect_mode use_rollouts = rollouts if self.collect_mode == 'RL' else rollouts2 if ADD_IL: self.exchange_lasts() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(use_rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps #print(delta_env_time, delta_pth_time) if TIME_DEBUG: s = log_time(s, 'collect rollout done') (delta_pth_time, value_loss, action_loss, dist_entropy, il_loss) = self._update_agent(ppo_cfg, rollouts, rollouts2) #print(delta_pth_time) pth_time += delta_pth_time use_rollouts.after_update() if TIME_DEBUG: s = log_time(s, 'update agent') for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1][:self.num_processes] - v[0][:self.num_processes]).sum().item() if len(v) > 1 else v[0][:self.num_processes].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) #self.write_tb('train', writer, deltas, count_steps, losses) eval_deltas = { k: ((v[-1][self.num_processes:] - v[0][self.num_processes:]).sum().item() if len(v) > 1 else v[0][self.num_processes:].sum().item()) for k, v in window_episode_stats.items() } eval_deltas["count"] = max(eval_deltas["count"], 1.0) #self.write_tb('val', writer, eval_deltas, count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) logger.info("validation metrics: {}".format( " ".join("{}: {:.3f}".format(k, v / eval_deltas["count"]) for k, v in eval_deltas.items() if k != "count"), )) self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.add_new_based_on_cfg() self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg, train=True) if self.config.PRETRAINED_CHECKPOINT_PATH: ckpt_dict = self.load_checkpoint( self.config.PRETRAINED_CHECKPOINT_PATH, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"], strict=False) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs_augment_aux(observations, self.envs.get_shared_mem()) for sensor in rollouts.observations: if sensor in batch: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None info_data_keys = ["discovered", "collisions_wall", "collisions_prox"] log_data_keys = [ "episode_rewards", "episode_go_rewards", "episode_counts", "current_episode_reward", "current_episode_go_reward" ] + info_data_keys log_data = dict( {k: torch.zeros(self.envs.num_envs, 1) for k in log_data_keys}) info_data = dict({k: log_data[k] for k in info_data_keys}) win_keys = log_data_keys win_keys.pop(win_keys.index("current_episode_reward")) win_keys.pop(win_keys.index("current_episode_go_reward")) windows = dict({ k: deque(maxlen=ppo_cfg.reward_window_size) for k in log_data.keys() }) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) train_steps = min(self.config.NUM_UPDATES, self.config.HARD_NUM_UPDATES) log_interval = self.config.LOG_INTERVAL num_updates = self.config.NUM_UPDATES agent = self.agent ckpt_interval = self.config.CHECKPOINT_INTERVAL with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(train_steps): if ppo_cfg.use_linear_clip_decay: agent.clip_param = ppo_cfg.clip_param * linear_decay( update, num_updates) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, log_data["current_episode_reward"], log_data["current_episode_go_reward"], log_data["episode_rewards"], log_data["episode_go_rewards"], log_data["episode_counts"], info_data) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy,\ aux_loss = self._update_agent(ppo_cfg, rollouts) # TODO check if LR is init if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() pth_time += delta_pth_time # ================================================================================== # -- Log data for window averaging for k, v in windows.items(): windows[k].append(log_data[k].clone()) value_names = ["value", "policy", "entropy"] + list( aux_loss.keys()) losses = [value_loss, action_loss, dist_entropy] + list( aux_loss.values()) stats = zip(list(windows.keys()), list(windows.values())) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } act_ep = deltas["episode_counts"] counts = max(act_ep, 1.0) deltas["episode_counts"] *= counts for k, v in deltas.items(): deltas[k] = v / counts writer.add_scalar(k, deltas[k], count_steps) writer.add_scalars("losses", {k: l for l, k in zip(losses, value_names)}, count_steps) # log stats if update > 0 and update % log_interval == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) if act_ep > 0: log_txt = f"Average window size {len(windows['episode_counts'])}" for k, v in deltas.items(): log_txt += f" | {k}: {v:.3f}" logger.info(log_txt) logger.info( f"Aux losses: {list(zip(value_names, losses))}") else: logger.info("No episodes finish in current window") # ================================================================================== # checkpoint model if update % ckpt_interval == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ if TIME_DEBUG: s = time.time() self.envs = construct_envs( self.config, eval(self.config.ENV_NAME) ) if TIME_DEBUG: s = log_time(s, 'construct envs') ppo_cfg = self.config.RL.PPO self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) if 'SMT' in self.config.POLICY: sd = torch.load('visual_embedding18_explore.pth') if 'Explore' in self.config.POLICY else torch.load('visual_embedding18.pth') self.actor_critic.net.visual_encoder.load_state_dict(sd['visual_encoder']) self.actor_critic.net.prev_action_embedding.load_state_dict(sd['prev_action_embedding']) self.actor_critic.net.visual_encoder.cuda() self.actor_critic.net.prev_action_embedding.cuda() total_num = self.config.NUM_PROCESSES + self.config.NUM_VAL_PROCESSES args_list = {'visual_encoder': self.actor_critic.net.visual_encoder, 'prev_action_embedding': self.actor_critic.net.prev_action_embedding} self.envs.call(['setup_embedding_network']*total_num, [args_list]*total_num) logger.info( "agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()) ) ) num_train_processes, num_val_processes = self.config.NUM_PROCESSES, self.config.NUM_VAL_PROCESSES total_processes = num_train_processes + num_val_processes OBS_LIST = self.config.OBS_TO_SAVE self.num_processes = num_train_processes rollouts = RolloutStorage( ppo_cfg.num_steps, num_train_processes, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, self.actor_critic.net.num_recurrent_layers, OBS_LIST = OBS_LIST ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor][:num_train_processes]) self.last_observations = batch self.last_recurrent_hidden_states = torch.zeros(self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device) self.last_prev_actions = torch.zeros(total_processes, rollouts.prev_actions.shape[-1]).to(self.device) self.last_masks = torch.zeros(total_processes,1).to(self.device) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps start_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) if TIME_DEBUG: s = log_time(s, 'setup all') with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) if TIME_DEBUG: s = log_time(s, 'collect rollout start') for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps if TIME_DEBUG: s = log_time(s, 'collect rollout done') ( delta_pth_time, value_loss, action_loss, dist_entropy, il_loss ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time rollouts.after_update() if TIME_DEBUG: s = log_time(s, 'update agent') for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ( (v[-1][:self.num_processes] - v[0][:self.num_processes]).sum().item() if len(v) > 1 else v[0][:self.num_processes].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) losses = [value_loss, action_loss, dist_entropy, il_loss] self.write_tb('train', writer, deltas, count_steps, losses) eval_deltas = { k: ( (v[-1][self.num_processes:] - v[0][self.num_processes:]).sum().item() if len(v) > 1 else v[0][self.num_processes:].sum().item() ) for k, v in window_episode_stats.items() } eval_deltas["count"] = max(eval_deltas["count"], 1.0) self.write_tb('val', writer, eval_deltas, count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, (count_steps - start_steps) / (time.time() - t_start) ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) logger.info( "validation metrics: {}".format( " ".join( "{}: {:.3f}".format(k, v / eval_deltas["count"]) for k, v in eval_deltas.items() if k != "count" ), ) ) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps) ) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.add_new_based_on_cfg() self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg, train=True) if self.config.PRETRAINED_CHECKPOINT_PATH: ckpt_dict = self.load_checkpoint( self.config.PRETRAINED_CHECKPOINT_PATH, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"], strict=False) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs_augment_aux(observations) for sensor in rollouts.observations: if sensor in batch: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_go_rewards = torch.zeros(self.envs.num_envs, 1) # Grid oracle rewars episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) current_episode_go_reward = torch.zeros(self.envs.num_envs, 1) # Grid oracle rewars window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_go_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) train_steps = min(self.config.NUM_UPDATES, self.config.HARD_NUM_UPDATES) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(train_steps): if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, current_episode_go_reward, episode_rewards, episode_go_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy,\ aux_loss = self._update_agent(ppo_cfg, rollouts) # TODO check if LR is init if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_go_reward.append(episode_go_rewards.clone()) window_episode_counts.append(episode_counts.clone()) value_names = ["value", "policy", "entropy"] + list( aux_loss.keys()) losses = [value_loss, action_loss, dist_entropy] + list( aux_loss.values()) stats = zip( ["count", "reward", "reward_go"], [ window_episode_counts, window_episode_reward, window_episode_go_reward ], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalar("reward_go", deltas["reward_go"] / deltas["count"], count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, value_names)}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_go_rewards = (window_episode_go_reward[-1] - window_episode_go_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f} reward_go: {:3f}" .format( len(window_episode_reward), (window_rewards / window_counts).item(), (window_go_rewards / window_counts).item(), )) logger.info( f"Aux losses: {list(zip(value_names, losses))}") else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
class PPOTrainer_(BaseRLTrainer): supported_tasks = ["Nav-v0"] def __init__(self, config=None): super().__init__(config) self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self._setup_actor_critic_agent(self.config) self.rollout = RolloutStorage(self.config.RL.PPO.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], self.config.RL.PPO.hidden_size) self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) self.rollout.to(self.device) def _setup_actor_critic_agent(self, ppo_cfg): print(self.config) # len(observation_spaces)==NUM_PROCESSES==4? print(self.envs.observation_spaces[0]) print(self.envs.action_spaces[0]) self.actor_critic = PointNavBaselinePolicy_( cnn_parameter={ "observation_spaces": self.envs.observation_spaces[0], "feature_dim": self.config.RL.PPO.cnn_output_size }, depth_decoder_parameter={}, rnn_parameter={ "input_dim": 0, "hidden_dim": self.config.RL.PPO.hidden_size, "n_layer": 1 }, actor_parameter={ "action_spaces": 4, }, critic_parameter={}, use_splitnet_auxiliary=self.config.RL.PPO.use_splitnet_auxiliary, ) self.agent = PPO_( self.actor_critic, self.config, ) def _collect_rollout_step(self, current_episode_reward, running_episode_stats): ''' obss : "rgb" "depth" "pointgoal_with_gps_compass" ''' ### PASS DATA with torch.no_grad(): obs = { k: v[self.rollout.step] for k, v in self.rollout.observations.items() } hidden_states = self.rollout.recurrent_hidden_states[ self.rollout.step] inverse_dones = self.rollout.masks[self.rollout.step] (actions_log_probs, actions, value, distributions, rnn_hidden_state) = (self.actor_critic.act( obs, hidden_states, inverse_dones)) ### PASS ENV res = self.envs.step([action.item() for action in actions]) ### Process observations, rewards, dones, infos = [list(x) for x in zip(*res)] observations = batch_obs(observations, self.device) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(-1) dones = torch.tensor(dones, dtype=torch.float, device=self.device) inverse_dones = torch.tensor([[0] if done else [1] for done in dones], dtype=torch.float, device=self.device) self.rollout.insert( observations, rnn_hidden_state, actions, actions_log_probs, value, rewards, inverse_dones, ) # print("rewards:", rewards) # print("dones:", dones) current_episode_reward += rewards running_episode_stats["reward"] += ( 1 - inverse_dones) * current_episode_reward running_episode_stats["count"] += 1 - inverse_dones for k, v in self._extract_scalars_from_infos(infos).items(): v = torch.tensor(v, dtype=torch.float, device=current_episode_reward.device).unsqueeze(1) if k not in running_episode_stats: running_episode_stats[k] = torch.zeros_like( running_episode_stats["count"]) running_episode_stats[k] += (1 - inverse_dones) * v current_episode_reward *= inverse_dones return def _save_checkpoint(self, checkpoint_name): checkpoint_dict = { "state_dict": self.agent.state_dict(), "config": self.config, } torch.save( checkpoint_dict, os.path.join(self.config.CHECKPOINT_FOLDER, checkpoint_name), ) def load_checkpoint(self, checkpoint_path, *args, **kwargs): return torch.load(checkpoint_path, *args, **kwargs) def _update_agent(self): with torch.no_grad(): last_obs = { k: v[self.rollout.step] for k, v in self.rollout.observations.items() } hidden_states = self.rollout.recurrent_hidden_states[ self.rollout.step] inverse_dones = self.rollout.masks[self.rollout.step] next_value = self.actor_critic.get_value(last_obs, hidden_states, inverse_dones).detach() self.rollout.compute_returns(next_value, self.config.RL.PPO.use_gae, self.config.RL.PPO.gamma, self.config.RL.PPO.tau) loss, loss_auxiliary = self.agent.update(self.rollout) self.rollout.after_update() return loss, loss_auxiliary def train(self) -> None: #### init obs = self.envs.reset() #### 先暫存一個 obs batch = batch_obs(obs, device=self.device) for sensor in self.rollout.observations: self.rollout.observations[sensor][0].copy_(batch[sensor]) #### Para lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) #### PPO LOG PARA count_steps = 0 current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=self.config.RL.PPO.reward_window_size)) #### 開始訓練 with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for epoch in range(self.config.NUM_UPDATES): #### decay if self.config.RL.PPO.use_linear_lr_decay: lr_scheduler.step() if (epoch + 1) % self.config.CHECKPOINT_INTERVAL == 0: self.agent.entropy_coef = self.agent.entropy_coef * 0.9 print(self.agent.entropy_coef) #### 蒐集rollout print("=== collect rollout ===") for step in range(self.config.RL.PPO.num_steps): self._collect_rollout_step(current_episode_reward, running_episode_stats) count_steps += self.envs.num_envs #### 更新 loss, loss_auxiliary = self._update_agent() #### LOGGER writer.add_scalars("loss", loss, epoch * self.envs.num_envs) writer.add_scalars("loss_auxiliary", loss_auxiliary, epoch * self.envs.num_envs) #### PPO LOG PARA for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) #### PPO LOG PARA if epoch % (1) == 0: print("count_steps:", count_steps) print("deltas:", deltas) print("metrics:", metrics) if epoch % self.config.CHECKPOINT_INTERVAL == 0: self._save_checkpoint(f"checkpoint.{epoch}.pth") self.envs.close() # 每個checkpoint call一次 # TEST_EPISODE_COUNT 跑幾次EPISODE # 還沒結束就將observations存到rgb_frames中,並且做後處裡,top_down_map... # Dones之後,generate_video,存到disk跟tensorboard def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(config, get_env_class(config.ENV_NAME)) self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic observations = self.envs.reset() batch = batch_obs(observations, device=self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( 1, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) pbar = tqdm.tqdm(total=self.config.TEST_EPISODE_COUNT) self.actor_critic.eval() while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): (_, actions, _, _, test_recurrent_hidden_states) = (self.actor_critic.act( batch, test_recurrent_hidden_states, not_done_masks)) prev_actions.copy_(actions) outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations, device=self.device) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: pbar.update() episode_stats = dict() episode_stats["reward"] = current_episode_reward[i].item() episode_stats.update( self._extract_scalars_from_info(infos[i])) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, )] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics=self._extract_scalars_from_info(infos[i]), tb_writer=writer, ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) rgb_frames[i].append(frame) ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) num_episodes = len(stats_episodes) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = ( sum([v[stat_key] for v in stats_episodes.values()]) / num_episodes) for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.4f}") step_id = checkpoint_index if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] writer.add_scalars( "eval_reward", {"average reward": aggregated_stats["reward"]}, step_id, ) metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} if len(metrics) > 0: writer.add_scalars("eval_metrics", metrics, step_id) self.envs.close() METRICS_BLACKLIST = {"top_down_map", "collisions.is_collision"} @classmethod def _extract_scalars_from_info(cls, info: Dict[str, Any]) -> Dict[str, float]: result = {} for k, v in info.items(): if k in cls.METRICS_BLACKLIST: continue if isinstance(v, dict): result.update({ k + "." + subk: subv for subk, subv in cls._extract_scalars_from_info( v).items() if (k + "." + subk) not in cls.METRICS_BLACKLIST }) # Things that are scalar-like will have an np.size of 1. # Strings also have an np.size of 1, so explicitly ban those elif np.size(v) == 1 and not isinstance(v, str): result[k] = float(v) return result @classmethod def _extract_scalars_from_infos( cls, infos: List[Dict[str, Any]]) -> Dict[str, List[float]]: results = defaultdict(list) for i in range(len(infos)): for k, v in cls._extract_scalars_from_info(infos[i]).items(): results[k].append(v) return results