def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = torch.device("cuda", self.config.TORCH_GPU_ID) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) observations = self.envs.reset() batch = batch_obs(observations) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) rollouts.to(self.device) episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, episode_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) losses = [value_loss, action_loss] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) # Initialize auxiliary tasks observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = \ self._setup_auxiliary_tasks(aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] + aux_task_losses writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) writer.add_scalar( "success", deltas["success"] / deltas["count"], count_steps, ) # Log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {}\t action_loss: {}\taux_task_loss: {} \t aux_entropy {}" .format(update, value_loss, action_loss, aux_task_losses, aux_dist_entropy)) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self): # Get environments for training self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) #logger.info( # "agent number of parameters: {}".format( # sum(param.numel() for param in self.agent.parameters()) # ) #) # Change for the actual value cfg = self.config.RL.PPO rollouts = RolloutStorage( cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) for sensor in rollouts.observations: print(batch[sensor].shape) # Copy the information to the wrapper for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) #current_episode_reward = torch.zeros(self.envs.num_envs, 1) #window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) #window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) '''
def replay(self, num_updates, ppo_cfg, lr_scheduler, t_start, pth_time, writer, count_steps, count_checkpoints): print(".....start memory replay for {} updates.....".format(num_updates)) env_time = 0 window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) memories = self.memory.recall(num_updates) for update in range(num_updates): rollouts, episode_rewards, episode_counts = memories[update] if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent_memory(ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) losses = [value_loss, action_loss] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps ) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start) ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) window_rewards = ( window_episode_reward[-1] - window_episode_reward[0] ).sum() window_counts = ( window_episode_counts[-1] - window_episode_counts[0] ).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), ) ) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps) ) count_checkpoints += 1 count_steps += 1
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.freeze() if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks( aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), # including bonus ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 prev_time = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_updates = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() ( delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0).to(self.device) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ dist_entropy, aux_dist_entropy, ] + [value_loss, action_loss] + aux_task_losses + [count_steps_delta], device=self.device, ) distrib.all_reduce(stats) if aux_weights is not None and len(aux_weights) > 0: distrib.all_reduce( torch.tensor(aux_weights, device=self.device)) count_steps += stats[-1].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") avg_stats = [ stats[i].item() / self.world_size for i in range(len(stats) - 1) ] losses = avg_stats[2:] dist_entropy, aux_dist_entropy = avg_stats[:2] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) # Log stats formatted_aux_losses = [ "{:.3g}".format(l) for l in aux_task_losses ] if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t" .format( update, value_loss, action_loss, formatted_aux_losses, aux_dist_entropy, )) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def benchmark(self) -> None: if TIME_DEBUG: s = time.time() #self.config.defrost() #self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_EPISODES = 10 #self.config.freeze() if torch.cuda.device_count() <= 1: self.config.defrost() self.config.TORCH_GPU_ID = 0 self.config.SIMULATOR_GPU_ID = 0 self.config.freeze() self.envs = construct_envs(self.config, eval(self.config.ENV_NAME)) if ADD_IL: self.il_envs = construct_envs(self.config, eval(self.config.ENV_NAME), no_val=True) self.collect_mode = 'RL' if TIME_DEBUG: s = log_time(s, 'construct envs') ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) # if 'SMT' in self.config.POLICY: # sd = torch.load('visual_embedding18.pth') # self.actor_critic.net.visual_encoder.load_state_dict(sd['visual_encoder']) # self.actor_critic.net.prev_action_embedding.load_state_dict(sd['prev_action_embedding']) # self.actor_critic.net.visual_encoder.cuda() # self.actor_critic.net.prev_action_embedding.cuda() # self.envs.setup_embedding_network(self.actor_critic.net.visual_encoder, self.actor_critic.net.prev_action_embedding) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) num_train_processes, num_val_processes = self.config.NUM_PROCESSES, self.config.NUM_VAL_PROCESSES total_processes = num_train_processes + num_val_processes OBS_LIST = self.config.OBS_TO_SAVE self.num_processes = num_train_processes rollouts = RolloutStorage(ppo_cfg.num_steps, num_train_processes, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, self.actor_critic.net.num_recurrent_layers, OBS_LIST=OBS_LIST) rollouts.to(self.device) batch = self.envs.reset() for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_( batch[sensor][:num_train_processes]) self.last_observations = batch self.last_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device) self.last_prev_actions = torch.zeros( total_processes, rollouts.prev_actions.shape[-1]).to(self.device) self.last_masks = torch.zeros(total_processes, 1).to(self.device) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None if ADD_IL: rollouts2 = RolloutStorage( ppo_cfg.num_steps, num_train_processes, self.il_envs.observation_spaces[0], self.il_envs.action_spaces[0], ppo_cfg.hidden_size, self.actor_critic.net.num_recurrent_layers, OBS_LIST=OBS_LIST) rollouts2.to(self.device) batch2 = self.il_envs.reset() for sensor in rollouts2.observations: rollouts2.observations[sensor][0].copy_( batch2[sensor][:num_train_processes]) self.saved_last_obs = batch2 self.saved_last_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device) self.saved_last_prev_actions = torch.zeros( total_processes, rollouts2.prev_actions.shape[-1]).to(self.device) self.saved_last_masks = torch.zeros(total_processes, 1).to(self.device) batch2 = None else: rollouts2 = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) if TIME_DEBUG: s = log_time(s, 'setup all') for update in range(100): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if TIME_DEBUG: s = log_time(s, 'collect rollout start') modes = ['RL'] if ADD_IL: modes += ['IL'] for collect_mode in modes: self.collect_mode = collect_mode use_rollouts = rollouts if self.collect_mode == 'RL' else rollouts2 if ADD_IL: self.exchange_lasts() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(use_rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps #print(delta_env_time, delta_pth_time) if TIME_DEBUG: s = log_time(s, 'collect rollout done') (delta_pth_time, value_loss, action_loss, dist_entropy, il_loss) = self._update_agent(ppo_cfg, rollouts, rollouts2) #print(delta_pth_time) pth_time += delta_pth_time use_rollouts.after_update() if TIME_DEBUG: s = log_time(s, 'update agent') for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1][:self.num_processes] - v[0][:self.num_processes]).sum().item() if len(v) > 1 else v[0][:self.num_processes].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) #self.write_tb('train', writer, deltas, count_steps, losses) eval_deltas = { k: ((v[-1][self.num_processes:] - v[0][self.num_processes:]).sum().item() if len(v) > 1 else v[0][self.num_processes:].sum().item()) for k, v in window_episode_stats.items() } eval_deltas["count"] = max(eval_deltas["count"], 1.0) #self.write_tb('val', writer, eval_deltas, count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) logger.info("validation metrics: {}".format( " ".join("{}: {:.3f}".format(k, v / eval_deltas["count"]) for k, v in eval_deltas.items() if k != "count"), )) self.envs.close()
def train(self) -> None: r"""Main method for DD-PPO SLAM. Returns: None """ ##################################################################### ## init distrib and configuration ##################################################################### self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # self.local_rank = 1 add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore( "rollout_tracker", tcp_store ) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() # server number self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server self.config.SIMULATOR_GPU_ID = self.local_rank print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID) print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID) # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( self.world_rank * self.config.NUM_PROCESSES ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") ##################################################################### ## build distrib NavSLAMRLEnv environment ##################################################################### print("#############################################################") print("## build distrib NavSLAMRLEnv environment") print("#############################################################") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observations = self.envs.reset() print("*************************** observations len:", len(observations)) # semantic process for i in range(len(observations)): observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32) se = list(set(observations[i]["semantic"].ravel())) print(se) # print("*************************** observations type:", observations) # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23 # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # [] batch = batch_obs(observations, device=self.device) print("*************************** batch len:", len(batch)) # print("*************************** batch:", batch) # print("************************************* current_episodes:", (self.envs.current_episodes())) ##################################################################### ## init actor_critic agent ##################################################################### print("#############################################################") print("## init actor_critic agent") print("#############################################################") self.map_w = observations[0]["map_sum"].shape[0] self.map_h = observations[0]["map_sum"].shape[1] # print("map_: ", observations[0]["curr_pose"].shape) ppo_cfg = self.config.RL.PPO if ( not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0 ): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(observations, ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) ##################################################################### ## init Global Rollout Storage ##################################################################### print("#############################################################") print("## init Global Rollout Storage") print("#############################################################") self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step rollouts = GlobalRolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.g_action_space, ) rollouts.to(self.device) print('rollouts type:', type(rollouts)) print('--------------------------') # for k in rollouts.keys(): # print("rollouts: {0}".format(rollouts.observations)) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } _, actions, _, = self.actor_critic.act( step_observation, rollouts.prev_g_actions[0], rollouts.masks[0], ) self.global_goals = [[int(action[0].item() * self.map_w), int(action[1].item() * self.map_h)] for action in actions] # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device ) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) print("*************************** current_episode_reward:", current_episode_reward) print("*************************** running_episode_stats:", running_episode_stats) # print("*************************** window_episode_stats:", window_episode_stats) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth") interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] deif = {} with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if self.world_rank == 0 else contextlib.suppress() ) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) # print("************************************* current_episodes:", type(self.envs.count_episodes())) # print(EXIT.is_set()) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth" ) print("********************EXIT*********************") requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_global_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # print("************************************* current_episodes:") for i in range(len(self.envs.current_episodes())): # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category) if self.envs.current_episodes()[i].scene_id not in deif: deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)] else: deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id)) # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if ( step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size ): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0 ) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) # for k in deif: # deif[k] = list(set(deif[k])) # deif[k].sort() # print("deif: k", k, " : ", deif[k]) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) print('=' * 20 + 'Save Model' + '=' * 20) logger.info( "Save Model : {}".format(count_checkpoints) ) count_checkpoints += 1 self.envs.close()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observation_space = self.envs.observation_spaces[0] ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK aux_cfg = self.config.RL.AUX_TASKS self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_dqn_agent(ppo_cfg, task_cfg, aux_cfg, []) self.dataset = RolloutDataset() self.dataloader = DataLoader(self.dataset, batch_size=16, num_workers=0) # Use environment to initialize the metadata for training the model self.envs.close() if self.config.RESUME_CURIOUS: weights = torch.load(self.config.RESUME_CURIOUS)['state_dict'] state_dict = self.q_network.state_dict() weights_new = {} for k, v in weights.items(): if "model_encoder" in k: k = k.replace("model_encoder", "visual_resnet").replace("actor_critic.", "") if k in state_dict: weights_new[k] = v state_dict.update(weights_new) self.q_network.load_state_dict(state_dict) logger.info( "agent number of parameters: {}".format( sum(param.numel() for param in self.q_network.parameters()) ) ) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.q_network.load_state_dict(ckpt_dict["state_dict"]) self.q_network_target.load_state_dict(ckpt_dict["target_state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) im_size = 256 with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: update = 0 for i in range(self.config.NUM_EPOCHS): for im, pointgoal, action, mask, reward in self.dataloader: if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() im, pointgoal, action, mask, reward = collate(im), collate(pointgoal), collate(action), collate(mask), collate(reward) im = im.to(self.device).float() pointgoal = pointgoal.to(self.device).float() mask = mask.to(self.device).float() reward = reward.to(self.device).float() action = action.to(self.device).long() nstep = im.size(1) hidden_states = None hidden_states_target = None # q_vals = [] # q_vals_target = [] step = random.randint(0, nstep-1) output = self.q_network({'rgb': im[:, step]}, None, None) mse_loss = torch.pow(output - im[:, step] / 255., 2).mean() mse_loss.backward() # for step in range(nstep): # q_val, hidden_states = self.q_network({'rgb': im[:, step], 'pointgoal_with_gps_compass': pointgoal[:, step]}, hidden_states, mask[:, step]) # q_val_target, hidden_states_target = self.q_network_target({'rgb': im[:, step], 'pointgoal_with_gps_compass': pointgoal[:, step]}, hidden_states_target, mask[:, step]) # q_vals.append(q_val) # q_vals_target.append(q_val_target) # q_vals = torch.stack(q_vals, dim=1) # q_vals_target = torch.stack(q_vals_target, dim=1) # a_select = torch.argmax(q_vals, dim=-1, keepdim=True) # target_select = torch.gather(q_vals_target, -1, a_select) # target = reward + ppo_cfg.gamma * target_select[:, 1:] * mask[:, 1:] # target = target.detach() # pred_q = torch.gather(q_vals[:, :-1], -1, action) # mse_loss = torch.pow(pred_q - target, 2).mean() # mse_loss.backward() # grad_norm = torch.nn.utils.clip_grad_norm(self.q_network.parameters(), 80) self.optimizer.step() self.optimizer.zero_grad() writer.add_scalar( "loss", mse_loss, update, ) # writer.add_scalar( # "q_val", # q_vals.max(), # update, # ) if update % 10 == 0: print("Update: {}, loss: {}".format(update, mse_loss)) if update % 100 == 0: self.sync_model() # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps) ) count_checkpoints += 1 update = update + 1
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.add_new_based_on_cfg() self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg, train=True) if self.config.PRETRAINED_CHECKPOINT_PATH: ckpt_dict = self.load_checkpoint( self.config.PRETRAINED_CHECKPOINT_PATH, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"], strict=False) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs_augment_aux(observations, self.envs.get_shared_mem()) for sensor in rollouts.observations: if sensor in batch: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None info_data_keys = ["discovered", "collisions_wall", "collisions_prox"] log_data_keys = [ "episode_rewards", "episode_go_rewards", "episode_counts", "current_episode_reward", "current_episode_go_reward" ] + info_data_keys log_data = dict( {k: torch.zeros(self.envs.num_envs, 1) for k in log_data_keys}) info_data = dict({k: log_data[k] for k in info_data_keys}) win_keys = log_data_keys win_keys.pop(win_keys.index("current_episode_reward")) win_keys.pop(win_keys.index("current_episode_go_reward")) windows = dict({ k: deque(maxlen=ppo_cfg.reward_window_size) for k in log_data.keys() }) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) train_steps = min(self.config.NUM_UPDATES, self.config.HARD_NUM_UPDATES) log_interval = self.config.LOG_INTERVAL num_updates = self.config.NUM_UPDATES agent = self.agent ckpt_interval = self.config.CHECKPOINT_INTERVAL with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(train_steps): if ppo_cfg.use_linear_clip_decay: agent.clip_param = ppo_cfg.clip_param * linear_decay( update, num_updates) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, log_data["current_episode_reward"], log_data["current_episode_go_reward"], log_data["episode_rewards"], log_data["episode_go_rewards"], log_data["episode_counts"], info_data) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy,\ aux_loss = self._update_agent(ppo_cfg, rollouts) # TODO check if LR is init if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() pth_time += delta_pth_time # ================================================================================== # -- Log data for window averaging for k, v in windows.items(): windows[k].append(log_data[k].clone()) value_names = ["value", "policy", "entropy"] + list( aux_loss.keys()) losses = [value_loss, action_loss, dist_entropy] + list( aux_loss.values()) stats = zip(list(windows.keys()), list(windows.values())) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } act_ep = deltas["episode_counts"] counts = max(act_ep, 1.0) deltas["episode_counts"] *= counts for k, v in deltas.items(): deltas[k] = v / counts writer.add_scalar(k, deltas[k], count_steps) writer.add_scalars("losses", {k: l for l, k in zip(losses, value_names)}, count_steps) # log stats if update > 0 and update % log_interval == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) if act_ep > 0: log_txt = f"Average window size {len(windows['episode_counts'])}" for k, v in deltas.items(): log_txt += f" | {k}: {v:.3f}" logger.info(log_txt) logger.info( f"Aux losses: {list(zip(value_names, losses))}") else: logger.info("No episodes finish in current window") # ================================================================================== # checkpoint model if update % ckpt_interval == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ if TIME_DEBUG: s = time.time() self.envs = construct_envs( self.config, eval(self.config.ENV_NAME) ) if TIME_DEBUG: s = log_time(s, 'construct envs') ppo_cfg = self.config.RL.PPO self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) if 'SMT' in self.config.POLICY: sd = torch.load('visual_embedding18_explore.pth') if 'Explore' in self.config.POLICY else torch.load('visual_embedding18.pth') self.actor_critic.net.visual_encoder.load_state_dict(sd['visual_encoder']) self.actor_critic.net.prev_action_embedding.load_state_dict(sd['prev_action_embedding']) self.actor_critic.net.visual_encoder.cuda() self.actor_critic.net.prev_action_embedding.cuda() total_num = self.config.NUM_PROCESSES + self.config.NUM_VAL_PROCESSES args_list = {'visual_encoder': self.actor_critic.net.visual_encoder, 'prev_action_embedding': self.actor_critic.net.prev_action_embedding} self.envs.call(['setup_embedding_network']*total_num, [args_list]*total_num) logger.info( "agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()) ) ) num_train_processes, num_val_processes = self.config.NUM_PROCESSES, self.config.NUM_VAL_PROCESSES total_processes = num_train_processes + num_val_processes OBS_LIST = self.config.OBS_TO_SAVE self.num_processes = num_train_processes rollouts = RolloutStorage( ppo_cfg.num_steps, num_train_processes, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, self.actor_critic.net.num_recurrent_layers, OBS_LIST = OBS_LIST ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor][:num_train_processes]) self.last_observations = batch self.last_recurrent_hidden_states = torch.zeros(self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device) self.last_prev_actions = torch.zeros(total_processes, rollouts.prev_actions.shape[-1]).to(self.device) self.last_masks = torch.zeros(total_processes,1).to(self.device) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps start_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) if TIME_DEBUG: s = log_time(s, 'setup all') with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) if TIME_DEBUG: s = log_time(s, 'collect rollout start') for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps if TIME_DEBUG: s = log_time(s, 'collect rollout done') ( delta_pth_time, value_loss, action_loss, dist_entropy, il_loss ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time rollouts.after_update() if TIME_DEBUG: s = log_time(s, 'update agent') for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ( (v[-1][:self.num_processes] - v[0][:self.num_processes]).sum().item() if len(v) > 1 else v[0][:self.num_processes].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) losses = [value_loss, action_loss, dist_entropy, il_loss] self.write_tb('train', writer, deltas, count_steps, losses) eval_deltas = { k: ( (v[-1][self.num_processes:] - v[0][self.num_processes:]).sum().item() if len(v) > 1 else v[0][self.num_processes:].sum().item() ) for k, v in window_episode_stats.items() } eval_deltas["count"] = max(eval_deltas["count"], 1.0) self.write_tb('val', writer, eval_deltas, count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, (count_steps - start_steps) / (time.time() - t_start) ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) logger.info( "validation metrics: {}".format( " ".join( "{}: {:.3f}".format(k, v / eval_deltas["count"]) for k, v in eval_deltas.items() if k != "count" ), ) ) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps) ) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.add_new_based_on_cfg() self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg, train=True) if self.config.PRETRAINED_CHECKPOINT_PATH: ckpt_dict = self.load_checkpoint( self.config.PRETRAINED_CHECKPOINT_PATH, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"], strict=False) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs_augment_aux(observations) for sensor in rollouts.observations: if sensor in batch: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_go_rewards = torch.zeros(self.envs.num_envs, 1) # Grid oracle rewars episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) current_episode_go_reward = torch.zeros(self.envs.num_envs, 1) # Grid oracle rewars window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_go_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) train_steps = min(self.config.NUM_UPDATES, self.config.HARD_NUM_UPDATES) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(train_steps): if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, current_episode_go_reward, episode_rewards, episode_go_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy,\ aux_loss = self._update_agent(ppo_cfg, rollouts) # TODO check if LR is init if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_go_reward.append(episode_go_rewards.clone()) window_episode_counts.append(episode_counts.clone()) value_names = ["value", "policy", "entropy"] + list( aux_loss.keys()) losses = [value_loss, action_loss, dist_entropy] + list( aux_loss.values()) stats = zip( ["count", "reward", "reward_go"], [ window_episode_counts, window_episode_reward, window_episode_go_reward ], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalar("reward_go", deltas["reward_go"] / deltas["count"], count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, value_names)}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_go_rewards = (window_episode_go_reward[-1] - window_episode_go_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f} reward_go: {:3f}" .format( len(window_episode_reward), (window_rewards / window_counts).item(), (window_go_rewards / window_counts).item(), )) logger.info( f"Aux losses: {list(zip(value_names, losses))}") else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(self) -> None: #### init obs = self.envs.reset() #### 先暫存一個 obs batch = batch_obs(obs, device=self.device) for sensor in self.rollout.observations: self.rollout.observations[sensor][0].copy_(batch[sensor]) #### Para lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) #### PPO LOG PARA count_steps = 0 current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=self.config.RL.PPO.reward_window_size)) #### 開始訓練 with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for epoch in range(self.config.NUM_UPDATES): #### decay if self.config.RL.PPO.use_linear_lr_decay: lr_scheduler.step() if (epoch + 1) % self.config.CHECKPOINT_INTERVAL == 0: self.agent.entropy_coef = self.agent.entropy_coef * 0.9 print(self.agent.entropy_coef) #### 蒐集rollout print("=== collect rollout ===") for step in range(self.config.RL.PPO.num_steps): self._collect_rollout_step(current_episode_reward, running_episode_stats) count_steps += self.envs.num_envs #### 更新 loss, loss_auxiliary = self._update_agent() #### LOGGER writer.add_scalars("loss", loss, epoch * self.envs.num_envs) writer.add_scalars("loss_auxiliary", loss_auxiliary, epoch * self.envs.num_envs) #### PPO LOG PARA for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) #### PPO LOG PARA if epoch % (1) == 0: print("count_steps:", count_steps) print("deltas:", deltas) print("metrics:", metrics) if epoch % self.config.CHECKPOINT_INTERVAL == 0: self._save_checkpoint(f"checkpoint.{epoch}.pth") self.envs.close()