def __getitem__(self, index): im_path = Path(self.root) / "images/{}.jpg".format(self.im_list[index]) im = Image.open(im_path).convert("RGB") name = Path(im_path).stem anno_template = str(Path(self.root) / "labels/{}/{}_lbl{:02d}.png") seg = np.zeros((im.size[1], im.size[0]), dtype=np.uint8) # for ii in range(10, 0, -1): for ii in range(1, 11): anno_path = anno_template.format(name, name, ii) lbl_im = np.array(Image.open(anno_path).convert("L")) assert lbl_im.ndim == 2, "expected greyscale" # if sum(seg[lbl_im > self.thresh]) > 0: # print("already colored these pixels") # import ipdb; ipdb.set_trace() seg[lbl_im > self.thresh * 255] = ii # plt.matshow(seg) # zs_dispFig() seg = Image.fromarray(seg, "L") # if self.train and False: # im, seg = self.augs(im, seg) seg = self.label_resizer(seg) seg = torch.from_numpy(np.array(seg)) data = self.resizer(im) data = self.transforms(data) if False: counts = torch.histc(seg.float(), bins=11, min=0, max=10) probs = counts / counts.sum() for name, prob in zip(self.classnames, probs): print("{}\t {:.2f}".format(name, prob)) if self.rand_in: data = torch.randn(data.shape) if self.visualize: from torchvision.utils import make_grid from utils.visualization import norm_range ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy() plt.close("all") plt.axis("off") fig = plt.figure() # a new figure window ax1 = fig.add_subplot(1, 3, 1) ax2 = fig.add_subplot(1, 3, 2) ax3 = fig.add_subplot(1, 3, 3) ax1.imshow(ims) ax2.imshow(label_colormap(seg).numpy()) if self.downsample_labels: sz = tuple([x * self.downsample_labels for x in seg.size()]) seg_ = np.array(Image.fromarray(seg.numpy()).resize(sz)) else: seg_ = seg # ax3.matshow(seg_) ax3.imshow(label_colormap(seg_).numpy()) ax3.imshow(ims, alpha=0.5) zs_dispFig() return {"data": data, "meta": {"im_path": str(im_path), "lbls": seg}}
def show_im(im_cv2, title_str): """Quick image visualiser Args: im_cv2 (ndarray): input image in BGR format title_str (str): title header """ fig = plt.figure(frameon=False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.axis('off') fig.add_axes(ax) im = im_cv2[:, :, ::-1] ax.imshow(im) plt.title(title_str) zs_dispFig()
def vis_sequence(self, sequence: torch.Tensor): columns = 4 rows = (sequence.shape[0] + 1) // (columns) figsize = (32, (16 // columns) * rows) plt.figure(figsize=figsize) gs = gridspec.GridSpec(rows, columns) for j in range(rows * columns): plt.subplot(gs[j]) plt.axis("off") im = sequence[:, j].permute(1, 2, 0).cpu().numpy() im = im - im.min() im = im / im.max() plt.imshow(im) if self.disp_fig: from zsvision.zs_iterm import zs_dispFig zs_dispFig()
def parse_video_content(video_idx, video_path, store_compressed, vis, resize_res, total_videos, processes): frames = [] markers = 100 if processes > 1 and video_idx % int( max(total_videos, markers) / markers) == 0: pct = 100 * video_idx / total_videos print( f"processing {video_idx}/{total_videos} [{pct:.1f}%] [{video_path}]" ) cap = cv2.VideoCapture(str(video_path)) orig_dims = None while True: ret, rgb = cap.read() if ret: # BGR (OpenCV) to RGB rgb = rgb[:, :, [2, 1, 0]] if store_compressed: buffer = io.BytesIO() im = Image.fromarray(rgb) orig_dims = im.size resized = im.resize((resize_res, resize_res)) resized.save(buffer, format="JPEG", quality=store_compressed) if vis: plt.imshow(resized) zs_dispFig() rgb = buffer else: # apply Gul-style preproc iH, iW, iC = rgb.shape if iW > iH: nH, nW = resize_res, int(resize_res * iW / iH) else: nH, nW = int(resize_res * iH / iW), resize_res orig_dims = (iH, iW) rgb = resize_generic(rgb, nH, nW, interp="bilinear") frames.append(rgb) else: break cap.release() if not store_compressed: frames = np.array(frames) store = {"data": frames, "orig_dims": orig_dims} if frames: assert orig_dims is not None, "expected resize_ratio to be set" return store
def retrieval_as_classification(sims, query_masks=None): """Compute classification metrics from a similiarity matrix. """ assert sims.ndim == 2, "expected a matrix" # switch axes of query-labels and video sims = sims.T query_masks = query_masks.T dists = -sims num_queries, num_labels = sims.shape break_ties = "averaging" query_ranks = [] for ii in range(num_queries): row_dists = dists[ii, :] # NOTE: Using distance subtraction to perform the ranking is easier to make # deterministic than using argsort, which suffers from the issue of defining # "stability" for equal distances. Example of distance subtraction code: # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py sorted_dists = np.sort(row_dists) # min_rank = np.inf label_ranks = [] for gt_label in np.where(query_masks[ii, :])[0]: ranks = np.where((sorted_dists - row_dists[gt_label]) == 0)[0] if break_ties == "optimistically": rank = ranks[0] elif break_ties == "averaging": # NOTE: If there is more than one caption per video, its possible for the # method to do "worse than chance" in the degenerate case when all # similarities are tied. TODO(Samuel): Address this case. rank = ranks.mean() else: raise ValueError(f"unknown tie-breaking method: {break_ties}") label_ranks.append(rank) # Avoid penalising for assigning higher similarity to other gt labels. This is # done by subtracting out the better ranked query labels. Note that this step # introduces a slight skew in favour of videos with lots of labels. We can # address this later with a normalisation step if needed. label_ranks = [x - idx for idx, x in enumerate(label_ranks)] # Include all labels in the final calculation query_ranks.extend(label_ranks) query_ranks = np.array(query_ranks) # sanity check against old version of code if False: # visualise the distance matrix import sys import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python")) from zsvision.zs_iterm import zs_dispFig # NOQA # plt.matshow(dists) # zs_dispFig() plt.hist(query_ranks, bins=313, alpha=0.5) plt.grid() zs_dispFig() import ipdb; ipdb.set_trace() return cols2metrics(query_ranks, num_queries=len(query_ranks))
def v2t_metrics(sims, query_masks=None): """Compute retrieval metrics from a similiarity matrix. Args: sims (th.Tensor): N x M matrix of similarities between embeddings, where x_{i,j} = <text_embd[i], vid_embed[j]> query_masks (th.Tensor): mask any missing captions from the dataset Returns: (dict[str:float]): retrieval metrics NOTES: We find the closest "GT caption" in the style of VSE, which corresponds to finding the rank of the closest relevant caption in embedding space: github.com/ryankiros/visual-semantic-embedding/blob/master/evaluation.py#L52-L56 """ # switch axes of text and video sims = sims.T if False: # experiment with toy example sims = np.ones((3, 3)) sims[0, 0] = 2 sims[1, 1:2] = 2 sims[2, :] = 2 query_masks = None assert sims.ndim == 2, "expected a matrix" num_queries, num_caps = sims.shape dists = -sims caps_per_video = num_caps // num_queries break_ties = "averaging" MISSING_VAL = 1E8 query_ranks = [] for ii in range(num_queries): row_dists = dists[ii, :] if query_masks is not None: # Set missing queries to have a distance of infinity. A missing query # refers to a query position `n` for a video that had less than `n` # captions (for example, a few MSRVTT videos only have 19 queries) row_dists[np.logical_not(query_masks.reshape(-1))] = MISSING_VAL # NOTE: Using distance subtraction to perform the ranking is easier to make # deterministic than using argsort, which suffers from the issue of defining # "stability" for equal distances. Example of distance subtraction code: # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py sorted_dists = np.sort(row_dists) min_rank = np.inf for jj in range(ii * caps_per_video, (ii + 1) * caps_per_video): if row_dists[jj] == MISSING_VAL: # skip rankings of missing captions continue ranks = np.where((sorted_dists - row_dists[jj]) == 0)[0] if break_ties == "optimistically": rank = ranks[0] elif break_ties == "averaging": # NOTE: If there is more than one caption per video, its possible for the # method to do "worse than chance" in the degenerate case when all # similarities are tied. TODO(Samuel): Address this case. rank = ranks.mean() if rank < min_rank: min_rank = rank query_ranks.append(min_rank) query_ranks = np.array(query_ranks) # sanity check against old version of code if False: sorted_dists = np.sort(dists, axis=1) gt_dists_old = np.diag(dists) gt_dists_old = gt_dists_old[:, np.newaxis] rows_old, cols_old = np.where((sorted_dists - gt_dists_old) == 0) if rows_old.size > num_queries: _, idx = np.unique(rows_old, return_index=True) cols_old = cols_old[idx] num_diffs = (1 - (cols_old == query_ranks)).sum() msg = f"new metric doesn't match in {num_diffs} places" assert np.array_equal(cols_old, query_ranks), msg # visualise the distance matrix import sys import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python")) from zsvision.zs_iterm import zs_dispFig # NOQA plt.matshow(dists) zs_dispFig() return cols2metrics(query_ranks, num_queries)
def t2v_metrics(sims, query_masks=None): """Compute retrieval metrics from a similiarity matrix. Args: sims (th.Tensor): N x M matrix of similarities between embeddings, where x_{i,j} = <text_embd[i], vid_embed[j]> query_masks (th.Tensor): mask any missing queries from the dataset (two videos in MSRVTT only have 19, rather than 20 captions) Returns: (dict[str:float]): retrieval metrics """ assert sims.ndim == 2, "expected a matrix" num_queries, num_vids = sims.shape dists = -sims sorted_dists = np.sort(dists, axis=1) if False: import sys import matplotlib from pathlib import Path matplotlib.use("Agg") import matplotlib.pyplot as plt sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python")) from zsvision.zs_iterm import zs_dispFig # NOQA plt.matshow(dists) zs_dispFig() import ipdb; ipdb.set_trace() # The indices are computed such that they slice out the ground truth distances # from the psuedo-rectangular dist matrix queries_per_video = num_queries // num_vids gt_idx = [[np.ravel_multi_index([ii, jj], (num_queries, num_vids)) for ii in range(jj * queries_per_video, (jj + 1) * queries_per_video)] for jj in range(num_vids)] gt_idx = np.array(gt_idx) gt_dists = dists.reshape(-1)[gt_idx.reshape(-1)] gt_dists = gt_dists[:, np.newaxis] rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT # -------------------------------- # NOTE: Breaking ties # -------------------------------- # We sometimes need to break ties (in general, these should occur extremely rarely, # but there are pathological cases when they can distort the scores, such as when # the similarity matrix is all zeros). Previous implementations (e.g. the t2i # evaluation function used # here: https://github.com/niluthpol/multimodal_vtt/blob/master/evaluation.py and # here: https://github.com/linxd5/VSE_Pytorch/blob/master/evaluation.py#L87) generally # break ties "optimistically". However, if the similarity matrix is constant this # can evaluate to a perfect ranking. A principled option is to average over all # possible partial orderings implied by the ties. See # this paper for a discussion: # McSherry, Frank, and Marc Najork, # "Computing information retrieval performance measures efficiently in the presence # of tied scores." European conference on information retrieval. Springer, Berlin, # Heidelberg, 2008. # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.145.8892&rep=rep1&type=pdf # break_ties = "optimistically" break_ties = "averaging" if rows.size > num_queries: assert np.unique(rows).size == num_queries, "issue in metric evaluation" if break_ties == "optimistically": _, idx = np.unique(rows, return_index=True) cols = cols[idx] elif break_ties == "averaging": # fast implementation, based on this code: # https://stackoverflow.com/a/49239335 locs = np.argwhere((sorted_dists - gt_dists) == 0) # Find the split indices steps = np.diff(locs[:, 0]) splits = np.nonzero(steps)[0] + 1 splits = np.insert(splits, 0, 0) # Compute the result columns summed_cols = np.add.reduceat(locs[:, 1], splits) counts = np.diff(np.append(splits, locs.shape[0])) avg_cols = summed_cols / counts if False: print("Running slower code to verify rank averaging across ties") # slow, but more interpretable version, used for testing avg_cols_slow = [np.mean(cols[rows == idx]) for idx in range(num_queries)] assert np.array_equal(avg_cols, avg_cols_slow), "slow vs fast difference" print("passed num check") cols = avg_cols msg = "expected ranks to match queries ({} vs {}) " if cols.size != num_queries: import ipdb; ipdb.set_trace() assert cols.size == num_queries, msg if False: # overload mask to check that we can recover the scores for single-query # retrieval print("DEBUGGING MODE") query_masks = np.zeros_like(query_masks) query_masks[:, 0] = 1 # recover single query score if query_masks is not None: # remove invalid queries assert query_masks.size == num_queries, "invalid query mask shape" cols = cols[query_masks.reshape(-1).astype(np.bool)] assert cols.size == query_masks.sum(), "masking was not applied correctly" # update number of queries to account for those that were missing num_queries = query_masks.sum() if False: # sanity check against old logic for square matrices gt_dists_old = np.diag(dists) gt_dists_old = gt_dists_old[:, np.newaxis] _, cols_old = np.where((sorted_dists - gt_dists_old) == 0) assert np.array_equal(cols_old, cols), "new metric doesn't match" return cols2metrics(cols, num_queries)
(cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] vis_detections(im, cls, dets, thresh=CONF_THRESH) cfg.TEST.HAS_RPN = True # Use RPN for proposals if net_name == 'ResNet-50': pt = 'resnet50_rfcn_pascal_debug.proto' model = 'resnet50_rfcn_final.caffemodel' else: raise ValueError('{} not recognised'.format(net_name)) prototxt = pjoin(model_dir, 'proto', pt) caffemodel = pjoin(model_dir, 'weights', model) if not os.path.isfile(caffemodel): raise IOError(('{:s} not found.\nDid you run ./data/script/' 'fetch_faster_rcnn_models.sh?').format(caffemodel)) # psroi pooling only works on gpu caffe.set_mode_gpu() caffe.set_device(2) cfg.GPU_ID = 2 net = caffe.Net(prototxt, caffemodel, caffe.TEST) demo(net, im_path, blob_save_path) plt.show() zs_dispFig()
def __getitem__(self, index): if (not self.use_ims and not self.use_keypoints): # early exit when caching is used return {"data": torch.zeros(3, 1, 1), "meta": {"index": index}} if self.use_ims: im = Image.open(os.path.join(self.subdir, self.filenames[index])) # print("imread: {:.3f}s".format(time.time() - tic)) ; tic = time.time() kp = None if self.use_keypoints: kp = self.keypoints[index].copy() meta = {} if self.warper is not None: if self.warper.returns_pairs: # tic = time.time() im1 = self.initial_transforms(im.convert("RGB")) # print("tx1: {:.3f}s".format(time.time() - tic)) ; tic = time.time() im1 = TF.to_tensor(im1) * 255 if False: from utils.visualization import norm_range plt.imshow(norm_range(im1).permute(1, 2, 0).cpu().numpy()) plt.scatter(kp[:, 0], kp[:, 1]) im1, im2, flow, grid, kp1, kp2 = self.warper(im1, keypts=kp, crop=self.crop) # print("warper: {:.3f}s".format(time.time() - tic)) ; tic = time.time() im1 = im1.to(torch.uint8) im2 = im2.to(torch.uint8) C, H, W = im1.shape im1 = TF.to_pil_image(im1) im2 = TF.to_pil_image(im2) im1 = self.transforms(im1) im2 = self.transforms(im2) # print("tx-2: {:.3f}s".format(time.time() - tic)) ; tic = time.time() C, H, W = im1.shape data = torch.stack((im1, im2), 0) meta = { 'flow': flow[0], 'grid': grid[0], 'im1': im1, 'im2': im2, 'index': index } if self.use_keypoints: meta = {**meta, **{'kp1': kp1, 'kp2': kp2}} else: im1 = self.initial_transforms(im.convert("RGB")) im1 = TF.to_tensor(im1) * 255 im1, kp = self.warper(im1, keypts=kp, crop=self.crop) im1 = im1.to(torch.uint8) im1 = TF.to_pil_image(im1) im1 = self.transforms(im1) C, H, W = im1.shape data = im1 if self.use_keypoints: meta = { 'keypts': kp, 'keypts_normalized': kp_normalize(H, W, kp), 'index': index } else: if self.use_ims: data = self.transforms( self.initial_transforms(im.convert("RGB"))) if self.crop != 0: data = data[:, self.crop:-self.crop, self.crop:-self.crop] C, H, W = data.shape else: # after caching descriptors, there is no point doing I/O H = W = self.imwidth - 2 * self.crop data = torch.zeros(3, 1, 1) if kp is not None: kp = kp - self.crop kp = torch.tensor(kp) if self.use_keypoints: meta = { 'keypts': kp, 'keypts_normalized': kp_normalize(H, W, kp), 'index': index } if self.visualize: # from torchvision.utils import make_grid from utils.visualization import norm_range num_show = 2 if self.warper and self.warper.returns_pairs else 1 plt.clf() fig = plt.figure() for ii in range(num_show): im_ = data[ii] if num_show > 1 else data ax = fig.add_subplot(1, num_show, ii + 1) ax.imshow(norm_range(im_).permute(1, 2, 0).cpu().numpy()) if self.use_keypoints: if num_show == 2: kp_x = meta["kp{}".format(ii + 1)][:, 0].numpy() kp_y = meta["kp{}".format(ii + 1)][:, 1].numpy() else: kp_x = kp[:, 0].numpy() kp_y = kp[:, 1].numpy() ax.scatter(kp_x, kp_y) zs_dispFig() import ipdb ipdb.set_trace() # zs. # if self.train: # else: # if len(data.size()) < 4: # data_ = data.unsqueeze(0) # else: # data_ = data # for im_ in data_: # plt.clf() # im_ = norm_range(im_).permute(1, 2, 0).cpu().numpy() # plt.imshow(im_) # import ipdb; ipdb.set_trace() # else: # ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy() # plt.imshow(ims) return {"data": data, "meta": meta}
def evaluation(config, logger=None, eval_data=None): device = torch.device('cuda:0' if config["n_gpu"] > 0 else 'cpu') if logger is None: logger = config.get_logger('test') logger.info("Running evaluation with configuration:") logger.info(config) imwidth = config['dataset']['args']['imwidth'] root = config["dataset"]["args"]["root"] warp_crop_default = config['warper']['args'].get('crop', None) crop = config['dataset']['args'].get('crop', warp_crop_default) # Want explicit pair warper disable_warps = True dense_match = config.get("dense_match", False) if dense_match and disable_warps: # rotsd = 2.5 # scalesd=0.1 * .5 rotsd = 0 scalesd = 0 warp_kwargs = dict(warpsd_all=0, warpsd_subset=0, transsd=0, scalesd=scalesd, rotsd=rotsd, im1_multiplier=1, im1_multiplier_aff=1) else: warp_kwargs = dict(warpsd_all=0.001 * .5, warpsd_subset=0.01 * .5, transsd=0.1 * .5, scalesd=0.1 * .5, rotsd=5 * .5, im1_multiplier=1, im1_multiplier_aff=1) warper = tps.Warper(imwidth, imwidth, **warp_kwargs) if eval_data is None: eval_data = config["dataset"]["type"] constructor = getattr(module_data, eval_data) # handle the case of the MAFL split, which by default will evaluate on Celeba kwargs = { "val_split": "mafl" } if eval_data == "CelebAPrunedAligned_MAFLVal" else {} val_dataset = constructor( train=False, pair_warper=warper, use_keypoints=True, imwidth=imwidth, crop=crop, root=root, **kwargs, ) # NOTE: Since the matching is performed with pairs, we fix the ordering and then # use all pairs for datasets with even numbers of images, and all but one for # datasets that have odd numbers of images (via drop_last=True) data_loader = DataLoader(val_dataset, batch_size=2, collate_fn=dict_coll, shuffle=False, drop_last=True) # build model architecture model = get_instance(module_arch, 'arch', config) model.summary() # load state dict ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) # checkpoint = torch.load(config["weights"]) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(clean_state_dict(state_dict)) if config['n_gpu'] > 1: model = model.module model = model.to(device) model.train() if dense_match: warp_dir = Path(config["warp_dir"]) / config["name"] warp_dir = warp_dir / "disable_warps{}".format(disable_warps) if not warp_dir.exists(): warp_dir.mkdir(exist_ok=True, parents=True) writer = SummaryWriter(warp_dir) model.eval() same_errs = [] diff_errs = [] torch.manual_seed(0) with torch.no_grad(): for i, batch in enumerate(tqdm(data_loader)): data, meta = batch["data"], batch["meta"] if (config.get("mini_eval", False) and i > 3): break # if i == 0: # # Checksum to make sure warps are deterministic # if True: # # redo later # if data.shape[2] == 64: # assert float(data.sum()) == -553.9221801757812 # elif data.shape[2] == 128: # assert float(data.sum()) == 754.1907348632812 data = data.to(device) output = model(data) descs = output[0] descs1 = descs[0::2] # 1st in pair (more warped) descs2 = descs[1::2] # 2nd in pair ims1 = data[0::2].cpu() ims2 = data[1::2].cpu() im_source = ims1[0] im_same = ims2[0] im_diff = ims2[1] C, imH, imW = im_source.shape B, C, H, W = descs1.shape stride = imW / W desc_source = descs1[0] desc_same = descs2[0] desc_diff = descs2[1] if not dense_match: kp1 = meta['kp1'] kp2 = meta['kp2'] kp_source = kp1[0] kp_same = kp2[0] kp_diff = kp2[1] if config.get("vis", False): fig = plt.figure() # a new figure window ax1 = fig.add_subplot(1, 3, 1) ax2 = fig.add_subplot(1, 3, 2) ax3 = fig.add_subplot(1, 3, 3) ax1.imshow(norm_range(im_source).permute(1, 2, 0)) ax2.imshow(norm_range(im_same).permute(1, 2, 0)) ax3.imshow(norm_range(im_diff).permute(1, 2, 0)) if not dense_match: ax1.scatter(kp_source[:, 0], kp_source[:, 1], c='g') ax2.scatter(kp_same[:, 0], kp_same[:, 1], c='g') ax3.scatter(kp_diff[:, 0], kp_diff[:, 1], c='g') if False: fsrc = F.normalize(desc_source, p=2, dim=0) fsame = F.normalize(desc_same, p=2, dim=0) fdiff = F.normalize(desc_diff, p=2, dim=0) else: fsrc = desc_source.clone() fsame = desc_same.clone() fdiff = desc_diff.clone() if dense_match: # if False: # print("DEBUGGING WITH IDENTICAL FEATS") # fdiff = fsrc # tic = time.time() grid = dense_desc_match(fsrc, fdiff) im_warped = F.grid_sample(im_source.view(1, 3, imH, imW), grid) im_warped = im_warped.squeeze(0) # print("done matching in {:.3f}s".format(time.time() - tic)) plt.close("all") if config["subplots"]: fig = plt.figure() # a new figure window ax1 = fig.add_subplot(1, 3, 1) ax2 = fig.add_subplot(1, 3, 2) ax3 = fig.add_subplot(1, 3, 3) ax1.imshow(norm_range(im_source).permute(1, 2, 0)) ax2.imshow(norm_range(im_diff).permute(1, 2, 0)) ax3.imshow(norm_range(im_warped).permute(1, 2, 0)) triplet_dest = warp_dir / "triplet-{:05d}.jpg".format(i) fig.savefig(triplet_dest) else: triplet_dest_dir = warp_dir / "triplet-{:05d}".format(i) if not triplet_dest_dir.exists(): triplet_dest_dir.mkdir(exist_ok=True, parents=True) for jj, im in enumerate((im_source, im_diff, im_warped)): plt.axis("off") fig = plt.figure(figsize=(1.5, 1.5)) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) # ax.imshow(data, cmap = plt.get_cmap("bone")) im_ = norm_range(im).permute(1, 2, 0) ax.imshow(im_) dest_path = triplet_dest_dir / "im-{}-{}.jpg".format( i, jj) plt.savefig(str(dest_path), dpi=im_.shape[0]) # plt.savefig(filename, dpi = sizes[0]) writer.add_figure('warp-triplets', fig) else: for ki, kp in enumerate(kp_source): x, y = np.array(kp) gt_same_x, gt_same_y = np.array(kp_same[ki]) gt_diff_x, gt_diff_y = np.array(kp_diff[ki]) same_x, same_y = find_descriptor(x, y, fsrc, fsame, stride) err = compute_pixel_err( pred_x=same_x, pred_y=same_y, gt_x=gt_same_x, gt_y=gt_same_y, imwidth=imwidth, crop=crop, ) same_errs.append(err) diff_x, diff_y = find_descriptor(x, y, fsrc, fdiff, stride) err = compute_pixel_err( pred_x=diff_x, pred_y=diff_y, gt_x=gt_diff_x, gt_y=gt_diff_y, imwidth=imwidth, crop=crop, ) diff_errs.append(err) if config.get("vis", False): ax2.scatter(same_x, same_y, c='b') ax3.scatter(diff_x, diff_y, c='b') if config.get("vis", False): zs_dispFig() fig.savefig('/tmp/matching.pdf') print("") # cleanup print from tqdm subtraction logger.info("Matching Metrics:") logger.info(f"Mean Pixel Error (same-identity): {np.mean(same_errs)}") logger.info(f"Mean Pixel Error (different-identity) {np.mean(diff_errs)}")
def tween_scatter(t, im1, im2, scatter1, scatter2, title1, title2, fade_ims=True, heading1=None, heading2=None, frame=None, is_dve=None): ax_reset() base_subplot = plt.gca() plt.subplot(base_subplot) gridsize = int(np.sqrt(len(im1))) inner_grid = matplotlib.gridspec.GridSpec(gridsize, gridsize, hspace=0.05, wspace=0.05) bb = base_subplot.get_position() l, b, r, tp = bb.extents inner_grid.update(left=l, bottom=b, right=r, top=tp) if fade_ims: prev_alpha = np.maximum(0., 1 - 2 * t) cur_alpha = np.maximum(0., -1 + 2 * t) else: prev_alpha = 0. cur_alpha = 1. for gi in range(gridsize**2): gax = plt.gcf().add_subplot(inner_grid[gi]) ax_reset() if prev_alpha: plt.imshow(norm_range(im1[gi]).permute(1, 2, 0), alpha=prev_alpha) if cur_alpha: plt.imshow(norm_range(im2[gi]).permute(1, 2, 0), alpha=cur_alpha) ease = (-np.cos(np.pi * t) + 1) / 2 scatter_tween = (1 - ease) * scatter1[gi] + ease * scatter2[gi] fac = plt.gca().get_position().width / base_subplot.get_position( ).width plt.scatter(scatter_tween[:, 0], scatter_tween[:, 1], c=rainbow, s=(matplotlib.rcParams['lines.markersize'] * fac)**2) if frame == args.hq_frame_snapshot and args.save_hq_ims: # create temp figure inline prev_fig = plt.gcf() plt.figure(figsize=(15, 15)) inline_ax = plt.subplot(1, 1, 1) plt.sca(inline_ax) plt.imshow(norm_range(im1[gi]).permute(1, 2, 0)) plt.xticks([], []) plt.yticks([], []) fname = f"frame{frame}-match-face{gi}" if is_dve is not None and is_dve: fname += "-dve" plt.scatter(scatter2[gi][:, 0], scatter2[gi][:, 1], c=rainbow, s=(matplotlib.rcParams['lines.markersize'] * 8)**2) plt.savefig(str(Path(args.fig_dir) / f"{fname}.png")) zs_dispFig() # return to prev figure plt.figure(prev_fig.number) plt.sca(base_subplot) ttl1 = plt.text(0.5, -.08, title1, transform=base_subplot.transAxes, horizontalalignment='center') ttl2 = plt.text(0.5, -.08, title2, transform=base_subplot.transAxes, horizontalalignment='center') if title1 == title2: ttl1.set_alpha(1) ttl2.set_alpha(0) else: ttl1.set_alpha(1 - t) ttl2.set_alpha(t) if heading2 is not None: h1 = plt.suptitle(heading1, x=0.5, y=0.94) h2 = plt.text(*h1.get_position(), heading2) foot = plt.text( 0.5, 0.08, 'DVE enables the use of higher dimensional unsupervised embeddings!' ) foot.update_from(h1) h1.set_fontsize('x-large') h2.update_from(h1) # Prevent flashing from font aliasing and alpha - brittle if not monospace and makes it too bold though cover = ''.join([ heading1[i] if heading1[i] == heading2[i] else '\u00a0' for i in range(min(len(heading1), len(heading2))) ]) hc = plt.text(*h1.get_position(), cover) hc.update_from(h1) h1.set_alpha(1 - t) h2.set_alpha(t)