Exemple #1
0
def plot_sample_preds(images, labels, calnet_preds, pred_dist, dataset):

    n_plotted_preds = 5

    n_cols = n_plotted_preds + 4
    n_rows = pred_dist.shape[1]

    fig = plt.figure(figsize=(n_cols + 2, n_rows + 2))
    canvas = FigureCanvasAgg(fig)

    # plot sample predictions
    for i in range(n_rows):

        # plot input
        plt.subplot(n_rows, n_cols, i * n_cols + 1)
        plottable_images = move_color_channel(de_torch((images[i] + 1) / 2))
        if plottable_images.shape[-1] == 1:
            plottable_images = plottable_images.squeeze()
        plt.imshow(plottable_images, interpolation="none")
        if i == 0: plt.title("Input")
        plt.xticks([])
        plt.yticks([])

        for j in range(n_cols - 4):
            plottable_pred = _recolour_label(de_torch(
                _1hot_2_2d(pred_dist[j, i, :, :], sample=True)),
                                             dataset=dataset)
            plt.subplot(n_rows, n_cols, i * n_cols + j + 2)
            plt.imshow(plottable_pred, interpolation="none")
            if i == 0: plt.title(f"Pred {j + 1}")
            plt.xticks([])
            plt.yticks([])

        # plot average predictions
        plt.subplot(n_rows, n_cols, i * n_cols + n_cols - 2)
        plottable_avg_pred = _recolour_label(de_torch(
            _1hot_2_2d(pred_dist[:, i, :, :].mean(0), sample=True)),
                                             dataset=dataset)
        plt.imshow(plottable_avg_pred, interpolation="none")
        if i == 0: plt.title(f"Avg Pred\nN = {pred_dist.shape[0]}")
        plt.xticks([])
        plt.yticks([])

        # plot calibration net predictions
        plt.subplot(n_rows, n_cols, i * n_cols + n_cols - 1)
        plottable_calnet_pred = _recolour_label(de_torch(
            _1hot_2_2d(calnet_preds[i], sample=True)),
                                                dataset=dataset)
        plt.imshow(plottable_calnet_pred, interpolation="none")
        if i == 0: plt.title("Cal Pred")
        plt.xticks([])
        plt.yticks([])

        # plot actual predictions
        plt.subplot(n_rows, n_cols, i * n_cols + n_cols)

        if labels.shape[1] != LABELS_CHANNELS:
            label = torch.eye(
                LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute(
                    0, 3, 1, 2)[i]  # convert rgb label to one-hot
        else:
            label = labels[i]

        plottable_label = _recolour_label(de_torch(
            _1hot_2_2d(label, sample=True)),
                                          dataset=dataset)
        plt.imshow(plottable_label, interpolation="none")
        if i == 0: plt.title("Label 0")
        plt.xticks([])
        plt.yticks([])

    fig.suptitle('Sample predictions')

    # convert figure to array
    canvas.draw()
    _, (width, height) = canvas.print_to_buffer()
    s = canvas.tostring_rgb()

    plt.close(fig)

    return np.fromstring(s, dtype='uint8').reshape((height, width, 3))
Exemple #2
0
def plot_comparison_figure(batch, calnet_preds, fake_labels, al_maps,
                           gan_al_maps, generator, calibration_net,
                           discriminator, args):

    if args.dataset == "LIDC":
        images, labels, gt_dist = unpack_batch(batch)
        gt_labels = None
        lidc_norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
    else:
        images, labels = unpack_batch(batch)
        gt_dist = None
        gt_labels = None
        lidc_norm = None

    if (args.dataset == "CITYSCAPES19" and args.class_flip):
        gt_labels = labels.clone()
        labels = torch.eye(LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute(
            0, 3, 1, 2)
        bb_preds = batch["bb_preds"].to(DEVICE).float()
        bb_preds = torch.eye(LABELS_CHANNELS)[
            bb_preds[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE)
    else:
        bb_preds = None

    # check used model types
    use_calnet = args.calibration_net != "EmptyCalNet"
    use_generator = args.generator != "EmptyGenerator"

    # free up some space
    del (calibration_net, generator, discriminator)

    # initialize figure size arguments
    n_pics = 5
    n_plots = 4 if not use_calnet else 6
    n_plots = n_plots + 1 if args.dataset == "CITYSCAPES19" else n_plots

    # initialize figure
    fig = plt.figure(figsize=(n_plots * 2 + 2, n_pics * 2))
    canvas = FigureCanvasAgg(fig)

    for idx in range(n_pics):

        extra = 0

        # convert to plottable format
        plottable_images = move_color_channel(de_torch(
            (images[idx] + 1) / 2))  # +1/2 to normalize between 0 and 1
        if plottable_images.shape[-1] == 1:
            plottable_images = plottable_images.squeeze()

        if args.dataset == "LIDC":
            pad = lambda x: np.pad(x.cpu().numpy(),
                                   pad_width=2,
                                   mode='constant',
                                   constant_values=1)
            glued_top = np.concatenate(
                (pad(gt_dist[idx, 0]), pad(gt_dist[idx, 1])), axis=1)
            glued_bottom = np.concatenate(
                (pad(gt_dist[idx, 2]), pad(gt_dist[idx, 3])), axis=1)
            plottable_t_labels = np.concatenate([glued_top, glued_bottom],
                                                axis=0)
        else:
            plottable_t_labels = _recolour_label(de_torch(
                _1hot_2_2d(labels[idx], sample=True)),
                                                 dataset=args.dataset)

        if args.dataset == "CITYSCAPES19":
            plottable_bb_preds = _recolour_label(de_torch(
                _1hot_2_2d(bb_preds[idx], sample=True)),
                                                 dataset=args.dataset)

        if use_generator:
            plottable_f_labels = _recolour_label(de_torch(
                _1hot_2_2d(fake_labels[idx], sample=True)),
                                                 dataset=args.dataset)

        if use_calnet:
            plottable_al_maps = de_torch(al_maps[idx])
            plottable_calnet_preds = _recolour_label(de_torch(
                _1hot_2_2d(calnet_preds[idx], sample=True)),
                                                     dataset=args.dataset)

        if use_generator:
            plottable_gan_al_maps = de_torch(gan_al_maps[idx])

        # plot figure

        # input image
        plt.subplot(n_pics, n_plots, idx * n_plots + 1)
        plt.imshow(plottable_images, interpolation="none")
        if idx == 0: plt.title("Input Image")
        plt.xticks([])
        plt.yticks([])

        # true label
        plt.subplot(n_pics, n_plots, idx * n_plots + 2)
        plt.imshow(plottable_t_labels, norm=lidc_norm, interpolation="none")
        if idx == 0: plt.title("Label")
        plt.xticks([])
        plt.yticks([])

        if args.dataset == "CITYSCAPES19":
            # black-box net prediction
            extra += 1
            plt.subplot(n_pics, n_plots, idx * n_plots + 2 + extra)
            plt.imshow(plottable_bb_preds, interpolation="none")
            if idx == 0: plt.title("BB Pred")
            plt.xticks([])
            plt.yticks([])

        if use_calnet:
            # calibration net prediction
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)
            plt.imshow(plottable_calnet_preds,
                       norm=lidc_norm,
                       interpolation="none")
            if idx == 0: plt.title("CalNet Pred")
            plt.xticks([])
            plt.yticks([])

            extra += 1

        if use_generator:
            # final prediction
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)
            plt.imshow(plottable_f_labels,
                       norm=lidc_norm,
                       interpolation="none")
            if idx == 0: plt.title("RefNet Pred")
            plt.xticks([])
            plt.yticks([])

            extra += 1

        if use_calnet:
            # calibration pred aleatoric uncertainty
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)

            al_norm = matplotlib.colors.Normalize(
                vmin=0, vmax=MAX_ALEATORIC
            )  # set range into which we normalize the aleatoric unc maps

            # make sure the aleatoric uncertainty is within range
            assert al_maps.max(
            ) <= MAX_ALEATORIC_GT, "Predicted aleatoric uncertainty not within range: True = 0 < " + str(
                MAX_ALEATORIC) + ", Plottable = " + str(
                    al_maps.min().item()) + " < " + str(al_maps.max().item())

            plt.imshow(plottable_al_maps,
                       cmap='hot',
                       norm=al_norm,
                       interpolation="none")
            if idx == 0: plt.title("CalNet Aleatoric")
            plt.xticks([])
            plt.yticks([])
            extra += 1

        if use_generator:
            # generator aleatoric uncertinty
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)
            al_norm = matplotlib.colors.Normalize(
                vmin=0, vmax=MAX_ALEATORIC
            )  # set range into which we normalize the aleatoric unc maps
            plt.imshow(plottable_gan_al_maps,
                       cmap='hot',
                       norm=al_norm,
                       interpolation="none")
            if idx == 0: plt.title("RefNet Aleatoric")
            plt.xticks([])
            plt.yticks([])
            extra += 1

    canvas.draw()
    _, (width, height) = canvas.print_to_buffer()
    s = canvas.tostring_rgb()

    plt.close(fig)
    return np.fromstring(s, dtype='uint8').reshape((height, width, 3))
Exemple #3
0
            transformations.ChangeChannels()
        ]
    )

    dataset = Cityscapes19(mode="test", transform=transform) #TODO NO TEST DIRECTORY IN PROCESSED

    batch_size = 5

    data = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=True, pin_memory=True, num_workers=16)

    data_bar = tqdm(data)

    for i, (batches) in enumerate(data_bar):

        # Visualize batches for DEBUG
        batch_1 = list(iter(data))[8]

        image_1, labels_1 = unpack_batch(batch_1)

        p_preds = batch_1["bb_preds"].to(DEVICE).float()

        image_1 = 255*(image_1+1)/2
        labels_1 = _recolour_label(_1hot_2_2d(labels_1,dim=1), dataset="CITYSCAPES19").permute(0,3,1,2).float().to(DEVICE)
        p_preds = _recolour_label(_1hot_2_2d(p_preds, dim=1), dataset="CITYSCAPES19").permute(0, 3, 1, 2).float().to(DEVICE)

        batch = torch.cat((image_1, labels_1, p_preds), dim=0)

        plt.figure(figsize=(5,10))
        plt.imshow(vutils.make_grid(batch, nrow=batch_size, normalize=True).cpu().numpy().transpose(1, 2, 0))
        plt.show()
Exemple #4
0
    dataset = Cityscapes35(mode="train", transform=transform)

    batch_size = 3

    data = DataLoader(dataset,
                      shuffle=False,
                      batch_size=batch_size,
                      drop_last=True,
                      num_workers=4)

    data_bar = tqdm(data)

    for i, (batches) in enumerate(data_bar):
        # Visualize batches for DEBUG

        batch_1 = batches

        image_1, labels_1 = unpack_batch(batch_1)

        labels_1 = _recolour_label(_1hot_2_2d(labels_1, dim=1),
                                   dataset="CITYSCAPES35").permute(
                                       0, 3, 1, 2).float().to(DEVICE)

        batch = torch.cat((image_1, labels_1), dim=0)

        plt.figure(1)
        plt.imshow(
            vutils.make_grid(batch, nrow=batch_size,
                             normalize=True).cpu().numpy().transpose(1, 2, 0))
        plt.show()
Exemple #5
0
                      batch_size=batch_size,
                      drop_last=True,
                      num_workers=0)

    plotting_batches = next(iter(data))
    batch_1 = plotting_batches

    data_bar = tqdm(data)

    for i, (batches) in enumerate(data_bar):
        # Visualize batches for DEBUG
        image_1, labels_1, dist = unpack_batch(batches)

        image_1 = (image_1 + 1) / 2

        labels_1 = _1hot_2_2d(labels_1, dim=1).float().to(
            constants.DEVICE).unsqueeze(dim=1).repeat(1, 3, 1, 1)

        pad = lambda x: np.pad(
            x.cpu().numpy(), pad_width=2, mode='constant', constant_values=1)

        glued_top = np.concatenate((pad(dist[1, 0]), pad(dist[1, 1])), axis=1)
        glued_bottom = np.concatenate((pad(dist[1, 2]), pad(dist[1, 3])),
                                      axis=1)
        glued_all = np.concatenate([glued_top, glued_bottom], axis=0)

        batch = torch.cat((image_1, labels_1), dim=0)

        plt.figure(1)
        plt.imshow(
            vutils.make_grid(batch, nrow=batch_size,
                             normalize=True).cpu().numpy().transpose(1, 2, 0))
Exemple #6
0
    def plot_sample_preds(self, images, labels, calnet_preds, pred_dist,
                          gt_dist, n_preds, dataset):

        n_plotted_preds = 5 if n_preds > 5 else n_preds

        n_cols = n_plotted_preds + 4
        n_rows = pred_dist.shape[1]

        fig = plt.figure(figsize=(n_cols + 2, n_rows + 2))
        canvas = FigureCanvasAgg(fig)

        if dataset == "LIDC":
            lidc_norm = matplotlib.colors.Normalize(vmin=0, vmax=1)

        # plot sample predictions
        for i in range(n_rows):

            # plot input
            plt.subplot(n_rows, n_cols, i * n_cols + 1)
            plottable_images = move_color_channel(de_torch(
                (images[i] + 1) / 2))
            if plottable_images.shape[-1] == 1:
                plottable_images = plottable_images.squeeze()
            plt.imshow(plottable_images, interpolation="none")
            if i == 0: plt.title("Input")
            plt.xticks([])
            plt.yticks([])

            for j in range(n_cols - 4):
                plottable_pred = _recolour_label(de_torch(
                    _1hot_2_2d(pred_dist[j, i, :, :], sample=True)),
                                                 dataset=dataset)

                plt.subplot(n_rows, n_cols, i * n_cols + j + 2)
                if dataset == "LIDC":
                    plt.imshow(plottable_pred,
                               norm=lidc_norm,
                               interpolation="none")
                else:
                    plt.imshow(plottable_pred, interpolation="none")
                if i == 0: plt.title(f"Pred {j + 1}")
                plt.xticks([])
                plt.yticks([])

            # plot average predictions
            plt.subplot(n_rows, n_cols, i * n_cols + n_cols - 2)
            plottable_avg_pred = _recolour_label(de_torch(
                _1hot_2_2d(pred_dist[:, i, :, :].mean(0), sample=True)),
                                                 dataset=dataset)
            if dataset == "LIDC":
                plt.imshow(plottable_avg_pred,
                           norm=lidc_norm,
                           interpolation="none")
            else:
                plt.imshow(plottable_avg_pred, interpolation="none")
            if i == 0: plt.title(f"Avg Pred\nN = {pred_dist.shape[0]}")
            plt.xticks([])
            plt.yticks([])

            # plot calibration net predictions
            plt.subplot(n_rows, n_cols, i * n_cols + n_cols - 1)
            plottable_calnet_pred = _recolour_label(de_torch(
                _1hot_2_2d(calnet_preds[i], sample=True)),
                                                    dataset=dataset)

            if dataset == "LIDC":
                plt.imshow(plottable_calnet_pred,
                           norm=lidc_norm,
                           interpolation="none")
            else:
                plt.imshow(plottable_calnet_pred, interpolation="none")
            if i == 0: plt.title("CalNet Pred")
            plt.xticks([])
            plt.yticks([])

            # plot actual predictions
            plt.subplot(n_rows, n_cols, i * n_cols + n_cols)

            if gt_dist is None:
                if labels.shape[1] != LABELS_CHANNELS:
                    label = torch.eye(
                        LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute(
                            0, 3, 1, 2)[i]  # convert rgb label to one-hot
                else:
                    label = labels[i]
                plottable_label = _recolour_label(de_torch(
                    _1hot_2_2d(label, sample=True)),
                                                  dataset=dataset)
            else:
                pad = lambda x: np.pad(x.cpu().numpy(),
                                       pad_width=2,
                                       mode='constant',
                                       constant_values=1)
                glued_top = np.concatenate(
                    (pad(gt_dist[i, 0]), pad(gt_dist[i, 1])), axis=1)
                glued_bottom = np.concatenate(
                    (pad(gt_dist[i, 2]), pad(gt_dist[i, 3])), axis=1)
                plottable_label = np.concatenate([glued_top, glued_bottom],
                                                 axis=0)

            if dataset == "LIDC":
                plt.imshow(plottable_label,
                           norm=lidc_norm,
                           interpolation="none")
            else:
                plt.imshow(plottable_label, interpolation="none")
            if i == 0: plt.title("Label 0")
            plt.xticks([])
            plt.yticks([])

        fig.suptitle('Sample predictions')

        # convert figure to array
        canvas.draw()
        _, (width, height) = canvas.print_to_buffer()
        s = canvas.tostring_rgb()

        plt.close(fig)

        return np.fromstring(s, dtype='uint8').reshape((height, width, 3))