def forward(self, generator: GeneralGenerator,
                discriminator: GeneralDiscriminator,
                batch_1: Dict[str, torch.Tensor],
                batch_2: Dict[str, torch.Tensor],
                batch_3: Dict[str, torch.Tensor]) \
            -> Tuple[
                Any, Dict, torch.Tensor, torch.Tensor, torch.Tensor
            ]:
        """ combined loss function from the tiple-cons paper """

        # prepare input
        image_1, landmarks_1 = unpack_batch(batch_1)
        image_2, landmarks_2 = unpack_batch(batch_2)
        _, landmarks_3 = unpack_batch(batch_3)
        image_1 = image_1.to(DEVICE).float()
        image_2 = image_2.to(DEVICE).float()
        landmarks_2 = landmarks_2.to(DEVICE).float()

        target_landmarked_input = torch.cat((image_1, landmarks_2),
                                            dim=CHANNEL_DIM)
        target_landmarked_truth = torch.cat((image_2, landmarks_2),
                                            dim=CHANNEL_DIM)
        fake = generator(target_landmarked_input)
        target_landmarked_fake = torch.cat((fake, landmarks_2),
                                           dim=CHANNEL_DIM)

        total_loss = 0

        real_feats, fake_feats = None, None

        if self.pp.active or self.id.active:
            if self.pp.active and self.id.active:
                feature_selection = (13, 3, 8, 15, 24)
            elif self.pp.active:
                feature_selection = (None, 3, 8, 15, 24)
            elif self.id.active:
                feature_selection = (13)

            real_feats, fake_feats = self.get_features(image_2, fake,
                                                       feature_selection)

        # adverserial loss
        loss_adv, save_adv = self.adv(target_landmarked_fake, discriminator)
        total_loss += loss_adv
        del loss_adv
        target_landmarked_fake.detach()

        # consistency losses
        loss_self, save_self = self.self(image_1, fake, landmarks_1, generator)
        total_loss += loss_self
        del loss_self
        landmarks_1.detach()
        del landmarks_1

        loss_triple, save_triple = self.trip(image_1, fake, landmarks_3,
                                             landmarks_2, generator)
        total_loss += loss_triple
        del loss_triple
        landmarks_2.detach()
        del landmarks_2

        # pixel losses
        loss_pix, save_pix = self.pix(image_2, fake)
        total_loss += loss_pix
        del loss_pix
        image_2.detach()
        del image_2

        # style losses
        loss_pp, save_pp = self.pp(real_feats, fake_feats)
        total_loss += loss_pp
        del loss_pp

        # loss_id, save_id = self.id(image_1, fake)
        loss_id, save_id = self.id(real_feats, fake_feats)
        total_loss += loss_id
        del loss_id

        # merge dicts
        merged = {
            **save_adv,
            **save_pix,
            **save_pp,
            **save_self,
            **save_triple,
            **save_id
        }

        return total_loss, merged, fake.detach(
        ), target_landmarked_fake.detach(), target_landmarked_truth.detach()
Example #2
0
def plot_comparison_figure(batch, calnet_preds, fake_labels, al_maps,
                           gan_al_maps, generator, calibration_net,
                           discriminator, args):

    if args.dataset == "LIDC":
        images, labels, gt_dist = unpack_batch(batch)
        gt_labels = None
        lidc_norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
    else:
        images, labels = unpack_batch(batch)
        gt_dist = None
        gt_labels = None
        lidc_norm = None

    if (args.dataset == "CITYSCAPES19" and args.class_flip):
        gt_labels = labels.clone()
        labels = torch.eye(LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute(
            0, 3, 1, 2)
        bb_preds = batch["bb_preds"].to(DEVICE).float()
        bb_preds = torch.eye(LABELS_CHANNELS)[
            bb_preds[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE)
    else:
        bb_preds = None

    # check used model types
    use_calnet = args.calibration_net != "EmptyCalNet"
    use_generator = args.generator != "EmptyGenerator"

    # free up some space
    del (calibration_net, generator, discriminator)

    # initialize figure size arguments
    n_pics = 5
    n_plots = 4 if not use_calnet else 6
    n_plots = n_plots + 1 if args.dataset == "CITYSCAPES19" else n_plots

    # initialize figure
    fig = plt.figure(figsize=(n_plots * 2 + 2, n_pics * 2))
    canvas = FigureCanvasAgg(fig)

    for idx in range(n_pics):

        extra = 0

        # convert to plottable format
        plottable_images = move_color_channel(de_torch(
            (images[idx] + 1) / 2))  # +1/2 to normalize between 0 and 1
        if plottable_images.shape[-1] == 1:
            plottable_images = plottable_images.squeeze()

        if args.dataset == "LIDC":
            pad = lambda x: np.pad(x.cpu().numpy(),
                                   pad_width=2,
                                   mode='constant',
                                   constant_values=1)
            glued_top = np.concatenate(
                (pad(gt_dist[idx, 0]), pad(gt_dist[idx, 1])), axis=1)
            glued_bottom = np.concatenate(
                (pad(gt_dist[idx, 2]), pad(gt_dist[idx, 3])), axis=1)
            plottable_t_labels = np.concatenate([glued_top, glued_bottom],
                                                axis=0)
        else:
            plottable_t_labels = _recolour_label(de_torch(
                _1hot_2_2d(labels[idx], sample=True)),
                                                 dataset=args.dataset)

        if args.dataset == "CITYSCAPES19":
            plottable_bb_preds = _recolour_label(de_torch(
                _1hot_2_2d(bb_preds[idx], sample=True)),
                                                 dataset=args.dataset)

        if use_generator:
            plottable_f_labels = _recolour_label(de_torch(
                _1hot_2_2d(fake_labels[idx], sample=True)),
                                                 dataset=args.dataset)

        if use_calnet:
            plottable_al_maps = de_torch(al_maps[idx])
            plottable_calnet_preds = _recolour_label(de_torch(
                _1hot_2_2d(calnet_preds[idx], sample=True)),
                                                     dataset=args.dataset)

        if use_generator:
            plottable_gan_al_maps = de_torch(gan_al_maps[idx])

        # plot figure

        # input image
        plt.subplot(n_pics, n_plots, idx * n_plots + 1)
        plt.imshow(plottable_images, interpolation="none")
        if idx == 0: plt.title("Input Image")
        plt.xticks([])
        plt.yticks([])

        # true label
        plt.subplot(n_pics, n_plots, idx * n_plots + 2)
        plt.imshow(plottable_t_labels, norm=lidc_norm, interpolation="none")
        if idx == 0: plt.title("Label")
        plt.xticks([])
        plt.yticks([])

        if args.dataset == "CITYSCAPES19":
            # black-box net prediction
            extra += 1
            plt.subplot(n_pics, n_plots, idx * n_plots + 2 + extra)
            plt.imshow(plottable_bb_preds, interpolation="none")
            if idx == 0: plt.title("BB Pred")
            plt.xticks([])
            plt.yticks([])

        if use_calnet:
            # calibration net prediction
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)
            plt.imshow(plottable_calnet_preds,
                       norm=lidc_norm,
                       interpolation="none")
            if idx == 0: plt.title("CalNet Pred")
            plt.xticks([])
            plt.yticks([])

            extra += 1

        if use_generator:
            # final prediction
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)
            plt.imshow(plottable_f_labels,
                       norm=lidc_norm,
                       interpolation="none")
            if idx == 0: plt.title("RefNet Pred")
            plt.xticks([])
            plt.yticks([])

            extra += 1

        if use_calnet:
            # calibration pred aleatoric uncertainty
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)

            al_norm = matplotlib.colors.Normalize(
                vmin=0, vmax=MAX_ALEATORIC
            )  # set range into which we normalize the aleatoric unc maps

            # make sure the aleatoric uncertainty is within range
            assert al_maps.max(
            ) <= MAX_ALEATORIC_GT, "Predicted aleatoric uncertainty not within range: True = 0 < " + str(
                MAX_ALEATORIC) + ", Plottable = " + str(
                    al_maps.min().item()) + " < " + str(al_maps.max().item())

            plt.imshow(plottable_al_maps,
                       cmap='hot',
                       norm=al_norm,
                       interpolation="none")
            if idx == 0: plt.title("CalNet Aleatoric")
            plt.xticks([])
            plt.yticks([])
            extra += 1

        if use_generator:
            # generator aleatoric uncertinty
            plt.subplot(n_pics, n_plots, idx * n_plots + 3 + extra)
            al_norm = matplotlib.colors.Normalize(
                vmin=0, vmax=MAX_ALEATORIC
            )  # set range into which we normalize the aleatoric unc maps
            plt.imshow(plottable_gan_al_maps,
                       cmap='hot',
                       norm=al_norm,
                       interpolation="none")
            if idx == 0: plt.title("RefNet Aleatoric")
            plt.xticks([])
            plt.yticks([])
            extra += 1

    canvas.draw()
    _, (width, height) = canvas.print_to_buffer()
    s = canvas.tostring_rgb()

    plt.close(fig)
    return np.fromstring(s, dtype='uint8').reshape((height, width, 3))
Example #3
0
def plot_batch(batch_1,
               batch_2,
               batch_3,
               embedder: GeneralEmbedder,
               generator: GeneralGenerator,
               arguments,
               number_of_pictures: int = 4,
               number_of_batches: int = 1):

    image_1, landmarks_1 = unpack_batch(batch_1)
    image_2, landmarks_2 = unpack_batch(batch_2)

    generator.eval()

    combined = torch.cat((image_1, landmarks_2),
                         dim=CHANNEL_DIM).to(DEVICE).float()

    generated_images = generator(combined[:number_of_pictures, :, :, :])

    landmarks_1 = torch.sum(landmarks_1, dim=CHANNEL_DIM)
    landmarks_2 = torch.sum(landmarks_2, dim=CHANNEL_DIM)

    plots = 5

    fig = plt.figure()
    canvas = FigureCanvasAgg(fig)

    for image_index in range(number_of_pictures):
        plottable_generated = BGR2RGB_numpy(
            denormalize_picture(
                de_torch(generated_images[image_index, :, :, :])))
        plottable_landmarks_1 = (denormalize_picture(de_torch(
            -1 * landmarks_1[image_index, :, :]),
                                                     binarised=True))
        plottable_landmarks_2 = (denormalize_picture(de_torch(
            -1 * landmarks_2[image_index, :, :]),
                                                     binarised=True))
        plottable_source = BGR2RGB_numpy(
            denormalize_picture(de_torch(image_1[image_index, :, :, :])))
        plottable_target = BGR2RGB_numpy(
            denormalize_picture(de_torch(image_2[image_index, :, :, :])))

        plt.subplot(number_of_pictures, plots, image_index * plots + 1)
        plt.imshow(
            np.stack((plottable_landmarks_1.T, plottable_landmarks_1.T,
                      plottable_landmarks_1.T),
                     axis=2))
        plt.title("source")
        plt.xticks([])
        plt.yticks([])

        plt.subplot(number_of_pictures, plots, image_index * plots + 2)
        plt.imshow(plottable_source)
        plt.title("source")
        plt.xticks([])
        plt.yticks([])

        plt.subplot(number_of_pictures, plots, image_index * plots + 3)
        plt.imshow(plottable_generated)
        plt.title("generated")
        plt.xticks([])
        plt.yticks([])

        plt.subplot(number_of_pictures, plots, image_index * plots + 4)
        plt.imshow(plottable_target)
        plt.title("target")
        plt.xticks([])
        plt.yticks([])

        plt.subplot(number_of_pictures, plots, image_index * plots + 5)
        plt.imshow(
            np.stack((plottable_landmarks_2.T, plottable_landmarks_2.T,
                      plottable_landmarks_2.T),
                     axis=2))
        plt.title("target")
        plt.xticks([])
        plt.yticks([])

    canvas.draw()
    _, (width, height) = canvas.print_to_buffer()
    s = canvas.tostring_rgb()
    plt.close(fig)
    return np.fromstring(s, dtype='uint8').reshape((height, width, 3))
Example #4
0
            transformations.ChangeChannels()
        ]
    )

    dataset = Cityscapes19(mode="test", transform=transform) #TODO NO TEST DIRECTORY IN PROCESSED

    batch_size = 5

    data = DataLoader(dataset, shuffle=False, batch_size=batch_size, drop_last=True, pin_memory=True, num_workers=16)

    data_bar = tqdm(data)

    for i, (batches) in enumerate(data_bar):

        # Visualize batches for DEBUG
        batch_1 = list(iter(data))[8]

        image_1, labels_1 = unpack_batch(batch_1)

        p_preds = batch_1["bb_preds"].to(DEVICE).float()

        image_1 = 255*(image_1+1)/2
        labels_1 = _recolour_label(_1hot_2_2d(labels_1,dim=1), dataset="CITYSCAPES19").permute(0,3,1,2).float().to(DEVICE)
        p_preds = _recolour_label(_1hot_2_2d(p_preds, dim=1), dataset="CITYSCAPES19").permute(0, 3, 1, 2).float().to(DEVICE)

        batch = torch.cat((image_1, labels_1, p_preds), dim=0)

        plt.figure(figsize=(5,10))
        plt.imshow(vutils.make_grid(batch, nrow=batch_size, normalize=True).cpu().numpy().transpose(1, 2, 0))
        plt.show()
Example #5
0
    batch_size = 5

    data = DataLoader(dataset,
                      shuffle=False,
                      batch_size=batch_size,
                      drop_last=True,
                      num_workers=0)

    plotting_batches = next(iter(data))
    batch_1 = plotting_batches

    data_bar = tqdm(data)

    for i, (batches) in enumerate(data_bar):
        # Visualize batches for DEBUG
        image_1, labels_1, dist = unpack_batch(batches)

        image_1 = (image_1 + 1) / 2

        labels_1 = _1hot_2_2d(labels_1, dim=1).float().to(
            constants.DEVICE).unsqueeze(dim=1).repeat(1, 3, 1, 1)

        pad = lambda x: np.pad(
            x.cpu().numpy(), pad_width=2, mode='constant', constant_values=1)

        glued_top = np.concatenate((pad(dist[1, 0]), pad(dist[1, 1])), axis=1)
        glued_bottom = np.concatenate((pad(dist[1, 2]), pad(dist[1, 3])),
                                      axis=1)
        glued_all = np.concatenate([glued_top, glued_bottom], axis=0)

        batch = torch.cat((image_1, labels_1), dim=0)