예제 #1
0
def main(input, label, img_transforms=None, tensor_transforms=None):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.img_utils import tensor2rgb
    from hyperseg.datasets.seg_transforms import Compose
    from PIL import Image

    # Initialize transforms
    img_transforms = obj_factory(
        img_transforms) if img_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    transform = Compose(img_transforms + tensor_transforms)

    # Read input image and corresponding label
    img = Image.open(input).convert('RGB')
    lbl = Image.open(label)
    palette = lbl.getpalette()

    # Apply transformations
    img_t, lbl_t = transform(img, lbl)

    if isinstance(img_t, (list, tuple)):
        img_t = img_t[-1]
        if lbl_t.shape[-2:] != img_t.shape[-2:]:
            lbl_t = lbl_t
            lbl_t = interpolate(lbl_t.float().view(1, 1, *lbl_t.shape),
                                img_t.shape[-2:],
                                mode='nearest').long().squeeze()

    # Render results
    img, lbl = np.array(img), np.array(lbl.convert('RGB'))
    img_t = img_t[0] if isinstance(img_t, (list, tuple)) else img_t
    img_t = tensor2rgb(img_t)
    lbl_t = Image.fromarray(lbl_t.squeeze().numpy().astype('uint8'), mode='P')
    lbl_t.putpalette(palette)
    lbl_t = np.array(lbl_t.convert('RGB'))

    render_img_orig = np.concatenate((img, lbl), axis=1)
    render_img_transformed = np.concatenate((img_t, lbl_t), axis=1)
    f, ax = plt.subplots(2, 1, figsize=(8, 8))
    ax[0].imshow(render_img_orig)
    ax[1].imshow(render_img_transformed)
    plt.show()
    pass
예제 #2
0
def display_subset(dataset,
                   indices,
                   model,
                   device,
                   batch_size=16,
                   scale=0.5,
                   alpha=0.75,
                   with_input=True,
                   dpi=100,
                   display_sources=None,
                   ignore_index=0):
    data_loader = DataLoader(Subset(dataset, indices),
                             batch_size=batch_size,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=False,
                             shuffle=False)
    inputs, preds, targets = [], [], []
    for i, (input, target) in enumerate(
            tqdm(data_loader, unit='batches', file=sys.stdout)):
        # Prepare input
        if isinstance(input, (list, tuple)):
            for j in range(len(input)):
                input[j] = input[j].to(device)
        else:
            input = input.to(device)
        target = target.to(device)

        # Execute model
        pred = model(input)

        # Append
        inputs.append(input[0].cpu() if isinstance(input, (
            list, tuple)) else input.cpu())
        preds.append(pred.cpu())
        targets.append(target.cpu())
    inputs = torch.cat(inputs, dim=0)
    preds = torch.cat(preds, dim=0)
    targets = torch.cat(targets, dim=0)
    inputs = (inputs - inputs.min()) / (inputs.max() - inputs.min()) * 2. - 1.

    # Load display sources
    source_images = []
    if display_sources is not None:
        for display_source in display_sources:
            img_paths = glob(os.path.join(display_source, '*.png'))
            assert len(img_paths) == len(dataset), 'all display sources must be directories with the same number' \
                                                   ' of images as the dataset'
            img_paths = np.array(img_paths)[indices]
            imgs = []
            size = inputs.shape[-2:][::-1]
            for img_path in img_paths:
                img = Image.open(img_path)
                padding = (0, 0) + tuple(
                    np.maximum(size - np.array(img.size),
                               0))  # left, top, right and bottom
                img = pad(img, padding)
                imgs.append(img)
            imgs = [
                torch.from_numpy(np.array(img).astype('long')).unsqueeze(0)
                for img in imgs
            ]
            imgs = torch.cat(imgs, dim=0).to(device)
            source_images.append(imgs)

    seg_sources = [
        blend_seg(inputs, src_img, dataset.color_map, alpha=alpha)
        for src_img in source_images
    ]
    seg_pred = blend_seg(inputs,
                         preds,
                         dataset.color_map,
                         alpha=alpha,
                         ignore_index=ignore_index)
    seg_gt = blend_seg(inputs,
                       targets,
                       dataset.color_map,
                       alpha=alpha,
                       ignore_index=ignore_index)
    if with_input:
        grid = make_grid(inputs,
                         *seg_sources,
                         seg_pred,
                         seg_gt,
                         normalize=False,
                         padding=0)
    else:
        grid = make_grid(*seg_sources,
                         seg_pred,
                         seg_gt,
                         normalize=False,
                         padding=0)
    grid = tensor2rgb(grid)

    fig_size = tuple((np.array(grid.shape[1::-1]) * scale // dpi).astype(int))
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    plt.imshow(grid)
    plt.axis('off')
    plt.show()