def train_agent(agent): r"""Runs experiment given mode and config Args: exp_config: path to config file. run_type: "train" or "eval. opts: list of strings of additional config options. Returns: None. """ config_file = r"/home/userone/workspace/bed1/ppo_custom/ppo_pointnav_example.yaml" config = get_config(config_file, None) # config = get_config1(agent) env = construct_envs(config, get_env_class(config.ENV_NAME)) config.defrost() config.TASK_CONFIG.DATASET.DATA_PATH = config.DATA_PATH_SET config.TASK_CONFIG.DATASET.SCENES_DIR = config.SCENE_DIR_SET config.freeze() random.seed(config.TASK_CONFIG.SEED) np.random.seed(config.TASK_CONFIG.SEED) if agent == "ppo_agent": trainer = ppo_trainer.PPOTrainer(config) trainer.train(env)
def _load_real_data(self, config): real_mem = self.memory_real config = config.clone() config.defrost() config.TASK_CONFIG.DATASET.TYPE = "PepperRobot" config.ENV_NAME = "PepperPlaybackEnv" config.NUM_PROCESSES = 1 config.freeze() real_env = construct_envs( config, get_env_class(config.ENV_NAME) ) observations = real_env.reset() real_mem.insert([observations[0]["rgb"], observations[0]["depth"]]) print(f"Loading {self.real_mem_size} ROBOT observations ...") for i in range(self.real_mem_size): outputs = real_env.step([0]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] real_mem.insert([observations[0]["rgb"], observations[0]["depth"]]) print(f"Loaded {self.real_mem_size} observations from ROBOT dataset")
def _init_envs(self, config=None): if config is None: config = self.config self.envs = construct_envs( config, get_env_class(config.ENV_NAME), workers_ignore_signals=is_slurm_batch_job(), )
def __init__(self, config=None): super().__init__(config) self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self._setup_actor_critic_agent(self.config) self.rollout = RolloutStorage(self.config.RL.PPO.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], self.config.RL.PPO.hidden_size) self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) self.rollout.to(self.device)
def run_random_agent(): config_file = r"/home/userone/workspace/bed1/ppo_custom/ppo_pointnav_example.yaml" config = get_config(config_file, None) # config = get_config1(agent) env = construct_envs(config, get_env_class(config.ENV_NAME)) config.defrost() config.TASK_CONFIG.DATASET.DATA_PATH = config.DATA_PATH_SET config.TASK_CONFIG.DATASET.SCENES_DIR = config.SCENE_DIR_SET config.freeze() random.seed(config.TASK_CONFIG.SEED) np.random.seed(config.TASK_CONFIG.SEED) random_agent.run(config, env, 3)
def trigger_transfer_learn(agent): config_file = r"/home/userone/workspace/bed1/ppo_custom/ppo_pointnav_example.yaml" config = get_config(config_file, None) # config = get_config1(agent) env = construct_envs(config, get_env_class(config.ENV_NAME)) config.defrost() config.TASK_CONFIG.DATASET.DATA_PATH = config.DATA_PATH_SET config.TASK_CONFIG.DATASET.SCENES_DIR = config.SCENE_DIR_SET config.freeze() random.seed(config.TASK_CONFIG.SEED) np.random.seed(config.TASK_CONFIG.SEED) if agent == "ppo_agent": trainer = transfer_ppo.PPOTrainer(config) trainer.train(env)
def __init__( self, config: Any, video_option: List[str], video_dir: str, vid_filename_metrics: Set[str], traj_save_dir: str = None, should_save_fn=None, writer=None, ) -> None: env_class = get_env_class(config.ENV_NAME) env = make_env_fn(env_class=env_class, config=config) self._gym_env = HabGymWrapper(env, save_orig_obs=True) self._video_option = video_option self._video_dir = video_dir self._writer = writer self._vid_filename_metrics = vid_filename_metrics self._traj_save_path = traj_save_dir self._should_save_fn = should_save_fn
def test_gym_wrapper_contract(config_file): config = baselines_get_config(config_file) env_class = get_env_class(config.ENV_NAME) env = habitat_baselines.utils.env_utils.make_env_fn(env_class=env_class, config=config) env = HabGymWrapper(env) env = HabRenderWrapper(env) assert isinstance(env.action_space, spaces.Box) obs = env.reset() assert isinstance(obs, np.ndarray), f"Obs {obs}" assert obs.shape == env.observation_space.shape obs, _, _, info = env.step(env.action_space.sample()) assert isinstance(obs, np.ndarray), f"Obs {obs}" assert obs.shape == env.observation_space.shape frame = env.render() assert isinstance(frame, np.ndarray) assert len(frame.shape) == 3 and frame.shape[-1] == 3 for _, v in info.items(): assert not isinstance(v, dict)
def test_rearrange_task(test_cfg_path): config = baselines_get_config(test_cfg_path) env_class = get_env_class(config.ENV_NAME) env = habitat_baselines.utils.env_utils.make_env_fn( env_class=env_class, config=config ) with env: for _ in range(10): env.reset() done = False while not done: action = env.action_space.sample() habitat.logger.info( f"Action : " f"{action['action']}, " f"args: {action['action_args']}." ) _, _, done, info = env.step(action=action) logger.info(info)
def test_rearrange_task(): config = baselines_get_config( "habitat_baselines/config/rearrange/ddppo_rearrangepick.yaml") # if not RearrangeDatasetV0.check_config_paths_exist(config.TASK_CONFIG.DATASET): # pytest.skip("Test skipped as dataset files are missing.") env_class = get_env_class(config.ENV_NAME) env = habitat_baselines.utils.env_utils.make_env_fn(env_class=env_class, config=config) with env: for _ in range(10): env.reset() done = False while not done: action = env.action_space.sample() habitat.logger.info(f"Action : " f"{action['action']}, " f"args: {action['action_args']}.") _, _, done, info = env.step(action=action) logger.info(info)
def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ ckpt_dict = self.load_checkpoint(checkpoint_path, map_location=self.device) config = self._setup_eval_config(ckpt_dict["config"]) ppo_cfg = config.RL.PPO if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic # get name of performance metric, e.g. "spl" metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0] metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name) measure_type = baseline_registry.get_measure(metric_cfg.TYPE) assert measure_type is not None, "invalid measurement type {}".format( metric_cfg.TYPE) self.metric_uuid = measure_type(None, None)._get_uuid() observations = self.envs.reset() batch = batch_obs(observations) for sensor in batch: batch[sensor] = batch[sensor].to(self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [ [] ] * self.config.NUM_PROCESSES # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): _, actions, _, test_recurrent_hidden_states = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) prev_actions.copy_(actions) outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations) for sensor in batch: batch[sensor] = batch[sensor].to(self.device) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: episode_stats = dict() episode_stats[self.metric_uuid] = infos[i][ self.metric_uuid] episode_stats["success"] = int( infos[i][self.metric_uuid] > 0) episode_stats["reward"] = current_episode_reward[i].item() current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, )] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metric_name=self.metric_uuid, metric_value=infos[i][self.metric_uuid], tb_writer=writer, ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) rgb_frames[i].append(frame) # pausing self.envs with no new episode if len(envs_to_pause) > 0: state_index = list(range(self.envs.num_envs)) for idx in reversed(envs_to_pause): state_index.pop(idx) self.envs.pause_at(idx) # indexing along the batch dimensions test_recurrent_hidden_states = test_recurrent_hidden_states[ state_index] not_done_masks = not_done_masks[state_index] current_episode_reward = current_episode_reward[state_index] prev_actions = prev_actions[state_index] for k, v in batch.items(): batch[k] = v[state_index] if len(self.config.VIDEO_OPTION) > 0: rgb_frames = [rgb_frames[i] for i in state_index] aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = sum( [v[stat_key] for v in stats_episodes.values()]) num_episodes = len(stats_episodes) episode_reward_mean = aggregated_stats["reward"] / num_episodes episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes episode_success_mean = aggregated_stats["success"] / num_episodes logger.info(f"Average episode reward: {episode_reward_mean:.6f}") logger.info(f"Average episode success: {episode_success_mean:.6f}") logger.info( f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}") writer.add_scalars( "eval_reward", {"average reward": episode_reward_mean}, checkpoint_index, ) writer.add_scalars( f"eval_{self.metric_uuid}", {f"average {self.metric_uuid}": episode_metric_mean}, checkpoint_index, ) writer.add_scalars( "eval_success", {"average success": episode_success_mean}, checkpoint_index, ) self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = torch.device("cuda", self.config.TORCH_GPU_ID) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) observations = self.envs.reset() batch = batch_obs(observations) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) rollouts.to(self.device) episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, episode_rewards, episode_counts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) losses = [value_loss, action_loss] stats = zip( ["count", "reward"], [window_episode_counts, window_episode_reward], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(self): # Get environments for training self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) #logger.info( # "agent number of parameters: {}".format( # sum(param.numel() for param in self.agent.parameters()) # ) #) # Change for the actual value cfg = self.config.RL.PPO rollouts = RolloutStorage( cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) for sensor in rollouts.observations: print(batch[sensor].shape) # Copy the information to the wrapper for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) #current_episode_reward = torch.zeros(self.envs.num_envs, 1) #window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) #window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) '''
def __init__(self, wrapper, config_file, base_config_file=None, horizon=None, gamma=0.99, width=None, height=None): """ Constructor. For more details on how to pass YAML configuration files, please see <MUSHROOM_RL PATH>/examples/habitat/README.md Args: wrapper (str): wrapper for converting observations and actions (e.g., HabitatRearrangeWrapper); config_file (str): path to the YAML file specifying the RL task configuration (see <HABITAT_LAB PATH>/habitat_baselines/configs/); base_config_file (str, None): path to an optional YAML file, used as 'BASE_TASK_CONFIG_PATH' in the first YAML (see <HABITAT_LAB PATH>/configs/); horizon (int, None): the horizon; gamma (float, 0.99): the discount factor; width (int, None): width of the pixel observation. If None, the value specified in the config file is used. height (int, None): height of the pixel observation. If None, the value specified in the config file is used. """ # MDP creation self._not_pybullet = False self._first = True if base_config_file is None: base_config_file = config_file config = get_config(config_paths=config_file, opts=['BASE_TASK_CONFIG_PATH', base_config_file]) config.defrost() if horizon is None: horizon = config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS # Get the default horizon config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS = horizon + 1 # Hack to ignore gym time limit # Overwrite all RGB width / height used for the TASK (not SIMULATOR) for k in config['TASK_CONFIG']['SIMULATOR']: if 'rgb' in k.lower(): if height is not None: config['TASK_CONFIG']['SIMULATOR'][k]['HEIGHT'] = height if width is not None: config['TASK_CONFIG']['SIMULATOR'][k]['WIDTH'] = width config.freeze() env_class = get_env_class(config.ENV_NAME) env = make_env_fn(env_class=env_class, config=config) env = globals()[wrapper](env) self.env = env self._img_size = env.observation_space.shape[0:2] # MDP properties action_space = self.env.action_space observation_space = Box(low=0., high=255., shape=(3, self._img_size[1], self._img_size[0])) mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) if isinstance(action_space, Discrete): self._convert_action = lambda a: a[0] else: self._convert_action = lambda a: a self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), 1 / 10) Environment.__init__(self, mdp_info)
def _eval_checkpoint(self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0) -> None: r"""Evaluates a single checkpoint. Assumes episode IDs are unique. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ logger.info(f"checkpoint_path: {checkpoint_path}") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config( self.load_checkpoint(checkpoint_path, map_location="cpu")["config"]) else: config = self.config.clone() config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.TASK_CONFIG.TASK.NDTW.SPLIT = config.EVAL.SPLIT config.TASK_CONFIG.TASK.SDTW.SPLIT = config.EVAL.SPLIT config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1 if len(config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() # setup agent self.envs = construct_envs_auto_reset_false( config, get_env_class(config.ENV_NAME)) self.device = (torch.device("cuda", config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) self._setup_actor_critic_agent(config.MODEL, True, checkpoint_path) observations = self.envs.reset() observations = transform_obs( observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID) batch = batch_obs(observations, self.device) eval_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, config.NUM_PROCESSES, config.MODEL.STATE_ENCODER.hidden_size, device=self.device, ) prev_actions = torch.zeros(config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(config.NUM_PROCESSES, 1, device=self.device) stats_episodes = {} # dict of dicts that stores stats per episode if len(config.VIDEO_OPTION) > 0: os.makedirs(config.VIDEO_DIR, exist_ok=True) rgb_frames = [[] for _ in range(config.NUM_PROCESSES)] self.actor_critic.eval() while (self.envs.num_envs > 0 and len(stats_episodes) < config.EVAL.EPISODE_COUNT): current_episodes = self.envs.current_episodes() with torch.no_grad(): (_, actions, _, eval_recurrent_hidden_states) = self.actor_critic.act( batch, eval_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=True, ) # actions = batch["vln_oracle_action_sensor"].long() prev_actions.copy_(actions) outputs = self.envs.step([a[0].item() for a in actions]) observations, _, dones, infos = [list(x) for x in zip(*outputs)] not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) # reset envs and observations if necessary for i in range(self.envs.num_envs): if len(config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) frame = append_text_to_image( frame, current_episodes[i].instruction.instruction_text) rgb_frames[i].append(frame) if not dones[i]: continue stats_episodes[current_episodes[i].episode_id] = infos[i] observations[i] = self.envs.reset_at(i)[0] prev_actions[i] = torch.zeros(1, dtype=torch.long) if len(config.VIDEO_OPTION) > 0: generate_video( video_option=config.VIDEO_OPTION, video_dir=config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics={ "SPL": round( stats_episodes[current_episodes[i].episode_id] ["spl"], 6) }, tb_writer=writer, ) del stats_episodes[ current_episodes[i].episode_id]["top_down_map"] del stats_episodes[ current_episodes[i].episode_id]["collisions"] rgb_frames[i] = [] observations = transform_obs( observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID) batch = batch_obs(observations, self.device) envs_to_pause = [] next_episodes = self.envs.current_episodes() for i in range(self.envs.num_envs): if next_episodes[i].episode_id in stats_episodes: envs_to_pause.append(i) ( self.envs, eval_recurrent_hidden_states, not_done_masks, prev_actions, batch, ) = self._pause_envs( envs_to_pause, self.envs, eval_recurrent_hidden_states, not_done_masks, prev_actions, batch, ) self.envs.close() aggregated_stats = {} num_episodes = len(stats_episodes) for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = ( sum([v[stat_key] for v in stats_episodes.values()]) / num_episodes) split = config.TASK_CONFIG.DATASET.SPLIT with open(f"stats_ckpt_{checkpoint_index}_{split}.json", "w") as f: json.dump(aggregated_stats, f, indent=4) logger.info(f"Episodes evaluated: {num_episodes}") checkpoint_num = checkpoint_index + 1 for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.6f}") writer.add_scalar(f"eval_{split}_{k}", v, checkpoint_num)
def _update_dataset(self, data_it): if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() if self.envs is None: self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, self.config.MODEL.STATE_ENCODER.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) observations = self.envs.reset() observations = transform_obs( observations, self.config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID) batch = batch_obs(observations, self.device) episodes = [[] for _ in range(self.envs.num_envs)] skips = [False for _ in range(self.envs.num_envs)] # Populate dones with False initially dones = [False for _ in range(self.envs.num_envs)] # https://arxiv.org/pdf/1011.0686.pdf # Theoretically, any beta function is fine so long as it converges to # zero as data_it -> inf. The paper suggests starting with beta = 1 and # exponential decay. if self.config.DAGGER.P == 0.0: # in Python 0.0 ** 0.0 == 1.0, but we want 0.0 beta = 0.0 else: beta = self.config.DAGGER.P**data_it ensure_unique_episodes = beta == 1.0 def hook_builder(tgt_tensor): def hook(m, i, o): tgt_tensor.set_(o.cpu()) return hook rgb_features = None rgb_hook = None if self.config.MODEL.RGB_ENCODER.cnn_type == "TorchVisionResNet50": rgb_features = torch.zeros((1, ), device="cpu") rgb_hook = self.actor_critic.net.rgb_encoder.layer_extract.register_forward_hook( hook_builder(rgb_features)) depth_features = None depth_hook = None if self.config.MODEL.DEPTH_ENCODER.cnn_type == "VlnResnetDepthEncoder": depth_features = torch.zeros((1, ), device="cpu") depth_hook = self.actor_critic.net.depth_encoder.visual_encoder.register_forward_hook( hook_builder(depth_features)) collected_eps = 0 if ensure_unique_episodes: ep_ids_collected = set( [ep.episode_id for ep in self.envs.current_episodes()]) with tqdm.tqdm( total=self.config.DAGGER.UPDATE_SIZE) as pbar, lmdb.open( self.lmdb_features_dir, map_size=int(self.config.DAGGER.LMDB_MAP_SIZE )) as lmdb_env, torch.no_grad(): start_id = lmdb_env.stat()["entries"] txn = lmdb_env.begin(write=True) while collected_eps < self.config.DAGGER.UPDATE_SIZE: if ensure_unique_episodes: envs_to_pause = [] current_episodes = self.envs.current_episodes() for i in range(self.envs.num_envs): if dones[i] and not skips[i]: ep = episodes[i] traj_obs = batch_obs([step[0] for step in ep], device=torch.device("cpu")) del traj_obs["vln_oracle_action_sensor"] for k, v in traj_obs.items(): traj_obs[k] = v.numpy() transposed_ep = [ traj_obs, np.array([step[1] for step in ep], dtype=np.int64), np.array([step[2] for step in ep], dtype=np.int64), ] txn.put( str(start_id + collected_eps).encode(), msgpack_numpy.packb(transposed_ep, use_bin_type=True), ) pbar.update() collected_eps += 1 if (collected_eps % self.config.DAGGER.LMDB_COMMIT_FREQUENCY) == 0: txn.commit() txn = lmdb_env.begin(write=True) if ensure_unique_episodes: if current_episodes[ i].episode_id in ep_ids_collected: envs_to_pause.append(i) else: ep_ids_collected.add( current_episodes[i].episode_id) if dones[i]: episodes[i] = [] if ensure_unique_episodes: ( self.envs, recurrent_hidden_states, not_done_masks, prev_actions, batch, ) = self._pause_envs( envs_to_pause, self.envs, recurrent_hidden_states, not_done_masks, prev_actions, batch, ) if self.envs.num_envs == 0: break (_, actions, _, recurrent_hidden_states) = self.actor_critic.act( batch, recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) # print("action: ", actions) # print("batch[vln_oracle_action_sensor]: ", batch["vln_oracle_action_sensor"]) # print("torch.rand_like(actions, dtype=torch.float) < beta: ", torch.rand_like(actions, dtype=torch.float) < beta) actions = torch.where( torch.rand_like(actions, dtype=torch.float) < beta, batch["vln_oracle_action_sensor"].long(), actions, ) for i in range(self.envs.num_envs): if rgb_features is not None: observations[i]["rgb_features"] = rgb_features[i] del observations[i]["rgb"] if depth_features is not None: observations[i]["depth_features"] = depth_features[i] del observations[i]["depth"] episodes[i].append(( observations[i], prev_actions[i].item(), batch["vln_oracle_action_sensor"][i].item(), )) skips = batch["vln_oracle_action_sensor"].long() == -1 actions = torch.where(skips, torch.zeros_like(actions), actions) skips = skips.squeeze(-1).to(device="cpu", non_blocking=True) prev_actions.copy_(actions) outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, _ = [ list(x) for x in zip(*outputs) ] not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) observations = transform_obs( observations, self.config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID) batch = batch_obs(observations, self.device) txn.commit() self.envs.close() self.envs = None if rgb_hook is not None: rgb_hook.remove() if depth_hook is not None: depth_hook.remove()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observation_space = self.envs.observation_spaces[0] ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK aux_cfg = self.config.RL.AUX_TASKS self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_dqn_agent(ppo_cfg, task_cfg, aux_cfg, []) self.dataset = RolloutDataset() self.dataloader = DataLoader(self.dataset, batch_size=16, num_workers=0) # Use environment to initialize the metadata for training the model self.envs.close() if self.config.RESUME_CURIOUS: weights = torch.load(self.config.RESUME_CURIOUS)['state_dict'] state_dict = self.q_network.state_dict() weights_new = {} for k, v in weights.items(): if "model_encoder" in k: k = k.replace("model_encoder", "visual_resnet").replace("actor_critic.", "") if k in state_dict: weights_new[k] = v state_dict.update(weights_new) self.q_network.load_state_dict(state_dict) logger.info( "agent number of parameters: {}".format( sum(param.numel() for param in self.q_network.parameters()) ) ) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.q_network.load_state_dict(ckpt_dict["state_dict"]) self.q_network_target.load_state_dict(ckpt_dict["target_state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) im_size = 256 with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: update = 0 for i in range(self.config.NUM_EPOCHS): for im, pointgoal, action, mask, reward in self.dataloader: if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() im, pointgoal, action, mask, reward = collate(im), collate(pointgoal), collate(action), collate(mask), collate(reward) im = im.to(self.device).float() pointgoal = pointgoal.to(self.device).float() mask = mask.to(self.device).float() reward = reward.to(self.device).float() action = action.to(self.device).long() nstep = im.size(1) hidden_states = None hidden_states_target = None # q_vals = [] # q_vals_target = [] step = random.randint(0, nstep-1) output = self.q_network({'rgb': im[:, step]}, None, None) mse_loss = torch.pow(output - im[:, step] / 255., 2).mean() mse_loss.backward() # for step in range(nstep): # q_val, hidden_states = self.q_network({'rgb': im[:, step], 'pointgoal_with_gps_compass': pointgoal[:, step]}, hidden_states, mask[:, step]) # q_val_target, hidden_states_target = self.q_network_target({'rgb': im[:, step], 'pointgoal_with_gps_compass': pointgoal[:, step]}, hidden_states_target, mask[:, step]) # q_vals.append(q_val) # q_vals_target.append(q_val_target) # q_vals = torch.stack(q_vals, dim=1) # q_vals_target = torch.stack(q_vals_target, dim=1) # a_select = torch.argmax(q_vals, dim=-1, keepdim=True) # target_select = torch.gather(q_vals_target, -1, a_select) # target = reward + ppo_cfg.gamma * target_select[:, 1:] * mask[:, 1:] # target = target.detach() # pred_q = torch.gather(q_vals[:, :-1], -1, action) # mse_loss = torch.pow(pred_q - target, 2).mean() # mse_loss.backward() # grad_norm = torch.nn.utils.clip_grad_norm(self.q_network.parameters(), 80) self.optimizer.step() self.optimizer.zero_grad() writer.add_scalar( "loss", mse_loss, update, ) # writer.add_scalar( # "q_val", # q_vals.max(), # update, # ) if update % 10 == 0: print("Update: {}, loss: {}".format(update, mse_loss)) if update % 100 == 0: self.sync_model() # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps) ) count_checkpoints += 1 update = update + 1
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME), devices=self._assign_devices(), ) # set up configurations ppo_cfg = self.config.RL.PPO ans_cfg = self.config.RL.ANS # set up device self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) # set up checkpoint folder if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) # create rollouts for mapper self.mapper_rollouts = self._create_mapper_rollouts(ans_cfg) self.depth_projection_net = DepthProjectionNet( ans_cfg.SEMANTIC_ANTICIPATOR.EGO_PROJECTION ) self._setup_anticipator(ppo_cfg, ans_cfg) logger.info( "mapper_agent number of parameters: {}".format( sum(param.numel() for param in self.mapper_agent.parameters()) ) ) mapper_rollouts = self.mapper_rollouts # ===================== Create statistics buffers ===================== statistics_dict = {} # Mapper statistics statistics_dict["mapper"] = defaultdict( lambda: deque(maxlen=ppo_cfg.loss_stats_window_size) ) # Overall count statistics episode_counts = torch.zeros(self.envs.num_envs, 1) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 # ==================== Measuring memory consumption =================== total_memory_size = 0 print("=================== Mapper rollouts ======================") for k, v in mapper_rollouts.observations.items(): mem = v.element_size() * v.nelement() * 1e-9 print(f"key: {k:<40s}, memory: {mem:>10.4f} GB") total_memory_size += mem print(f"Total memory: {total_memory_size:>10.4f} GB") # Resume checkpoint if available ( num_updates_start, count_steps_start, count_checkpoints, ) = self.resume_checkpoint() count_steps = count_steps_start imH, imW = ans_cfg.image_scale_hw M = ans_cfg.overall_map_size # ==================== Create state variables ================= state_estimates = { # Agent's pose estimate "pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device), # Agent's map "map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device), "recurrent_hidden_states": torch.zeros( 1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size ).to(self.device), "visited_states": torch.zeros(self.envs.num_envs, 1, M, M).to(self.device), } ground_truth_states = { # To measure area seen "visible_occupancy": torch.zeros( self.envs.num_envs, 2, M, M, device=self.device ), "pose": torch.zeros(self.envs.num_envs, 3, device=self.device), "prev_global_reward_metric": torch.zeros( self.envs.num_envs, 1, device=self.device ), } masks = torch.zeros(self.envs.num_envs, 1) episode_step_count = torch.zeros(self.envs.num_envs, 1, device=self.device) # ==================== Reset the environments ================= observations = self.envs.reset() batch = self._prepare_batch(observations) prev_batch = batch running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), local_reward=torch.zeros(self.envs.num_envs, 1), global_reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) # Useful variables NUM_MAPPER_STEPS = ans_cfg.MAPPER.num_mapper_steps NUM_LOCAL_STEPS = ppo_cfg.num_local_steps NUM_GLOBAL_STEPS = ppo_cfg.num_global_steps GLOBAL_UPDATE_INTERVAL = NUM_GLOBAL_STEPS * ans_cfg.goal_interval NUM_GLOBAL_UPDATES_PER_EPISODE = self.config.T_EXP // GLOBAL_UPDATE_INTERVAL NUM_GLOBAL_UPDATES = ( self.config.NUM_EPISODES * NUM_GLOBAL_UPDATES_PER_EPISODE // self.config.NUM_PROCESSES ) # Sanity checks assert ( NUM_MAPPER_STEPS % NUM_LOCAL_STEPS == 0 ), "Mapper steps must be a multiple of global steps interval" assert ( NUM_LOCAL_STEPS == ans_cfg.goal_interval ), "Local steps must be same as subgoal sampling interval" with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: for update in range(num_updates_start, NUM_GLOBAL_UPDATES): for step in range(NUM_GLOBAL_STEPS): ( delta_pth_time, delta_env_time, delta_steps, prev_batch, batch, state_estimates, ground_truth_states, ) = self._collect_rollout_step( batch, prev_batch, episode_step_count, state_estimates, ground_truth_states, masks, mapper_rollouts, ) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps # Useful flags UPDATE_MAPPER_FLAG = ( True if episode_step_count[0].item() % NUM_MAPPER_STEPS == 0 else False ) # ------------------------ update mapper -------------------------- if UPDATE_MAPPER_FLAG: ( delta_pth_time, update_metrics_mapper, ) = self._update_mapper_agent(mapper_rollouts) for k, v in update_metrics_mapper.items(): statistics_dict["mapper"][k].append(v) pth_time += delta_pth_time # -------------------------- log statistics ----------------------- for k, v in statistics_dict.items(): logger.info( "=========== {:20s} ============".format(k + " stats") ) for kp, vp in v.items(): if len(vp) > 0: writer.add_scalar(f"{k}/{kp}", np.mean(vp), count_steps) logger.info(f"{kp:25s}: {np.mean(vp).item():10.5f}") for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) fps = (count_steps - count_steps_start) / (time.time() - t_start) writer.add_scalar("fps", fps, count_steps) if update > 0: logger.info("update: {}\tfps: {:.3f}\t".format(update, fps)) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) pth_time += delta_pth_time # At episode termination, manually set masks to zeros. if episode_step_count[0].item() == self.config.T_EXP: masks.fill_(0) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps, update=update), ) count_checkpoints += 1 if episode_step_count[0].item() == self.config.T_EXP: seed = int(time.time()) print('change random seed to {}'.format(seed)) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) observations = self.envs.reset() batch = self._prepare_batch(observations) prev_batch = batch # Reset episode step counter episode_step_count.fill_(0) # Reset states for k in ground_truth_states.keys(): ground_truth_states[k].fill_(0) for k in state_estimates.keys(): state_estimates[k].fill_(0) self.envs.close()
def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ from habitat_baselines.common.environments import get_env_class # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1 config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() # logger.info(f"env config: {config}") self.envs = construct_envs_habitat(config, get_env_class(config.ENV_NAME)) self.observation_space = SpaceDict({ "pointgoal_with_gps_compass": spaces.Box(low=0.0, high=1.0, shape=(2, ), dtype=np.float32) }) if self.config.COLOR: self.observation_space = SpaceDict({ "rgb": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(3, *self.config.RESOLUTION), dtype=np.uint8, ), **self.observation_space.spaces, }) if self.config.DEPTH: self.observation_space = SpaceDict({ "depth": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(1, *self.config.RESOLUTION), dtype=np.float32, ), **self.observation_space.spaces, }) self.action_space = self.envs.action_spaces[0] self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic self.actor_critic.script_net() self.actor_critic.eval() observations = self.envs.reset() batch = batch_obs(observations, device=self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.config.NUM_PROCESSES, self.actor_critic.num_recurrent_layers, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.bool) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) number_of_eval_episodes = self.config.TEST_EPISODE_COUNT if number_of_eval_episodes == -1: number_of_eval_episodes = sum(self.envs.number_of_episodes) else: total_num_eps = sum(self.envs.number_of_episodes) if total_num_eps < number_of_eval_episodes: logger.warn( f"Config specified {number_of_eval_episodes} eval episodes" ", dataset only has {total_num_eps}.") logger.warn(f"Evaluating with {total_num_eps} instead.") number_of_eval_episodes = total_num_eps evals_per_ep = 5 count_per_ep = defaultdict(lambda: 0) pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep) self.actor_critic.eval() while (len(stats_episodes) < (number_of_eval_episodes * evals_per_ep) and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): (_, dist_result, test_recurrent_hidden_states) = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) actions = dist_result["actions"] prev_actions.copy_(actions) actions = actions.to("cpu") outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations, device=self.device) not_done_masks = torch.tensor( [[False] if done else [True] for done in dones], dtype=torch.bool, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): next_count_key = ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) if count_per_ep[next_count_key] == evals_per_ep: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: pbar.update() episode_stats = dict() episode_stats["reward"] = current_episode_reward[i].item() episode_stats.update( self._extract_scalars_from_info(infos[i])) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats count_key = ( current_episodes[i].scene_id, current_episodes[i].episode_id, ) count_per_ep[count_key] = count_per_ep[count_key] + 1 ep_key = (count_key, count_per_ep[count_key]) stats_episodes[ep_key] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics=self._extract_scalars_from_info(infos[i]), tb_writer=writer, ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) rgb_frames[i].append(frame) ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) self.envs.close() pbar.close() num_episodes = len(stats_episodes) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): values = [ v[stat_key] for v in stats_episodes.values() if v[stat_key] is not None ] if len(values) > 0: aggregated_stats[stat_key] = sum(values) / len(values) else: aggregated_stats[stat_key] = 0 for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.4f}") step_id = checkpoint_index if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] writer.add_scalars( "eval_reward", {"average reward": aggregated_stats["reward"]}, step_id, ) metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} if len(metrics) > 0: writer.add_scalars("eval_metrics", metrics, step_id) self.num_frames = step_id time.sleep(30)
def train(self) -> None: r"""Main method for DD-PPO SLAM. Returns: None """ ##################################################################### ## init distrib and configuration ##################################################################### self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend ) # self.local_rank = 1 add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore( "rollout_tracker", tcp_store ) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() # server number self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server self.config.SIMULATOR_GPU_ID = self.local_rank print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID) print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID) # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += ( self.world_rank * self.config.NUM_PROCESSES ) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") ##################################################################### ## build distrib NavSLAMRLEnv environment ##################################################################### print("#############################################################") print("## build distrib NavSLAMRLEnv environment") print("#############################################################") self.envs = construct_envs( self.config, get_env_class(self.config.ENV_NAME) ) observations = self.envs.reset() print("*************************** observations len:", len(observations)) # semantic process for i in range(len(observations)): observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32) se = list(set(observations[i]["semantic"].ravel())) print(se) # print("*************************** observations type:", observations) # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23 # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # [] batch = batch_obs(observations, device=self.device) print("*************************** batch len:", len(batch)) # print("*************************** batch:", batch) # print("************************************* current_episodes:", (self.envs.current_episodes())) ##################################################################### ## init actor_critic agent ##################################################################### print("#############################################################") print("## init actor_critic agent") print("#############################################################") self.map_w = observations[0]["map_sum"].shape[0] self.map_h = observations[0]["map_sum"].shape[1] # print("map_: ", observations[0]["curr_pose"].shape) ppo_cfg = self.config.RL.PPO if ( not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0 ): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(observations, ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info( "agent number of trainable parameters: {}".format( sum( param.numel() for param in self.agent.parameters() if param.requires_grad ) ) ) ##################################################################### ## init Global Rollout Storage ##################################################################### print("#############################################################") print("## init Global Rollout Storage") print("#############################################################") self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step rollouts = GlobalRolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.obs_space, self.g_action_space, ) rollouts.to(self.device) print('rollouts type:', type(rollouts)) print('--------------------------') # for k in rollouts.keys(): # print("rollouts: {0}".format(rollouts.observations)) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) with torch.no_grad(): step_observation = { k: v[rollouts.step] for k, v in rollouts.observations.items() } _, actions, _, = self.actor_critic.act( step_observation, rollouts.prev_g_actions[0], rollouts.masks[0], ) self.global_goals = [[int(action[0].item() * self.map_w), int(action[1].item() * self.map_h)] for action in actions] # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros( self.envs.num_envs, 1, device=self.device ) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size) ) print("*************************** current_episode_reward:", current_episode_reward) print("*************************** running_episode_stats:", running_episode_stats) # print("*************************** window_episode_stats:", window_episode_stats) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth") interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"] ) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] deif = {} with ( TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) if self.world_rank == 0 else contextlib.suppress() ) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES ) # print("************************************* current_episodes:", type(self.envs.count_episodes())) # print(EXIT.is_set()) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth" ) print("********************EXIT*********************") requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_global_rollout_step( rollouts, current_episode_reward, running_episode_stats ) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # print("************************************* current_episodes:") for i in range(len(self.envs.current_episodes())): # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category) if self.envs.current_episodes()[i].scene_id not in deif: deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)] else: deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id)) # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if ( step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size ): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0 ) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ( (v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item() ) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), ) ) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format( update, env_time, pth_time, count_steps ) ) logger.info( "Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count" ), ) ) # for k in deif: # deif[k] = list(set(deif[k])) # deif[k].sort() # print("deif: k", k, " : ", deif[k]) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) print('=' * 20 + 'Save Model' + '=' * 20) logger.info( "Save Model : {}".format(count_checkpoints) ) count_checkpoints += 1 self.envs.close()
def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO ans_cfg = config.RL.ANS config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() self.envs = construct_envs(config, get_env_class(config.ENV_NAME)) self._setup_actor_critic_agent(ppo_cfg, ans_cfg) # Convert the state_dict of mapper_agent to mapper mapper_dict = { k.replace("mapper.", ""): v for k, v in ckpt_dict["mapper_state_dict"].items() } # Converting the state_dict of local_agent to just the local_policy. local_dict = { k.replace("actor_critic.", ""): v for k, v in ckpt_dict["local_state_dict"].items() } # Strict = False is set to ignore to handle the case where # pose_estimator is not required. self.mapper.load_state_dict(mapper_dict, strict=False) self.local_actor_critic.load_state_dict(local_dict) # Set models to evaluation self.mapper.eval() self.local_actor_critic.eval() number_of_eval_episodes = self.config.TEST_EPISODE_COUNT if number_of_eval_episodes == -1: number_of_eval_episodes = sum(self.envs.number_of_episodes) else: total_num_eps = sum(self.envs.number_of_episodes) if total_num_eps < number_of_eval_episodes: logger.warn( f"Config specified {number_of_eval_episodes} eval episodes" ", dataset only has {total_num_eps}." ) logger.warn(f"Evaluating with {total_num_eps} instead.") number_of_eval_episodes = total_num_eps M = ans_cfg.overall_map_size V = ans_cfg.MAPPER.map_size s = ans_cfg.MAPPER.map_scale imH, imW = ans_cfg.image_scale_hw assert ( self.envs.num_envs == 1 ), "Number of environments needs to be 1 for evaluation" # Define metric accumulators # Navigation metrics navigation_metrics = { "success_rate": Metric(), "spl": Metric(), "distance_to_goal": Metric(), "time": Metric(), "softspl": Metric(), } per_difficulty_navigation_metrics = { "easy": { "success_rate": Metric(), "spl": Metric(), "distance_to_goal": Metric(), "time": Metric(), "softspl": Metric(), }, "medium": { "success_rate": Metric(), "spl": Metric(), "distance_to_goal": Metric(), "time": Metric(), "softspl": Metric(), }, "hard": { "success_rate": Metric(), "spl": Metric(), "distance_to_goal": Metric(), "time": Metric(), "softspl": Metric(), }, } times_per_episode = deque() times_per_step = deque() # Define a simple function to return episode difficulty based on # the geodesic distance def classify_difficulty(gd): if gd < 5.0: return "easy" elif gd < 10.0: return "medium" else: return "hard" eval_start_time = time.time() # Reset environments only for the very first batch observations = self.envs.reset() for ep in range(number_of_eval_episodes): # ============================== Reset agent ============================== # Reset agent states state_estimates = { "pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device), "map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device), "recurrent_hidden_states": torch.zeros( 1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size ).to(self.device), } # Reset ANS states self.ans_net.reset() self.not_done_masks = torch.zeros(self.envs.num_envs, 1, device=self.device) self.prev_actions = torch.zeros(self.envs.num_envs, 1, device=self.device) self.prev_batch = None self.ep_time = torch.zeros(self.envs.num_envs, 1, device=self.device) # =========================== Episode loop ================================ ep_start_time = time.time() current_episodes = self.envs.current_episodes() for ep_step in range(self.config.T_MAX): step_start_time = time.time() # ============================ Action step ============================ batch = self._prepare_batch(observations) if self.prev_batch is None: self.prev_batch = copy.deepcopy(batch) prev_pose_estimates = state_estimates["pose_estimates"] with torch.no_grad(): ( _, _, mapper_outputs, local_policy_outputs, state_estimates, ) = self.ans_net.act( batch, self.prev_batch, state_estimates, self.ep_time, self.not_done_masks, deterministic=ans_cfg.LOCAL_POLICY.deterministic_flag, ) actions = local_policy_outputs["actions"] # Make masks not done till reset (end of episode) self.not_done_masks = torch.ones( self.envs.num_envs, 1, device=self.device ) self.prev_actions.copy_(actions) if ep_step == 0: state_estimates["pose_estimates"].copy_(prev_pose_estimates) self.ep_time += 1 # Update prev batch for k, v in batch.items(): self.prev_batch[k].copy_(v) # Remap actions from exploration to navigation agent. actions_rmp = self._remap_actions(actions) # =========================== Environment step ======================== outputs = self.envs.step([a[0].item() for a in actions_rmp]) observations, _, dones, infos = [list(x) for x in zip(*outputs)] times_per_step.append(time.time() - step_start_time) # ============================ Process metrics ======================== if dones[0]: times_per_episode.append(time.time() - ep_start_time) mins_per_episode = np.mean(times_per_episode).item() / 60.0 eta_completion = mins_per_episode * ( number_of_eval_episodes - ep - 1 ) secs_per_step = np.mean(times_per_step).item() for i in range(self.envs.num_envs): episode_id = int(current_episodes[i].episode_id) curr_metrics = { "spl": infos[i]["spl"], "softspl": infos[i]["softspl"], "success_rate": infos[i]["success"], "time": ep_step + 1, "distance_to_goal": infos[i]["distance_to_goal"], } # Estimate difficulty of episode episode_difficulty = classify_difficulty( current_episodes[i].info["geodesic_distance"] ) for k, v in curr_metrics.items(): navigation_metrics[k].update(v, 1.0) per_difficulty_navigation_metrics[episode_difficulty][ k ].update(v, 1.0) logger.info(f"====> {ep}/{number_of_eval_episodes} done") for k, v in curr_metrics.items(): logger.info(f"{k:25s} : {v:10.3f}") logger.info("{:25s} : {:10d}".format("episode_id", episode_id)) logger.info(f"Time per episode: {mins_per_episode:.3f} mins") logger.info(f"Time per step: {secs_per_step:.3f} secs") logger.info(f"ETA: {eta_completion:.3f} mins") # For navigation, terminate episode loop when dones is called break # done-for if checkpoint_index == 0: try: eval_ckpt_idx = self.config.EVAL_CKPT_PATH_DIR.split("/")[-1].split( "." )[1] logger.add_filehandler( f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_final_{eval_ckpt_idx}.txt" ) except: logger.add_filehandler( f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_{checkpoint_index}.txt" ) else: logger.add_filehandler( f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_{checkpoint_index}.txt" ) logger.info( f"======= Evaluating over {number_of_eval_episodes} episodes =============" ) logger.info(f"=======> Navigation metrics") for k, v in navigation_metrics.items(): logger.info(f"{k}: {v.get_metric():.3f}") writer.add_scalar(f"navigation/{k}", v.get_metric(), checkpoint_index) for diff, diff_metrics in per_difficulty_navigation_metrics.items(): logger.info(f"=============== {diff:^10s} metrics ==============") for k, v in diff_metrics.items(): logger.info(f"{k}: {v.get_metric():.3f}") writer.add_scalar( f"{diff}_navigation/{k}", v.get_metric(), checkpoint_index ) total_eval_time = (time.time() - eval_start_time) / 60.0 logger.info(f"Total evaluation time: {total_eval_time:.3f} mins") self.envs.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(config, get_env_class(config.ENV_NAME)) self._setup_actor_critic_agent(ppo_cfg) self.actor_critic.eval() if self._static_encoder: self._encoder = self.agent.actor_critic.net.visual_encoder self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic observations = self.envs.reset() batch = batch_obs(observations, device=self.device) if self._static_encoder: batch["visual_features"] = self._encoder(batch) batch["prev_visual_features"] = torch.zeros_like( batch["visual_features"]) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) number_of_eval_episodes = self.config.TEST_EPISODE_COUNT if number_of_eval_episodes == -1: number_of_eval_episodes = sum(self.envs.number_of_episodes) else: total_num_eps = sum(self.envs.number_of_episodes) if total_num_eps < number_of_eval_episodes: logger.warn( f"Config specified {number_of_eval_episodes} eval episodes" ", dataset only has {total_num_eps}.") logger.warn(f"Evaluating with {total_num_eps} instead.") number_of_eval_episodes = total_num_eps pbar = tqdm.tqdm(total=number_of_eval_episodes) self.actor_critic.eval() while (len(stats_episodes) < number_of_eval_episodes and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): step_batch = batch ( _, actions, _, test_recurrent_hidden_states, ) = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) prev_actions.copy_(actions) outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations, device=self.device) if self._static_encoder: batch["prev_visual_features"] = step_batch["visual_features"] batch["visual_features"] = self._encoder(batch) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: pbar.update() episode_stats = dict() episode_stats["reward"] = current_episode_reward[i].item() episode_stats.update( self._extract_scalars_from_info(infos[i])) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, )] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics=self._extract_scalars_from_info(infos[i]), tb_writer=writer, ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) rgb_frames[i].append(frame) ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) num_episodes = len(stats_episodes) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = ( sum([v[stat_key] for v in stats_episodes.values()]) / num_episodes) for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.4f}") step_id = checkpoint_index if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] writer.add_scalars( "eval_reward", {"average reward": aggregated_stats["reward"]}, step_id, ) metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} if len(metrics) > 0: writer.add_scalars("eval_metrics", metrics, step_id) self.envs.close()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) # Initialize auxiliary tasks observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = \ self._setup_auxiliary_tasks(aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time for k, v in running_episode_stats.items(): window_episode_stats[k].append(v.clone()) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) writer.add_scalar("reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) losses = [value_loss, action_loss] + aux_task_losses writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) writer.add_scalar( "success", deltas["success"] / deltas["count"], count_steps, ) # Log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {}\t action_loss: {}\taux_task_loss: {} \t aux_entropy {}" .format(update, value_loss, action_loss, aux_task_losses, aux_dist_entropy)) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()
def train(self) -> None: r"""Main method for training DAgger. Returns: None """ os.makedirs(self.lmdb_features_dir, exist_ok=True) os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True) if self.config.DAGGER.PRELOAD_LMDB_FEATURES: try: lmdb.open(self.lmdb_features_dir, readonly=True) except lmdb.Error as err: logger.error( "Cannot open database for teacher forcing preload.") raise err else: with lmdb.open(self.lmdb_features_dir, map_size=int(self.config.DAGGER.LMDB_MAP_SIZE) ) as lmdb_env, lmdb_env.begin(write=True) as txn: txn.drop(lmdb_env.open_db()) split = self.config.TASK_CONFIG.DATASET.SPLIT self.config.defrost() self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split # if doing teacher forcing, don't switch the scene until it is complete if self.config.DAGGER.P == 1.0: self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = ( -1) self.config.freeze() if self.config.DAGGER.PRELOAD_LMDB_FEATURES: # when preloadeding features, its quicker to just load one env as we just # need the observation space from it. single_proc_config = self.config.clone() single_proc_config.defrost() single_proc_config.NUM_PROCESSES = 1 single_proc_config.freeze() self.envs = construct_envs(single_proc_config, get_env_class(self.config.ENV_NAME)) else: self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self._setup_actor_critic_agent( self.config.MODEL, self.config.DAGGER.LOAD_FROM_CKPT, self.config.DAGGER.CKPT_TO_LOAD, ) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.actor_critic.parameters()))) logger.info("agent number of trainable parameters: {}".format( sum(p.numel() for p in self.actor_critic.parameters() if p.requires_grad))) if self.config.DAGGER.PRELOAD_LMDB_FEATURES: self.envs.close() del self.envs self.envs = None with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs, purge_step=0) as writer: for dagger_it in range(self.config.DAGGER.ITERATIONS): step_id = 0 if not self.config.DAGGER.PRELOAD_LMDB_FEATURES: self._update_dataset(dagger_it + ( 1 if self.config.DAGGER.LOAD_FROM_CKPT else 0)) if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() gc.collect() dataset = IWTrajectoryDataset( self.lmdb_features_dir, self.config.DAGGER.USE_IW, inflection_weight_coef=self.config.MODEL. inflection_weight_coef, lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE, batch_size=self.config.DAGGER.BATCH_SIZE, ) AuxLosses.activate() for epoch in tqdm.trange(self.config.DAGGER.EPOCHS): diter = torch.utils.data.DataLoader( dataset, batch_size=self.config.DAGGER.BATCH_SIZE, shuffle=False, collate_fn=collate_fn, pin_memory=False, drop_last=True, # drop last batch if smaller num_workers=0, ) for batch in tqdm.tqdm(diter, total=dataset.length // dataset.batch_size, leave=False): ( observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, weights_batch, ) = batch observations_batch = { k: v.to(device=self.device, non_blocking=True) for k, v in observations_batch.items() } try: loss, action_loss, aux_loss = self._update_agent( observations_batch, prev_actions_batch.to(device=self.device, non_blocking=True), not_done_masks.to(device=self.device, non_blocking=True), corrected_actions_batch.to(device=self.device, non_blocking=True), weights_batch.to(device=self.device, non_blocking=True), ) except: logger.info( "ERROR: failed to update agent. Updating agent with batch size of 1." ) loss, action_loss, aux_loss = 0, 0, 0 prev_actions_batch = prev_actions_batch.cpu() not_done_masks = not_done_masks.cpu() corrected_actions_batch = corrected_actions_batch.cpu( ) weights_batch = weights_batch.cpu() observations_batch = { k: v.cpu() for k, v in observations_batch.items() } for i in range(not_done_masks.size(0)): output = self._update_agent( { k: v[i].to(device=self.device, non_blocking=True) for k, v in observations_batch.items() }, prev_actions_batch[i].to( device=self.device, non_blocking=True), not_done_masks[i].to(device=self.device, non_blocking=True), corrected_actions_batch[i].to( device=self.device, non_blocking=True), weights_batch[i].to(device=self.device, non_blocking=True), ) loss += output[0] action_loss += output[1] aux_loss += output[2] logger.info(f"train_loss: {loss}") logger.info(f"train_action_loss: {action_loss}") logger.info(f"train_aux_loss: {aux_loss}") logger.info(f"Batches processed: {step_id}.") logger.info( f"On DAgger iter {dagger_it}, Epoch {epoch}.") writer.add_scalar(f"train_loss_iter_{dagger_it}", loss, step_id) writer.add_scalar( f"train_action_loss_iter_{dagger_it}", action_loss, step_id) writer.add_scalar(f"train_aux_loss_iter_{dagger_it}", aux_loss, step_id) step_id += 1 self.save_checkpoint( f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth" ) AuxLosses.deactivate()
def _eval_checkpoint(self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, log_diagnostics=[], output_dir='.', label='.', num_eval_runs=1) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ if checkpoint_index == -1: ckpt_file = checkpoint_path.split('/')[-1] split_info = ckpt_file.split('.') checkpoint_index = split_info[1] # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO task_cfg = config.TASK_CONFIG.TASK config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(config, get_env_class(config.ENV_NAME)) # pass in aux config if we're doing attention aux_cfg = self.config.RL.AUX_TASKS self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg) # Check if we accidentally recorded `visual_resnet` in our checkpoint and drop it (it's redundant with `visual_encoder`) ckpt_dict['state_dict'] = { k: v for k, v in ckpt_dict['state_dict'].items() if 'visual_resnet' not in k } self.agent.load_state_dict(ckpt_dict["state_dict"]) logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) self.actor_critic = self.agent.actor_critic observations = self.envs.reset() batch = batch_obs(observations, device=self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) _, num_recurrent_memories, _ = self._setup_auxiliary_tasks( aux_cfg, ppo_cfg, task_cfg, is_eval=True) if self.config.RL.PPO.policy in MULTIPLE_BELIEF_CLASSES: aux_tasks = self.config.RL.AUX_TASKS.tasks num_recurrent_memories = len(self.config.RL.AUX_TASKS.tasks) test_recurrent_hidden_states = test_recurrent_hidden_states.unsqueeze( 2).repeat(1, 1, num_recurrent_memories, 1) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) number_of_eval_episodes = self.config.TEST_EPISODE_COUNT if number_of_eval_episodes == -1: number_of_eval_episodes = sum(self.envs.number_of_episodes) else: total_num_eps = sum(self.envs.number_of_episodes) if total_num_eps < number_of_eval_episodes: logger.warn( f"Config specified {number_of_eval_episodes} eval episodes" ", dataset only has {total_num_eps}.") logger.warn(f"Evaluating with {total_num_eps} instead.") number_of_eval_episodes = total_num_eps videos_cap = 2 # number of videos to generate per checkpoint if len(log_diagnostics) > 0: videos_cap = 10 # video_indices = random.sample(range(self.config.TEST_EPISODE_COUNT), # min(videos_cap, self.config.TEST_EPISODE_COUNT)) video_indices = range(10) print(f"Videos: {video_indices}") total_stats = [] dones_per_ep = dict() # Logging more extensive evaluation stats for analysis if len(log_diagnostics) > 0: d_stats = {} for d in log_diagnostics: d_stats[d] = [ [] for _ in range(self.config.NUM_PROCESSES) ] # stored as nested list envs x timesteps x k (# tasks) pbar = tqdm.tqdm(total=number_of_eval_episodes * num_eval_runs) self.agent.eval() while (len(stats_episodes) < number_of_eval_episodes * num_eval_runs and self.envs.num_envs > 0): current_episodes = self.envs.current_episodes() with torch.no_grad(): weights_output = None if self.config.RL.PPO.policy in MULTIPLE_BELIEF_CLASSES: weights_output = torch.empty(self.envs.num_envs, len(aux_tasks)) ( _, actions, _, test_recurrent_hidden_states, ) = self.actor_critic.act(batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, weights_output=weights_output) prev_actions.copy_(actions) for i in range(self.envs.num_envs): if Diagnostics.actions in log_diagnostics: d_stats[Diagnostics.actions][i].append( prev_actions[i].item()) if Diagnostics.weights in log_diagnostics: aux_weights = None if weights_output is None else weights_output[ i] if aux_weights is not None: d_stats[Diagnostics.weights][i].append( aux_weights.half().tolist()) outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations, device=self.device) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.envs.current_episodes() envs_to_pause = [] n_envs = self.envs.num_envs for i in range(n_envs): next_k = ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) if dones_per_ep.get(next_k, 0) == num_eval_runs: envs_to_pause.append(i) # wait for the rest if not_done_masks[i].item() == 0: episode_stats = dict() episode_stats["reward"] = current_episode_reward[i].item() episode_stats.update( self._extract_scalars_from_info(infos[i])) current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats k = ( current_episodes[i].scene_id, current_episodes[i].episode_id, ) dones_per_ep[k] = dones_per_ep.get(k, 0) + 1 if dones_per_ep.get(k, 0) == 1 and len( self.config.VIDEO_OPTION) > 0 and len( stats_episodes) in video_indices: logger.info(f"Generating video {len(stats_episodes)}") category = getattr(current_episodes[i], "object_category", "") if category != "": category += "_" try: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=checkpoint_index, metrics=self._extract_scalars_from_info( infos[i]), tag=f"{category}{label}", tb_writer=writer, ) except Exception as e: logger.warning(str(e)) rgb_frames[i] = [] stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, dones_per_ep[k], )] = episode_stats if len(log_diagnostics) > 0: diagnostic_info = dict() for metric in log_diagnostics: diagnostic_info[metric] = d_stats[metric][i] d_stats[metric][i] = [] if Diagnostics.top_down_map in log_diagnostics: top_down_map = torch.tensor([]) if len(self.config.VIDEO_OPTION) > 0: top_down_map = infos[i]["top_down_map"]["map"] top_down_map = maps.colorize_topdown_map( top_down_map, fog_of_war_mask=None) diagnostic_info.update( dict(top_down_map=top_down_map)) total_stats.append( dict( stats=episode_stats, did_stop=bool(prev_actions[i] == 0), episode_info=attr.asdict(current_episodes[i]), info=diagnostic_info, )) pbar.update() # episode continues else: if len(self.config.VIDEO_OPTION) > 0: aux_weights = None if weights_output is None else weights_output[ i] frame = observations_to_image( observations[i], infos[i], current_episode_reward[i].item(), aux_weights, aux_tasks) rgb_frames[i].append(frame) if Diagnostics.gps in log_diagnostics: d_stats[Diagnostics.gps][i].append( observations[i]["gps"].tolist()) if Diagnostics.heading in log_diagnostics: d_stats[Diagnostics.heading][i].append( observations[i]["heading"].tolist()) ( self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.envs, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) num_episodes = len(stats_episodes) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = ( sum([v[stat_key] for v in stats_episodes.values()]) / num_episodes) for k, v in aggregated_stats.items(): logger.info(f"Average episode {k}: {v:.4f}") step_id = checkpoint_index if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] writer.add_scalars( "eval_reward", {"average reward": aggregated_stats["reward"]}, step_id, ) metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} if len(metrics) > 0: writer.add_scalars("eval_metrics", metrics, step_id) logger.info("eval_metrics") logger.info(metrics) if len(log_diagnostics) > 0: os.makedirs(output_dir, exist_ok=True) eval_fn = f"{label}.json" with open(os.path.join(output_dir, eval_fn), 'w', encoding='utf-8') as f: json.dump(total_stats, f, ensure_ascii=False, indent=4) self.envs.close()
def eval(self, checkpoint_path): r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() self.env = construct_envs(config, get_env_class(config.ENV_NAME)) self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic # get name of performance metric, e.g. "spl" metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0] metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name) measure_type = baseline_registry.get_measure(metric_cfg.TYPE) assert measure_type is not None, "invalid measurement type {}".format( metric_cfg.TYPE) self.metric_uuid = measure_type(sim=None, task=None, config=None)._get_uuid() observations = self.env.reset() batch = batch_obs(observations, self.device) current_episode_reward = torch.zeros(self.env.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) stats_episodes = dict() # dict of dicts that stores stats per episode rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES) ] # type: List[List[np.ndarray]] if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) self.actor_critic.eval() while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT and self.env.num_envs > 0): current_episodes = self.env.current_episodes() with torch.no_grad(): ( _, actions, _, test_recurrent_hidden_states, ) = self.actor_critic.act( batch, test_recurrent_hidden_states, prev_actions, not_done_masks, deterministic=False, ) prev_actions.copy_(actions) outputs = self.env.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ list(x) for x in zip(*outputs) ] batch = batch_obs(observations, self.device) not_done_masks = torch.tensor( [[0.0] if done else [1.0] for done in dones], dtype=torch.float, device=self.device, ) rewards = torch.tensor(rewards, dtype=torch.float, device=self.device).unsqueeze(1) current_episode_reward += rewards next_episodes = self.env.current_episodes() envs_to_pause = [] n_envs = self.env.num_envs for i in range(n_envs): if ( next_episodes[i].scene_id, next_episodes[i].episode_id, ) in stats_episodes: envs_to_pause.append(i) # episode ended if not_done_masks[i].item() == 0: episode_stats = dict() episode_stats[self.metric_uuid] = infos[i][ self.metric_uuid] episode_stats["success"] = int( infos[i][self.metric_uuid] > 0) episode_stats["reward"] = current_episode_reward[i].item() current_episode_reward[i] = 0 # use scene_id + episode_id as unique id for storing stats stats_episodes[( current_episodes[i].scene_id, current_episodes[i].episode_id, )] = episode_stats if len(self.config.VIDEO_OPTION) > 0: generate_video( video_option=self.config.VIDEO_OPTION, video_dir=self.config.VIDEO_DIR, images=rgb_frames[i], episode_id=current_episodes[i].episode_id, checkpoint_idx=0, metric_name=self.metric_uuid, metric_value=infos[i][self.metric_uuid], ) rgb_frames[i] = [] # episode continues elif len(self.config.VIDEO_OPTION) > 0: frame = observations_to_image(observations[i], infos[i]) rgb_frames[i].append(frame) ( self.env, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) = self._pause_envs( envs_to_pause, self.env, test_recurrent_hidden_states, not_done_masks, current_episode_reward, prev_actions, batch, rgb_frames, ) aggregated_stats = dict() for stat_key in next(iter(stats_episodes.values())).keys(): aggregated_stats[stat_key] = sum( [v[stat_key] for v in stats_episodes.values()]) num_episodes = len(stats_episodes) episode_reward_mean = aggregated_stats["reward"] / num_episodes episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes episode_success_mean = aggregated_stats["success"] / num_episodes print(f"Average episode reward: {episode_reward_mean:.6f}") print(f"Average episode success: {episode_success_mean:.6f}") print(f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}") if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: step_id = ckpt_dict["extra_state"]["step"] print("eval_reward", {"average reward": episode_reward_mean}) print( f"eval_{self.metric_uuid}", {f"average {self.metric_uuid}": episode_metric_mean}, ) print("eval_success", {"average success": episode_success_mean}) self.env.close()
def _eval_checkpoint_bruce( self, checkpoint_path: str, checkpoint_index: int = 0, ): r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: actor_critic batch not_done_masks test_recurrent_hidden_states """ # Map location CPU is almost always better than mapping to a CUDA device. ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu") print(ckpt_dict) if self.config.EVAL.USE_CKPT_CONFIG: config = self._setup_eval_config(ckpt_dict["config"]) else: config = self.config.clone() ppo_cfg = config.RL.PPO config.defrost() config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT config.freeze() if len(self.config.VIDEO_OPTION) > 0: config.defrost() config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS") config.freeze() logger.info(f"env config: {config}") self.envs = construct_envs(config, get_env_class(config.ENV_NAME)) self._setup_actor_critic_agent(ppo_cfg) self.agent.load_state_dict(ckpt_dict["state_dict"]) self.actor_critic = self.agent.actor_critic observations = self.envs.reset() batch = batch_obs(observations, device=self.device) current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, self.config.NUM_PROCESSES, ppo_cfg.hidden_size, device=self.device, ) prev_actions = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long) not_done_masks = torch.zeros(self.config.NUM_PROCESSES, 1, device=self.device) return (self.actor_critic, batch, not_done_masks, test_recurrent_hidden_states)
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None: r"""Main method for training PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank self.config.freeze() if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO task_cfg = self.config.TASK_CONFIG.TASK observation_space = self.envs.observation_spaces[0] aux_cfg = self.config.RL.AUX_TASKS init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks( aux_cfg, ppo_cfg, task_cfg, observation_space) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, observation_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_memories=num_recurrent_memories) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg, init_aux_tasks) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), # including bonus ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 prev_time = 0 if ckpt != -1: logger.info( f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly." ) assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported" # This is the checkpoint we start saving at count_checkpoints = ckpt + 1 count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.agent.load_state_dict(ckpt_dict["state_dict"]) if "optim_state" in ckpt_dict: self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"]) else: logger.warn("No optimizer state loaded, results may be funky") if "extra_state" in ckpt_dict and "step" in ckpt_dict[ "extra_state"]: count_steps = ckpt_dict["extra_state"]["step"] lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_updates = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_updates, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() ( delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0).to(self.device) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ dist_entropy, aux_dist_entropy, ] + [value_loss, action_loss] + aux_task_losses + [count_steps_delta], device=self.device, ) distrib.all_reduce(stats) if aux_weights is not None and len(aux_weights) > 0: distrib.all_reduce( torch.tensor(aux_weights, device=self.device)) count_steps += stats[-1].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") avg_stats = [ stats[i].item() / self.world_size for i in range(len(stats) - 1) ] losses = avg_stats[2:] dist_entropy, aux_dist_entropy = avg_stats[:2] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) writer.add_scalar( "entropy", dist_entropy, count_steps, ) writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", { k: l for l, k in zip(losses, ["value", "policy"] + aux_task_strings) }, count_steps, ) writer.add_scalars( "aux_weights", {k: l for l, k in zip(aux_weights, aux_task_strings)}, count_steps, ) # Log stats formatted_aux_losses = [ "{:.3g}".format(l) for l in aux_task_losses ] if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info( "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t" .format( update, value_loss, action_loss, formatted_aux_losses, aux_dist_entropy, )) logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)) count_checkpoints += 1 self.envs.close()