Beispiel #1
0
    def __getitem__(self, index):
        im_path = Path(self.root) / "images/{}.jpg".format(self.im_list[index])
        im = Image.open(im_path).convert("RGB")
        name = Path(im_path).stem
        anno_template = str(Path(self.root) / "labels/{}/{}_lbl{:02d}.png")
        seg = np.zeros((im.size[1], im.size[0]), dtype=np.uint8)
        # for ii in range(10, 0, -1):
        for ii in range(1, 11):
            anno_path = anno_template.format(name, name, ii)
            lbl_im = np.array(Image.open(anno_path).convert("L"))
            assert lbl_im.ndim == 2, "expected greyscale"
            # if sum(seg[lbl_im > self.thresh]) > 0:
            #     print("already colored these pixels")
            #     import ipdb; ipdb.set_trace()
            seg[lbl_im > self.thresh * 255] = ii
            # plt.matshow(seg)
            # zs_dispFig()

        seg = Image.fromarray(seg, "L")
        # if self.train and False:
        #     im, seg = self.augs(im, seg)

        seg = self.label_resizer(seg)
        seg = torch.from_numpy(np.array(seg))
        data = self.resizer(im)
        data = self.transforms(data)

        if False:
            counts = torch.histc(seg.float(), bins=11, min=0, max=10)
            probs = counts / counts.sum()
            for name, prob in zip(self.classnames, probs):
                print("{}\t {:.2f}".format(name, prob))
        if self.rand_in:
            data = torch.randn(data.shape)

        if self.visualize:
            from torchvision.utils import make_grid
            from utils.visualization import norm_range
            ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy()
            plt.close("all")
            plt.axis("off")
            fig = plt.figure()  # a new figure window
            ax1 = fig.add_subplot(1, 3, 1)
            ax2 = fig.add_subplot(1, 3, 2)
            ax3 = fig.add_subplot(1, 3, 3)
            ax1.imshow(ims)
            ax2.imshow(label_colormap(seg).numpy())
            if self.downsample_labels:
                sz = tuple([x * self.downsample_labels for x in seg.size()])
                seg_ = np.array(Image.fromarray(seg.numpy()).resize(sz))
            else:
                seg_ = seg
            # ax3.matshow(seg_)
            ax3.imshow(label_colormap(seg_).numpy())
            ax3.imshow(ims, alpha=0.5)
            zs_dispFig()

        return {"data": data, "meta": {"im_path": str(im_path), "lbls": seg}}
Beispiel #2
0
def show_im(im_cv2, title_str):
    """Quick image visualiser

    Args:
        im_cv2 (ndarray): input image in BGR format
        title_str (str): title header
    """
    fig = plt.figure(frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    im = im_cv2[:, :, ::-1]
    ax.imshow(im)
    plt.title(title_str)
    zs_dispFig()
Beispiel #3
0
 def vis_sequence(self, sequence: torch.Tensor):
     columns = 4
     rows = (sequence.shape[0] + 1) // (columns)
     figsize = (32, (16 // columns) * rows)
     plt.figure(figsize=figsize)
     gs = gridspec.GridSpec(rows, columns)
     for j in range(rows * columns):
         plt.subplot(gs[j])
         plt.axis("off")
         im = sequence[:, j].permute(1, 2, 0).cpu().numpy()
         im = im - im.min()
         im = im / im.max()
         plt.imshow(im)
     if self.disp_fig:
         from zsvision.zs_iterm import zs_dispFig
         zs_dispFig()
Beispiel #4
0
def parse_video_content(video_idx, video_path, store_compressed, vis,
                        resize_res, total_videos, processes):
    frames = []
    markers = 100
    if processes > 1 and video_idx % int(
            max(total_videos, markers) / markers) == 0:
        pct = 100 * video_idx / total_videos
        print(
            f"processing {video_idx}/{total_videos} [{pct:.1f}%] [{video_path}]"
        )
    cap = cv2.VideoCapture(str(video_path))
    orig_dims = None
    while True:
        ret, rgb = cap.read()
        if ret:
            # BGR (OpenCV) to RGB
            rgb = rgb[:, :, [2, 1, 0]]
            if store_compressed:
                buffer = io.BytesIO()
                im = Image.fromarray(rgb)
                orig_dims = im.size
                resized = im.resize((resize_res, resize_res))
                resized.save(buffer, format="JPEG", quality=store_compressed)
                if vis:
                    plt.imshow(resized)
                    zs_dispFig()
                rgb = buffer
            else:
                # apply Gul-style preproc
                iH, iW, iC = rgb.shape
                if iW > iH:
                    nH, nW = resize_res, int(resize_res * iW / iH)
                else:
                    nH, nW = int(resize_res * iH / iW), resize_res
                orig_dims = (iH, iW)
                rgb = resize_generic(rgb, nH, nW, interp="bilinear")
            frames.append(rgb)
        else:
            break
    cap.release()
    if not store_compressed:
        frames = np.array(frames)
    store = {"data": frames, "orig_dims": orig_dims}
    if frames:
        assert orig_dims is not None, "expected resize_ratio to be set"
    return store
Beispiel #5
0
def retrieval_as_classification(sims, query_masks=None):
    """Compute classification metrics from a similiarity matrix.
    """
    assert sims.ndim == 2, "expected a matrix"

    # switch axes of query-labels and video
    sims = sims.T
    query_masks = query_masks.T
    dists = -sims
    num_queries, num_labels = sims.shape
    break_ties = "averaging"

    query_ranks = []
    for ii in range(num_queries):
        row_dists = dists[ii, :]

        # NOTE: Using distance subtraction to perform the ranking is easier to make
        # deterministic than using argsort, which suffers from the issue of defining
        # "stability" for equal distances.  Example of distance subtraction code:
        # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py
        sorted_dists = np.sort(row_dists)

        # min_rank = np.inf
        label_ranks = []
        for gt_label in np.where(query_masks[ii, :])[0]:
            ranks = np.where((sorted_dists - row_dists[gt_label]) == 0)[0]
            if break_ties == "optimistically":
                rank = ranks[0]
            elif break_ties == "averaging":
                # NOTE: If there is more than one caption per video, its possible for the
                # method to do "worse than chance" in the degenerate case when all
                # similarities are tied.  TODO(Samuel): Address this case.
                rank = ranks.mean()
            else:
                raise ValueError(f"unknown tie-breaking method: {break_ties}")
            label_ranks.append(rank)
        # Avoid penalising for assigning higher similarity to other gt labels. This is
        # done by subtracting out the better ranked query labels.  Note that this step
        # introduces a slight skew in favour of videos with lots of labels.  We can
        # address this later with a normalisation step if needed.
        label_ranks = [x - idx for idx, x in enumerate(label_ranks)]

        # Include all labels in the final calculation
        query_ranks.extend(label_ranks)
    query_ranks = np.array(query_ranks)

    # sanity check against old version of code
    if False:
        # visualise the distance matrix
        import sys
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python"))
        from zsvision.zs_iterm import zs_dispFig # NOQA
        # plt.matshow(dists)
        # zs_dispFig()
        plt.hist(query_ranks, bins=313, alpha=0.5)
        plt.grid()
        zs_dispFig()
        import ipdb; ipdb.set_trace()

    return cols2metrics(query_ranks, num_queries=len(query_ranks))
Beispiel #6
0
def v2t_metrics(sims, query_masks=None):
    """Compute retrieval metrics from a similiarity matrix.

    Args:
        sims (th.Tensor): N x M matrix of similarities between embeddings, where
             x_{i,j} = <text_embd[i], vid_embed[j]>
        query_masks (th.Tensor): mask any missing captions from the dataset

    Returns:
        (dict[str:float]): retrieval metrics

    NOTES: We find the closest "GT caption" in the style of VSE, which corresponds
    to finding the rank of the closest relevant caption in embedding space:
    github.com/ryankiros/visual-semantic-embedding/blob/master/evaluation.py#L52-L56
    """
    # switch axes of text and video
    sims = sims.T

    if False:
        # experiment with toy example
        sims = np.ones((3, 3))
        sims[0, 0] = 2
        sims[1, 1:2] = 2
        sims[2, :] = 2
        query_masks = None

    assert sims.ndim == 2, "expected a matrix"
    num_queries, num_caps = sims.shape
    dists = -sims
    caps_per_video = num_caps // num_queries
    break_ties = "averaging"

    MISSING_VAL = 1E8
    query_ranks = []
    for ii in range(num_queries):
        row_dists = dists[ii, :]
        if query_masks is not None:
            # Set missing queries to have a distance of infinity.  A missing query
            # refers to a query position `n` for a video that had less than `n`
            # captions (for example, a few MSRVTT videos only have 19 queries)
            row_dists[np.logical_not(query_masks.reshape(-1))] = MISSING_VAL

        # NOTE: Using distance subtraction to perform the ranking is easier to make
        # deterministic than using argsort, which suffers from the issue of defining
        # "stability" for equal distances.  Example of distance subtraction code:
        # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py
        sorted_dists = np.sort(row_dists)

        min_rank = np.inf
        for jj in range(ii * caps_per_video, (ii + 1) * caps_per_video):
            if row_dists[jj] == MISSING_VAL:
                # skip rankings of missing captions
                continue
            ranks = np.where((sorted_dists - row_dists[jj]) == 0)[0]
            if break_ties == "optimistically":
                rank = ranks[0]
            elif break_ties == "averaging":
                # NOTE: If there is more than one caption per video, its possible for the
                # method to do "worse than chance" in the degenerate case when all
                # similarities are tied.  TODO(Samuel): Address this case.
                rank = ranks.mean()
            if rank < min_rank:
                min_rank = rank
        query_ranks.append(min_rank)
    query_ranks = np.array(query_ranks)

    # sanity check against old version of code
    if False:
        sorted_dists = np.sort(dists, axis=1)
        gt_dists_old = np.diag(dists)
        gt_dists_old = gt_dists_old[:, np.newaxis]
        rows_old, cols_old = np.where((sorted_dists - gt_dists_old) == 0)
        if rows_old.size > num_queries:
            _, idx = np.unique(rows_old, return_index=True)
            cols_old = cols_old[idx]
        num_diffs = (1 - (cols_old == query_ranks)).sum()
        msg = f"new metric doesn't match in {num_diffs} places"
        assert np.array_equal(cols_old, query_ranks), msg

        # visualise the distance matrix
        import sys
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python"))
        from zsvision.zs_iterm import zs_dispFig # NOQA
        plt.matshow(dists)
        zs_dispFig()

    return cols2metrics(query_ranks, num_queries)
Beispiel #7
0
def t2v_metrics(sims, query_masks=None):
    """Compute retrieval metrics from a similiarity matrix.

    Args:
        sims (th.Tensor): N x M matrix of similarities between embeddings, where
             x_{i,j} = <text_embd[i], vid_embed[j]>
        query_masks (th.Tensor): mask any missing queries from the dataset (two videos
             in MSRVTT only have 19, rather than 20 captions)

    Returns:
        (dict[str:float]): retrieval metrics
    """
    assert sims.ndim == 2, "expected a matrix"
    num_queries, num_vids = sims.shape
    dists = -sims
    sorted_dists = np.sort(dists, axis=1)

    if False:
        import sys
        import matplotlib
        from pathlib import Path
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python"))
        from zsvision.zs_iterm import zs_dispFig # NOQA
        plt.matshow(dists)
        zs_dispFig()
        import ipdb; ipdb.set_trace()

    # The indices are computed such that they slice out the ground truth distances
    # from the psuedo-rectangular dist matrix
    queries_per_video = num_queries // num_vids
    gt_idx = [[np.ravel_multi_index([ii, jj], (num_queries, num_vids))
              for ii in range(jj * queries_per_video, (jj + 1) * queries_per_video)]
              for jj in range(num_vids)]
    gt_idx = np.array(gt_idx)
    gt_dists = dists.reshape(-1)[gt_idx.reshape(-1)]
    gt_dists = gt_dists[:, np.newaxis]
    rows, cols = np.where((sorted_dists - gt_dists) == 0)  # find column position of GT

    # --------------------------------
    # NOTE: Breaking ties
    # --------------------------------
    # We sometimes need to break ties (in general, these should occur extremely rarely,
    # but there are pathological cases when they can distort the scores, such as when
    # the similarity matrix is all zeros). Previous implementations (e.g. the t2i
    # evaluation function used
    # here: https://github.com/niluthpol/multimodal_vtt/blob/master/evaluation.py and
    # here: https://github.com/linxd5/VSE_Pytorch/blob/master/evaluation.py#L87) generally
    # break ties "optimistically".  However, if the similarity matrix is constant this
    # can evaluate to a perfect ranking. A principled option is to average over all
    # possible partial orderings implied by the ties. See # this paper for a discussion:
    #    McSherry, Frank, and Marc Najork,
    #    "Computing information retrieval performance measures efficiently in the presence
    #    of tied scores." European conference on information retrieval. Springer, Berlin, 
    #    Heidelberg, 2008.
    # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.145.8892&rep=rep1&type=pdf

    # break_ties = "optimistically"
    break_ties = "averaging"

    if rows.size > num_queries:
        assert np.unique(rows).size == num_queries, "issue in metric evaluation"
        if break_ties == "optimistically":
            _, idx = np.unique(rows, return_index=True)
            cols = cols[idx]
        elif break_ties == "averaging":
            # fast implementation, based on this code:
            # https://stackoverflow.com/a/49239335
            locs = np.argwhere((sorted_dists - gt_dists) == 0)

            # Find the split indices
            steps = np.diff(locs[:, 0])
            splits = np.nonzero(steps)[0] + 1
            splits = np.insert(splits, 0, 0)

            # Compute the result columns
            summed_cols = np.add.reduceat(locs[:, 1], splits)
            counts = np.diff(np.append(splits, locs.shape[0]))
            avg_cols = summed_cols / counts
            if False:
                print("Running slower code to verify rank averaging across ties")
                # slow, but more interpretable version, used for testing
                avg_cols_slow = [np.mean(cols[rows == idx]) for idx in range(num_queries)]
                assert np.array_equal(avg_cols, avg_cols_slow), "slow vs fast difference"
                print("passed num check")
            cols = avg_cols

    msg = "expected ranks to match queries ({} vs {}) "
    if cols.size != num_queries:
        import ipdb; ipdb.set_trace()
    assert cols.size == num_queries, msg

    if False:
        # overload mask to check that we can recover the scores for single-query
        # retrieval
        print("DEBUGGING MODE")
        query_masks = np.zeros_like(query_masks)
        query_masks[:, 0] = 1  # recover single query score

    if query_masks is not None:
        # remove invalid queries
        assert query_masks.size == num_queries, "invalid query mask shape"
        cols = cols[query_masks.reshape(-1).astype(np.bool)]
        assert cols.size == query_masks.sum(), "masking was not applied correctly"
        # update number of queries to account for those that were missing
        num_queries = query_masks.sum()

    if False:
        # sanity check against old logic for square matrices
        gt_dists_old = np.diag(dists)
        gt_dists_old = gt_dists_old[:, np.newaxis]
        _, cols_old = np.where((sorted_dists - gt_dists_old) == 0)
        assert np.array_equal(cols_old, cols), "new metric doesn't match"

    return cols2metrics(cols, num_queries)
Beispiel #8
0
            (cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        vis_detections(im, cls, dets, thresh=CONF_THRESH)


cfg.TEST.HAS_RPN = True  # Use RPN for proposals

if net_name == 'ResNet-50':
    pt = 'resnet50_rfcn_pascal_debug.proto'
    model = 'resnet50_rfcn_final.caffemodel'
else:
    raise ValueError('{} not recognised'.format(net_name))

prototxt = pjoin(model_dir, 'proto', pt)
caffemodel = pjoin(model_dir, 'weights', model)

if not os.path.isfile(caffemodel):
    raise IOError(('{:s} not found.\nDid you run ./data/script/'
                   'fetch_faster_rcnn_models.sh?').format(caffemodel))

# psroi pooling only works on gpu
caffe.set_mode_gpu()
caffe.set_device(2)
cfg.GPU_ID = 2
net = caffe.Net(prototxt, caffemodel, caffe.TEST)

demo(net, im_path, blob_save_path)
plt.show()
zs_dispFig()
Beispiel #9
0
    def __getitem__(self, index):
        if (not self.use_ims and not self.use_keypoints):
            # early exit when caching is used
            return {"data": torch.zeros(3, 1, 1), "meta": {"index": index}}

        if self.use_ims:
            im = Image.open(os.path.join(self.subdir, self.filenames[index]))
        # print("imread: {:.3f}s".format(time.time() - tic)) ; tic = time.time()
        kp = None
        if self.use_keypoints:
            kp = self.keypoints[index].copy()
        meta = {}

        if self.warper is not None:
            if self.warper.returns_pairs:
                # tic = time.time()
                im1 = self.initial_transforms(im.convert("RGB"))
                # print("tx1: {:.3f}s".format(time.time() - tic)) ; tic = time.time()
                im1 = TF.to_tensor(im1) * 255
                if False:
                    from utils.visualization import norm_range
                    plt.imshow(norm_range(im1).permute(1, 2, 0).cpu().numpy())
                    plt.scatter(kp[:, 0], kp[:, 1])

                im1, im2, flow, grid, kp1, kp2 = self.warper(im1,
                                                             keypts=kp,
                                                             crop=self.crop)
                # print("warper: {:.3f}s".format(time.time() - tic)) ; tic = time.time()

                im1 = im1.to(torch.uint8)
                im2 = im2.to(torch.uint8)

                C, H, W = im1.shape

                im1 = TF.to_pil_image(im1)
                im2 = TF.to_pil_image(im2)

                im1 = self.transforms(im1)
                im2 = self.transforms(im2)
                # print("tx-2: {:.3f}s".format(time.time() - tic)) ; tic = time.time()

                C, H, W = im1.shape
                data = torch.stack((im1, im2), 0)
                meta = {
                    'flow': flow[0],
                    'grid': grid[0],
                    'im1': im1,
                    'im2': im2,
                    'index': index
                }
                if self.use_keypoints:
                    meta = {**meta, **{'kp1': kp1, 'kp2': kp2}}
            else:
                im1 = self.initial_transforms(im.convert("RGB"))
                im1 = TF.to_tensor(im1) * 255

                im1, kp = self.warper(im1, keypts=kp, crop=self.crop)

                im1 = im1.to(torch.uint8)
                im1 = TF.to_pil_image(im1)
                im1 = self.transforms(im1)

                C, H, W = im1.shape
                data = im1
                if self.use_keypoints:
                    meta = {
                        'keypts': kp,
                        'keypts_normalized': kp_normalize(H, W, kp),
                        'index': index
                    }

        else:
            if self.use_ims:
                data = self.transforms(
                    self.initial_transforms(im.convert("RGB")))
                if self.crop != 0:
                    data = data[:, self.crop:-self.crop, self.crop:-self.crop]
                C, H, W = data.shape
            else:
                #  after caching descriptors, there is no point doing I/O
                H = W = self.imwidth - 2 * self.crop
                data = torch.zeros(3, 1, 1)

            if kp is not None:
                kp = kp - self.crop
                kp = torch.tensor(kp)

            if self.use_keypoints:
                meta = {
                    'keypts': kp,
                    'keypts_normalized': kp_normalize(H, W, kp),
                    'index': index
                }
        if self.visualize:
            # from torchvision.utils import make_grid
            from utils.visualization import norm_range
            num_show = 2 if self.warper and self.warper.returns_pairs else 1
            plt.clf()
            fig = plt.figure()
            for ii in range(num_show):
                im_ = data[ii] if num_show > 1 else data
                ax = fig.add_subplot(1, num_show, ii + 1)
                ax.imshow(norm_range(im_).permute(1, 2, 0).cpu().numpy())
                if self.use_keypoints:
                    if num_show == 2:
                        kp_x = meta["kp{}".format(ii + 1)][:, 0].numpy()
                        kp_y = meta["kp{}".format(ii + 1)][:, 1].numpy()
                    else:
                        kp_x = kp[:, 0].numpy()
                        kp_y = kp[:, 1].numpy()
                    ax.scatter(kp_x, kp_y)
                zs_dispFig()
                import ipdb
                ipdb.set_trace()
            #     zs.
            # if self.train:
            # else:
            #     if len(data.size()) < 4:
            #         data_ = data.unsqueeze(0)
            #     else:
            #         data_ = data
            #     for im_ in data_:
            #         plt.clf()
            #         im_ = norm_range(im_).permute(1, 2, 0).cpu().numpy()
            #         plt.imshow(im_)
            #     import ipdb; ipdb.set_trace()
            # else:
            #     ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy()
            #     plt.imshow(ims)
        return {"data": data, "meta": meta}
Beispiel #10
0
def evaluation(config, logger=None, eval_data=None):
    device = torch.device('cuda:0' if config["n_gpu"] > 0 else 'cpu')

    if logger is None:
        logger = config.get_logger('test')

    logger.info("Running evaluation with configuration:")
    logger.info(config)

    imwidth = config['dataset']['args']['imwidth']
    root = config["dataset"]["args"]["root"]
    warp_crop_default = config['warper']['args'].get('crop', None)
    crop = config['dataset']['args'].get('crop', warp_crop_default)

    # Want explicit pair warper
    disable_warps = True
    dense_match = config.get("dense_match", False)
    if dense_match and disable_warps:
        # rotsd = 2.5
        # scalesd=0.1 * .5
        rotsd = 0
        scalesd = 0
        warp_kwargs = dict(warpsd_all=0,
                           warpsd_subset=0,
                           transsd=0,
                           scalesd=scalesd,
                           rotsd=rotsd,
                           im1_multiplier=1,
                           im1_multiplier_aff=1)
    else:
        warp_kwargs = dict(warpsd_all=0.001 * .5,
                           warpsd_subset=0.01 * .5,
                           transsd=0.1 * .5,
                           scalesd=0.1 * .5,
                           rotsd=5 * .5,
                           im1_multiplier=1,
                           im1_multiplier_aff=1)
    warper = tps.Warper(imwidth, imwidth, **warp_kwargs)
    if eval_data is None:
        eval_data = config["dataset"]["type"]
    constructor = getattr(module_data, eval_data)

    # handle the case of the MAFL split, which by default will evaluate on Celeba
    kwargs = {
        "val_split": "mafl"
    } if eval_data == "CelebAPrunedAligned_MAFLVal" else {}
    val_dataset = constructor(
        train=False,
        pair_warper=warper,
        use_keypoints=True,
        imwidth=imwidth,
        crop=crop,
        root=root,
        **kwargs,
    )
    # NOTE: Since the matching is performed with pairs, we fix the ordering and then
    # use all pairs for datasets with even numbers of images, and all but one for
    # datasets that have odd numbers of images (via drop_last=True)
    data_loader = DataLoader(val_dataset,
                             batch_size=2,
                             collate_fn=dict_coll,
                             shuffle=False,
                             drop_last=True)

    # build model architecture
    model = get_instance(module_arch, 'arch', config)
    model.summary()

    # load state dict
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    # checkpoint = torch.load(config["weights"])
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(clean_state_dict(state_dict))
    if config['n_gpu'] > 1:
        model = model.module

    model = model.to(device)
    model.train()

    if dense_match:
        warp_dir = Path(config["warp_dir"]) / config["name"]
        warp_dir = warp_dir / "disable_warps{}".format(disable_warps)
        if not warp_dir.exists():
            warp_dir.mkdir(exist_ok=True, parents=True)
        writer = SummaryWriter(warp_dir)

    model.eval()
    same_errs = []
    diff_errs = []

    torch.manual_seed(0)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            data, meta = batch["data"], batch["meta"]

            if (config.get("mini_eval", False) and i > 3):
                break
            # if i == 0:
            #     # Checksum to make sure warps are deterministic
            #     if True:
            #         # redo later
            #         if data.shape[2] == 64:
            #             assert float(data.sum()) == -553.9221801757812
            #         elif data.shape[2] == 128:
            #             assert float(data.sum()) == 754.1907348632812

            data = data.to(device)
            output = model(data)

            descs = output[0]
            descs1 = descs[0::2]  # 1st in pair (more warped)
            descs2 = descs[1::2]  # 2nd in pair
            ims1 = data[0::2].cpu()
            ims2 = data[1::2].cpu()

            im_source = ims1[0]
            im_same = ims2[0]
            im_diff = ims2[1]

            C, imH, imW = im_source.shape
            B, C, H, W = descs1.shape
            stride = imW / W

            desc_source = descs1[0]
            desc_same = descs2[0]
            desc_diff = descs2[1]

            if not dense_match:
                kp1 = meta['kp1']
                kp2 = meta['kp2']
                kp_source = kp1[0]
                kp_same = kp2[0]
                kp_diff = kp2[1]

            if config.get("vis", False):
                fig = plt.figure()  # a new figure window
                ax1 = fig.add_subplot(1, 3, 1)
                ax2 = fig.add_subplot(1, 3, 2)
                ax3 = fig.add_subplot(1, 3, 3)

                ax1.imshow(norm_range(im_source).permute(1, 2, 0))
                ax2.imshow(norm_range(im_same).permute(1, 2, 0))
                ax3.imshow(norm_range(im_diff).permute(1, 2, 0))

                if not dense_match:
                    ax1.scatter(kp_source[:, 0], kp_source[:, 1], c='g')
                    ax2.scatter(kp_same[:, 0], kp_same[:, 1], c='g')
                    ax3.scatter(kp_diff[:, 0], kp_diff[:, 1], c='g')

            if False:
                fsrc = F.normalize(desc_source, p=2, dim=0)
                fsame = F.normalize(desc_same, p=2, dim=0)
                fdiff = F.normalize(desc_diff, p=2, dim=0)
            else:
                fsrc = desc_source.clone()
                fsame = desc_same.clone()
                fdiff = desc_diff.clone()

            if dense_match:
                # if False:
                #     print("DEBUGGING WITH IDENTICAL FEATS")
                #     fdiff = fsrc
                # tic = time.time()
                grid = dense_desc_match(fsrc, fdiff)
                im_warped = F.grid_sample(im_source.view(1, 3, imH, imW), grid)
                im_warped = im_warped.squeeze(0)
                # print("done matching in {:.3f}s".format(time.time() - tic))
                plt.close("all")
                if config["subplots"]:
                    fig = plt.figure()  # a new figure window
                    ax1 = fig.add_subplot(1, 3, 1)
                    ax2 = fig.add_subplot(1, 3, 2)
                    ax3 = fig.add_subplot(1, 3, 3)
                    ax1.imshow(norm_range(im_source).permute(1, 2, 0))
                    ax2.imshow(norm_range(im_diff).permute(1, 2, 0))
                    ax3.imshow(norm_range(im_warped).permute(1, 2, 0))
                    triplet_dest = warp_dir / "triplet-{:05d}.jpg".format(i)
                    fig.savefig(triplet_dest)
                else:
                    triplet_dest_dir = warp_dir / "triplet-{:05d}".format(i)
                    if not triplet_dest_dir.exists():
                        triplet_dest_dir.mkdir(exist_ok=True, parents=True)
                    for jj, im in enumerate((im_source, im_diff, im_warped)):
                        plt.axis("off")
                        fig = plt.figure(figsize=(1.5, 1.5))
                        ax = plt.Axes(fig, [0., 0., 1., 1.])
                        ax.set_axis_off()
                        fig.add_axes(ax)
                        # ax.imshow(data, cmap = plt.get_cmap("bone"))
                        im_ = norm_range(im).permute(1, 2, 0)
                        ax.imshow(im_)
                        dest_path = triplet_dest_dir / "im-{}-{}.jpg".format(
                            i, jj)
                        plt.savefig(str(dest_path), dpi=im_.shape[0])
                        # plt.savefig(filename, dpi = sizes[0])
                writer.add_figure('warp-triplets', fig)
            else:
                for ki, kp in enumerate(kp_source):
                    x, y = np.array(kp)
                    gt_same_x, gt_same_y = np.array(kp_same[ki])
                    gt_diff_x, gt_diff_y = np.array(kp_diff[ki])
                    same_x, same_y = find_descriptor(x, y, fsrc, fsame, stride)

                    err = compute_pixel_err(
                        pred_x=same_x,
                        pred_y=same_y,
                        gt_x=gt_same_x,
                        gt_y=gt_same_y,
                        imwidth=imwidth,
                        crop=crop,
                    )
                    same_errs.append(err)
                    diff_x, diff_y = find_descriptor(x, y, fsrc, fdiff, stride)
                    err = compute_pixel_err(
                        pred_x=diff_x,
                        pred_y=diff_y,
                        gt_x=gt_diff_x,
                        gt_y=gt_diff_y,
                        imwidth=imwidth,
                        crop=crop,
                    )
                    diff_errs.append(err)
                    if config.get("vis", False):
                        ax2.scatter(same_x, same_y, c='b')
                        ax3.scatter(diff_x, diff_y, c='b')

            if config.get("vis", False):
                zs_dispFig()
                fig.savefig('/tmp/matching.pdf')

    print("")  # cleanup print from tqdm subtraction
    logger.info("Matching Metrics:")
    logger.info(f"Mean Pixel Error (same-identity): {np.mean(same_errs)}")
    logger.info(f"Mean Pixel Error (different-identity) {np.mean(diff_errs)}")
Beispiel #11
0
def tween_scatter(t,
                  im1,
                  im2,
                  scatter1,
                  scatter2,
                  title1,
                  title2,
                  fade_ims=True,
                  heading1=None,
                  heading2=None,
                  frame=None,
                  is_dve=None):
    ax_reset()
    base_subplot = plt.gca()
    plt.subplot(base_subplot)

    gridsize = int(np.sqrt(len(im1)))

    inner_grid = matplotlib.gridspec.GridSpec(gridsize,
                                              gridsize,
                                              hspace=0.05,
                                              wspace=0.05)
    bb = base_subplot.get_position()
    l, b, r, tp = bb.extents
    inner_grid.update(left=l, bottom=b, right=r, top=tp)

    if fade_ims:
        prev_alpha = np.maximum(0., 1 - 2 * t)
        cur_alpha = np.maximum(0., -1 + 2 * t)
    else:
        prev_alpha = 0.
        cur_alpha = 1.

    for gi in range(gridsize**2):
        gax = plt.gcf().add_subplot(inner_grid[gi])

        ax_reset()

        if prev_alpha:
            plt.imshow(norm_range(im1[gi]).permute(1, 2, 0), alpha=prev_alpha)
        if cur_alpha:
            plt.imshow(norm_range(im2[gi]).permute(1, 2, 0), alpha=cur_alpha)

        ease = (-np.cos(np.pi * t) + 1) / 2
        scatter_tween = (1 - ease) * scatter1[gi] + ease * scatter2[gi]

        fac = plt.gca().get_position().width / base_subplot.get_position(
        ).width
        plt.scatter(scatter_tween[:, 0],
                    scatter_tween[:, 1],
                    c=rainbow,
                    s=(matplotlib.rcParams['lines.markersize'] * fac)**2)

        if frame == args.hq_frame_snapshot and args.save_hq_ims:
            # create temp figure inline
            prev_fig = plt.gcf()
            plt.figure(figsize=(15, 15))
            inline_ax = plt.subplot(1, 1, 1)
            plt.sca(inline_ax)
            plt.imshow(norm_range(im1[gi]).permute(1, 2, 0))
            plt.xticks([], [])
            plt.yticks([], [])
            fname = f"frame{frame}-match-face{gi}"
            if is_dve is not None and is_dve:
                fname += "-dve"
            plt.scatter(scatter2[gi][:, 0],
                        scatter2[gi][:, 1],
                        c=rainbow,
                        s=(matplotlib.rcParams['lines.markersize'] * 8)**2)
            plt.savefig(str(Path(args.fig_dir) / f"{fname}.png"))
            zs_dispFig()

            # return to prev figure
            plt.figure(prev_fig.number)

    plt.sca(base_subplot)
    ttl1 = plt.text(0.5,
                    -.08,
                    title1,
                    transform=base_subplot.transAxes,
                    horizontalalignment='center')
    ttl2 = plt.text(0.5,
                    -.08,
                    title2,
                    transform=base_subplot.transAxes,
                    horizontalalignment='center')
    if title1 == title2:
        ttl1.set_alpha(1)
        ttl2.set_alpha(0)
    else:
        ttl1.set_alpha(1 - t)
        ttl2.set_alpha(t)

    if heading2 is not None:
        h1 = plt.suptitle(heading1, x=0.5, y=0.94)
        h2 = plt.text(*h1.get_position(), heading2)

        foot = plt.text(
            0.5, 0.08,
            'DVE enables the use of higher dimensional unsupervised embeddings!'
        )
        foot.update_from(h1)

        h1.set_fontsize('x-large')
        h2.update_from(h1)

        # Prevent flashing from font aliasing and alpha - brittle if not monospace and makes it too bold though
        cover = ''.join([
            heading1[i] if heading1[i] == heading2[i] else '\u00a0'
            for i in range(min(len(heading1), len(heading2)))
        ])
        hc = plt.text(*h1.get_position(), cover)
        hc.update_from(h1)

        h1.set_alpha(1 - t)
        h2.set_alpha(t)