Esempio n. 1
0
    def __getitem__(self, index):
        im = Image.open(self.filenames[index]).convert("RGB")
        kp = -1
        kp_normalized = -1  # None
        visible = -1
        if self.pair_image:  # unsupervised contrastive learning
            # randomresizecrop is the key to generate pairs of images
            img1 = self.transforms(self.initial_transforms(im))
            img2 = self.transforms(self.initial_transforms(im))
            data = torch.cat([img1, img2], dim=0)
            if self.crop != 0:  # maybe useful for datasets other than celebA/MAFL
                data = data[:, self.crop:-self.crop, self.crop:-self.crop]
        else:  # supervised postprocessing
            kp = self.keypoints[index].copy()
            data = self.transforms(self.initial_transforms(im))
            if self.crop != 0:  # maybe useful for datasets other than celebA/MAFL
                data = data[:, self.crop:-self.crop, self.crop:-self.crop]
                kp = kp - self.crop
            kp = torch.as_tensor(kp)
            C, H, W = data.shape
            # import pdb; pdb.set_trace()
            kp = kp_unnormalize(
                H, W, kp)  # the initial form of kp is normalized to [0,1]
            kp_normalized = kp_normalize(H, W, kp)
            visible = self.visible[index]

        if self.visualize:
            # from torchvision.utils import make_grid
            from utils.visualization import norm_range
            plt.clf()
            fig = plt.figure()
            if self.pair_image:
                im1, im2 = torch.split(data, [3, 3], dim=0)
                ax = fig.add_subplot(121)
                ax.imshow(norm_range(im1).permute(1, 2, 0).cpu().numpy())
                ax = fig.add_subplot(122)
                ax.imshow(norm_range(im2).permute(1, 2, 0).cpu().numpy())
                print(im1.shape, im2.shape)
            else:
                ax = fig.add_subplot(111)
                ax.imshow(norm_range(data).permute(1, 2, 0).cpu().numpy())
                kp_x = kp[visible][:, 0].numpy()
                kp_y = kp[visible][:, 1].numpy()
                ax.scatter(kp_x, kp_y)
                print(data.shape)
            # plt.savefig('check_dataloader.png', bbox_inches='tight')
            plt.savefig(os.path.join('sanity_check', vis_name + '.png'),
                        bbox_inches='tight')
            print(self.filenames[index])
            plt.close()
        # import pdb; pdb.set_trace()
        return data, visible, kp_normalized, index
Esempio n. 2
0
    def __getitem__(self, index):
        im_path = Path(self.root) / "images/{}.jpg".format(self.im_list[index])
        im = Image.open(im_path).convert("RGB")
        name = Path(im_path).stem
        anno_template = str(Path(self.root) / "labels/{}/{}_lbl{:02d}.png")
        seg = np.zeros((im.size[1], im.size[0]), dtype=np.uint8)
        # for ii in range(10, 0, -1):
        for ii in range(1, 11):
            anno_path = anno_template.format(name, name, ii)
            lbl_im = np.array(Image.open(anno_path).convert("L"))
            assert lbl_im.ndim == 2, "expected greyscale"
            # if sum(seg[lbl_im > self.thresh]) > 0:
            #     print("already colored these pixels")
            #     import ipdb; ipdb.set_trace()
            seg[lbl_im > self.thresh * 255] = ii
            # plt.matshow(seg)
            # zs_dispFig()

        seg = Image.fromarray(seg, "L")
        # if self.train and False:
        #     im, seg = self.augs(im, seg)

        seg = self.label_resizer(seg)
        seg = torch.from_numpy(np.array(seg))
        data = self.resizer(im)
        data = self.transforms(data)

        if False:
            counts = torch.histc(seg.float(), bins=11, min=0, max=10)
            probs = counts / counts.sum()
            for name, prob in zip(self.classnames, probs):
                print("{}\t {:.2f}".format(name, prob))
        if self.rand_in:
            data = torch.randn(data.shape)

        if self.visualize:
            from torchvision.utils import make_grid
            from utils.visualization import norm_range
            ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy()
            plt.close("all")
            plt.axis("off")
            fig = plt.figure()  # a new figure window
            ax1 = fig.add_subplot(1, 3, 1)
            ax2 = fig.add_subplot(1, 3, 2)
            ax3 = fig.add_subplot(1, 3, 3)
            ax1.imshow(ims)
            ax2.imshow(label_colormap(seg).numpy())
            if self.downsample_labels:
                sz = tuple([x * self.downsample_labels for x in seg.size()])
                seg_ = np.array(Image.fromarray(seg.numpy()).resize(sz))
            else:
                seg_ = seg
            # ax3.matshow(seg_)
            ax3.imshow(label_colormap(seg_).numpy())
            ax3.imshow(ims, alpha=0.5)
            zs_dispFig()

        return {"data": data, "meta": {"im_path": str(im_path), "lbls": seg}}
Esempio n. 3
0
def dict_coll(batch):
    cb = torch.utils.data.dataloader.default_collate(batch)
    cb["data"] = cb["data"].reshape((-1,) + cb["data"].shape[-3:])  # Flatten to be 4D
    if False:
        from torchvision.utils import make_grid
        from utils.visualization import norm_range
        ims = norm_range(make_grid(cb["data"])).permute(1, 2, 0).cpu().numpy()
        plt.imshow(ims)
    return cb
Esempio n. 4
0
 def __getitem__(self, index):
     im_path = self.im_list[index]
     im = Image.open(im_path).convert("RGB")
     data = self.transforms(self.initial_transforms(im))
     if False:
         from torchvision.utils import make_grid
         from utils.visualization import norm_range
         ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy()
         plt.imshow(ims)
         import ipdb
         ipdb.set_trace()
     return {"data": data, "im_path": im_path}
def plot_images(image, points, path):
    C, H, W = image.size()
    points = np.array(points)
    points[:, 0] = (points[:, 0] + 1.) / 2. * (H - 1)
    points[:, 1] = (points[:, 1] + 1.) / 2. * (W - 1)
    fig = plt.figure()
    fig.set_size_inches(1., 1, forward = False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(norm_range(image).permute(1, 2, 0).cpu().numpy())
    _cmap = plt.cm.get_cmap('gist_rainbow')
    K = len(points)
    colors = [np.array(_cmap(i)[:3]) for i in np.arange(0,1,1/K)]
    for i, point in enumerate(points):
        ax.scatter(point[1], point[0], c=[colors[i]], marker='.')
    plt.savefig(path, dpi=2*image.shape[1])
    plt.close()
Esempio n. 6
0
    def __getitem__(self, index):
        im = Image.open(os.path.join(self.root,
                                     self.filenames[index])).convert("RGB")
        kp_normalized = -1  # None

        # Crop bounding box
        xmin, ymin, xmax, ymax = self.bounding_boxes[index]
        keypts = self.keypoints[index]

        # This is basically copied from matlab code and assumes matlab indexing
        bw = xmax - xmin + 1
        bh = ymax - ymin + 1
        bcy = ymin + (bh + 1) / 2
        bcx = xmin + (bw + 1) / 2

        # To simplify the preprocessing, we do two image resizes (can fix later if speed
        # is an issue)
        preresize_sz = 100

        bw_ = 52  # make the (tightly cropped) face 52px
        fac = bw_ / bw
        imr = im.resize((int(im.width * fac), int(im.height * fac)))

        bcx_ = int(np.floor(fac * bcx))
        bcy_ = int(np.floor(fac * bcy))
        bx = bcx_ - bw_ / 2 + 1
        bX = bcx_ + bw_ / 2
        by = bcy_ - bw_ / 2 + 1
        bY = bcy_ + bw_ / 2
        pp = (preresize_sz - bw_) / 2
        bx = int(bx - pp)
        bX = int(bX + pp)
        by = int(by - pp - 2)
        bY = int(bY + pp - 2)

        imr = pad_and_crop(np.array(imr), [(by - 1), bY, (bx - 1), bX])
        im = Image.fromarray(imr)

        cutl = bx - 1
        keypts = keypts.copy() * fac
        keypts[:, 0] = keypts[:, 0] - cutl
        cutt = by - 1
        keypts[:, 1] = keypts[:, 1] - cutt

        kp = keypts - 1  # from matlab to python style
        kp = kp * self.imwidth / preresize_sz
        kp = torch.tensor(kp)

        if self.pair_image:  # unsupervised contrastive learning
            # randomresizecrop is the key to generate pairs of images
            img1 = self.transforms(self.initial_transforms(im))
            img2 = self.transforms(self.initial_transforms(im))
            data = torch.cat([img1, img2], dim=0)
            if self.crop != 0:  # maybe useful for datasets other than celebA/MAFL
                data = data[:, self.crop:-self.crop, self.crop:-self.crop]
        else:  # supervised postprocessing
            data = self.transforms(self.initial_transforms(im))
            if self.crop != 0:  # maybe useful for datasets other than celebA/MAFL
                data = data[:, self.crop:-self.crop, self.crop:-self.crop]
                kp = kp - self.crop
            C, H, W = data.shape
            # kp = torch.tensor(kp)
            kp = torch.as_tensor(kp)
            kp_normalized = kp_normalize(H, W, kp)

        if self.visualize:
            # from torchvision.utils import make_grid
            from utils.visualization import norm_range
            plt.clf()
            fig = plt.figure()
            if self.pair_image:
                im1, im2 = torch.split(data, [3, 3], dim=0)
                ax = fig.add_subplot(121)
                ax.imshow(norm_range(im1).permute(1, 2, 0).cpu().numpy())
                ax = fig.add_subplot(122)
                ax.imshow(norm_range(im2).permute(1, 2, 0).cpu().numpy())
            else:
                ax = fig.add_subplot(111)
                ax.imshow(norm_range(data).permute(1, 2, 0).cpu().numpy())
                kp_x = kp[:, 0].numpy()
                kp_y = kp[:, 1].numpy()
                ax.scatter(kp_x, kp_y)
            plt.savefig('check_dataloader.png')
            print(os.path.join(self.root, self.filenames[index]))
            plt.close()

        return data, kp, kp_normalized, index
Esempio n. 7
0
    def __getitem__(self, index):
        if (not self.use_ims and not self.use_keypoints):
            # early exit when caching is used
            return {"data": torch.zeros(3, 1, 1), "meta": {"index": index}}

        if self.use_ims:
            im = Image.open(os.path.join(self.subdir, self.filenames[index]))
        # print("imread: {:.3f}s".format(time.time() - tic)) ; tic = time.time()
        kp = None
        if self.use_keypoints:
            kp = self.keypoints[index].copy()
        meta = {}

        if self.warper is not None:
            if self.warper.returns_pairs:
                # tic = time.time()
                im1 = self.initial_transforms(im.convert("RGB"))
                # print("tx1: {:.3f}s".format(time.time() - tic)) ; tic = time.time()
                im1 = TF.to_tensor(im1) * 255
                if False:
                    from utils.visualization import norm_range
                    plt.imshow(norm_range(im1).permute(1, 2, 0).cpu().numpy())
                    plt.scatter(kp[:, 0], kp[:, 1])

                im1, im2, flow, grid, kp1, kp2 = self.warper(im1,
                                                             keypts=kp,
                                                             crop=self.crop)
                # print("warper: {:.3f}s".format(time.time() - tic)) ; tic = time.time()

                im1 = im1.to(torch.uint8)
                im2 = im2.to(torch.uint8)

                C, H, W = im1.shape

                im1 = TF.to_pil_image(im1)
                im2 = TF.to_pil_image(im2)

                im1 = self.transforms(im1)
                im2 = self.transforms(im2)
                # print("tx-2: {:.3f}s".format(time.time() - tic)) ; tic = time.time()

                C, H, W = im1.shape
                data = torch.stack((im1, im2), 0)
                meta = {
                    'flow': flow[0],
                    'grid': grid[0],
                    'im1': im1,
                    'im2': im2,
                    'index': index
                }
                if self.use_keypoints:
                    meta = {**meta, **{'kp1': kp1, 'kp2': kp2}}
            else:
                im1 = self.initial_transforms(im.convert("RGB"))
                im1 = TF.to_tensor(im1) * 255

                im1, kp = self.warper(im1, keypts=kp, crop=self.crop)

                im1 = im1.to(torch.uint8)
                im1 = TF.to_pil_image(im1)
                im1 = self.transforms(im1)

                C, H, W = im1.shape
                data = im1
                if self.use_keypoints:
                    meta = {
                        'keypts': kp,
                        'keypts_normalized': kp_normalize(H, W, kp),
                        'index': index
                    }

        else:
            if self.use_ims:
                data = self.transforms(
                    self.initial_transforms(im.convert("RGB")))
                if self.crop != 0:
                    data = data[:, self.crop:-self.crop, self.crop:-self.crop]
                C, H, W = data.shape
            else:
                #  after caching descriptors, there is no point doing I/O
                H = W = self.imwidth - 2 * self.crop
                data = torch.zeros(3, 1, 1)

            if kp is not None:
                kp = kp - self.crop
                kp = torch.tensor(kp)

            if self.use_keypoints:
                meta = {
                    'keypts': kp,
                    'keypts_normalized': kp_normalize(H, W, kp),
                    'index': index
                }
        if self.visualize:
            # from torchvision.utils import make_grid
            from utils.visualization import norm_range
            num_show = 2 if self.warper and self.warper.returns_pairs else 1
            plt.clf()
            fig = plt.figure()
            for ii in range(num_show):
                im_ = data[ii] if num_show > 1 else data
                ax = fig.add_subplot(1, num_show, ii + 1)
                ax.imshow(norm_range(im_).permute(1, 2, 0).cpu().numpy())
                if self.use_keypoints:
                    if num_show == 2:
                        kp_x = meta["kp{}".format(ii + 1)][:, 0].numpy()
                        kp_y = meta["kp{}".format(ii + 1)][:, 1].numpy()
                    else:
                        kp_x = kp[:, 0].numpy()
                        kp_y = kp[:, 1].numpy()
                    ax.scatter(kp_x, kp_y)
                zs_dispFig()
                import ipdb
                ipdb.set_trace()
            #     zs.
            # if self.train:
            # else:
            #     if len(data.size()) < 4:
            #         data_ = data.unsqueeze(0)
            #     else:
            #         data_ = data
            #     for im_ in data_:
            #         plt.clf()
            #         im_ = norm_range(im_).permute(1, 2, 0).cpu().numpy()
            #         plt.imshow(im_)
            #     import ipdb; ipdb.set_trace()
            # else:
            #     ims = norm_range(make_grid(data)).permute(1, 2, 0).cpu().numpy()
            #     plt.imshow(ims)
        return {"data": data, "meta": meta}
Esempio n. 8
0
    def __getitem__(self, index):
        if self.use_ims:
            im = Image.open(os.path.join(self.root,
                                         self.filenames[index])).convert("RGB")
        # Crop bounding box
        xmin, ymin, xmax, ymax = self.bounding_boxes[index]
        keypts = self.keypoints[index]

        # This is basically copied from matlab code and assumes matlab indexing
        bw = xmax - xmin + 1
        bh = ymax - ymin + 1
        bcy = ymin + (bh + 1) / 2
        bcx = xmin + (bw + 1) / 2

        # To simplify the preprocessing, we do two image resizes (can fix later if speed
        # is an issue)
        preresize_sz = 100

        bw_ = 52  # make the (tightly cropped) face 52px
        fac = bw_ / bw
        if self.use_ims:
            imr = im.resize((int(im.width * fac), int(im.height * fac)))

        bcx_ = int(np.floor(fac * bcx))
        bcy_ = int(np.floor(fac * bcy))
        bx = bcx_ - bw_ / 2 + 1
        bX = bcx_ + bw_ / 2
        by = bcy_ - bw_ / 2 + 1
        bY = bcy_ + bw_ / 2
        pp = (preresize_sz - bw_) / 2
        bx = int(bx - pp)
        bX = int(bX + pp)
        by = int(by - pp - 2)
        bY = int(bY + pp - 2)

        if self.use_ims:
            imr = pad_and_crop(np.array(imr), [(by - 1), bY, (bx - 1), bX])
            im = Image.fromarray(imr)

        cutl = bx - 1
        keypts = keypts.copy() * fac
        keypts[:, 0] = keypts[:, 0] - cutl
        cutt = by - 1
        keypts[:, 1] = keypts[:, 1] - cutt

        kp = None
        if self.use_keypoints:
            kp = keypts - 1  # from matlab to python style
            kp = kp * self.imwidth / preresize_sz
            kp = torch.tensor(kp)
        meta = {}

        if self.warper is not None:
            if self.warper.returns_pairs:
                im1 = self.initial_transforms(im.convert("RGB"))
                im1 = TF.to_tensor(im1) * 255

                im1, im2, flow, grid, kp1, kp2 = self.warper(im1,
                                                             keypts=kp,
                                                             crop=self.crop)

                im1 = im1.to(torch.uint8)
                im2 = im2.to(torch.uint8)

                C, H, W = im1.shape

                im1 = TF.to_pil_image(im1)
                im2 = TF.to_pil_image(im2)

                im1 = self.transforms(im1)
                im2 = self.transforms(im2)

                C, H, W = im1.shape
                data = torch.stack((im1, im2), 0)
                meta = {
                    'flow': flow[0],
                    'grid': grid[0],
                    'im1': im1,
                    'im2': im2,
                    'index': index
                }
                if self.use_keypoints:
                    meta = {**meta, **{'kp1': kp1, 'kp2': kp2}}
            else:
                im1 = self.initial_transforms(im.convert("RGB"))
                im1 = TF.to_tensor(im1) * 255

                im1, kp = self.warper(im1, keypts=kp, crop=self.crop)

                im1 = im1.to(torch.uint8)
                im1 = TF.to_pil_image(im1)
                im1 = self.transforms(im1)

                C, H, W = im1.shape
                data = im1
                if self.use_keypoints:
                    meta = {
                        'keypts': kp,
                        'keypts_normalized': kp_normalize(H, W, kp),
                        'index': index
                    }

        else:
            if self.use_ims:
                data = self.transforms(
                    self.initial_transforms(im.convert("RGB")))
                if self.crop != 0:
                    data = data[:, self.crop:-self.crop, self.crop:-self.crop]
                C, H, W = data.shape
            else:
                # after caching descriptors, there is no point doing I/O
                H = W = self.imwidth - 2 * self.crop
                data = torch.zeros(3, 1, 1)

            if kp is not None:
                kp = kp - self.crop
                kp = torch.tensor(kp)

            if self.use_keypoints:
                meta = {
                    'keypts': kp,
                    'keypts_normalized': kp_normalize(H, W, kp),
                    'index': index
                }

        if self.visualize:
            from utils.visualization import norm_range
            num_show = 2 if self.warper and self.warper.returns_pairs else 1
            fig = plt.figure()
            for ii in range(num_show):
                im_ = data[ii] if num_show > 1 else data
                ax = fig.add_subplot(1, num_show, ii + 1)
                ax.imshow(norm_range(im_).permute(1, 2, 0).cpu().numpy())
                if self.use_keypoints:
                    if num_show == 2:
                        kp_x = meta["kp{}".format(ii + 1)][:, 0].numpy()
                        kp_y = meta["kp{}".format(ii + 1)][:, 1].numpy()
                    else:
                        kp_x = kp[:, 0].numpy()
                        kp_y = kp[:, 1].numpy()
                    ax.scatter(kp_x, kp_y)
            import ipdb
            ipdb.set_trace()

        return {"data": data, "meta": meta}
Esempio n. 9
0
def evaluation(config, logger=None, eval_data=None):
    device = torch.device('cuda:0' if config["n_gpu"] > 0 else 'cpu')

    if logger is None:
        logger = config.get_logger('test')

    logger.info("Running evaluation with configuration:")
    logger.info(config)

    imwidth = config['dataset']['args']['imwidth']
    root = config["dataset"]["args"]["root"]
    warp_crop_default = config['warper']['args'].get('crop', None)
    crop = config['dataset']['args'].get('crop', warp_crop_default)

    # Want explicit pair warper
    disable_warps = True
    dense_match = config.get("dense_match", False)
    if dense_match and disable_warps:
        # rotsd = 2.5
        # scalesd=0.1 * .5
        rotsd = 0
        scalesd = 0
        warp_kwargs = dict(warpsd_all=0,
                           warpsd_subset=0,
                           transsd=0,
                           scalesd=scalesd,
                           rotsd=rotsd,
                           im1_multiplier=1,
                           im1_multiplier_aff=1)
    else:
        warp_kwargs = dict(warpsd_all=0.001 * .5,
                           warpsd_subset=0.01 * .5,
                           transsd=0.1 * .5,
                           scalesd=0.1 * .5,
                           rotsd=5 * .5,
                           im1_multiplier=1,
                           im1_multiplier_aff=1)
    warper = tps.Warper(imwidth, imwidth, **warp_kwargs)
    if eval_data is None:
        eval_data = config["dataset"]["type"]
    constructor = getattr(module_data, eval_data)

    # handle the case of the MAFL split, which by default will evaluate on Celeba
    kwargs = {
        "val_split": "mafl"
    } if eval_data == "CelebAPrunedAligned_MAFLVal" else {}
    val_dataset = constructor(
        train=False,
        pair_warper=warper,
        use_keypoints=True,
        imwidth=imwidth,
        crop=crop,
        root=root,
        **kwargs,
    )
    # NOTE: Since the matching is performed with pairs, we fix the ordering and then
    # use all pairs for datasets with even numbers of images, and all but one for
    # datasets that have odd numbers of images (via drop_last=True)
    data_loader = DataLoader(val_dataset,
                             batch_size=2,
                             collate_fn=dict_coll,
                             shuffle=False,
                             drop_last=True)

    # build model architecture
    model = get_instance(module_arch, 'arch', config)
    model.summary()

    # load state dict
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    # checkpoint = torch.load(config["weights"])
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(clean_state_dict(state_dict))
    if config['n_gpu'] > 1:
        model = model.module

    model = model.to(device)
    model.train()

    if dense_match:
        warp_dir = Path(config["warp_dir"]) / config["name"]
        warp_dir = warp_dir / "disable_warps{}".format(disable_warps)
        if not warp_dir.exists():
            warp_dir.mkdir(exist_ok=True, parents=True)
        writer = SummaryWriter(warp_dir)

    model.eval()
    same_errs = []
    diff_errs = []

    torch.manual_seed(0)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            data, meta = batch["data"], batch["meta"]

            if (config.get("mini_eval", False) and i > 3):
                break
            # if i == 0:
            #     # Checksum to make sure warps are deterministic
            #     if True:
            #         # redo later
            #         if data.shape[2] == 64:
            #             assert float(data.sum()) == -553.9221801757812
            #         elif data.shape[2] == 128:
            #             assert float(data.sum()) == 754.1907348632812

            data = data.to(device)
            output = model(data)

            descs = output[0]
            descs1 = descs[0::2]  # 1st in pair (more warped)
            descs2 = descs[1::2]  # 2nd in pair
            ims1 = data[0::2].cpu()
            ims2 = data[1::2].cpu()

            im_source = ims1[0]
            im_same = ims2[0]
            im_diff = ims2[1]

            C, imH, imW = im_source.shape
            B, C, H, W = descs1.shape
            stride = imW / W

            desc_source = descs1[0]
            desc_same = descs2[0]
            desc_diff = descs2[1]

            if not dense_match:
                kp1 = meta['kp1']
                kp2 = meta['kp2']
                kp_source = kp1[0]
                kp_same = kp2[0]
                kp_diff = kp2[1]

            if config.get("vis", False):
                fig = plt.figure()  # a new figure window
                ax1 = fig.add_subplot(1, 3, 1)
                ax2 = fig.add_subplot(1, 3, 2)
                ax3 = fig.add_subplot(1, 3, 3)

                ax1.imshow(norm_range(im_source).permute(1, 2, 0))
                ax2.imshow(norm_range(im_same).permute(1, 2, 0))
                ax3.imshow(norm_range(im_diff).permute(1, 2, 0))

                if not dense_match:
                    ax1.scatter(kp_source[:, 0], kp_source[:, 1], c='g')
                    ax2.scatter(kp_same[:, 0], kp_same[:, 1], c='g')
                    ax3.scatter(kp_diff[:, 0], kp_diff[:, 1], c='g')

            if False:
                fsrc = F.normalize(desc_source, p=2, dim=0)
                fsame = F.normalize(desc_same, p=2, dim=0)
                fdiff = F.normalize(desc_diff, p=2, dim=0)
            else:
                fsrc = desc_source.clone()
                fsame = desc_same.clone()
                fdiff = desc_diff.clone()

            if dense_match:
                # if False:
                #     print("DEBUGGING WITH IDENTICAL FEATS")
                #     fdiff = fsrc
                # tic = time.time()
                grid = dense_desc_match(fsrc, fdiff)
                im_warped = F.grid_sample(im_source.view(1, 3, imH, imW), grid)
                im_warped = im_warped.squeeze(0)
                # print("done matching in {:.3f}s".format(time.time() - tic))
                plt.close("all")
                if config["subplots"]:
                    fig = plt.figure()  # a new figure window
                    ax1 = fig.add_subplot(1, 3, 1)
                    ax2 = fig.add_subplot(1, 3, 2)
                    ax3 = fig.add_subplot(1, 3, 3)
                    ax1.imshow(norm_range(im_source).permute(1, 2, 0))
                    ax2.imshow(norm_range(im_diff).permute(1, 2, 0))
                    ax3.imshow(norm_range(im_warped).permute(1, 2, 0))
                    triplet_dest = warp_dir / "triplet-{:05d}.jpg".format(i)
                    fig.savefig(triplet_dest)
                else:
                    triplet_dest_dir = warp_dir / "triplet-{:05d}".format(i)
                    if not triplet_dest_dir.exists():
                        triplet_dest_dir.mkdir(exist_ok=True, parents=True)
                    for jj, im in enumerate((im_source, im_diff, im_warped)):
                        plt.axis("off")
                        fig = plt.figure(figsize=(1.5, 1.5))
                        ax = plt.Axes(fig, [0., 0., 1., 1.])
                        ax.set_axis_off()
                        fig.add_axes(ax)
                        # ax.imshow(data, cmap = plt.get_cmap("bone"))
                        im_ = norm_range(im).permute(1, 2, 0)
                        ax.imshow(im_)
                        dest_path = triplet_dest_dir / "im-{}-{}.jpg".format(
                            i, jj)
                        plt.savefig(str(dest_path), dpi=im_.shape[0])
                        # plt.savefig(filename, dpi = sizes[0])
                writer.add_figure('warp-triplets', fig)
            else:
                for ki, kp in enumerate(kp_source):
                    x, y = np.array(kp)
                    gt_same_x, gt_same_y = np.array(kp_same[ki])
                    gt_diff_x, gt_diff_y = np.array(kp_diff[ki])
                    same_x, same_y = find_descriptor(x, y, fsrc, fsame, stride)

                    err = compute_pixel_err(
                        pred_x=same_x,
                        pred_y=same_y,
                        gt_x=gt_same_x,
                        gt_y=gt_same_y,
                        imwidth=imwidth,
                        crop=crop,
                    )
                    same_errs.append(err)
                    diff_x, diff_y = find_descriptor(x, y, fsrc, fdiff, stride)
                    err = compute_pixel_err(
                        pred_x=diff_x,
                        pred_y=diff_y,
                        gt_x=gt_diff_x,
                        gt_y=gt_diff_y,
                        imwidth=imwidth,
                        crop=crop,
                    )
                    diff_errs.append(err)
                    if config.get("vis", False):
                        ax2.scatter(same_x, same_y, c='b')
                        ax3.scatter(diff_x, diff_y, c='b')

            if config.get("vis", False):
                zs_dispFig()
                fig.savefig('/tmp/matching.pdf')

    print("")  # cleanup print from tqdm subtraction
    logger.info("Matching Metrics:")
    logger.info(f"Mean Pixel Error (same-identity): {np.mean(same_errs)}")
    logger.info(f"Mean Pixel Error (different-identity) {np.mean(diff_errs)}")
Esempio n. 10
0
def tween_scatter(t,
                  im1,
                  im2,
                  scatter1,
                  scatter2,
                  title1,
                  title2,
                  fade_ims=True,
                  heading1=None,
                  heading2=None,
                  frame=None,
                  is_dve=None):
    ax_reset()
    base_subplot = plt.gca()
    plt.subplot(base_subplot)

    gridsize = int(np.sqrt(len(im1)))

    inner_grid = matplotlib.gridspec.GridSpec(gridsize,
                                              gridsize,
                                              hspace=0.05,
                                              wspace=0.05)
    bb = base_subplot.get_position()
    l, b, r, tp = bb.extents
    inner_grid.update(left=l, bottom=b, right=r, top=tp)

    if fade_ims:
        prev_alpha = np.maximum(0., 1 - 2 * t)
        cur_alpha = np.maximum(0., -1 + 2 * t)
    else:
        prev_alpha = 0.
        cur_alpha = 1.

    for gi in range(gridsize**2):
        gax = plt.gcf().add_subplot(inner_grid[gi])

        ax_reset()

        if prev_alpha:
            plt.imshow(norm_range(im1[gi]).permute(1, 2, 0), alpha=prev_alpha)
        if cur_alpha:
            plt.imshow(norm_range(im2[gi]).permute(1, 2, 0), alpha=cur_alpha)

        ease = (-np.cos(np.pi * t) + 1) / 2
        scatter_tween = (1 - ease) * scatter1[gi] + ease * scatter2[gi]

        fac = plt.gca().get_position().width / base_subplot.get_position(
        ).width
        plt.scatter(scatter_tween[:, 0],
                    scatter_tween[:, 1],
                    c=rainbow,
                    s=(matplotlib.rcParams['lines.markersize'] * fac)**2)

        if frame == args.hq_frame_snapshot and args.save_hq_ims:
            # create temp figure inline
            prev_fig = plt.gcf()
            plt.figure(figsize=(15, 15))
            inline_ax = plt.subplot(1, 1, 1)
            plt.sca(inline_ax)
            plt.imshow(norm_range(im1[gi]).permute(1, 2, 0))
            plt.xticks([], [])
            plt.yticks([], [])
            fname = f"frame{frame}-match-face{gi}"
            if is_dve is not None and is_dve:
                fname += "-dve"
            plt.scatter(scatter2[gi][:, 0],
                        scatter2[gi][:, 1],
                        c=rainbow,
                        s=(matplotlib.rcParams['lines.markersize'] * 8)**2)
            plt.savefig(str(Path(args.fig_dir) / f"{fname}.png"))
            zs_dispFig()

            # return to prev figure
            plt.figure(prev_fig.number)

    plt.sca(base_subplot)
    ttl1 = plt.text(0.5,
                    -.08,
                    title1,
                    transform=base_subplot.transAxes,
                    horizontalalignment='center')
    ttl2 = plt.text(0.5,
                    -.08,
                    title2,
                    transform=base_subplot.transAxes,
                    horizontalalignment='center')
    if title1 == title2:
        ttl1.set_alpha(1)
        ttl2.set_alpha(0)
    else:
        ttl1.set_alpha(1 - t)
        ttl2.set_alpha(t)

    if heading2 is not None:
        h1 = plt.suptitle(heading1, x=0.5, y=0.94)
        h2 = plt.text(*h1.get_position(), heading2)

        foot = plt.text(
            0.5, 0.08,
            'DVE enables the use of higher dimensional unsupervised embeddings!'
        )
        foot.update_from(h1)

        h1.set_fontsize('x-large')
        h2.update_from(h1)

        # Prevent flashing from font aliasing and alpha - brittle if not monospace and makes it too bold though
        cover = ''.join([
            heading1[i] if heading1[i] == heading2[i] else '\u00a0'
            for i in range(min(len(heading1), len(heading2)))
        ])
        hc = plt.text(*h1.get_position(), cover)
        hc.update_from(h1)

        h1.set_alpha(1 - t)
        h2.set_alpha(t)
Esempio n. 11
0
plt.figure(figsize=(7, 3))

query_ax = plt.subplot(1, 3, 2)
nodve_ax = plt.subplot(1, 3, 1, frameon=False)
dve_ax = plt.subplot(1, 3, 3, frameon=False)

nodve_ax.axis('square')
grow_axis(nodve_ax, 0.05)
nudge_axis(nodve_ax, 0.03)

dve_ax.axis('square')
grow_axis(dve_ax, 0.05)
nudge_axis(dve_ax, -0.03)

plt.sca(query_ax)
plt.imshow(norm_range(avface_tensor).permute(1, 2, 0))
rainbow = plt.cm.Spectral(np.linspace(0, 1, npts))
plt.xlabel('Query')
plt.gca().set_prop_cycle('color', rainbow)
grow_axis(query_ax, -0.05)
plt.xticks([], [])
plt.yticks([], [])

fac = plt.gca().get_position().width / dve_ax.get_position().width

for i in i_idxs:
    for j in j_idxs:
        plt.scatter(j, i, s=(matplotlib.rcParams['lines.markersize'] * fac)**2)


def ax_reset():