def load_latest_checkpoint(checkpoint, directory, device): ckpts = get_files(directory, "*.ckpt") if ckpts: last_ckpt = ckpts[-1] checkpoint.restore(last_ckpt, device) else: raise ValueError("No checkpoints found in {}.".format(directory))
def _trim_checkpoints(self): """Trim older checkpoints until `max_to_keep` remain. """ # get a list of checkpoints in reverse # chronological order ckpts = get_files(self.directory, "*.ckpt")[::-1] # remove until `max_to_keep` remain num_remove = len(ckpts) - self.max_to_keep while num_remove > 0: ckpt_name = ckpts.pop() os.remove(ckpt_name) num_remove -= 1
def get_embs(embs_path, embodiment_names, num_traj): emb_files = file_utils.get_files(osp.join(embs_path, "embs"), "*.npy", sort=False) embs, frames, pos_neg_labels = {}, {}, {} for i, emb_file in enumerate(emb_files): emb_name = osp.basename(emb_file).split(".")[0] if emb_name not in embodiment_names: continue with open(emb_file, "rb") as f: data = np.load(f, allow_pickle=True).item() embs[emb_name] = data["embs"][:num_traj] frames[emb_name] = data["frames"][:num_traj] pos_neg_labels[emb_name] = data["labels"][:num_traj] return embs, frames, pos_neg_labels
def restore_or_initialize(self): """Restore items in checkpoint from the latest checkpoint file. Returns: The global iteration step. This is parsed from the latest checkpoint file if one is found, else 0 is returned. """ ckpts = get_files(self.directory, "*.ckpt") if ckpts: last_ckpt = ckpts[-1] status = self.checkpoint.restore(last_ckpt, self.device) if not status: logging.info("Could not restore latest checkpoint file.") return 0 self.latest_checkpoint = last_ckpt return int(osp.basename(last_ckpt).split(".")[0]) return 0
def main(args): plt.switch_backend("Agg") emb_files = file_utils.get_files(osp.join(args.embs_path, "embs"), "*.npy", sort=False) print("Found files: {}".format(emb_files)) for i, emb_file in enumerate(emb_files): embs, frames = [], [] with open(emb_file, "rb") as f: query_dict = np.load(f, allow_pickle=True).item() for j in range(len(query_dict["embs"])): curr_embs = query_dict["embs"][j] if args.l2_normalize: curr_embs = [x / (np.linalg.norm(x) + 1e-7) for x in curr_embs] embs.append(curr_embs) frames.append(query_dict["frames"][j]) # generate video name video_path = osp.join(args.embs_path, "videos") file_utils.mkdir(video_path) ext = ".mp4" if args.align else "_original.mp4" video_path = osp.join(video_path, osp.basename(emb_file).split(".")[0] + ext) if not args.align: print("Playing videos without alignment.") create_original_videos(frames, video_path, args.interval) else: print("Aligning videos.") create_video( embs, frames, video_path, args.use_dtw, query=args.reference_video, candidate=args.candidate_video, interval=args.interval, )
def videos_to_images(in_dir, out_dir, num_cores): set_start_method("spawn") with Pool(num_cores) as pool: action_dirs = file_utils.get_subdirs(in_dir) for action_dir in action_dirs: print("Processing ", osp.basename(action_dir)) out_action = osp.join(out_dir, osp.basename(action_dir)) file_utils.mkdir(out_action) videos = file_utils.get_subdirs(action_dir) for video in videos: out_action_video = osp.join(out_action, osp.basename(video)) file_utils.mkdir(out_action_video) files = file_utils.get_files(video, "*.mp4", False) func_args = [[f, out_action_video, video] for f in files] for _ in tqdm( pool.imap_unordered(process_video, func_args), total=len(files), ): pass file_utils.copy_file( osp.join(video, "joint_angles.txt"), osp.join(out_action_video, "joint_angles.txt"), )
def _load_frames(self, vid_dirs): assert ( len(vid_dirs) > 1 ), f"{self.__class__.__name__} can only operate on multiple videos at a time." return [get_files(vd, self._pattern) for vd in vid_dirs]
def _load_frames(self, vid_dirs): assert ( len(vid_dirs) == 1 ), f"{self.__class__.__name__} can only operate on a single video at a time." return get_files(vid_dirs[0], self._pattern)
def main(args): emb_files = file_utils.get_files( osp.join(args.embs_path, "embs"), "*.npy", sort=False ) print(f"Found {len(emb_files)} files.") embodiment_names = [ "two_hands_two_fingers", "tongs", "rms", # "ski_gloves", # "quick_grip", "one_hand_two_fingers", "one_hand_five_fingers", "double_quick_grip", # "crab", # "quick_grasp", ] all_embs, all_names, phase_labels = [], [], [] for i, emb_file in enumerate(emb_files): emb_name = osp.basename(emb_file).split(".")[0] if emb_name not in embodiment_names: continue embs = [] with open(emb_file, "rb") as f: query_dict = np.load(f, allow_pickle=True).item() for j in range(len(query_dict["embs"])): if j == args.num_traj: break curr_embs = query_dict["embs"][j] if args.l2_normalize: curr_embs = [x / (np.linalg.norm(x) + 1e-7) for x in curr_embs] embs.append(curr_embs) all_embs.append(embs) all_names.append(emb_name) phase_labels.append(query_dict["phase_labels"]) # split one_hand_five_fingers into 2 h_idx = all_names.index("one_hand_five_fingers") h_emb = all_embs[h_idx] h_phase_labels = phase_labels[h_idx] all_names.pop(h_idx) all_embs.pop(h_idx) emb_one_hand_five_fivers_1 = [] emb_one_hand_five_fivers_2 = [] for emb, pl in zip(h_emb, h_phase_labels): if len(np.unique(pl)) == 2: emb_one_hand_five_fivers_1.append(emb) else: emb_one_hand_five_fivers_2.append(emb) all_names = all_names + [ "one_hand_five_fingers_without_slide", "one_hand_five_fingers_with_slide", ] all_embs.append(emb_one_hand_five_fivers_1) all_embs.append(emb_one_hand_five_fivers_2) num_embs = len(all_embs) chamfers = np.zeros((num_embs, num_embs)) for i in range(len(all_embs)): query_embs = all_embs[i] for j in range(len(all_embs)): if i == j: continue cand_embs = all_embs[j] query_chamfer = [] for query_emb in query_embs: kd_tree_query = KDTree(query_emb) chamfer_dist = [] for cand_emb in cand_embs: chamfer_dist.append( symmetric_chamfer(query_emb, cand_emb, kd_tree_query) ) query_chamfer.append(np.mean(chamfer_dist)) chamfers[i, j] = np.mean(query_chamfer) # determine row-wise min mins = [] for i in range(num_embs): mins.append(np.argsort(chamfers[i])[1]) # plot heatmap fig, ax = plt.subplots(figsize=(15, 10)) im = ax.imshow(chamfers, cmap="RdBu", interpolation="nearest") cbar = ax.figure.colorbar(im, ax=ax) cbar.ax.set_ylabel("Chamfer Distance", rotation=-90, va="bottom") ax.set_xticks(np.arange(num_embs)) ax.set_yticks(np.arange(num_embs)) ax.set_xticklabels(all_names) ax.set_yticklabels(all_names) ax.yaxis.set_tick_params(labelsize=7) ax.xaxis.set_tick_params(labelsize=7) plt.setp( ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" ) for i in range(num_embs): for j in range(num_embs): if i == j: continue txt = "{:.5f}".format(chamfers[i, j]) if j == mins[i]: txt += "_min" _ = ax.text( j, i, txt, ha="center", va="center", color="black", fontsize=7 ) task_name = args.embs_path.split("/")[-2] ax.set_title(task_name) name = "{}_heatmap.png".format(task_name) plt.savefig(osp.join("tmp", name), format="png", dpi=300) plt.show()
if args.output_dir is not None: file_utils.mkdir(args.output_dir) else: args.output_dir = args.input_dir frame_dir = osp.join(args.input_dir, "frames") label_dir = osp.join(args.input_dir, "labels") if args.resize: args.resize = (args.height, args.width) logging.info("Resizing images to ({}, {}).".format(*args.resize)) else: args.resize = None # read all label files and figure out which # videos belong to which class label_files = file_utils.get_files(label_dir, "*.mat") label_classes = [loadmat(lab)["action"][0] for lab in label_files] # figure out train test splits train_val_splits = [loadmat(lab)["train"][0][0] for lab in label_files] # now figure out at which indices the class # change occurs change_idxs = [] for i in range(len(label_classes) - 1): c_curr = label_classes[i] c_next = label_classes[i + 1] if c_curr != c_next: change_idxs.append(i + 1) change_idxs = [0, *change_idxs, len(label_files)]
def main(args): # get a list of embedding files # one for each action class emb_files = file_utils.get_files(osp.join(args.embs_path, "embs"), "*.npy", sort=False) print(f"Found {len(emb_files)} files.") all_embs, all_names = [], [] for i, emb_file in enumerate(emb_files): embs = [] with open(emb_file, "rb") as f: query_dict = np.load(f, allow_pickle=True).item() for j in range(len(query_dict["embs"])): curr_embs = query_dict["embs"][j] if args.l2_normalize: curr_embs = [x / (np.linalg.norm(x) + 1e-7) for x in curr_embs] embs.append(curr_embs) all_embs.append(embs) all_names.append(osp.basename(emb_file).split(".")[0]) query_name = "rms" query_idx = all_names.index(query_name) query_embs = all_embs[query_idx] all_embs.pop(query_idx) all_names.pop(query_idx) # pick a random query trajectory emb_query = query_embs[2] # np.random.choice(query_embs) kd_tree_query = KDTree(emb_query) # loop through all other embeddings # and find the closest trajectory # and its class # chamfer_dists = [] class_min = -1 traj_min = -1 chamfer_dist = 100 for i, emb_cands in enumerate(all_embs): for t, emb_cand in enumerate(emb_cands): dist = symmetric_chamfer(emb_query, emb_cand, kd_tree_query) if dist < chamfer_dist: class_min = i traj_min = t chamfer_dist = dist # chamfer_dists = np.array(chamfer_dists) # class_min = chamfer_dists.min(axis=1).argmin() # traj_min = chamfer_dists.argmin(axis=1)[class_min] print(all_names[class_min], traj_min) # artificially increase / decrease some trajectories # emb_best = all_embs[class_min][traj_min] emb_bad = all_embs[all_names.index("hand_5_fingers")][-9] # emb_best_mod = modify_emb(emb_best) # emb_bad_mod = modify_emb(emb_bad) # compute dense reward between query emb # and closest one as determined above # reward_best = dense_reward(emb_query, emb_best, True) reward_bad = dense_reward(emb_query, emb_bad, False) # reward_best_mod = dense_reward(emb_query, emb_best_mod) # reward_bad_mod = dense_reward(emb_query, emb_bad_mod) plt.figure() # plt.plot(reward_best, label='best') plt.plot(reward_bad, label="random") # plt.plot(reward_best_mod, label='best-updownsampled') # plt.plot(reward_bad_mod, label='random-updownsampled') plt.xlabel("Frame Index") plt.xlabel("Reward") plt.legend() plt.savefig("./reward_rms.png", format="png", dpi=300) plt.show()