def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.freeze() if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks( aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), # including bonus ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 prev_time = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_updates = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() ( delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0).to(self.device) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ dist_entropy, aux_dist_entropy, ] + [value_loss, action_loss] + aux_task_losses + [count_steps_delta], device=self.device, ) distrib.all_reduce(stats) if aux_weights is not None and len(aux_weights) > 0: distrib.all_reduce( torch.tensor(aux_weights, device=self.device)) count_steps += stats[-1].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") avg_stats = [ stats[i].item() / self.world_size for i in range(len(stats) - 1) ] losses = avg_stats[2:] dist_entropy, aux_dist_entropy = avg_stats[:2] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) # Log stats formatted_aux_losses = [ "{:.3g}".format(l) for l in aux_task_losses ] if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t" .format( update, value_loss, action_loss, formatted_aux_losses, aux_dist_entropy, )) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training DD/PPO. Returns: None """ self._init_train() count_checkpoints = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: 1 - self.percent_done(), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] self.env_time = requeue_stats["env_time"] self.pth_time = requeue_stats["pth_time"] self.num_steps_done = requeue_stats["num_steps_done"] self.num_updates_done = requeue_stats["num_updates_done"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent" ] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] self._last_checkpoint_percent = requeue_stats[ "_last_checkpoint_percent" ] ppo_cfg = self.config.RL.PPO with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if rank0_only() else contextlib.suppress() ) as writer: while not self.is_done(): profiling_wrapper.on_start_step() profiling_wrapper.range_push("train update") if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * ( 1 - self.percent_done() ) if EXIT.is_set(): profiling_wrapper.range_pop() # train update self.envs.close() if REQUEUE.is_set() and rank0_only(): requeue_stats = dict( env_time=self.env_time, pth_time=self.pth_time, count_checkpoints=count_checkpoints, num_steps_done=self.num_steps_done, num_updates_done=self.num_updates_done, _last_checkpoint_percent=self._last_checkpoint_percent, prev_time=(time.time() - self.t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) ) requeue_job() return self.agent.eval() count_steps_delta = 0 profiling_wrapper.range_push("rollouts loop") profiling_wrapper.range_push("_collect_rollout_step") for buffer_index in range(self._nbuffers): self._compute_actions_and_step_envs(buffer_index) for step in range(ppo_cfg.num_steps): is_last_step = ( self.should_end_early(step + 1) or (step + 1) == ppo_cfg.num_steps ) for buffer_index in range(self._nbuffers): count_steps_delta += self._collect_environment_result( buffer_index ) if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_pop() # _collect_rollout_step if not is_last_step: if (buffer_index + 1) == self._nbuffers: profiling_wrapper.range_push( "_collect_rollout_step" ) self._compute_actions_and_step_envs(buffer_index) if is_last_step: break profiling_wrapper.range_pop() # rollouts loop if self._is_distributed: self.num_rollouts_done_store.add("num_done", 1) ( value_loss, action_loss, dist_entropy, ) = self._update_agent() if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() # type: ignore self.num_updates_done += 1 losses = self._coalesce_post_step( dict(value_loss=value_loss, action_loss=action_loss), count_steps_delta, ) self._training_log(writer, losses, prev_time) # checkpoint model if rank0_only() and self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.num_steps_done, wall_time=(time.time() - self.t_start) + prev_time, ), ) count_checkpoints += 1 profiling_wrapper.range_pop() # train update self.envs.close()
def train(self) -> None: r"""Main method for DD-PPO SLAM. Returns: None """ ##################################################################### ## init distrib and configuration ##################################################################### self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # self.local_rank = 1 add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore( "rollout_tracker", tcp_store ) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() # server number self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server self.config.SIMULATOR_GPU_ID = self.local_rank print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID) print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID) # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( self.world_rank * self.config.NUM_PROCESSES ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") ##################################################################### ## build distrib NavSLAMRLEnv environment ##################################################################### print("#############################################################") print("## build distrib NavSLAMRLEnv environment") print("#############################################################") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observations = self.envs.reset() print("*************************** observations len:", len(observations)) # semantic process for i in range(len(observations)): observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32) se = list(set(observations[i]["semantic"].ravel())) print(se) # print("*************************** observations type:", observations) # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23 # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # [] batch = batch_obs(observations, device=self.device) print("*************************** batch len:", len(batch)) # print("*************************** batch:", batch) # print("************************************* current_episodes:", (self.envs.current_episodes())) ##################################################################### ## init actor_critic agent ##################################################################### print("#############################################################") print("## init actor_critic agent") print("#############################################################") self.map_w = observations[0]["map_sum"].shape[0] self.map_h = observations[0]["map_sum"].shape[1] # print("map_: ", observations[0]["curr_pose"].shape) ppo_cfg = self.config.RL.PPO if ( not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0 ): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(observations, ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) ##################################################################### ## init Global Rollout Storage ##################################################################### print("#############################################################") print("## init Global Rollout Storage") print("#############################################################") self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step rollouts = GlobalRolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.g_action_space, ) rollouts.to(self.device) print('rollouts type:', type(rollouts)) print('--------------------------') # for k in rollouts.keys(): # print("rollouts: {0}".format(rollouts.observations)) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } _, actions, _, = self.actor_critic.act( step_observation, rollouts.prev_g_actions[0], rollouts.masks[0], ) self.global_goals = [[int(action[0].item() * self.map_w), int(action[1].item() * self.map_h)] for action in actions] # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device ) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) print("*************************** current_episode_reward:", current_episode_reward) print("*************************** running_episode_stats:", running_episode_stats) # print("*************************** window_episode_stats:", window_episode_stats) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth") interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] deif = {} with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if self.world_rank == 0 else contextlib.suppress() ) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) # print("************************************* current_episodes:", type(self.envs.count_episodes())) # print(EXIT.is_set()) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth" ) print("********************EXIT*********************") requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_global_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # print("************************************* current_episodes:") for i in range(len(self.envs.current_episodes())): # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category) if self.envs.current_episodes()[i].scene_id not in deif: deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)] else: deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id)) # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if ( step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size ): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0 ) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) # for k in deif: # deif[k] = list(set(deif[k])) # deif[k].sort() # print("deif: k", k, " : ", deif[k]) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) print('=' * 20 + 'Save Model' + '=' * 20) logger.info( "Save Model : {}".format(count_checkpoints) ) count_checkpoints += 1 self.envs.close()