コード例 #1
0
def main():

    dataSize = 32
    batchSize = 8
    elpipsBatchSize = 1
    # imageSize = 32
    imageSize = 64
    nz = 100

    # discCheckpointPath = r'E:\projects\visus\PyTorch-GAN\implementations\dcgan\checkpoints\2020_07_10_15_53_34\disc_step4800.pth'
    discCheckpointPath = r'E:\projects\visus\pytorch-examples\dcgan\out\netD_epoch_24.pth'
    genCheckpointPath = r'E:\projects\visus\pytorch-examples\dcgan\out\netG_epoch_24.pth'

    gpu = torch.device('cuda')

    # For now we normalize the vectors to have norm 1, but don't make sure
    # that the data has certain mean/std.
    pointDataset = AuthorDataset(
        jsonPath=r'E:\out\scripts\metaphor-vis\authors-all.json'
    )

    # Take top N points.
    points = np.asarray([pointDataset[i][0] for i in range(dataSize)])
    distPointsCpu = l2_sqr_dist_matrix(torch.tensor(points)).numpy()

    latents = torch.tensor(np.random.normal(0.0, 1.0, (dataSize, nz)),
                           requires_grad=True, dtype=torch.float32, device=gpu)

    scale = torch.tensor(2.7, requires_grad=True, dtype=torch.float32, device=gpu)  # todo Re-check!
    bias = torch.tensor(0.0, requires_grad=True, dtype=torch.float32, device=gpu)  # todo Re-check!

    lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True).to(gpu)
    # lossModel = lpips
    config = elpips.Config()
    config.batch_size = elpipsBatchSize  # Ensemble size for ELPIPS.
    config.set_scale_levels_by_image_size(imageSize, imageSize)
    lossModel = elpips.ElpipsMetric(config, lpips).to(gpu)

    discriminator = Discriminator(3, 64, 1)
    if discCheckpointPath:
        discriminator.load_state_dict(torch.load(discCheckpointPath))
    else:
        discriminator.init_params()
    discriminator = discriminator.to(gpu)

    generator = Generator(nz=nz, ngf=64)
    if genCheckpointPath:
        generator.load_state_dict(torch.load(genCheckpointPath))
    else:
        generator.init_params()
    generator = generator.to(gpu)

    # optimizerImages = torch.optim.Adam([images, scale], lr=1e-2, betas=(0.9, 0.999))
    optimizerScale = torch.optim.Adam([scale, bias], lr=0.001)
    # optimizerGen = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    # optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.9, 0.999))
    # optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizerLatents = torch.optim.Adam([latents], lr=5e-3, betas=(0.9, 0.999))

    fig, axes = plt.subplots(nrows=2, ncols=batchSize // 2)

    fig2 = plt.figure()
    ax2 = fig2.add_subplot(1, 1, 1)

    outPath = os.path.join('runs', datetime.datetime.today().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(outPath)

    summaryWriter = SummaryWriter(outPath)

    for batchIndex in range(10000):

        # noinspection PyTypeChecker
        randomIndices = np.random.randint(0, dataSize, batchSize).tolist()  # type: List[int]
        # # randomIndices = list(range(dataSize))  # type: List[int]
        distTarget = torch.tensor(distPointsCpu[randomIndices, :][:, randomIndices], dtype=torch.float32, device=gpu)
        latentsBatch = latents[randomIndices]

        imageBatchFake = generator(latentsBatch[:, :, None, None].float())

        # todo It's possible to compute this more efficiently, but would require re-implementing lpips.
        # For now, compute the full BSxBS matrix row-by-row to avoid memory issues.
        lossDistTotal = torch.tensor(0.0, device=gpu)
        distanceRows = []
        for iRow in range(batchSize):
            distPredFlat = lossModel(imageBatchFake[iRow].repeat(repeats=(batchSize, 1, 1, 1)).contiguous(),
                                     imageBatchFake, normalize=True)
            distPred = distPredFlat.reshape((1, batchSize))
            distanceRows.append(distPred)
            lossDist = torch.sum((distTarget[iRow] - (distPred * scale + bias)) ** 2)  # MSE
            lossDistTotal += lossDist

        lossDistTotal /= batchSize * batchSize  # Compute the mean.

        distPredFull = torch.cat(distanceRows, dim=0)

        # print('{} - {} || {} - {}'.format(
        #     torch.min(distPred).item(),
        #     torch.max(distPred).item(),
        #     torch.min(distTarget).item(),
        #     torch.max(distTarget).item()
        # ))

        # discPred = discriminator(imageBatchFake)
        # lossRealness = bceLoss(discPred, torch.ones(imageBatchFake.shape[0], device=gpu))
        # lossGen = lossDist + 1.0 * lossRealness
        lossLatents = lossDistTotal

        # optimizerGen.zero_grad()
        # optimizerScale.zero_grad()
        # lossGen.backward()
        # optimizerGen.step()
        # optimizerScale.step()

        optimizerLatents.zero_grad()
        # optimizerScale.zero_grad()
        lossLatents.backward()
        optimizerLatents.step()
        # optimizerScale.step()

        # with torch.no_grad():
        #     # todo  We're clamping all the images every batch, can we clamp only the ones updated?
        #     # images = torch.clamp(images, 0, 1)  # For some reason this was making the training worse.
        #     images.data = torch.clamp(images.data, 0, 1)

        if batchIndex % 100 == 0:
            msg = 'iter {} loss dist {:.3f} scale: {:.3f} bias: {:.3f}'.format(batchIndex, lossDistTotal.item(), scale.item(), bias.item())
            print(msg)

            summaryWriter.add_scalar('loss-dist', lossDistTotal.item(), global_step=batchIndex)

            def gpu_images_to_numpy(images):
                imagesNumpy = images.cpu().data.numpy().transpose(0, 2, 3, 1)
                imagesNumpy = (imagesNumpy + 1) / 2

                return imagesNumpy

            # print(discPred.tolist())
            imageBatchFakeCpu = gpu_images_to_numpy(imageBatchFake)
            # imageBatchRealCpu = gpu_images_to_numpy(imageBatchReal)
            for iCol, ax in enumerate(axes.flatten()[:batchSize]):
                ax.imshow(imageBatchFakeCpu[iCol])
            fig.suptitle(msg)

            with torch.no_grad():
                images = gpu_images_to_numpy(generator(latents[..., None, None]))

                authorVectorsProj = umap.UMAP(n_neighbors=min(5, dataSize), random_state=1337).fit_transform(points)
                plot_image_scatter(ax2, authorVectorsProj, images, downscaleRatio=2)

            fig.savefig(os.path.join(outPath, f'batch_{batchIndex}.png'))
            fig2.savefig(os.path.join(outPath, f'scatter_{batchIndex}.png'))
            plt.close(fig)
            plt.close(fig2)

            with torch.no_grad():
                imagesGpu = generator(latents[..., None, None])
                imageNumber = imagesGpu.shape[0]

                # Compute LPIPS distances, batch to avoid memory issues.
                bs = min(imageNumber, 8)
                assert imageNumber % bs == 0
                distPredEval = np.zeros((imagesGpu.shape[0], imagesGpu.shape[0]))
                for iCol in range(imageNumber // bs):
                    startA, endA = iCol * bs, (iCol + 1) * bs
                    imagesA = imagesGpu[startA:endA]
                    for j in range(imageNumber // bs):
                        startB, endB = j * bs, (j + 1) * bs
                        imagesB = imagesGpu[startB:endB]

                        distBatchEval = lossModel(imagesA.repeat(repeats=(bs, 1, 1, 1)).contiguous(),
                                                  imagesB.repeat_interleave(repeats=bs, dim=0).contiguous(),
                                                  normalize=True).cpu().numpy()

                        distPredEval[startA:endA, startB:endB] = distBatchEval.reshape((bs, bs))

                distPredEval = (distPredEval * scale.item() + bias.item())

                # Move to the CPU and append an alpha channel for rendering.
                images = gpu_images_to_numpy(imagesGpu)
                images = [np.concatenate([im, np.ones(im.shape[:-1] + (1,))], axis=-1) for im in images]

                distPoints = distPointsCpu
                assert np.abs(distPoints - distPoints.T).max() < 1e-5
                distPoints = np.minimum(distPoints, distPoints.T)  # Remove rounding errors, guarantee symmetry.
                config = DistanceMatrixConfig()
                config.dataRange = (0., 4.)
                _, pointIndicesSorted = render_distance_matrix(
                    os.path.join(outPath, f'dist_point_{batchIndex}.png'),
                    distPoints,
                    images,
                    config=config
                )

                # print(np.abs(distPredFlat - distPredFlat.T).max())
                # assert np.abs(distPredFlat - distPredFlat.T).max() < 1e-5
                # todo The symmetry doesn't hold for E-LPIPS, since it's stochastic.
                distPredEval = np.minimum(distPredEval, distPredEval.T)  # Remove rounding errors, guarantee symmetry.
                config = DistanceMatrixConfig()
                config.dataRange = (0., 4.)
                render_distance_matrix(
                    os.path.join(outPath, f'dist_images_{batchIndex}.png'),
                    distPredEval,
                    images,
                    config=config
                )

                config = DistanceMatrixConfig()
                config.dataRange = (0., 4.)
                render_distance_matrix(
                    os.path.join(outPath, f'dist_images_aligned_{batchIndex}.png'),
                    distPredEval,
                    images,
                    predefinedOrder=pointIndicesSorted,
                    config=config
                )

                fig, axes = plt.subplots(ncols=2)
                axes[0].matshow(distTarget.cpu().numpy(), vmin=0, vmax=4)
                axes[1].matshow(distPredFull.cpu().numpy() * scale.item(), vmin=0, vmax=4)
                fig.savefig(os.path.join(outPath, f'batch_dist_{batchIndex}.png'))
                plt.close(fig)

                surveySize = 30
                fig, axes = plt.subplots(nrows=3, ncols=surveySize, figsize=(surveySize, 3))
                assert len(images) == dataSize
                allIndices = list(range(dataSize))
                with open(os.path.join(outPath, f'survey_{batchIndex}.txt'), 'w') as file:
                    for iCol in range(surveySize):
                        randomIndices = random.sample(allIndices, k=3)
                        leftToMid = distPointsCpu[randomIndices[0], randomIndices[1]]
                        rightToMid = distPointsCpu[randomIndices[2], randomIndices[1]]

                        correctAnswer = 'left' if leftToMid < rightToMid else 'right'
                        file.write("{}\t{}\t{}\t{}\t{}\n".format(iCol, correctAnswer, leftToMid, rightToMid,
                                                                 str(tuple(randomIndices))))

                        for iRow in (0, 1, 2):
                            axes[iRow][iCol].imshow(images[randomIndices[iRow]])

                fig.savefig(os.path.join(outPath, f'survey_{batchIndex}.png'))
                plt.close(fig)

            torch.save(generator.state_dict(), os.path.join(outPath, 'gen_{}.pth'.format(batchIndex)))
            torch.save(discriminator.state_dict(), os.path.join(outPath, 'gen_{}.pth'.format(batchIndex)))

    summaryWriter.close()
コード例 #2
0
def main():

    dataSize = 128
    batchSize = 8
    # imageSize = 32
    imageSize = 64
    initWithCats = True

    # discCheckpointPath = r'E:\projects\visus\PyTorch-GAN\implementations\dcgan\checkpoints\2020_07_10_15_53_34\disc_step4800.pth'
    # discCheckpointPath = r'E:\projects\visus\pytorch-examples\dcgan\out\netD_epoch_24.pth'
    discCheckpointPath = None

    gpu = torch.device('cuda')

    imageRootPath = r'E:\data\cat-vs-dog\cat'
    catDataset = CatDataset(
        imageSubdirPath=imageRootPath,
        transform=transforms.Compose([
            transforms.Resize((imageSize, imageSize)),
            # torchvision.transforms.functional.to_grayscale,
            transforms.ToTensor(),
            # transforms.Lambda(lambda x: torch.reshape(x, x.shape[1:])),
            transforms.Normalize([0.5], [0.5])
        ]))

    sampler = InfiniteSampler(catDataset)
    catLoader = DataLoader(catDataset, batch_size=batchSize, sampler=sampler)

    # Generate a random distance matrix.
    # # Make a matrix with positive values.
    # distancesCpu = np.clip(np.random.normal(0.5, 1.0 / 3, (dataSize, dataSize)), 0, 1)
    # # Make it symmetrical.
    # distancesCpu = np.matmul(distancesCpu, distancesCpu.T)

    # Generate random points and compute distances, guaranteeing that the triangle rule isn't broken.
    randomPoints = generate_points(dataSize)
    distancesCpu = scipy.spatial.distance_matrix(randomPoints,
                                                 randomPoints,
                                                 p=2)

    if initWithCats:
        imagePaths = random.choices(glob.glob(os.path.join(imageRootPath,
                                                           '*')),
                                    k=dataSize)
        catImages = []
        for p in imagePaths:
            image = skimage.transform.resize(imageio.imread(p),
                                             (imageSize, imageSize),
                                             1).transpose(2, 0, 1)
            catImages.append(image)

        imagesInitCpu = np.asarray(catImages)
    else:
        imagesInitCpu = np.clip(
            np.random.normal(0.5, 0.5 / 3,
                             (dataSize, 3, imageSize, imageSize)), 0, 1)

    images = torch.tensor(imagesInitCpu,
                          requires_grad=True,
                          dtype=torch.float32,
                          device=gpu)

    scale = torch.tensor(1.0,
                         requires_grad=True,
                         dtype=torch.float32,
                         device=gpu)

    lossModel = models.PerceptualLoss(model='net-lin', net='vgg',
                                      use_gpu=True).to(gpu)
    lossBce = torch.nn.BCELoss()

    # discriminator = Discriminator(imageSize, 3)
    discriminator = Discriminator(3, 64, 1)
    if discCheckpointPath:
        discriminator.load_state_dict(torch.load(discCheckpointPath))
    else:
        discriminator.init_params()
    discriminator = discriminator.to(gpu)

    optimizerImages = torch.optim.Adam([images, scale],
                                       lr=1e-3,
                                       betas=(0.9, 0.999))
    # optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.9, 0.999))
    optimizerDisc = torch.optim.Adam(discriminator.parameters(),
                                     lr=0.0002,
                                     betas=(0.5, 0.999))

    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(nrows=2, ncols=batchSize // 2)

    fig2 = plt.figure()
    ax2 = fig2.add_subplot(1, 1, 1)

    outPath = os.path.join(
        'images',
        datetime.datetime.today().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(outPath)

    catIter = iter(catLoader)
    for batchIndex in range(10000):

        realImageBatch, _ = next(catIter)  # type: Tuple(torch.Tensor, Any)
        realImageBatch = realImageBatch.to(gpu)
        # realImageBatch = torch.tensor(realImageBatchCpu, device=gpu)

        # noinspection PyTypeChecker
        randomIndices = np.random.randint(
            0, dataSize, batchSize).tolist()  # type: List[int]
        # randomIndices = list(range(dataSize))  # type: List[int]
        distanceBatch = torch.tensor(
            distancesCpu[randomIndices, :][:, randomIndices],
            dtype=torch.float32,
            device=gpu)
        imageBatch = images[randomIndices].contiguous()

        distPred = lossModel.forward(
            imageBatch.repeat(repeats=(batchSize, 1, 1, 1)).contiguous(),
            imageBatch.repeat_interleave(repeats=batchSize,
                                         dim=0).contiguous(),
            normalize=True)
        distPredMat = distPred.reshape((batchSize, batchSize))

        lossDist = torch.sum((distanceBatch - distPredMat * scale)**2)  # MSE
        discPred = discriminator(imageBatch)
        lossRealness = lossBce(discPred,
                               torch.ones(imageBatch.shape[0], 1, device=gpu))
        lossImages = lossDist + 100.0 * lossRealness  # todo
        # lossImages = lossRealness  # todo

        optimizerImages.zero_grad()
        lossImages.backward()
        optimizerImages.step()

        lossDiscReal = lossBce(
            discriminator(realImageBatch),
            torch.ones(realImageBatch.shape[0], 1, device=gpu))
        lossDiscFake = lossBce(discriminator(imageBatch.detach()),
                               torch.zeros(imageBatch.shape[0], 1, device=gpu))
        lossDisc = (lossDiscFake + lossDiscReal) / 2
        # lossDisc = torch.tensor(0)

        optimizerDisc.zero_grad()
        lossDisc.backward()
        optimizerDisc.step()

        with torch.no_grad():
            # todo  We're clamping all the images every batch, can we do clamp only the ones updated?
            # images = torch.clamp(images, 0, 1)  # For some reason this was making the training worse.
            images.data = torch.clamp(images.data, 0, 1)

        if batchIndex % 100 == 0:
            msg = 'iter {}, loss images {:.3f}, loss dist {:.3f}, loss real {:.3f}, loss disc {:.3f}, scale: {:.3f}'.format(
                batchIndex, lossImages.item(), lossDist.item(),
                lossRealness.item(), lossDisc.item(), scale.item())
            print(msg)
            # print(discPred.tolist())
            imageBatchCpu = imageBatch.cpu().data.numpy().transpose(0, 2, 3, 1)
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(imageBatchCpu[i])
            fig.suptitle(msg)

            imagesAllCpu = images.cpu().data.numpy().transpose(0, 2, 3, 1)
            plot_image_scatter(ax2,
                               randomPoints,
                               imagesAllCpu,
                               downscaleRatio=2)

            fig.savefig(
                os.path.join(outPath, 'batch_{}.png'.format(batchIndex)))
            fig2.savefig(
                os.path.join(outPath, 'scatter_{}.png'.format(batchIndex)))
コード例 #3
0
def main():

    dataSize = 128
    batchSize = 8
    # imageSize = 32
    imageSize = 64

    # discCheckpointPath = r'E:\projects\visus\PyTorch-GAN\implementations\dcgan\checkpoints\2020_07_10_15_53_34\disc_step4800.pth'
    # discCheckpointPath = r'E:\projects\visus\pytorch-examples\dcgan\out\netD_epoch_24.pth'
    discCheckpointPath = None

    gpu = torch.device('cuda')

    # imageDataset = CatDataset(
    #     imageSubdirPath=r'E:\data\cat-vs-dog\cat',
    #     transform=transforms.Compose(
    #         [
    #             transforms.Resize((imageSize, imageSize)),
    #             transforms.ToTensor(),
    #             transforms.Normalize([0.5], [0.5])
    #         ]
    #     )
    # )

    imageDataset = datasets.CIFAR10(root=r'e:\data\images\cifar10', download=True,
                                    transform=transforms.Compose([
                                        transforms.Resize((imageSize, imageSize)),
                                        transforms.ToTensor(),
                                        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                        transforms.Normalize([0.5], [0.5]),
                               ]))

    # For now we normalize the vectors to have norm 1, but don't make sure
    # that the data has certain mean/std.
    pointDataset = AuthorDataset(
        jsonPath=r'E:\out\scripts\metaphor-vis\authors-all.json'
    )

    imageLoader = DataLoader(imageDataset, batch_size=batchSize, sampler=InfiniteSampler(imageDataset))
    pointLoader = DataLoader(pointDataset, batch_size=batchSize, sampler=InfiniteSampler(pointDataset))

    # Generate a random distance matrix.
    # # Make a matrix with positive values.
    # distancesCpu = np.clip(np.random.normal(0.5, 1.0 / 3, (dataSize, dataSize)), 0, 1)
    # # Make it symmetrical.
    # distancesCpu = np.matmul(distancesCpu, distancesCpu.T)

    # Generate random points and compute distances, guaranteeing that the triangle rule isn't broken.
    # randomPoints = generate_points(dataSize)
    # distancesCpu = scipy.spatial.distance_matrix(randomPoints, randomPoints, p=2)


    # catImagePath = os.path.expandvars(r'${DEV_METAPHOR_DATA_PATH}/cats/cat.247.jpg')
    # catImage = skimage.transform.resize(imageio.imread(catImagePath), (64, 64), 1).transpose(2, 0, 1)

    # imagesInitCpu = np.clip(np.random.normal(0.5, 0.5 / 3, (dataSize, 3, imageSize, imageSize)), 0, 1)
    # imagesInitCpu = np.clip(np.tile(catImage, (dataSize, 1, 1, 1)) + np.random.normal(0., 0.5 / 6, (dataSize, 3, 64, 64)), 0, 1)
    # images = torch.tensor(imagesInitCpu, requires_grad=True, dtype=torch.float32, device=gpu)

    scale = torch.tensor(4.0, requires_grad=True, dtype=torch.float32, device=gpu)

    lossModel = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True).to(gpu)
    bceLoss = torch.nn.BCELoss()

    # discriminator = Discriminator(imageSize, 3)
    discriminator = Discriminator(3, 64, 1)
    if discCheckpointPath:
        discriminator.load_state_dict(torch.load(discCheckpointPath))
    else:
        discriminator.init_params()

    discriminator = discriminator.to(gpu)

    generator = Generator(nz=pointDataset[0][0].shape[0], ngf=64)
    generator.init_params()
    generator = generator.to(gpu)

    # todo init properly, if training
    # discriminator.apply(weights_init_normal)

    # optimizerImages = torch.optim.Adam([images, scale], lr=1e-2, betas=(0.9, 0.999))
    optimizerScale = torch.optim.Adam([scale], lr=0.001)
    optimizerGen = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    # optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.9, 0.999))
    optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(nrows=2 * 2, ncols=batchSize // 2)

    fig2 = plt.figure()
    ax2 = fig2.add_subplot(1, 1, 1)

    outPath = os.path.join('runs', datetime.datetime.today().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(outPath)

    imageIter = iter(imageLoader)
    pointIter = iter(pointLoader)
    for batchIndex in range(10000):

        imageBatchReal, _ = next(imageIter)  # type: Tuple(torch.Tensor, Any)
        imageBatchReal = imageBatchReal.to(gpu)
        # imageBatchReal = torch.tensor(realImageBatchCpu, device=gpu)

        # noinspection PyTypeChecker
        # randomIndices = np.random.randint(0, dataSize, batchSize).tolist()  # type: List[int]
        # # randomIndices = list(range(dataSize))  # type: List[int]
        # distanceBatch = torch.tensor(distancesCpu[randomIndices, :][:, randomIndices], dtype=torch.float32, device=gpu)
        # imageBatchFake = images[randomIndices].contiguous()
        vectorBatch, _ = next(pointIter)
        vectorBatch = vectorBatch.to(gpu)
        distanceBatch = l2_sqr_dist_matrix(vectorBatch)  # In-batch vector distances.

        imageBatchFake = generator(vectorBatch[:, :, None, None].float())

        # todo It's possible to compute this more efficiently, but would require re-implementing lpips.
        distImages = lossModel.forward(imageBatchFake.repeat(repeats=(batchSize, 1, 1, 1)).contiguous(),
                                       imageBatchFake.repeat_interleave(repeats=batchSize, dim=0).contiguous(), normalize=True)
        distPredMat = distImages.reshape((batchSize, batchSize))

        lossDist = torch.sum((distanceBatch - distPredMat * scale) ** 2)  # MSE
        discPred = discriminator(imageBatchFake)
        lossRealness = bceLoss(discPred, torch.ones(imageBatchFake.shape[0], device=gpu))
        lossGen = lossDist + 1.0 * lossRealness

        optimizerGen.zero_grad()
        optimizerScale.zero_grad()
        lossGen.backward()
        optimizerGen.step()
        optimizerScale.step()

        lossDiscReal = bceLoss(discriminator(imageBatchReal), torch.ones(imageBatchReal.shape[0], device=gpu))
        lossDiscFake = bceLoss(discriminator(imageBatchFake.detach()), torch.zeros(imageBatchFake.shape[0], device=gpu))
        lossDisc = (lossDiscFake + lossDiscReal) / 2
        # lossDisc = torch.tensor(0)

        optimizerDisc.zero_grad()
        lossDisc.backward()
        optimizerDisc.step()

        # with torch.no_grad():
        #     # todo  We're clamping all the images every batch, can we clamp only the ones updated?
        #     # images = torch.clamp(images, 0, 1)  # For some reason this was making the training worse.
        #     images.data = torch.clamp(images.data, 0, 1)

        if batchIndex % 100 == 0:
            msg = 'iter {}, loss gen {:.3f}, loss dist {:.3f}, loss real {:.3f}, loss disc {:.3f}, scale: {:.3f}'.format(
                batchIndex, lossGen.item(), lossDist.item(), lossRealness.item(), lossDisc.item(), scale.item()
            )
            print(msg)

            def gpu_images_to_numpy(images):
                imagesNumpy = images.cpu().data.numpy().transpose(0, 2, 3, 1)
                imagesNumpy = (imagesNumpy + 1) / 2

                return imagesNumpy

            # print(discPred.tolist())
            imageBatchFakeCpu = gpu_images_to_numpy(imageBatchFake)
            imageBatchRealCpu = gpu_images_to_numpy(imageBatchReal)
            for i, ax in enumerate(axes.flatten()[:batchSize]):
                ax.imshow(imageBatchFakeCpu[i])
            for i, ax in enumerate(axes.flatten()[batchSize:]):
                ax.imshow(imageBatchRealCpu[i])
            fig.suptitle(msg)

            with torch.no_grad():
                points = np.asarray([pointDataset[i][0] for i in range(200)], dtype=np.float32)
                images = gpu_images_to_numpy(generator(torch.tensor(points[..., None, None], device=gpu)))

                authorVectorsProj = umap.UMAP(n_neighbors=5, random_state=1337).fit_transform(points)
                plot_image_scatter(ax2, authorVectorsProj, images, downscaleRatio=2)

            fig.savefig(os.path.join(outPath, f'batch_{batchIndex}.png'))
            fig2.savefig(os.path.join(outPath, f'scatter_{batchIndex}.png'))
            plt.close(fig)
            plt.close(fig2)

            with torch.no_grad():
                imageNumber = 48
                points = np.asarray([pointDataset[i][0] for i in range(imageNumber)], dtype=np.float32)
                imagesGpu = generator(torch.tensor(points[..., None, None], device=gpu))

                # Compute LPIPS distances, batch to avoid memory issues.
                bs = 8
                assert imageNumber % bs == 0
                distImages = np.zeros((imagesGpu.shape[0], imagesGpu.shape[0]))
                for i in range(imageNumber // bs):
                    startA, endA = i * bs, (i + 1) * bs 
                    imagesA = imagesGpu[startA:endA]
                    for j in range(imageNumber // bs):
                        startB, endB = j * bs, (j + 1) * bs
                        imagesB = imagesGpu[startB:endB]

                        distBatch = lossModel.forward(imagesA.repeat(repeats=(bs, 1, 1, 1)).contiguous(),
                                                      imagesB.repeat_interleave(repeats=bs, dim=0).contiguous(),
                                                      normalize=True).cpu().numpy()

                        distImages[startA:endA, startB:endB] = distBatch.reshape((bs, bs))

                # Move to the CPU and append an alpha channel for rendering.
                images = gpu_images_to_numpy(imagesGpu)
                images = [np.concatenate([im, np.ones(im.shape[:-1] + (1,))], axis=-1) for im in images]

                distPoints = l2_sqr_dist_matrix(torch.tensor(points, dtype=torch.double)).numpy()
                assert np.abs(distPoints - distPoints.T).max() < 1e-5
                distPoints = np.minimum(distPoints, distPoints.T)  # Remove rounding errors, guarantee symmetry.
                config = DistanceMatrixConfig()
                config.dataRange = (0., 4.)
                render_distance_matrix(os.path.join(outPath, f'dist_point_{batchIndex}.png'),
                                       distPoints,
                                       images,
                                       config)

                assert np.abs(distImages - distImages.T).max() < 1e-5
                distImages = np.minimum(distImages, distImages.T)  # Remove rounding errors, guarantee symmetry.
                config = DistanceMatrixConfig()
                config.dataRange = (0., 1.)
                render_distance_matrix(os.path.join(outPath, f'dist_images_{batchIndex}.png'),
                                       distImages,
                                       images,
                                       config)

            torch.save(generator.state_dict(), os.path.join(outPath, 'gen_{}.pth'.format(batchIndex)))
            torch.save(discriminator.state_dict(), os.path.join(outPath, 'disc_{}.pth'.format(batchIndex)))