Ejemplo n.º 1
0
def main():
    # Read RGB image and it's noisy version
    x = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0,
                                                                  1) / 255.
    y = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) / 255.

    if torch.cuda.is_available():
        # Move to GPU to make computaions faster
        x = x.cuda()
        y = y.cuda()

    # To compute BRISQUE score as a measure, use lower case function from the library
    brisque_index: torch.Tensor = piq.brisque(x,
                                              data_range=1.,
                                              reduction='none')
    # In order to use BRISQUE as a loss function, use corresponding PyTorch module.
    # Note: the back propagation is not available using torch==1.5.0.
    # Update the environment with latest torch and torchvision.
    brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1.,
                                                 reduction='none')(x)
    print(
        f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}"
    )

    # To compute Content score as a loss function, use corresponding PyTorch module
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    content_loss = piq.ContentLoss(feature_extractor="vgg16",
                                   layers=("relu3_3", ),
                                   reduction='none')(x, y)
    print(f"ContentLoss: {content_loss.item():0.4f}")

    # To compute DISTS as a loss function, use corresponding PyTorch module
    # By default input images are normalized with ImageNet statistics before forwarding through VGG16 model.
    # If there is no need to normalize the data, use mean=[0.0, 0.0, 0.0] and std=[1.0, 1.0, 1.0].
    dists_loss = piq.DISTS(reduction='none')(x, y)
    print(f"DISTS: {dists_loss.item():0.4f}")

    # To compute FSIM as a measure, use lower case function from the library
    fsim_index: torch.Tensor = piq.fsim(x, y, data_range=1., reduction='none')
    # In order to use FSIM as a loss function, use corresponding PyTorch module
    fsim_loss = piq.FSIMLoss(data_range=1., reduction='none')(x, y)
    print(
        f"FSIM index: {fsim_index.item():0.4f}, loss: {fsim_loss.item():0.4f}")

    # To compute GMSD as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    # In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.
    gmsd_index: torch.Tensor = piq.gmsd(x, y, data_range=1., reduction='none')
    # In order to use GMSD as a loss function, use corresponding PyTorch module:
    gmsd_loss: torch.Tensor = piq.GMSDLoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"GMSD index: {gmsd_index.item():0.4f}, loss: {gmsd_loss.item():0.4f}")

    # To compute HaarPSI as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    haarpsi_index: torch.Tensor = piq.haarpsi(x,
                                              y,
                                              data_range=1.,
                                              reduction='none')
    # In order to use HaarPSI as a loss function, use corresponding PyTorch module
    haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1.,
                                                 reduction='none')(x, y)
    print(
        f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}"
    )

    # To compute LPIPS as a loss function, use corresponding PyTorch module
    lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y)
    print(f"LPIPS: {lpips_loss.item():0.4f}")

    # To compute MDSI as a measure, use lower case function from the library
    mdsi_index: torch.Tensor = piq.mdsi(x, y, data_range=1., reduction='none')
    # In order to use MDSI as a loss function, use corresponding PyTorch module
    mdsi_loss: torch.Tensor = piq.MDSILoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"MDSI index: {mdsi_index.item():0.4f}, loss: {mdsi_loss.item():0.4f}")

    # To compute MS-SSIM index as a measure, use lower case function from the library:
    ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.)
    # In order to use MS-SSIM as a loss function, use corresponding PyTorch module:
    ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x,
                                                                           y)
    print(
        f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}"
    )

    # To compute Multi-Scale GMSD as a measure, use lower case function from the library
    # It can be used both as a measure and as a loss function. In any case it should me minimized.
    # By defualt scale weights are initialized with values from the paper.
    # You can change them by passing a list of 4 variables to scale_weights argument during initialization
    # Note that input tensors should contain images with height and width equal 2 ** number_of_scales + 1 at least.
    ms_gmsd_index: torch.Tensor = piq.multi_scale_gmsd(x,
                                                       y,
                                                       data_range=1.,
                                                       chromatic=True,
                                                       reduction='none')
    # In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module
    ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(chromatic=True,
                                                        data_range=1.,
                                                        reduction='none')(x, y)
    print(
        f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}"
    )

    # To compute PSNR as a measure, use lower case function from the library.
    psnr_index = piq.psnr(x, y, data_range=1., reduction='none')
    print(f"PSNR index: {psnr_index.item():0.4f}")

    # To compute PieAPP as a loss function, use corresponding PyTorch module:
    pieapp_loss: torch.Tensor = piq.PieAPP(reduction='none', stride=32)(x, y)
    print(f"PieAPP loss: {pieapp_loss.item():0.4f}")

    # To compute SSIM index as a measure, use lower case function from the library:
    ssim_index = piq.ssim(x, y, data_range=1.)
    # In order to use SSIM as a loss function, use corresponding PyTorch module:
    ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(x, y)
    print(
        f"SSIM index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}")

    # To compute Style score as a loss function, use corresponding PyTorch module:
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    style_loss = piq.StyleLoss(feature_extractor="vgg16",
                               layers=("relu3_3", ))(x, y)
    print(f"Style: {style_loss.item():0.4f}")

    # To compute TV as a measure, use lower case function from the library:
    tv_index: torch.Tensor = piq.total_variation(x)
    # In order to use TV as a loss function, use corresponding PyTorch module:
    tv_loss: torch.Tensor = piq.TVLoss(reduction='none')(x)
    print(f"TV index: {tv_index.item():0.4f}, loss: {tv_loss.item():0.4f}")

    # To compute VIF as a measure, use lower case function from the library:
    vif_index: torch.Tensor = piq.vif_p(x, y, data_range=1.)
    # In order to use VIF as a loss function, use corresponding PyTorch class:
    vif_loss: torch.Tensor = piq.VIFLoss(sigma_n_sq=2.0, data_range=1.)(x, y)
    print(f"VIFp index: {vif_index.item():0.4f}, loss: {vif_loss.item():0.4f}")

    # To compute VSI score as a measure, use lower case function from the library:
    vsi_index: torch.Tensor = piq.vsi(x, y, data_range=1.)
    # In order to use VSI as a loss function, use corresponding PyTorch module:
    vsi_loss: torch.Tensor = piq.VSILoss(data_range=1.)(x, y)
    print(f"VSI index: {vsi_index.item():0.4f}, loss: {vsi_loss.item():0.4f}")
Ejemplo n.º 2
0
    'GMSD': (
        2,
        {
            'piq.gmsd': piq.gmsd,
            'piq.GMSD': piq.GMSDLoss(),
            # 'IQA.GMSD': IQA.GMSD(),
            'piqa.GMSD': piqa.GMSD(),
        }),
    'MS-GMSD': (2, {
        'piq.ms_gmsd': piq.multi_scale_gmsd,
        'piq.MS_GMSD': piq.MultiScaleGMSDLoss(),
        'piqa.MS_GMSD': piqa.MS_GMSD(),
    }),
    'MDSI': (2, {
        'piq.mdsi': piq.mdsi,
        'piq.MDSI-loss': piq.MDSILoss(),
        'piqa.MDSI': piqa.MDSI(),
    }),
    'HaarPSI': (2, {
        'piq.haarpsi': piq.haarpsi,
        'piq.HaarPSI-loss': piq.HaarPSILoss(),
        'piqa.HaarPSI': piqa.HaarPSI(),
    }),
}


def timeit(f, n: int) -> float:
    start = time.perf_counter()

    for _ in range(n):
        f()
Ejemplo n.º 3
0
def train(models, args, hyperparams, loss_dict, loss_stats, **kwargs):
    for scale_index, scale in enumerate(hyperparams["valid_scales"]):
        scale_num += 1
        epochs = hyperparams["epochs_by_scale"][scale]
        batch_size = hyperparams["batch_size_by_scale"][scale]
        args["SAVE_IMAGES_EACH"] = 25 if scale > 64 else 50
        args["SAVE_IMAGES_EACH"] = args["SAVE_IMAGES_EACH"] // 10 if scale > 512 else args["SAVE_IMAGES_EACH"]

        if scale_index > 0:
            epoch_len = hyperparams["steps_per_scale"][scale]
            for key, tracker in loss_stats.items():
                tracker.expand_buffer(block_size=epoch_len)

        tracked_images = 0
        total_images_phase = args["TOTAL_IMAGES"] * epochs
        limit = int(0.5 * total_images_phase)
        alpha = utils.find_alpha(tracked_images, limit)    

        n_blocks = int(log2(scale) - 1)

        if args["CONTRAST_ENABLE"]:
            loss_dict["nce"] = []
            for idx, (_, batch_size_temp)  in enumerate(hyperparams["bs_per_scale"].items()):
                loss_dict["nce"].append(PatchNCELoss(hyperparams["nce_t"], batch_size=batch_size_temp).to(args["DEVICE"]))
            models["Fp"] = net.PatchSampleFeatureProjection(scale, patch_size=max(scale//8, 8), gpu_ids=[args["DEVICE"]], nc=5, use_perceptual=True, use_mlp=True)
            models["optF"] = Adam([
                {'params': models["Fp"].parameters()},
            ], lr=0.0002, betas=(0.5, 0.999))
            

        if n_blocks > args["ENABLE_DISC_PATCH"]:
            models["Dp"] = net.PatchDiscriminator(patch_size=max(scale//8, 8), 
                                                n_layers=n_blocks, 
                                                ndf=32, 
                                                no_antialias=True, 
                                                scale=scale).to(args["DEVICE"])
            models["optDp"] = Adam([
                {'params': list(models["G"].parameters()) + list(models["Dp"].parameters())},
            ], lr=lr_per_resolution[4], betas=(0.5, 0.99))
        
        
        warmup = True
        epoch_start = time()
        print(f"[INFO]\t Starting phase {scale_num} at {scale}x{scale} scale training, {total_images_phase} total images this phase - alpha becomes 1 at {limit} and output saved every {args["SAVE_IMAGES_EACH"]} batches ({int(args["SAVE_IMAGES_EACH"]*batch_size)} images)")
        
        # Set necessary learning rate
        for opt in [models["optAE"], models["optG"], models["optD"]]:
            adjust_lr(opt, lr_per_resolution[scale])
        total_batches = len(fract)//batch_size
        extra_images = len(fract)%batch_size
        with tqdm_notebook(total=epochs*total_batches*batch_size, unit='Images', unit_scale=True, unit_divisor=1, desc="Epochs") as pbar:
            dataloader = dataset.make_fractal_alae_dataloader(fract, batch_size, image_size=scale, num_workers=3)
            loss_dict["per"] = GeneralPerceptualLoss(models["D"], 4)
            for epoch in range(epochs):
                for batch_idx, real_samples in enumerate(dataloader):
                    labels = torch.Tensor(np.array(list(range(batch_idx*batch_size, (batch_idx*batch_size)+batch_size)))).to(args["DEVICE"])
                    bs = batch_size
                    ncrops = 1
                    # In the paper 500k with blending & 500k with alpha=1 for each scale
                    alpha = utils.find_alpha(tracked_images, limit)

                    # Discriminator loss
                    z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    z2 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])                
                    w = models["F"](z1, scale, z2, p_mix=0.9)
                    
                    real_samp = real_samples.to(args["DEVICE"]).requires_grad_()
                    fake_samples = models["G"](w, scale, alpha).detach()

                    lossD = loss_dict["discriminator"](models["E"], models["D"], alpha, real_samp, fake_samples)
                                        
                    models["optD"].zero_grad()                    
                    loss_stats["D"].update(lossD)
                    models["optD"].step()

                    # Sample 'styles' from normal distibution to be mixed
                    z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    z2 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    # Project styles to latent space
                    w = models["F"](z1, scale, z2, p_mix=0.9)
                    fake_samples = models["G"](w, scale, alpha)
                    # Generator loss                
                    lossG = loss_dict["generator"](models["E"], models["D"], alpha, fake_samples)
                    
                    models["optG"].zero_grad()
                    loss_stats["G"].update(lossG)

                    z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    z2 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    w = models["F"](z1, scale, z2, p_mix=0.9).detach()

                    lossGAvg = loss_dict["avg_generator"](models["G"], models["G_average"], w, scale, alpha)

                    loss_stats["Ga"].update(lossGAvg)
                    # Sample 'styles' from normal distibution to be mixed
                    z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    real_samp = real_samples.to(args["DEVICE"])
                    E_z = models["E"](real_samp, alpha)
                    # Project styles to latent space
                    w = models["F"](z1, scale, E_z, p_mix=0.75).detach()              
                    fake_samples = models["G"](w, scale, alpha)
                    real_samples_grad = real_samples.to(args["DEVICE"]).requires_grad_()

                    lossGcon = loss_dict["generator_consistency"](fake_samples, real_samples_grad, 
                                                                loss_fn=lambda x,y: piq.MSID()(x.reshape(batch_size, -1), y.reshape(batch_size, -1)), 
                                                                use_perceptual=True,
                                                                use_ssim=True,
                                                                ssim_weight=10,
                                                                use_ssim_tv=False,
                                                                use_sobel=True,
                                                                sobel_weight=0.1,
                                                                use_sobel_tv=True,
                                                                sobel_fn=nn.L1Loss())

                    loss_stats["Gc"].update(lossGcon)
                    
                    # Sample 'styles' from normal distibution to be mixed
                    z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                    real_samp = real_samples.to(args["DEVICE"])
                    E_z = models["E"](real_samp, alpha)
                    # Project styles to latent space
                    w = models["F"](z1, scale, E_z, p_mix=0.75).detach()              
                    fake_samples = models["G"](w, scale, alpha)                
                    rand_samples = torch.randn(batch_size, 3, scale, scale).to(args["DEVICE"]).requires_grad_()
                    lossGsanity = loss_dict["generator_consistency"](fake_samples, rand_samples, 
                                                                loss_fn=lambda x,y: loss_dict["fft"](x, y).mean(), 
                                                                use_ssim=False,
                                                                ssim_weight=1000,
                                                                use_ssim_tv=False,
                                                                use_sobel=False,
                                                                sobel_weight=1,
                                                                use_sobel_tv=False,
                                                                sobel_fn=piq.ssim)

                    loss_stats["Gs"].update(lossGsanity)

                    # Reconstruction loss to make generator more like real samples
                    try:
                        real_samp = real_samples.to(args["DEVICE"]).requires_grad_()
                        E_z = models["E"](real_samp, alpha).repeat(batch_size, int(log2(scale)-1), 1).detach()
                        recon_samples = models["G"](E_z, scale, alpha)

                        lossGmsd = piq.MDSILoss(data_range=1.)(stand(recon_samples), stand(real_samp))
                        loss_stats["Gm"].update(lossGmsd)
                    except:
                        loss_stats["Gm"].update(losses.ssim_loss(recon_samples, real_samp))

                    real_samp = real_samples.to(args["DEVICE"]).requires_grad_()
                    E_z = models["E"](real_samp, alpha).repeat(batch_size, int(log2(scale)-1), 1).detach()
                    recon_samples = models["G"](E_z, scale, alpha)

                    lossGrec = loss_dict["generator_consistency"](recon_samples, real_samp, 
                                                                loss_fn=color_vect_loss,
                                                                use_perceptual=False,
                                                                use_ssim=True,
                                                                ssim_weight=100,
                                                                use_ssim_tv=False,
                                                                use_sobel=False,
                                                                sobel_weight=10,
                                                                use_sobel_tv=False,
                                                                sobel_fn=loss_dict["fft"])

                    loss_stats["Gr"].update(lossGrec)
                    models["optG"].step()

                    # Autoencoder loss
                    z_input = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])

                    lossAE = loss_dict["autoencoder"](models["F"], models["G"], models["E"], scale, alpha, z_input, 
                                                    loss_fn=nn.CosineEmbeddingLoss(), 
                                                    labels=torch.cat([torch.ones([batch_size]), torch.zeros([batch_size])]).to(args["DEVICE"]))

                    models["optAE"].zero_grad()
                    loss_stats["AE"].update(lossAE)
                    models["optAE"].step()   
                

                    # Contrastive
                    if args["CONTRAST_ENABLE"]:
                        real_samp = real_samples.to(args["DEVICE"]).requires_grad_()
                        E_z = models["E"](real_samp, alpha).repeat(batch_size, int(log2(scale)-1), 1).detach()
                        recon_samples = models["G"](E_z, scale, alpha)
                        lossNCE = loss_dict["NCE_new"](models["Fp"], 
                                                    loss_dict["nce"], labels, 5, 
                                                    recon_samples, 
                                                    real_samp, 
                                                    bs * ncrops, scale, alpha)
                        models["optF"].zero_grad()
                        models["optG"].zero_grad()
                        loss_stats["NCE"].update(lossNCE)
                        models["optF"].step()
                        models["optG"].step()
                    elif args["CONTRAST_ORIG_ENABLE"]:
                        z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                        z2 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])                
                        w_tgt = models["F"](z1, scale, z2, p_mix=0.9).detach()
                        
                        E_z = models["E"](real_samp, alpha)
                        w_src = models["F"](E_z, scale, E_z, p_mix=0.9)

                        lossNCE = loss_dict["NCE_orig"](models["G"], models["Fp"], 
                                                    loss_dict["nce"], 
                                                    [n for n in range(n_blocks)], 
                                                    w_src, w_tgt, 4, scale, alpha)

                        models["optF"].zero_grad()
                        models["optG"].zero_grad()
                        loss_stats["NCE"].update(lossNCE)
                        models["optF"].step()
                        models["optG"].step()
                    

                    # Discriminator patch loss
                    if n_blocks > args["ENABLE_DISC_PATCH"]:
                        z1 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])
                        z2 = utils.sample_noise(batch_size, code=code_length, device=args["DEVICE"])                
                        w = models["F"](z1, scale, z2, p_mix=0.9).detach()
                        
                        real_samp = real_samples.to(args["DEVICE"]).requires_grad_()
                        fake_samples = models["G"](w, scale, alpha).detach()

                        lossD2 = loss_dict["discriminator_patch"](models["Dp"], real_samp, fake_samples, 
                                                        gamma=5, use_bce=True)
                        
                        models["optDp"].zero_grad()
                        loss_stats["Dp"].update(lossD2)
                        models["optDp"].step()

                    tracked_images += real_samples.shape[0]
                    
                    # Keep average version of Generator
                    models["G_average"].ema(models["G"], beta=0.999) 

                    # increment progress bar  
                    experiment.set_step(step)
                    step += 1
                    increment_amount = real_samples.shape[0]                
                    pbar.update(increment_amount)
                    
                    if (step % args["SAVE_IMAGES_EACH"]) == 0:
                        data = {
                            "real_samp": real_samp,
                            "real_samples": real_samples,
                            "fake_samples": fake_samples,
                            "scale": scale,
                            "alpha": alpha,
                            "field_100": set_random_field_100,
                            "field_9": set_random_field_9
                        }
                        if USE_CLR_DATA:
                            groupsize = ncrops
                        else:
                            groupsize = None
                        epoch_errors, error_msg = create_previews(models, data, step, experiment_root, epoch_errors, error_msg, groupsize=groupsize)
                        
                        # save model checkpoint
                        if step % 100 == 0:
                            make_plot_result_msg = save_model(models, experiment_root, scale, step=step, name="checkpoint")

                        pbar.set_postfix(alpha=round(alpha, 3), epoch_errors=epoch_errors, 
                            error_msg=error_msg, make_plot_result_msg=make_plot_result_msg, 
                            step=step, refresh=False)
                        
                # Save plot of loss
                #make_plot_result_msg = make_plots(loss_stats, experiment_root)
                pbar.set_postfix(alpha=round(alpha, 3), epoch_errors=epoch_errors, error_msg=error_msg, 
                                step=step, make_plot_result_msg=make_plot_result_msg, refresh=False)

                if hyperparams["use_scheduling"] & (alpha > 0.99) & (epoch in [int(epochs * 0.6), int(epochs * 0.8)]):
                    scheduler_D.step()
                    scheduler_G.step()
                    scheduler_AE.step()
            make_plot_result_msg = save_model(models, experiment_root, scale, step=step, name="final")
            pbar.set_postfix(alpha=round(alpha, 3), epoch_errors=epoch_errors, error_msg=error_msg, 
                                step=step, make_plot_result_msg=make_plot_result_msg, refresh=False)