def act(self, observations): batch = batch_obs([observations], device=self.device) with torch.no_grad(): _, action, _, self.test_recurrent_hidden_states = self.actor_critic.act( batch, self.test_recurrent_hidden_states, self.prev_actions, self.not_done_masks, deterministic=False, ) # Make masks not done till reset (end of episode) will be called self.not_done_masks.fill_(1.0) self.prev_actions.copy_(action) return action.item()
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training: self.belief_predictor.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) if ppo_cfg.use_belief_predictor: logger.info( "belief predictor number of trainable parameters: {}". format( sum(param.numel() for param in self.belief_predictor.parameters() if param.requires_grad))) logger.info(f"config: {self.config}") observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if ppo_cfg.use_external_memory: memory_dim = self.actor_critic.net.memory_dim else: memory_dim = None rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.action_space, ppo_cfg.hidden_size, ppo_cfg.use_external_memory, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, memory_dim, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, None) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # Try to resume at previous checkpoint (independent of interrupted states) count_steps_start, count_checkpoints, start_update = self.try_to_resume_checkpoint( ) count_steps = count_steps_start interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.load_state_dict( interrupted_state["belief_predictor"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) state_dict = dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) if self.config.RL.PPO.use_belief_predictor: state_dict[ 'belief_predictor'] = self.belief_predictor.state_dict( ) save_interrupted_state(state_dict) requeue_job() return count_steps_delta = 0 self.agent.eval() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.train() self.belief_predictor.set_eval_encoders() if self._static_smt_encoder: self.actor_critic.net.set_eval_encoders() if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training: location_predictor_loss, prediction_accuracy = self.train_belief_predictor( rollouts) else: location_predictor_loss = 0 prediction_accuracy = 0 ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ value_loss, action_loss, dist_entropy, location_predictor_loss, prediction_accuracy, count_steps_delta ], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[5].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, stats[2].item() / self.world_size, stats[3].item() / self.world_size, stats[4].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("Metrics/reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: for metric, value in metrics.items(): writer.add_scalar(f"Metrics/{metric}", value, count_steps) writer.add_scalar("Policy/value_loss", losses[0], count_steps) writer.add_scalar("Policy/policy_loss", losses[1], count_steps) writer.add_scalar("Policy/entropy_loss", losses[2], count_steps) writer.add_scalar("Policy/predictor_loss", losses[3], count_steps) writer.add_scalar("Policy/predictor_accuracy", losses[4], count_steps) writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, (count_steps - count_steps_start) / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def _eval_checkpoint(self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0) -> Dict: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ random.seed(self.config.SEED) np.random.seed(self.config.SEED) torch.manual_seed(self.config.SEED) # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT if self.config.DISPLAY_RESOLUTION != config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH: model_resolution = config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH = config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT = \ config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH = config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT = \ self.config.DISPLAY_RESOLUTION else: model_resolution = self.config.DISPLAY_RESOLUTION config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.TASK_CONFIG.TASK.SENSORS.append("AUDIOGOAL_SENSOR") config.freeze() elif "top_down_map" in self.config.VISUALIZATION_OPTION: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(config, get_env_class(config.ENV_NAME), auto_reset_done=False) if self.config.DISPLAY_RESOLUTION != model_resolution: observation_space = self.envs.observation_spaces[0] observation_space.spaces['depth'].shape = (model_resolution, model_resolution, 1) observation_space.spaces['rgb'].shape = (model_resolution, model_resolution, 3) else: observation_space = self.envs.observation_spaces[0] self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic self.metric_uuids = [] for metric_name in self.config.TASK_CONFIG.TASK.MEASUREMENTS: metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name) measure_type = baseline_registry.get_measure(metric_cfg.TYPE) assert measure_type is not None, "invalid measurement type {}".format( metric_cfg.TYPE) self.metric_uuids.append( measure_type(sim=None, task=None, config=None)._get_uuid()) observations = self.envs.reset() if self.config.DISPLAY_RESOLUTION != model_resolution: resize_observation(observations, model_resolution) batch = batch_obs(observations, self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) current_episode_reaching_waypoint = torch.zeros(self.envs.num_envs, 1, device=self.device) current_episode_cant_reach_waypoint = torch.zeros(self.envs.num_envs, 1, device=self.device) current_episode_step_count = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 2, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] audios = [[] for _ in range(self.config.NUM_PROCESSES)] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) t = tqdm(total=self.config.TEST_EPISODE_COUNT) while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): _, actions, _, test_recurrent_hidden_states, distributions = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=True) prev_actions.copy_(actions) outputs = self.envs.step([{ "action": a[0].item() } for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] if config.DISPLAY_RESOLUTION != model_resolution: resize_observation(observations, model_resolution) batch = batch_obs(observations, self.device) if len(self.config.VIDEO_OPTION) > 0: rgb_frames[0] += infos[0]['rgb_frames'] audios[0] += infos[0]['audios'] not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) logging.debug('Reward: {}'.format(rewards[0])) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards current_episode_step_count += 1 next_episodes = self.envs.current_episodes() n_envs = self.envs.num_envs envs_to_pause = [] for i in range(n_envs): # pause envs which runs out of episodes if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) logging.info( 'Pause env {} and remaining number of envs: {}'.format( i, self.envs.num_envs)) current_episode_reaching_waypoint[i] += infos[i][ 'reaching_waypoint'] current_episode_cant_reach_waypoint[i] += infos[i][ 'cant_reach_waypoint'] # episode ended if not_done_masks[i].item() == 0: episode_stats = dict() for metric_uuid in self.metric_uuids: episode_stats[metric_uuid] = infos[i][metric_uuid] episode_stats["reward"] = current_episode_reward[i].item() episode_stats['geodesic_distance'] = current_episodes[ i].info['geodesic_distance'] episode_stats['euclidean_distance'] = norm( np.array(current_episodes[i].goals[0].position) - np.array(current_episodes[i].start_position)) episode_stats["reaching_waypoint"] = current_episode_reaching_waypoint[i].item() / \ current_episode_step_count[i].item() episode_stats["cant_reach_waypoint"] = current_episode_cant_reach_waypoint[i].item() / \ current_episode_step_count[i].item() current_episode_reaching_waypoint[i] = 0 current_episode_cant_reach_waypoint[i] = 0 current_episode_step_count[i] = 0 current_episode_reward[i] = 0 logging.debug(episode_stats) # use scene_id + episode_id as unique id for storing stats stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, )] = episode_stats t.update() if len(self.config.VIDEO_OPTION) > 0: if self.config.VISUALIZE_FAILURE_ONLY and infos[i][ 'success'] > 0: pass else: fps = self.config.TASK_CONFIG.SIMULATOR.VIEW_CHANGE_FPS \ if self.config.TASK_CONFIG.SIMULATOR.CONTINUOUS_VIEW_CHANGE else 1 if 'sound' in current_episodes[i].info: sound = current_episodes[i].info['sound'] else: sound = current_episodes[i].sound_id.split( '/')[1][:-4] generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i][:-1], scene_name=current_episodes[i].scene_id.split( '/')[3], sound=sound, sr=self.config.TASK_CONFIG.SIMULATOR.AUDIO. RIR_SAMPLING_RATE, episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metric_name='spl', metric_value=infos[i]['spl'], tb_writer=writer, audios=audios[i][:-1], fps=fps) rgb_frames[i] = [] audios[i] = [] if "top_down_map" in self.config.VISUALIZATION_OPTION: top_down_map = plot_top_down_map(infos[i]) scene = current_episodes[i].scene_id.split('/')[-3] writer.add_image('{}_{}_{}/{}'.format( config.EVAL.SPLIT, scene, current_episodes[i].episode_id, config.BASE_TASK_CONFIG_PATH.split('/')[-1][:-5]), top_down_map, dataformats='WHC') ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = sum( [v[stat_key] for v in stats_episodes.values()]) num_episodes = len(stats_episodes) stats_file = os.path.join( config.TENSORBOARD_DIR, '{}_stats_{}.json'.format(config.EVAL.SPLIT, config.SEED)) new_stats_episodes = { ','.join(key): value for key, value in stats_episodes.items() } with open(stats_file, 'w') as fo: json.dump(new_stats_episodes, fo) episode_reward_mean = aggregated_stats["reward"] / num_episodes episode_reaching_waypoint_mean = aggregated_stats[ "reaching_waypoint"] / num_episodes episode_cant_reach_waypoint_mean = aggregated_stats[ "cant_reach_waypoint"] / num_episodes episode_metrics_mean = {} for metric_uuid in self.metric_uuids: episode_metrics_mean[ metric_uuid] = aggregated_stats[metric_uuid] / num_episodes logger.info(f"Average episode reward: {episode_reward_mean:.6f}") logger.info( f"Average episode reaching_waypoint: {episode_reaching_waypoint_mean:.6f}" ) logger.info( f"Average episode cant_reach_waypoint: {episode_cant_reach_waypoint_mean:.6f}" ) for metric_uuid in self.metric_uuids: logger.info( f"Average episode {metric_uuid}: {episode_metrics_mean[metric_uuid]:.6f}" ) if not config.EVAL.SPLIT.startswith('test'): writer.add_scalar("{}/reward".format(config.EVAL.SPLIT), episode_reward_mean, checkpoint_index) writer.add_scalar("{}/reaching_waypoint".format(config.EVAL.SPLIT), episode_reaching_waypoint_mean, checkpoint_index) writer.add_scalar( "{}/cant_reach_waypoint".format(config.EVAL.SPLIT), episode_cant_reach_waypoint_mean, checkpoint_index) for metric_uuid in self.metric_uuids: writer.add_scalar(f"{config.EVAL.SPLIT}/{metric_uuid}", episode_metrics_mean[metric_uuid], checkpoint_index) self.envs.close() result = { 'episode_reward_mean': episode_reward_mean, 'episode_reaching_waypoint_mean': episode_reaching_waypoint_mean, 'episode_cant_reach_waypoint_mean': episode_cant_reach_waypoint_mean } for metric_uuid in self.metric_uuids: result['episode_{}_mean'.format( metric_uuid)] = episode_metrics_mean[metric_uuid] return result
def train(self) -> None: r"""Main method for training PPO. Returns: None """ global lr_lambda logger.info(f"config: {self.config}") random.seed(self.config.SEED) np.random.seed(self.config.SEED) torch.manual_seed(self.config.SEED) self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME), auto_reset_done=False) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage(ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None # episode_rewards and episode_counts accumulates over the entire training course episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_spls = torch.zeros(self.envs.num_envs, 1) episode_steps = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) episode_distances = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) current_episode_step = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_spl = deque(maxlen=ppo_cfg.reward_window_size) window_episode_step = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) window_episode_distances = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 if ppo_cfg.use_linear_lr_decay: def lr_lambda(x): return linear_decay(x, self.config.NUM_UPDATES) elif ppo_cfg.use_exponential_lr_decay: def lr_lambda(x): return exponential_decay(x, self.config.NUM_UPDATES, ppo_cfg.exp_decay_lambda) else: def lr_lambda(x): return 1 lr_scheduler = LambdaLR(optimizer=self.agent.optimizer, lr_lambda=lr_lambda) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay or ppo_cfg.use_exponential_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, current_episode_step, episode_rewards, episode_spls, episode_counts, episode_steps, episode_distances) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_spl.append(episode_spls.clone()) window_episode_step.append(episode_steps.clone()) window_episode_counts.append(episode_counts.clone()) window_episode_distances.append(episode_distances.clone()) losses = [value_loss, action_loss, dist_entropy] stats = zip( ["count", "reward", "step", 'spl', 'distance'], [ window_episode_counts, window_episode_reward, window_episode_step, window_episode_spl, window_episode_distances ], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) # this reward is averaged over all the episodes happened during window_size updates # approximately number of steps is window_size * num_steps writer.add_scalar("Environment/Reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalar("Environment/SPL", deltas["spl"] / deltas["count"], count_steps) logging.debug('Number of steps: {}'.format(deltas["step"] / deltas["count"])) writer.add_scalar("Environment/Episode_length", deltas["step"] / deltas["count"], count_steps) writer.add_scalar("Environment/Distance_to_goal", deltas["distance"] / deltas["count"], count_steps) # writer.add_scalars( # "losses", # {k: l for l, k in zip(losses, ["value", "policy"])}, # count_steps, # ) writer.add_scalar('Policy/Value_Loss', value_loss, count_steps) writer.add_scalar('Policy/Action_Loss', action_loss, count_steps) writer.add_scalar('Policy/Entropy', dist_entropy, count_steps) writer.add_scalar('Policy/Learning_Rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def _collect_rollout_step(self, rollouts, current_episode_reward, current_episode_step, episode_rewards, episode_spls, episode_counts, episode_steps, episode_distances): pth_time = 0.0 env_time = 0.0 t_sample_action = time.time() # sample actions with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } (values, actions, actions_log_probs, recurrent_hidden_states, distributions) = self.actor_critic.act( step_observation, rollouts.recurrent_hidden_states[rollouts.step], rollouts.prev_actions[rollouts.step], rollouts.masks[rollouts.step]) pth_time += time.time() - t_sample_action t_step_env = time.time() outputs = self.envs.step([{"action": a[0].item()} for a in actions]) observations, rewards, dones, infos = [list(x) for x in zip(*outputs)] logging.debug('Reward: {}'.format(rewards[0])) env_time += time.time() - t_step_env t_update_stats = time.time() batch = batch_obs(observations) rewards = torch.tensor(rewards, dtype=torch.float) rewards = rewards.unsqueeze(1) masks = torch.tensor([[0.0] if done else [1.0] for done in dones], dtype=torch.float) spls = torch.tensor([[info['spl']] for info in infos]) distances = torch.tensor([[info['distance_to_goal']] for info in infos]) current_episode_reward += rewards current_episode_step += 1 # current_episode_reward is accumulating rewards across multiple updates, # as long as the current episode is not finished # the current episode reward is added to the episode rewards only if the current episode is done # the episode count will also increase by 1 episode_rewards += (1 - masks) * current_episode_reward episode_spls += (1 - masks) * spls episode_steps += (1 - masks) * current_episode_step episode_counts += 1 - masks episode_distances += (1 - masks) * distances current_episode_reward *= masks current_episode_step *= masks rollouts.insert(batch, recurrent_hidden_states, actions, actions_log_probs, values, rewards, masks) pth_time += time.time() - t_update_stats return pth_time, env_time, self.envs.num_envs
def _eval_checkpoint(self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0) -> Dict: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ random.seed(self.config.SEED) np.random.seed(self.config.SEED) torch.manual_seed(self.config.SEED) # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT if self.config.DISPLAY_RESOLUTION != config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH: model_resolution = config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH = config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT = \ config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH = config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT = \ self.config.DISPLAY_RESOLUTION else: model_resolution = self.config.DISPLAY_RESOLUTION config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() elif "top_down_map" in self.config.VISUALIZATION_OPTION: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(config, get_env_class(config.ENV_NAME)) if self.config.DISPLAY_RESOLUTION != model_resolution: observation_space = self.envs.observation_spaces[0] observation_space.spaces['depth'].shape = (model_resolution, model_resolution, 1) observation_space.spaces['rgb'].shape = (model_resolution, model_resolution, 3) else: observation_space = self.envs.observation_spaces[0] self._setup_actor_critic_agent(ppo_cfg, observation_space) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic if self.config.RL.PPO.use_belief_predictor and "belief_predictor" in ckpt_dict: self.belief_predictor.load_state_dict( ckpt_dict["belief_predictor"]) self.metric_uuids = [] # get name of performance metric, e.g. "spl" for metric_name in self.config.TASK_CONFIG.TASK.MEASUREMENTS: metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name) measure_type = baseline_registry.get_measure(metric_cfg.TYPE) assert measure_type is not None, "invalid measurement type {}".format( metric_cfg.TYPE) self.metric_uuids.append( measure_type(sim=None, task=None, config=None)._get_uuid()) observations = self.envs.reset() if config.DISPLAY_RESOLUTION != model_resolution: obs_copy = resize_observation(observations, model_resolution) else: obs_copy = observations batch = batch_obs(obs_copy, self.device, skip_list=[ 'view_point_goals', 'intermediate', 'oracle_action_sensor' ]) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) if self.actor_critic.net.num_recurrent_layers == -1: num_recurrent_layers = 1 else: num_recurrent_layers = self.actor_critic.net.num_recurrent_layers test_recurrent_hidden_states = torch.zeros( num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) if ppo_cfg.use_external_memory: test_em = ExternalMemory( self.config.NUM_PROCESSES, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, self.actor_critic.net.memory_dim, ) test_em.to(self.device) else: test_em = None prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, None) descriptor_pred_gt = [[] for _ in range(self.config.NUM_PROCESSES)] for i in range(len(descriptor_pred_gt)): category_prediction = np.argmax( batch['category_belief'].cpu().numpy()[i]) location_prediction = batch['location_belief'].cpu().numpy()[i] category_gt = np.argmax(batch['category'].cpu().numpy()[i]) location_gt = batch['pointgoal_with_gps_compass'].cpu().numpy( )[i] geodesic_distance = -1 pair = (category_prediction, location_prediction, category_gt, location_gt, geodesic_distance) if 'view_point_goals' in observations[i]: pair += (observations[i]['view_point_goals'], ) descriptor_pred_gt[i].append(pair) rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] audios = [[] for _ in range(self.config.NUM_PROCESSES)] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) self.actor_critic.eval() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.eval() t = tqdm(total=self.config.TEST_EPISODE_COUNT) while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): _, actions, _, test_recurrent_hidden_states, test_em_features = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, test_em.memory[:, 0] if ppo_cfg.use_external_memory else None, test_em.masks if ppo_cfg.use_external_memory else None, deterministic=False) prev_actions.copy_(actions) actions = [a[0].item() for a in actions] outputs = self.envs.step(actions) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] if config.DISPLAY_RESOLUTION != model_resolution: obs_copy = resize_observation(observations, model_resolution) else: obs_copy = observations batch = batch_obs(obs_copy, self.device, skip_list=[ 'view_point_goals', 'intermediate', 'oracle_action_sensor' ]) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) # Update external memory if ppo_cfg.use_external_memory: test_em.insert(test_em_features, not_done_masks) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, dones) for i in range(len(descriptor_pred_gt)): category_prediction = np.argmax( batch['category_belief'].cpu().numpy()[i]) location_prediction = batch['location_belief'].cpu().numpy( )[i] category_gt = np.argmax(batch['category'].cpu().numpy()[i]) location_gt = batch['pointgoal_with_gps_compass'].cpu( ).numpy()[i] if dones[i]: geodesic_distance = -1 else: geodesic_distance = infos[i]['distance_to_goal'] pair = (category_prediction, location_prediction, category_gt, location_gt, geodesic_distance) if 'view_point_goals' in observations[i]: pair += (observations[i]['view_point_goals'], ) descriptor_pred_gt[i].append(pair) for i in range(self.envs.num_envs): if len(self.config.VIDEO_OPTION) > 0: if self.config.RL.PPO.use_belief_predictor: pred = descriptor_pred_gt[i][-1] else: pred = None if config.TASK_CONFIG.SIMULATOR.CONTINUOUS_VIEW_CHANGE and 'intermediate' in observations[ i]: for observation in observations[i]['intermediate']: frame = observations_to_image(observation, infos[i], pred=pred) rgb_frames[i].append(frame) del observations[i]['intermediate'] if "rgb" not in observations[i]: observations[i]["rgb"] = np.zeros( (self.config.DISPLAY_RESOLUTION, self.config.DISPLAY_RESOLUTION, 3)) frame = observations_to_image(observations[i], infos[i], pred=pred) rgb_frames[i].append(frame) audios[i].append(observations[i]['audiogoal']) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] for i in range(self.envs.num_envs): # pause envs which runs out of episodes if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: episode_stats = dict() for metric_uuid in self.metric_uuids: episode_stats[metric_uuid] = infos[i][metric_uuid] episode_stats["reward"] = current_episode_reward[i].item() episode_stats['geodesic_distance'] = current_episodes[ i].info['geodesic_distance'] episode_stats['euclidean_distance'] = norm( np.array(current_episodes[i].goals[0].position) - np.array(current_episodes[i].start_position)) episode_stats['audio_duration'] = int( current_episodes[i].duration) episode_stats['gt_na'] = int( current_episodes[i].info['num_action']) if self.config.RL.PPO.use_belief_predictor: episode_stats['gt_na'] = int( current_episodes[i].info['num_action']) episode_stats[ 'descriptor_pred_gt'] = descriptor_pred_gt[i][:-1] descriptor_pred_gt[i] = [descriptor_pred_gt[i][-1]] logging.debug(episode_stats) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, )] = episode_stats t.update() if len(self.config.VIDEO_OPTION) > 0: fps = self.config.TASK_CONFIG.SIMULATOR.VIEW_CHANGE_FPS \ if self.config.TASK_CONFIG.SIMULATOR.CONTINUOUS_VIEW_CHANGE else 1 if 'sound' in current_episodes[i].info: sound = current_episodes[i].info['sound'] else: sound = current_episodes[i].sound_id.split( '/')[1][:-4] generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i][:-1], scene_name=current_episodes[i].scene_id.split( '/')[3], sound=sound, sr=self.config.TASK_CONFIG.SIMULATOR.AUDIO. RIR_SAMPLING_RATE, episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metric_name='spl', metric_value=infos[i]['spl'], tb_writer=writer, audios=audios[i][:-1], fps=fps) # observations has been reset but info has not # to be consistent, do not use the last frame rgb_frames[i] = [] audios[i] = [] if "top_down_map" in self.config.VISUALIZATION_OPTION: if self.config.RL.PPO.use_belief_predictor: pred = episode_stats['descriptor_pred_gt'][-1] else: pred = None top_down_map = plot_top_down_map( infos[i], dataset=self.config.TASK_CONFIG.SIMULATOR. SCENE_DATASET, pred=pred) scene = current_episodes[i].scene_id.split('/')[3] sound = current_episodes[i].sound_id.split('/')[1][:-4] writer.add_image( f"{config.EVAL.SPLIT}_{scene}_{current_episodes[i].episode_id}_{sound}/" f"{infos[i]['spl']}", top_down_map, dataformats='WHC') if not self.config.RL.PPO.use_belief_predictor: descriptor_pred_gt = None ( self.envs, test_recurrent_hidden_states, not_done_masks, test_em, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs(envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, test_em, descriptor_pred_gt) # dump stats for each episode stats_file = os.path.join( config.TENSORBOARD_DIR, '{}_stats_{}.json'.format(config.EVAL.SPLIT, config.SEED)) with open(stats_file, 'w') as fo: json.dump( { ','.join(key): value for key, value in stats_episodes.items() }, fo, cls=NpEncoder) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): if stat_key in [ 'audio_duration', 'gt_na', 'descriptor_pred_gt', 'view_point_goals' ]: continue aggregated_stats[stat_key] = sum( [v[stat_key] for v in stats_episodes.values()]) num_episodes = len(stats_episodes) episode_reward_mean = aggregated_stats["reward"] / num_episodes episode_metrics_mean = {} for metric_uuid in self.metric_uuids: episode_metrics_mean[ metric_uuid] = aggregated_stats[metric_uuid] / num_episodes logger.info(f"Average episode reward: {episode_reward_mean:.6f}") for metric_uuid in self.metric_uuids: logger.info( f"Average episode {metric_uuid}: {episode_metrics_mean[metric_uuid]:.6f}" ) if not config.EVAL.SPLIT.startswith('test'): writer.add_scalar("{}/reward".format(config.EVAL.SPLIT), episode_reward_mean, checkpoint_index) for metric_uuid in self.metric_uuids: writer.add_scalar(f"{config.EVAL.SPLIT}/{metric_uuid}", episode_metrics_mean[metric_uuid], checkpoint_index) self.envs.close() result = {'episode_reward_mean': episode_reward_mean} for metric_uuid in self.metric_uuids: result['episode_{}_mean'.format( metric_uuid)] = episode_metrics_mean[metric_uuid] return result
def train(self) -> None: r"""Main method for training PPO. Returns: None """ logger.info(f"config: {self.config}") random.seed(self.config.SEED) np.random.seed(self.config.SEED) torch.manual_seed(self.config.SEED) # add_signal_handlers() self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME), workers_ignore_signals=True) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) if ppo_cfg.use_external_memory: memory_dim = self.actor_critic.net.memory_dim else: memory_dim = None rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ppo_cfg.use_external_memory, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, memory_dim, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, None) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state( model_dir=self.config.MODEL_DIR) if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optimizer_state"]) lr_scheduler.load_state_dict( interrupted_state["lr_scheduler_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set(): requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state(dict( state_dict=self.agent.state_dict(), optimizer_state=self.agent.optimizer.state_dict(), lr_scheduler_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), model_dir=self.config.MODEL_DIR) requeue_job() return for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("Metrics/reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: # writer.add_scalars("metrics", metrics, count_steps) for metric, value in metrics.items(): writer.add_scalar(f"Metrics/{metric}", value, count_steps) writer.add_scalar("Policy/value_loss", value_loss, count_steps) writer.add_scalar("Policy/policy_loss", action_loss, count_steps) writer.add_scalar("Policy/entropy_loss", dist_entropy, count_steps) writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def _collect_rollout_step(self, rollouts, current_episode_reward, running_episode_stats): pth_time = 0.0 env_time = 0.0 t_sample_action = time.time() # sample actions with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } external_memory = None external_memory_masks = None if self.config.RL.PPO.use_external_memory: external_memory = rollouts.external_memory[:, rollouts. step].contiguous() external_memory_masks = rollouts.external_memory_masks[ rollouts.step] (values, actions, actions_log_probs, recurrent_hidden_states, external_memory_features) = self.actor_critic.act( step_observation, rollouts.recurrent_hidden_states[rollouts.step], rollouts.prev_actions[rollouts.step], rollouts.masks[rollouts.step], external_memory, external_memory_masks, ) pth_time += time.time() - t_sample_action t_step_env = time.time() outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [list(x) for x in zip(*outputs)] logging.debug('Reward: {}'.format(rewards[0])) env_time += time.time() - t_step_env t_update_stats = time.time() batch = batch_obs(observations, device=self.device) rewards = torch.tensor(rewards, dtype=torch.float, device=current_episode_reward.device) rewards = rewards.unsqueeze(1) masks = torch.tensor([[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=current_episode_reward.device) current_episode_reward += rewards running_episode_stats["reward"] += (1 - masks) * current_episode_reward running_episode_stats["count"] += 1 - masks for k, v in self._extract_scalars_from_infos(infos).items(): v = torch.tensor(v, dtype=torch.float, device=current_episode_reward.device).unsqueeze(1) if k not in running_episode_stats: running_episode_stats[k] = torch.zeros_like( running_episode_stats["count"]) running_episode_stats[k] += (1 - masks) * v current_episode_reward *= masks rollouts.insert( batch, recurrent_hidden_states, actions, actions_log_probs, values, rewards.to(device=self.device), masks.to(device=self.device), external_memory_features, ) if self.config.RL.PPO.use_belief_predictor: step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } self.belief_predictor.update(step_observation, dones) for sensor in [LocationBelief.cls_uuid, CategoryBelief.cls_uuid]: rollouts.observations[sensor][rollouts.step].copy_( step_observation[sensor]) pth_time += time.time() - t_update_stats return pth_time, env_time, self.envs.num_envs