def train(self) -> None: r"""Main method for pre-training Encoder-Decoder Feature Extractor for EQA. Returns: None """ config = self.config eqa_cnn_pretrain_dataset = EQACNNPretrainDataset(config) train_loader = DataLoader( eqa_cnn_pretrain_dataset, batch_size=config.IL.EQACNNPretrain.batch_size, shuffle=True, ) logger.info("[ train_loader has {} samples ]".format( len(eqa_cnn_pretrain_dataset))) model = MultitaskCNN() model.train().to(self.device) optim = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=float(config.IL.EQACNNPretrain.lr), ) depth_loss = torch.nn.SmoothL1Loss() ae_loss = torch.nn.SmoothL1Loss() seg_loss = torch.nn.CrossEntropyLoss() epoch, t = 1, 0 with TensorboardWriter(config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: while epoch <= config.IL.EQACNNPretrain.max_epochs: start_time = time.time() avg_loss = 0.0 for batch in train_loader: t += 1 idx, gt_rgb, gt_depth, gt_seg = batch optim.zero_grad() gt_rgb = gt_rgb.to(self.device) gt_depth = gt_depth.to(self.device) gt_seg = gt_seg.to(self.device) pred_seg, pred_depth, pred_rgb = model(gt_rgb) l1 = seg_loss(pred_seg, gt_seg.long()) l2 = ae_loss(pred_rgb, gt_rgb) l3 = depth_loss(pred_depth, gt_depth) loss = l1 + (10 * l2) + (10 * l3) avg_loss += loss.item() if t % config.LOG_INTERVAL == 0: logger.info( "[ Epoch: {}; iter: {}; loss: {:.3f} ]".format( epoch, t, loss.item())) writer.add_scalar("total_loss", loss, t) writer.add_scalars( "individual_losses", { "seg_loss": l1, "ae_loss": l2, "depth_loss": l3 }, t, ) loss.backward() optim.step() end_time = time.time() time_taken = "{:.1f}".format((end_time - start_time) / 60) avg_loss = avg_loss / len(train_loader) logger.info( "[ Epoch {} completed. Time taken: {} minutes. ]".format( epoch, time_taken)) logger.info("[ Average loss: {:.3f} ]".format(avg_loss)) print("-----------------------------------------") if epoch % config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(model.state_dict(), "epoch_{}.ckpt".format(epoch)) epoch += 1
def __init__( self, config: Config, env: habitat.Env, device: torch.device, max_controller_actions: int = 5, ): """ Args: config: Config env: habitat Env device: torch.device max_controller_actions (int) """ self.config = config.TASK_CONFIG self.env = env self.episodes = self.env._dataset.episodes self.max_controller_actions = max_controller_actions self.device = device self.sim = self.env.sim # sorting and making episode ids consecutive for simpler indexing self.sort_episodes() self.q_vocab = self.env._dataset.question_vocab self.ans_vocab = self.env._dataset.answer_vocab self.eval_save_results = config.EVAL_SAVE_RESULTS if self.config.DATASET.SPLIT == config.EVAL.SPLIT: self.mode = "val" else: self.mode = "train" self.frame_dataset_path = config.FRAME_DATASET_PATH.format( split=self.mode) self.calc_max_length() self.restructure_ans_vocab() cnn_kwargs = { "only_encoder": True, "checkpoint_path": config.EQA_CNN_PRETRAIN_CKPT_PATH, } self.cnn = MultitaskCNN(**cnn_kwargs) self.cnn.eval() self.cnn.to(self.device) self.scene_episode_dict = get_scene_episode_dict(self.episodes) self.preprocess_actions() if self.mode == "val": ctr = 0 # ids in a way that episodes with same scenes are grouped together for scene in tqdm( self.scene_episode_dict.keys(), desc="going through all scenes from dataset", ): for episode in self.scene_episode_dict[scene]: episode.episode_id = ctr ctr += 1 self.sort_episodes(consecutive_ids=False) group_by_keys = filters.Curried(self.group_by_keys_) super().__init__( urls=self.frame_dataset_path + ".tar", initial_pipeline=[group_by_keys()], ) if not self.cache_exists(): """ for each scene > load scene in memory > save frames for each episode corresponding to each scene """ logger.info("[ Dataset cache not present / is incomplete. ]\ \n[ Saving episode frames to disk. ]") logger.info("Number of {} episodes: {}".format( self.mode, len(self.episodes))) ctr = 0 for scene in tqdm( list(self.scene_episode_dict.keys()), desc="Going through all scenes from dataset", ): self.load_scene(scene) for episode in tqdm( self.scene_episode_dict[scene], desc="Saving episode frames for each scene", ): pos_queue = episode.shortest_paths[0] self.save_frame_queue(pos_queue, episode.episode_id) logger.info("[ Saved all episodes' frames to disk. ]") create_tar_archive( self.frame_dataset_path + ".tar", self.frame_dataset_path, ) logger.info("[ Tar archive created. ]") logger.info( "[ Deleting dataset folder. This will take a few minutes. ]") delete_folder(self.frame_dataset_path) logger.info("[ Frame dataset is ready. ]")
def _eval_checkpoint( self, checkpoint_path: str, writer: TensorboardWriter, checkpoint_index: int = 0, ) -> None: r"""Evaluates a single checkpoint. Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard checkpoint_index: index of cur checkpoint for logging Returns: None """ config = self.config config.defrost() config.TASK_CONFIG.DATASET.SPLIT = self.config.EVAL.SPLIT config.freeze() eqa_cnn_pretrain_dataset = EQACNNPretrainDataset(config, mode="val") eval_loader = DataLoader( eqa_cnn_pretrain_dataset, batch_size=config.IL.EQACNNPretrain.batch_size, shuffle=False, ) logger.info("[ eval_loader has {} samples ]".format( len(eqa_cnn_pretrain_dataset))) model = MultitaskCNN() state_dict = torch.load(checkpoint_path) model.load_state_dict(state_dict) model.to(self.device).eval() depth_loss = torch.nn.SmoothL1Loss() ae_loss = torch.nn.SmoothL1Loss() seg_loss = torch.nn.CrossEntropyLoss() t = 0 avg_loss = 0.0 avg_l1 = 0.0 avg_l2 = 0.0 avg_l3 = 0.0 with torch.no_grad(): for batch in eval_loader: t += 1 idx, gt_rgb, gt_depth, gt_seg = batch gt_rgb = gt_rgb.to(self.device) gt_depth = gt_depth.to(self.device) gt_seg = gt_seg.to(self.device) pred_seg, pred_depth, pred_rgb = model(gt_rgb) l1 = seg_loss(pred_seg, gt_seg.long()) l2 = ae_loss(pred_rgb, gt_rgb) l3 = depth_loss(pred_depth, gt_depth) loss = l1 + (10 * l2) + (10 * l3) avg_loss += loss.item() avg_l1 += l1.item() avg_l2 += l2.item() avg_l3 += l3.item() if t % config.LOG_INTERVAL == 0: logger.info( "[ Iter: {}; loss: {:.3f} ]".format(t, loss.item()), ) if (config.EVAL_SAVE_RESULTS and t % config.EVAL_SAVE_RESULTS_INTERVAL == 0): result_id = "ckpt_{}_{}".format(checkpoint_index, idx[0].item()) result_path = os.path.join(self.config.RESULTS_DIR, result_id) self._save_results( gt_rgb, pred_rgb, gt_seg, pred_seg, gt_depth, pred_depth, result_path, ) avg_loss /= len(eval_loader) avg_l1 /= len(eval_loader) avg_l2 /= len(eval_loader) avg_l3 /= len(eval_loader) writer.add_scalar("avg val total loss", avg_loss, checkpoint_index) writer.add_scalars( "avg val individual_losses", { "seg_loss": avg_l1, "ae_loss": avg_l2, "depth_loss": avg_l3 }, checkpoint_index, ) logger.info("[ Average loss: {:.3f} ]".format(avg_loss)) logger.info("[ Average seg loss: {:.3f} ]".format(avg_l1)) logger.info("[ Average autoencoder loss: {:.4f} ]".format(avg_l2)) logger.info("[ Average depthloss: {:.4f} ]".format(avg_l3))
class NavDataset(wds.Dataset): """Pytorch dataset for PACMAN based navigation""" def __init__( self, config: Config, env: habitat.Env, device: torch.device, max_controller_actions: int = 5, ): """ Args: config: Config env: habitat Env device: torch.device max_controller_actions (int) """ self.config = config.TASK_CONFIG self.env = env self.episodes = self.env._dataset.episodes self.max_controller_actions = max_controller_actions self.device = device self.sim = self.env.sim # sorting and making episode ids consecutive for simpler indexing self.sort_episodes() self.q_vocab = self.env._dataset.question_vocab self.ans_vocab = self.env._dataset.answer_vocab self.eval_save_results = config.EVAL_SAVE_RESULTS if self.config.DATASET.SPLIT == config.EVAL.SPLIT: self.mode = "val" else: self.mode = "train" self.frame_dataset_path = config.FRAME_DATASET_PATH.format( split=self.mode) self.calc_max_length() self.restructure_ans_vocab() cnn_kwargs = { "only_encoder": True, "checkpoint_path": config.EQA_CNN_PRETRAIN_CKPT_PATH, } self.cnn = MultitaskCNN(**cnn_kwargs) self.cnn.eval() self.cnn.to(self.device) self.scene_episode_dict = get_scene_episode_dict(self.episodes) self.preprocess_actions() if self.mode == "val": ctr = 0 # ids in a way that episodes with same scenes are grouped together for scene in tqdm( self.scene_episode_dict.keys(), desc="going through all scenes from dataset", ): for episode in self.scene_episode_dict[scene]: episode.episode_id = ctr ctr += 1 self.sort_episodes(consecutive_ids=False) group_by_keys = filters.Curried(self.group_by_keys_) super().__init__( urls=self.frame_dataset_path + ".tar", initial_pipeline=[group_by_keys()], ) if not self.cache_exists(): """ for each scene > load scene in memory > save frames for each episode corresponding to each scene """ logger.info("[ Dataset cache not present / is incomplete. ]\ \n[ Saving episode frames to disk. ]") logger.info("Number of {} episodes: {}".format( self.mode, len(self.episodes))) ctr = 0 for scene in tqdm( list(self.scene_episode_dict.keys()), desc="Going through all scenes from dataset", ): self.load_scene(scene) for episode in tqdm( self.scene_episode_dict[scene], desc="Saving episode frames for each scene", ): pos_queue = episode.shortest_paths[0] self.save_frame_queue(pos_queue, episode.episode_id) logger.info("[ Saved all episodes' frames to disk. ]") create_tar_archive( self.frame_dataset_path + ".tar", self.frame_dataset_path, ) logger.info("[ Tar archive created. ]") logger.info( "[ Deleting dataset folder. This will take a few minutes. ]") delete_folder(self.frame_dataset_path) logger.info("[ Frame dataset is ready. ]") def flat_to_hierarchical_actions(self, actions: List, controller_action_lim: int): assert len(actions) != 0 controller_action_ctr = 0 planner_actions, controller_actions = [1], [] prev_action = 1 pq_idx, cq_idx, ph_idx = [], [], [] ph_trck = 0 for i in range(len(actions)): if actions[i] != prev_action: planner_actions.append(actions[i]) pq_idx.append(i) if i > 0: ph_idx.append(ph_trck) if actions[i] == prev_action: controller_actions.append(1) controller_action_ctr += 1 else: controller_actions.append(0) controller_action_ctr = 0 ph_trck += 1 cq_idx.append(i) prev_action = actions[i] if controller_action_ctr == controller_action_lim - 1: prev_action = False return planner_actions, controller_actions, pq_idx, cq_idx, ph_idx def get_img_features(self, img: Union[np.ndarray, torch.Tensor], preprocess: bool = False) -> torch.Tensor: if preprocess: img = ((torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0).view(1, 3, 256, 256).to(self.device)) with torch.no_grad(): return self.cnn(img) def get_hierarchical_features_till_spawn( self, idx: int, actions: np.ndarray, backtrack_steps: int = 0, max_controller_actions: int = 5, ): action_length = len(actions) pa, ca, pq_idx, cq_idx, ph_idx = self.flat_to_hierarchical_actions( actions=actions, controller_action_lim=max_controller_actions) # count how many actions of same type have been encountered before # starting navigation backtrack_controller_steps = actions[0:action_length - backtrack_steps + 1: # noqa: E203 ][::-1] counter = 0 if len(backtrack_controller_steps) > 0: while (counter <= self.max_controller_actions) and ( counter < len(backtrack_controller_steps) and (backtrack_controller_steps[counter] == backtrack_controller_steps[0])): counter += 1 target_pos_idx = action_length - backtrack_steps controller_step = True if target_pos_idx in pq_idx: controller_step = False pq_idx_pruned = [v for v in pq_idx if v <= target_pos_idx] pa_pruned = pa[:len(pq_idx_pruned) + 1] raw_img_feats = (self.get_img_features( self.frame_queue).cpu().numpy().copy()) controller_img_feat = torch.from_numpy( raw_img_feats[target_pos_idx].copy()) controller_action_in = pa_pruned[-1] - 2 planner_img_feats = torch.from_numpy( raw_img_feats[pq_idx_pruned].copy()) planner_actions_in = torch.from_numpy(np.array(pa_pruned[:-1]) - 1) init_pos = self.episodes[idx].shortest_paths[0][target_pos_idx] return ( planner_actions_in, planner_img_feats, controller_step, controller_action_in, controller_img_feat, init_pos, counter, ) def preprocess_actions(self) -> None: """ actions before - 0 - FWD; 1 - LEFT; 2 - RIGHT; 3 - STOP; actions after - 0 - NULL; 1 - START; 2 - FWD; 3 - LEFT; 4 - RIGHT; 5 - STOP; """ for ep in self.episodes: ep.actions = [x.action + 2 for x in ep.shortest_paths[0]] ep.action_length = len(ep.actions) ( planner_actions, controller_actions, pq_idx, cq_idx, ph_idx, ) = self.flat_to_hierarchical_actions( actions=ep.actions, controller_action_lim=self.max_controller_actions, ) # padding actions with 0 diff = self.max_action_len - ep.action_length for _ in range(diff): ep.actions.append(0) ep.actions = torch.Tensor(ep.actions) ep.planner_actions = ep.actions.clone().fill_(0) ep.controller_actions = ep.actions.clone().fill_(-1) ep.planner_hidden_idx = ep.actions.clone().fill_(0) ep.planner_pos_queue_idx, ep.controller_pos_queue_idx = [], [] ep.planner_actions[:len(planner_actions)] = torch.Tensor( planner_actions) ep.controller_actions[:len(controller_actions)] = torch.Tensor( controller_actions) ep.planner_action_length = len(planner_actions) - 1 ep.controller_action_length = len(controller_actions) ep.planner_pos_queue_idx.append(pq_idx) ep.controller_pos_queue_idx.append(cq_idx) ep.planner_hidden_idx[:len(controller_actions)] = torch.Tensor( ph_idx) def group_by_keys_( self, data: Generator, keys: Callable[[str], Tuple[str]] = base_plus_ext, lcase: bool = True, suffixes=None, ): """Returns function over iterator that groups key, value pairs into samples- a custom pipeline for grouping episode info & images in the webdataset. keys: function that splits the key into key and extension (base_plus_ext) lcase: convert suffixes to lower case (Default value = True) """ current_sample = {} for fname, value in data: prefix, suffix = keys(fname) if prefix is None: continue if lcase: suffix = suffix.lower() if not current_sample or prefix != current_sample["__key__"]: if valid_sample(current_sample): yield current_sample current_sample = dict(__key__=prefix) episode_id = int(prefix[prefix.rfind("/") + 1:]) current_sample["episode_id"] = self.episodes[ episode_id].episode_id question = self.episodes[episode_id].question.question_tokens if len(question) < self.max_q_len: diff = self.max_q_len - len(question) for _ in range(diff): question.append(0) current_sample["question"] = np.array(question, dtype=np.int_) current_sample["answer"] = self.ans_vocab.word2idx( self.episodes[episode_id].question.answer_text) if suffix in current_sample: raise ValueError( f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}" ) if suffixes is None or suffix in suffixes: current_sample[suffix] = value if valid_sample(current_sample): yield current_sample def calc_max_length(self) -> None: r"""Calculates max length of questions and actions. This will be used for padding questions and actions with 0s so that they have same string length. """ self.max_q_len = max( len(episode.question.question_tokens) for episode in self.episodes) self.max_action_len = max( len(episode.shortest_paths[0]) for episode in self.episodes) def restructure_ans_vocab(self) -> None: r""" Restructures answer vocab so that each answer id corresponds to a numerical index starting from 0 for first answer. """ for idx, key in enumerate(sorted(self.ans_vocab.word2idx_dict.keys())): self.ans_vocab.word2idx_dict[key] = idx def get_vocab_dicts(self) -> Tuple[VocabDict, VocabDict]: r"""Returns Q&A VocabDicts""" return self.q_vocab, self.ans_vocab def sort_episodes(self, consecutive_ids: bool = True) -> None: # TODO: can be done in mp3d_eqa_dataset class too? self.episodes = sorted(self.episodes, key=lambda x: int(x.episode_id)) if consecutive_ids: for idx, ep in enumerate(self.episodes): ep.episode_id = idx def save_frame_queue( self, pos_queue: List[ShortestPathPoint], episode_id: str, ) -> None: r"""Writes episode's frame queue to disk.""" for idx, pos in enumerate(pos_queue): observation = self.env.sim.get_observations_at( pos.position, pos.rotation) img = observation["rgb"] idx = "{0:0=3d}".format(idx) episode_id = "{0:0=4d}".format(int(episode_id)) new_path = os.path.join(self.frame_dataset_path, "{}.{}".format(episode_id, idx)) cv2.imwrite(new_path + ".jpg", img[..., ::-1]) def cache_exists(self) -> bool: if os.path.exists(self.frame_dataset_path + ".tar"): return True else: os.makedirs(self.frame_dataset_path, exist_ok=True) return False def load_scene(self, scene: str) -> None: self.config.defrost() self.config.SIMULATOR.SCENE = scene self.config.freeze() self.env.sim.reconfigure(self.config.SIMULATOR) def map_dataset_sample(self, x: Dict) -> Tuple: """Mapper function to pre-process webdataset sample, example: img features, planner & controller actions etc. Args: x: webdataset sample containing ep_id, question, answer and imgs Returns: Processed sample containing img features, planner & controller actions etc. """ idx = x["episode_id"] question = x["question"] answer = x["answer"] if len(question) < self.max_q_len: diff = self.max_q_len - len(question) for _ in range(diff): question.append(0) self.frame_queue = np.array( [img.transpose(2, 0, 1) / 255.0 for img in list(x.values())[4:]]) self.frame_queue = torch.Tensor(self.frame_queue).to(self.device) if self.mode == "val": # works only with batch size 1 actions = self.episodes[idx].actions action_length = self.episodes[idx].action_length scene = self.episodes[idx].scene_id if scene != self.config.SIMULATOR.SCENE: logger.info("[ Loading scene - {}]".format(scene)) self.config.defrost() self.config.SIMULATOR.SCENE = scene self.config.freeze() self.env.sim.reconfigure(self.config.SIMULATOR) goal_pos = self.episodes[idx].goals[0].position return idx, question, answer, actions, action_length, goal_pos planner_actions = self.episodes[idx].planner_actions controller_actions = self.episodes[idx].controller_actions planner_hidden_idx = self.episodes[idx].planner_hidden_idx planner_action_length = self.episodes[idx].planner_action_length controller_action_length = self.episodes[idx].controller_action_length raw_img_feats = (self.get_img_features( self.frame_queue).cpu().numpy().copy()) img_feats = np.zeros((self.max_action_len, raw_img_feats.shape[1]), dtype=np.float32) img_feats[:raw_img_feats.shape[0], :] = raw_img_feats.copy() planner_pos_queue_idx = self.episodes[idx].planner_pos_queue_idx controller_pos_queue_idx = self.episodes[idx].controller_pos_queue_idx planner_img_feats = np.zeros((self.max_action_len, img_feats.shape[1]), dtype=np.float32) planner_img_feats[:self.episodes[idx]. planner_action_length] = img_feats[tuple( planner_pos_queue_idx)] planner_actions_in = planner_actions.clone() - 1 planner_actions_out = planner_actions[1:].clone() - 2 planner_actions_in[planner_action_length:].fill_(0) planner_mask = planner_actions_out.clone().gt(-1) if len(planner_actions_out) > planner_action_length: planner_actions_out[planner_action_length:].fill_(0) controller_img_feats = np.zeros( (self.max_action_len, img_feats.shape[1]), dtype=np.float32) controller_img_feats[:controller_action_length] = img_feats[tuple( controller_pos_queue_idx)] controller_actions_in = self.episodes[idx].actions.clone() - 2 if len(controller_actions_in) > controller_action_length: controller_actions_in[controller_action_length:].fill_(0) controller_out = controller_actions controller_mask = controller_out.clone().gt(-1) if len(controller_out) > controller_action_length: controller_out[controller_action_length:].fill_(0) # zero out forced controller return for i in range(controller_action_length): if (i >= self.max_controller_actions - 1 and controller_out[i] == 0 and (self.max_controller_actions == 1 or controller_out[i - self.max_controller_actions + 1:i # noqa ].sum() == self.max_controller_actions - 1)): controller_mask[i] = 0 return ( idx, question, answer, planner_img_feats, planner_actions_in, planner_actions_out, planner_action_length, planner_mask, controller_img_feats, controller_actions_in, planner_hidden_idx, controller_out, controller_action_length, controller_mask, ) def __len__(self) -> int: return len(self.episodes)