def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training: self.belief_predictor.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) if ppo_cfg.use_belief_predictor: logger.info( "belief predictor number of trainable parameters: {}". format( sum(param.numel() for param in self.belief_predictor.parameters() if param.requires_grad))) logger.info(f"config: {self.config}") observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if ppo_cfg.use_external_memory: memory_dim = self.actor_critic.net.memory_dim else: memory_dim = None rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.action_space, ppo_cfg.hidden_size, ppo_cfg.use_external_memory, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, memory_dim, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, None) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # Try to resume at previous checkpoint (independent of interrupted states) count_steps_start, count_checkpoints, start_update = self.try_to_resume_checkpoint( ) count_steps = count_steps_start interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.load_state_dict( interrupted_state["belief_predictor"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) state_dict = dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) if self.config.RL.PPO.use_belief_predictor: state_dict[ 'belief_predictor'] = self.belief_predictor.state_dict( ) save_interrupted_state(state_dict) requeue_job() return count_steps_delta = 0 self.agent.eval() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.train() self.belief_predictor.set_eval_encoders() if self._static_smt_encoder: self.actor_critic.net.set_eval_encoders() if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training: location_predictor_loss, prediction_accuracy = self.train_belief_predictor( rollouts) else: location_predictor_loss = 0 prediction_accuracy = 0 ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ value_loss, action_loss, dist_entropy, location_predictor_loss, prediction_accuracy, count_steps_delta ], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[5].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, stats[2].item() / self.world_size, stats[3].item() / self.world_size, stats[4].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("Metrics/reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: for metric, value in metrics.items(): writer.add_scalar(f"Metrics/{metric}", value, count_steps) writer.add_scalar("Policy/value_loss", losses[0], count_steps) writer.add_scalar("Policy/policy_loss", losses[1], count_steps) writer.add_scalar("Policy/entropy_loss", losses[2], count_steps) writer.add_scalar("Policy/predictor_loss", losses[3], count_steps) writer.add_scalar("Policy/predictor_accuracy", losses[4], count_steps) writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, (count_steps - count_steps_start) / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def init_rpc( name, backend=BackendType.PROCESS_GROUP, rank=-1, world_size=None, rpc_backend_options=None, ): r""" Initializes RPC primitives such as the local RPC agent and distributed autograd, which immediately makes the current process ready to send and receive RPCs. Arguments: backend (BackendType, optional): The type of RPC backend implementation. Supported values include ``BackendType.PROCESS_GROUP`` (the default) and ``BackendType.TENSORPIPE``. See :ref:`rpc-backends` for more information. name (str): a globally unique name of this node. (e.g., ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) Name can only contain number, alphabet, underscore, colon, and/or dash, and must be shorter than 128 characters. rank (int): a globally unique id/rank of this node. world_size (int): The number of workers in the group. rpc_backend_options (RpcBackendOptions, optional): The options passed to the RpcAgent constructor. It must be an agent-specific subclass of :class:`~torch.distributed.rpc.RpcBackendOptions` and contains agent-specific initialization configurations. By default, for all agents, it sets the default timeout to 60 seconds and performs the rendezvous with an underlying process group initialized using ``init_method = "env://"``, meaning that environment variables ``MASTER_ADDR`` and ``MASTER_PORT`` need to be set properly. See :ref:`rpc-backends` for more information and find which options are available. """ if not rpc_backend_options: # default construct a set of RPC backend options. rpc_backend_options = backend_registry.construct_rpc_backend_options( backend) # Rendezvous. # This rendezvous state sometimes is destroyed before all processes # finishing handshaking. To avoid that issue, we make it global to # keep it alive. global rendezvous_iterator rendezvous_iterator = torch.distributed.rendezvous( rpc_backend_options.init_method, rank=rank, world_size=world_size) store, _, _ = next(rendezvous_iterator) # Use a PrefixStore to distinguish multiple invocations. with _init_counter_lock: global _init_counter store = dist.PrefixStore( str('rpc_prefix_{}'.format(_init_counter)), store) _init_counter += 1 # Initialize autograd before RPC since _init_rpc_backend guarantees all # processes sync via the store. If we initialize autograd after RPC, # there could be a race where some nodes might have initialized autograd # and others might not have. As a result, a node calling # torch.distributed.autograd.backward() would run into errors since # other nodes might not have been initialized. dist_autograd._init(rank) _set_profiler_node_id(rank) # Initialize RPC. api._init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
def _create_store(self): return c10d.PrefixStore(self.prefix, self.tcpstore)
def _create_store(self): return dist.PrefixStore(self.prefix, self.filestore)
def init_rpc( name, backend=None, rank=-1, world_size=None, rpc_backend_options=None, ): r""" Initializes RPC primitives such as the local RPC agent and distributed autograd, which immediately makes the current process ready to send and receive RPCs. Args: backend (BackendType, optional): The type of RPC backend implementation. Supported values include ``BackendType.TENSORPIPE`` (the default) and ``BackendType.PROCESS_GROUP``. See :ref:`rpc-backends` for more information. name (str): a globally unique name of this node. (e.g., ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) Name can only contain number, alphabet, underscore, colon, and/or dash, and must be shorter than 128 characters. rank (int): a globally unique id/rank of this node. world_size (int): The number of workers in the group. rpc_backend_options (RpcBackendOptions, optional): The options passed to the RpcAgent constructor. It must be an agent-specific subclass of :class:`~torch.distributed.rpc.RpcBackendOptions` and contains agent-specific initialization configurations. By default, for all agents, it sets the default timeout to 60 seconds and performs the rendezvous with an underlying process group initialized using ``init_method = "env://"``, meaning that environment variables ``MASTER_ADDR`` and ``MASTER_PORT`` need to be set properly. See :ref:`rpc-backends` for more information and find which options are available. """ if backend is not None and not isinstance( backend, backend_registry.BackendType): raise TypeError("Argument backend must be a member of BackendType") if rpc_backend_options is not None and not isinstance( rpc_backend_options, RpcBackendOptions): raise TypeError( "Argument rpc_backend_options must be an instance of RpcBackendOptions" ) # To avoid breaking users that passed a ProcessGroupRpcBackendOptions # without specifying the backend as PROCESS_GROUP when that was the # default, we try to detect the backend from the options when only the # latter is passed. if backend is None and rpc_backend_options is not None: for candidate_backend in BackendType: if isinstance( rpc_backend_options, type( backend_registry.construct_rpc_backend_options( candidate_backend)), ): backend = candidate_backend break else: raise TypeError( f"Could not infer backend for options {rpc_backend_options}" ) # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] logger.warning( f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] f"corresponding to {backend}, hence that backend will be used " f"instead of the default {BackendType.TENSORPIPE}. To silence this " f"warning pass `backend={backend}` explicitly.") if backend is None: backend = BackendType.TENSORPIPE # type: ignore[attr-defined] if backend == BackendType.PROCESS_GROUP: # type: ignore[attr-defined] logger.warning( "RPC was initialized with the PROCESS_GROUP backend which is " "deprecated and slated to be removed and superseded by the TENSORPIPE " "backend. It is recommended to migrate to the TENSORPIPE backend." ) if rpc_backend_options is None: # default construct a set of RPC backend options. rpc_backend_options = backend_registry.construct_rpc_backend_options( backend) # Rendezvous. # This rendezvous state sometimes is destroyed before all processes # finishing handshaking. To avoid that issue, we make it global to # keep it alive. global rendezvous_iterator rendezvous_iterator = torch.distributed.rendezvous( rpc_backend_options.init_method, rank=rank, world_size=world_size) store, _, _ = next(rendezvous_iterator) # Use a PrefixStore to distinguish multiple invocations. with _init_counter_lock: global _init_counter store = dist.PrefixStore( str('rpc_prefix_{}'.format(_init_counter)), store) _init_counter += 1 # Initialize autograd before RPC since _init_rpc_backend guarantees all # processes sync via the store. If we initialize autograd after RPC, # there could be a race where some nodes might have initialized autograd # and others might not have. As a result, a node calling # torch.distributed.autograd.backward() would run into errors since # other nodes might not have been initialized. dist_autograd._init(rank) _set_profiler_node_id(rank) # Initialize RPC. _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.freeze() if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks( aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), # including bonus ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 prev_time = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_updates = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() ( delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0).to(self.device) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ dist_entropy, aux_dist_entropy, ] + [value_loss, action_loss] + aux_task_losses + [count_steps_delta], device=self.device, ) distrib.all_reduce(stats) if aux_weights is not None and len(aux_weights) > 0: distrib.all_reduce( torch.tensor(aux_weights, device=self.device)) count_steps += stats[-1].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") avg_stats = [ stats[i].item() / self.world_size for i in range(len(stats) - 1) ] losses = avg_stats[2:] dist_entropy, aux_dist_entropy = avg_stats[:2] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) # Log stats formatted_aux_losses = [ "{:.3g}".format(l) for l in aux_task_losses ] if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t" .format( update, value_loss, action_loss, formatted_aux_losses, aux_dist_entropy, )) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ import apex self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # add_signal_handlers() self.timing = Timing() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() set_cpus(self.local_rank, self.world_size) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += self.world_rank * self.config.SIM_BATCH_SIZE self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") double_buffered = False self._num_worker_groups = self.config.NUM_PARALLEL_SCENES self._depth = self.config.DEPTH self._color = self.config.COLOR if self.config.TASK.lower() == "pointnav": self.observation_space = SpaceDict( { "pointgoal_with_gps_compass": spaces.Box( low=0.0, high=1.0, shape=(2,), dtype=np.float32 ) } ) else: self.observation_space = SpaceDict({}) self.action_space = spaces.Discrete(4) if self._color: self.observation_space = SpaceDict( { "rgb": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(3, *self.config.RESOLUTION), dtype=np.uint8, ), **self.observation_space.spaces, } ) if self._depth: self.observation_space = SpaceDict( { "depth": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(1, *self.config.RESOLUTION), dtype=np.float32, ), **self.observation_space.spaces, } ) ppo_cfg = self.config.RL.PPO if not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0: os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.count_steps = 0 burn_steps = 0 burn_time = 0 count_checkpoints = 0 prev_time = 0 self.update = 0 LR_SCALE = ( max( np.sqrt( ppo_cfg.num_steps * self.config.SIM_BATCH_SIZE * ppo_cfg.num_accumulate_steps / ppo_cfg.num_mini_batch * self.world_size / (128 * 2) ), 1.0, ) if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale) else 1.0 ) def cosine_decay(x): if x < 1: return (np.cos(x * np.pi) + 1.0) / 2.0 else: return 0.0 def warmup_fn(x): return LR_SCALE * (0.5 + 0.5 * x) def decay_fn(x): return LR_SCALE * (DECAY_TARGET + (1 - DECAY_TARGET) * cosine_decay(x)) DECAY_TARGET = ( 0.01 / LR_SCALE if self.config.RL.PPO.ada_scale or True else (0.25 / LR_SCALE if self.config.RL.DDPPO.scale_lr else 1.0) ) DECAY_PERCENT = 1.0 if self.config.RL.PPO.ada_scale or True else 0.5 WARMUP_PERCENT = ( 0.01 if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale) else 0.0 ) def lr_fn(): x = self.percent_done() if x < WARMUP_PERCENT: return warmup_fn(x / WARMUP_PERCENT) else: return decay_fn((x - WARMUP_PERCENT) / DECAY_PERCENT) lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: lr_fn() ) interrupted_state = load_interrupted_state(resume_from=self.resume_from) if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.init_amp(self.config.SIM_BATCH_SIZE) self.actor_critic.init_trt(self.config.SIM_BATCH_SIZE) self.actor_critic.script_net() self.agent.init_distributed(find_unused_params=False) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder self.observation_space = SpaceDict( { "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **self.observation_space, } ) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) nenvs = self.config.SIM_BATCH_SIZE rollouts = DoubleBufferedRolloutStorage( ppo_cfg.num_steps, nenvs, self.observation_space, self.action_space, ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.num_recurrent_layers, use_data_aug=ppo_cfg.use_data_aug, aug_type=ppo_cfg.aug_type, double_buffered=double_buffered, vtrace=ppo_cfg.vtrace, ) rollouts.to(self.device) rollouts.to_fp16() self._warmup(rollouts) ( self.envs, self._observations, self._rewards, self._masks, self._rollout_infos, self._syncs, ) = construct_envs( self.config, num_worker_groups=self.config.NUM_PARALLEL_SCENES, double_buffered=double_buffered, ) def _setup_render_and_populate_initial_frame(): for idx in range(2 if double_buffered else 1): self.envs.reset(idx) batch = self._observations[idx] self._syncs[idx].wait() tree_copy_in_place( tree_select(0, rollouts[idx].storage_buffers["observations"]), batch, ) _setup_render_and_populate_initial_frame() current_episode_reward = torch.zeros(nenvs, 1) running_episode_stats = dict( count=torch.zeros(nenvs, 1,), reward=torch.zeros(nenvs, 1,), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) time_per_frame_window = deque(maxlen=ppo_cfg.reward_window_size) buffer_ranges = [] for i in range(2 if double_buffered else 1): start_ind = buffer_ranges[-1].stop if i > 0 else 0 buffer_ranges.append( slice( start_ind, start_ind + self.config.SIM_BATCH_SIZE // (2 if double_buffered else 1), ) ) if interrupted_state is not None: requeue_stats = interrupted_state["requeue_stats"] self.count_steps = requeue_stats["count_steps"] self.update = requeue_stats["start_update"] count_checkpoints = requeue_stats["count_checkpoints"] prev_time = requeue_stats["prev_time"] burn_steps = requeue_stats["burn_steps"] burn_time = requeue_stats["burn_time"] self.agent.ada_scale.load_state_dict(interrupted_state["ada_scale_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) if "amp_state" in interrupted_state: apex.amp.load_state_dict(interrupted_state["amp_state"]) if "grad_scaler_state" in interrupted_state: self.agent.grad_scaler.load_state_dict( interrupted_state["grad_scaler_state"] ) with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs, purge_step=int(self.count_steps), ) if self.world_rank == 0 else contextlib.suppress() ) as writer: distrib.barrier() t_start = time.time() while not self.is_done(): t_rollout_start = time.time() if self.update == BURN_IN_UPDATES: burn_time = t_rollout_start - t_start burn_steps = self.count_steps if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( self.percent_done(), final_decay=ppo_cfg.decay_factor, ) if ( not BPS_BENCHMARK and (REQUEUE.is_set() or ((self.update + 1) % 100) == 0) and self.world_rank == 0 ): requeue_stats = dict( count_steps=self.count_steps, count_checkpoints=count_checkpoints, start_update=self.update, prev_time=(time.time() - t_start) + prev_time, burn_time=burn_time, burn_steps=burn_steps, ) def _cast(param): if "Half" in param.type(): param = param.to(dtype=torch.float32) return param save_interrupted_state( dict( state_dict={ k: _cast(v) for k, v in self.agent.state_dict().items() }, ada_scale_state=self.agent.ada_scale.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, grad_scaler_state=self.agent.grad_scaler.state_dict(), ) ) if EXIT.is_set(): self._observations = None self._rewards = None self._masks = None self._rollout_infos = None self._syncs = None del self.envs self.envs = None requeue_job() return self.agent.eval() count_steps_delta = self._n_buffered_sampling( rollouts, current_episode_reward, running_episode_stats, buffer_ranges, ppo_cfg.num_steps, num_rollouts_done_store, ) num_rollouts_done_store.add("num_done", 1) if not rollouts.vtrace: self._compute_returns(ppo_cfg, rollouts) (value_loss, action_loss, dist_entropy) = self._update_agent(rollouts) if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") lr_scheduler.step() with self.timing.add_time("Logging"): stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0, ).to(device=self.device) distrib.all_reduce(stats) stats = stats.to(device="cpu") for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i]) stats = torch.tensor( [ value_loss, action_loss, count_steps_delta, *self.envs.swap_stats, ], device=self.device, ) distrib.all_reduce(stats) stats = stats.to(device="cpu") count_steps_delta = int(stats[2].item()) self.count_steps += count_steps_delta time_per_frame_window.append( (time.time() - t_rollout_start) / count_steps_delta ) if self.world_rank == 0: losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], self.count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, self.count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, self.count_steps, ) optim = self.agent.optimizer writer.add_scalar( "optimizer/base_lr", optim.param_groups[-1]["lr"], self.count_steps, ) if "gain" in optim.param_groups[-1]: for idx, group in enumerate(optim.param_groups): writer.add_scalar( f"optimizer/lr_{idx}", group["lr"] * group["gain"], self.count_steps, ) writer.add_scalar( f"optimizer/gain_{idx}", group["gain"], self.count_steps, ) # log stats if ( self.update > 0 and self.update % self.config.LOG_INTERVAL == 0 ): logger.info( "update: {}\twindow fps: {:.3f}\ttotal fps: {:.3f}\tframes: {}".format( self.update, 1.0 / ( sum(time_per_frame_window) / len(time_per_frame_window) ), (self.count_steps - burn_steps) / ((time.time() - t_start) + prev_time - burn_time), self.count_steps, ) ) logger.info( "swap percent: {:.3f}\tscenes in use: {:.3f}\tenvs per scene: {:.3f}".format( stats[3].item() / self.world_size, stats[4].item() / self.world_size, stats[5].item() / self.world_size, ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) logger.info(self.timing) # self.envs.print_renderer_stats() # checkpoint model if self.should_checkpoint(): self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict( step=self.count_steps, wall_clock_time=( (time.time() - t_start) + prev_time ), ), ) count_checkpoints += 1 self.update += 1 self.save_checkpoint( "ckpt.done.pth", dict( step=self.count_steps, wall_clock_time=((time.time() - t_start) + prev_time), ), ) self._observations = None self._rewards = None self._masks = None self._rollout_infos = None self._syncs = None del self.envs self.envs = None
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.freeze() if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1, device=self.device) episode_counts = torch.zeros(self.envs.num_envs, 1, device=self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step( rollouts, current_episode_reward, episode_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats = torch.stack([episode_rewards, episode_counts], 0) distrib.all_reduce(stats) window_episode_reward.append(stats[0].clone()) window_episode_counts.append(stats[1].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for DD-PPO SLAM. Returns: None """ ##################################################################### ## init distrib and configuration ##################################################################### self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # self.local_rank = 1 add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore( "rollout_tracker", tcp_store ) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() # server number self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server self.config.SIMULATOR_GPU_ID = self.local_rank print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID) print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID) # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( self.world_rank * self.config.NUM_PROCESSES ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") ##################################################################### ## build distrib NavSLAMRLEnv environment ##################################################################### print("#############################################################") print("## build distrib NavSLAMRLEnv environment") print("#############################################################") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observations = self.envs.reset() print("*************************** observations len:", len(observations)) # semantic process for i in range(len(observations)): observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32) se = list(set(observations[i]["semantic"].ravel())) print(se) # print("*************************** observations type:", observations) # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23 # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # [] batch = batch_obs(observations, device=self.device) print("*************************** batch len:", len(batch)) # print("*************************** batch:", batch) # print("************************************* current_episodes:", (self.envs.current_episodes())) ##################################################################### ## init actor_critic agent ##################################################################### print("#############################################################") print("## init actor_critic agent") print("#############################################################") self.map_w = observations[0]["map_sum"].shape[0] self.map_h = observations[0]["map_sum"].shape[1] # print("map_: ", observations[0]["curr_pose"].shape) ppo_cfg = self.config.RL.PPO if ( not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0 ): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(observations, ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) ##################################################################### ## init Global Rollout Storage ##################################################################### print("#############################################################") print("## init Global Rollout Storage") print("#############################################################") self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step rollouts = GlobalRolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.g_action_space, ) rollouts.to(self.device) print('rollouts type:', type(rollouts)) print('--------------------------') # for k in rollouts.keys(): # print("rollouts: {0}".format(rollouts.observations)) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } _, actions, _, = self.actor_critic.act( step_observation, rollouts.prev_g_actions[0], rollouts.masks[0], ) self.global_goals = [[int(action[0].item() * self.map_w), int(action[1].item() * self.map_h)] for action in actions] # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device ) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) print("*************************** current_episode_reward:", current_episode_reward) print("*************************** running_episode_stats:", running_episode_stats) # print("*************************** window_episode_stats:", window_episode_stats) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth") interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] deif = {} with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if self.world_rank == 0 else contextlib.suppress() ) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) # print("************************************* current_episodes:", type(self.envs.count_episodes())) # print(EXIT.is_set()) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth" ) print("********************EXIT*********************") requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_global_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # print("************************************* current_episodes:") for i in range(len(self.envs.current_episodes())): # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category) if self.envs.current_episodes()[i].scene_id not in deif: deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)] else: deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id)) # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if ( step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size ): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0 ) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) # for k in deif: # deif[k] = list(set(deif[k])) # deif[k].sort() # print("deif: k", k, " : ", deif[k]) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) print('=' * 20 + 'Save Model' + '=' * 20) logger.info( "Save Model : {}".format(count_checkpoints) ) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.TASK_CONFIG.TASK.POINTGOAL_WITH_EGO_PREDICTION_SENSOR.MODEL.GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME), workers_ignore_signals=True, ) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), "prev_visual_features": spaces.Box( low=np.iinfo(np.uint32).min, high=np.iinfo(np.uint32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) batch["prev_visual_features"] = torch.zeros_like( batch["visual_features"]) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) rollouts.previous_observations[sensor][0].copy_( torch.zeros_like(batch[sensor])) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() self.actor_critic.update_ib_beta(count_steps) if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, aux_loss, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ value_loss, action_loss, aux_loss, AuxLosses.get_loss("egomotion_error"), AuxLosses.get_loss("information"), 0.1 * dist_entropy, count_steps_delta, ], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[-1].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[i].item() / self.world_size for i in range(stats.size(0) - 1) ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) print(deltas["reward"]) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", { k: l for l, k in zip(losses, [ "value", "policy", "aux", "egomotion_error", "information", "entropy" ]) }, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) count_checkpoints += 1 self.envs.close()