Exemplo n.º 1
0
    def on_validation_epoch_end(
            self, _: Trainer,
            pl_module: BaseGenerativeModule) -> dict[str, float]:
        generator = pl_module.get_generator()
        noise_dim = pl_module.get_noise_dim()
        internally_wrapped_generator = GeneratorWrapper(generator)
        assert pl_module.logger

        wrapped_generator = torch_fidelity.GenerativeModelModuleWrapper(
            internally_wrapped_generator, noise_dim, 'normal', 0)
        dataset = self.__data_module.generative_eval_dataset()

        result = torch_fidelity.calculate_metrics(
            input1=wrapped_generator,
            input1_model_num_samples=10000,
            input2=dataset,
            input2_cache_name=f'{self.__data_module.dataset_name()}-geneval',
            cuda=True,
            isc=False,
            fid=True,
            kid=True,
            verbose=True,
        )

        for key, value in result.items():
            pl_module.log(key, value, prog_bar=False, on_epoch=True)

        return result
Exemplo n.º 2
0
def test_compare_fid(tmpdir, feature=2048):
    """check that the hole pipeline give the same result as torch-fidelity."""
    from torch_fidelity import calculate_metrics

    metric = FrechetInceptionDistance(feature=feature).cuda()

    # Generate some synthetic data
    img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8)
    img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)

    batch_size = 10
    for i in range(img1.shape[0] // batch_size):
        metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(),
                      real=True)

    for i in range(img2.shape[0] // batch_size):
        metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(),
                      real=False)

    torch_fid = calculate_metrics(
        input1=_ImgDataset(img1),
        input2=_ImgDataset(img2),
        fid=True,
        feature_layer_fid=str(feature),
        batch_size=batch_size,
        save_cpu_ram=True,
    )

    tm_res = metric.compute()

    assert torch.allclose(tm_res.cpu(),
                          torch.tensor(
                              [torch_fid["frechet_inception_distance"]]),
                          atol=1e-3)
Exemplo n.º 3
0
    def test_all(self):
        cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''

        limit = 5000
        input_1 = os.path.join(tempfile.gettempdir(), f'cifar10-train-img-{limit}')
        input_2 = os.path.join(tempfile.gettempdir(), f'cifar10-valid-img-noise-{limit}')

        res = subprocess.run(
            ('python3', 'utils/util_dump_dataset_as_images.py', 'cifar10-train', input_1,
             '-l', str(limit)),
        )
        self.assertEqual(res.returncode, 0, msg=res)
        res = subprocess.run(
            ('python3', 'utils/util_dump_dataset_as_images.py', 'cifar10-valid', input_2,
             '-l', str(limit), '-n'),
        )
        self.assertEqual(res.returncode, 0, msg=res)

        kwargs = {
            'cuda': cuda,
            'input1_cache_name': 'test_input_1',
            'input2_cache_name': 'test_input_2',
            'save_cpu_ram': True,
        }

        all = calculate_metrics(input1=input_1, input2=input_2, isc=True, fid=True, kid=True, **kwargs)

        isc = calculate_isc(1, **kwargs)
        fid = calculate_fid(**kwargs)
        kid = calculate_kid(**kwargs)

        self.assertEqual(isc[KEY_METRIC_ISC_MEAN], all[KEY_METRIC_ISC_MEAN])
        self.assertEqual(fid[KEY_METRIC_FID], all[KEY_METRIC_FID])
        self.assertEqual(kid[KEY_METRIC_KID_MEAN], all[KEY_METRIC_KID_MEAN])
Exemplo n.º 4
0
def calculate_fid(model,
                  fid_dataset,
                  bs,
                  size,
                  num_batches,
                  latent_size,
                  integer_values,
                  save_dir='fid_imgs',
                  device='cuda'):
    coords = convert_to_coord_format(bs,
                                     size,
                                     size,
                                     device,
                                     integer_values=integer_values)
    for i in range(num_batches):
        z = torch.randn(bs, latent_size, device=device)
        fake_img, _ = model(coords, [z])
        for j in range(bs):
            torchvision.utils.save_image(
                fake_img[j, :, :, :],
                os.path.join(save_dir, '%s.png' % str(i * bs + j).zfill(5)),
                range=(-1, 1),
                normalize=True)
    metrics_dict = calculate_metrics(save_dir,
                                     fid_dataset,
                                     cuda=True,
                                     isc=False,
                                     fid=True,
                                     kid=False,
                                     verbose=False)
    return metrics_dict
Exemplo n.º 5
0
 def on_test_end(self, trainer, pl_module):
     metrics_dict = calculate_metrics(
         pl_module.test_dirs['gen'],
         pl_module.test_dirs['rea'],
         cuda=True,
         isc=True,
         fid=True,
         kid=True,
         kid_subset_size=pl_module.hparams.test_size // 10,
         verbose=False)
     trainer.logger.log_metrics(metrics_dict)
     print(metrics_dict)
Exemplo n.º 6
0
def compute_FID(path_1, path_2):
    transform = T.Compose([T.CenterCrop(229), T.Resize(229), T.ToTensor()])
    dataset_1 = dset.ImageFolder(root=path_1, transform=transform)
    dataset_2 = dset.ImageFolder(root=path_2, transform=transform)
    metrics_dict = calculate_metrics(
        dataset_1,
        dataset_2,
        cuda=True,
        fid=True,
        isc=False,
        kid=False,
        verbose=True,
        feature_extractor='inception-v3-compat-malware',
        save_cpu_ram=True)

    print(f"{metrics_dict}")
    return metrics_dict
Exemplo n.º 7
0
    def test_all(self):
        cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''

        input_1 = 'cifar10-train'
        input_2 = Cifar10_RGB(tempfile.gettempdir(), train=True, transform=Compose((
            TransformPILtoRGBTensor(),
            TransformAddNoise()
        )), download=True)
        input_2.name = None

        isc = calculate_isc(input_1, cuda=cuda)
        fid = calculate_fid(input_1, input_2, cuda=cuda)
        kid = calculate_kid(input_1, input_2, cuda=cuda)

        all = calculate_metrics(input_1, input_2, cuda=cuda, isc=True, fid=True, kid=True)

        self.assertEqual(isc[KEY_METRIC_ISC_MEAN], all[KEY_METRIC_ISC_MEAN])
        self.assertEqual(fid[KEY_METRIC_FID], all[KEY_METRIC_FID])
        self.assertEqual(kid[KEY_METRIC_KID_MEAN], all[KEY_METRIC_KID_MEAN])
Exemplo n.º 8
0
 def _test_fid_feature_layer(self, feature_size):
     input1 = RandomlyGeneratedDataset(10,
                                       3,
                                       299,
                                       299,
                                       dtype=torch.uint8,
                                       seed=2021)
     input2 = RandomlyGeneratedDataset(10,
                                       3,
                                       299,
                                       299,
                                       dtype=torch.uint8,
                                       seed=2022)
     metrics = calculate_metrics(input1=input1,
                                 input2=input2,
                                 fid=True,
                                 feature_layer_fid=feature_size,
                                 verbose=False)
     self.assertTrue(metrics[KEY_METRIC_FID] > 0)
Exemplo n.º 9
0
def test_compare_is(tmpdir):
    """ check that the hole pipeline give the same result as torch-fidelity """
    from torch_fidelity import calculate_metrics

    metric = IS(splits=1).cuda()

    # Generate some synthetic data
    img1 = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)

    batch_size = 10
    for i in range(img1.shape[0] // batch_size):
        metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda())

    torch_fid = calculate_metrics(input1=_ImgDataset(img1),
                                  isc=True,
                                  isc_splits=1,
                                  batch_size=batch_size,
                                  save_cpu_ram=True)

    tm_mean, tm_std = metric.compute()

    assert torch.allclose(tm_mean.cpu(),
                          torch.tensor([torch_fid['inception_score_mean']]),
                          atol=1e-3)
Exemplo n.º 10
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--augmentation", action='store_true')
    parser.add_argument("--train-dataset-percentage", type=float, default=100)
    parser.add_argument("--val-dataset-percentage", type=int, default=100)
    parser.add_argument("--label-smoothing", type=float, default=0.9)
    parser.add_argument("--validation-frequency", type=int, default=1)
    args = parser.parse_args()

    ENABLE_AUGMENTATION = args.augmentation
    TRAIN_DATASET_PERCENTAGE = args.train_dataset_percentage
    VAL_DATASET_PERCENTAGE = args.val_dataset_percentage
    LABEL_SMOOTHING_FACTOR = args.label_smoothing
    VALIDATION_FREQUENCY = args.validation_frequency

    if ENABLE_AUGMENTATION:
        augment_batch = AugmentPipe()
        augment_batch.to(device)
    else:
        augment_batch = lambda x: x
        augment_batch.p = 0

    NUM_ADV_EPOCHS = round(NUM_ADV_BASELINE_EPOCHS /
                           (TRAIN_DATASET_PERCENTAGE / 100))
    NUM_PRETRAIN_EPOCHS = round(NUM_BASELINE_PRETRAIN_EPOCHS /
                                (TRAIN_DATASET_PERCENTAGE / 100))
    VALIDATION_FREQUENCY = round(VALIDATION_FREQUENCY /
                                 (TRAIN_DATASET_PERCENTAGE / 100))

    training_start = datetime.datetime.now().isoformat()

    train_set = TrainDatasetFromFolder(train_dataset_dir,
                                       patch_size=PATCH_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    len_train_set = len(train_set)
    train_set = Subset(
        train_set,
        list(
            np.random.choice(
                np.arange(len_train_set),
                int(len_train_set * TRAIN_DATASET_PERCENTAGE / 100), False)))

    val_set = ValDatasetFromFolder(val_dataset_dir,
                                   upscale_factor=UPSCALE_FACTOR)
    len_val_set = len(val_set)
    val_set = Subset(
        val_set,
        list(
            np.random.choice(np.arange(len_val_set),
                             int(len_val_set * VAL_DATASET_PERCENTAGE / 100),
                             False)))

    train_loader = DataLoader(dataset=train_set,
                              num_workers=8,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              pin_memory=True,
                              prefetch_factor=8)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=2,
                            batch_size=VAL_BATCH_SIZE,
                            shuffle=False,
                            pin_memory=True,
                            prefetch_factor=2)

    epoch_validation_hr_dataset = HrValDatasetFromFolder(
        val_dataset_dir)  # Useful to compute FID metric

    results_folder = Path(
        f"results_{training_start}_CS:{PATCH_SIZE}_US:{UPSCALE_FACTOR}x_TRAIN:{TRAIN_DATASET_PERCENTAGE}%_AUGMENTATION:{ENABLE_AUGMENTATION}"
    )
    results_folder.mkdir(exist_ok=True)
    writer = SummaryWriter(str(results_folder / "tensorboard_log"))
    g_net = Generator(n_residual_blocks=NUM_RESIDUAL_BLOCKS,
                      upsample_factor=UPSCALE_FACTOR)
    d_net = Discriminator(patch_size=PATCH_SIZE)
    lpips_metric = lpips.LPIPS(net='alex')

    g_net.to(device=device)
    d_net.to(device=device)
    lpips_metric.to(device=device)

    g_optimizer = optim.Adam(g_net.parameters(), lr=1e-4)
    d_optimizer = optim.Adam(d_net.parameters(), lr=1e-4)

    bce_loss = BCELoss()
    mse_loss = MSELoss()

    bce_loss.to(device=device)
    mse_loss.to(device=device)
    results = {
        'd_total_loss': [],
        'g_total_loss': [],
        'g_adv_loss': [],
        'g_content_loss': [],
        'd_real_mean': [],
        'd_fake_mean': [],
        'psnr': [],
        'ssim': [],
        'lpips': [],
        'fid': [],
        'rt': [],
        'augment_probability': []
    }

    augment_probability = 0
    num_images = len(train_set) * (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS)
    prediction_list = []
    rt = 0

    for epoch in range(1, NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS + 1):
        train_bar = tqdm(train_loader, ncols=200)
        running_results = {
            'batch_sizes': 0,
            'd_epoch_total_loss': 0,
            'g_epoch_total_loss': 0,
            'g_epoch_adv_loss': 0,
            'g_epoch_content_loss': 0,
            'd_epoch_real_mean': 0,
            'd_epoch_fake_mean': 0,
            'rt': 0,
            'augment_probability': 0
        }
        image_percentage = epoch / (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) * 100
        g_net.train()
        d_net.train()

        for data, target in train_bar:
            augment_batch.p = torch.tensor([augment_probability],
                                           device=device)
            batch_size = data.size(0)
            running_results["batch_sizes"] += batch_size
            target = target.to(device)
            data = data.to(device)
            real_labels = torch.ones(batch_size, device=device)
            fake_labels = torch.zeros(batch_size, device=device)

            if epoch > NUM_PRETRAIN_EPOCHS:
                # Discriminator training
                d_optimizer.zero_grad(set_to_none=True)

                d_real_output = d_net(augment_batch(target))
                d_real_output_loss = bce_loss(
                    d_real_output, real_labels * LABEL_SMOOTHING_FACTOR)

                fake_img = g_net(data)
                d_fake_output = d_net(augment_batch(fake_img))
                d_fake_output_loss = bce_loss(d_fake_output, fake_labels)

                d_total_loss = d_real_output_loss + d_fake_output_loss
                d_total_loss.backward()
                d_optimizer.step()

                d_real_mean = d_real_output.mean()
                d_fake_mean = d_fake_output.mean()

            # Generator training
            g_optimizer.zero_grad(set_to_none=True)

            fake_img = g_net(data)
            if epoch > NUM_PRETRAIN_EPOCHS:
                adversarial_loss = bce_loss(d_net(augment_batch(fake_img)),
                                            real_labels) * ADV_LOSS_BALANCER
                content_loss = mse_loss(fake_img, target)
                g_total_loss = content_loss + adversarial_loss
            else:
                adversarial_loss = mse_loss(torch.zeros(
                    1, device=device), torch.zeros(
                        1,
                        device=device))  # Logging purposes, it is always zero
                content_loss = mse_loss(fake_img, target)
                g_total_loss = content_loss

            g_total_loss.backward()
            g_optimizer.step()

            if epoch > NUM_PRETRAIN_EPOCHS and ENABLE_AUGMENTATION:
                prediction_list.append(
                    (torch.sign(d_real_output - 0.5)).tolist())
                if len(prediction_list) == RT_BATCH_SMOOTHING_FACTOR:
                    rt_list = [
                        prediction for sublist in prediction_list
                        for prediction in sublist
                    ]
                    rt = mean(rt_list)
                    if mean(rt_list) > AUGMENT_PROB_TARGET:
                        augment_probability = min(
                            0.85,
                            augment_probability + AUGMENT_PROBABABILITY_STEP)
                    else:
                        augment_probability = max(
                            0.,
                            augment_probability - AUGMENT_PROBABABILITY_STEP)
                    prediction_list.clear()

            running_results['g_epoch_total_loss'] += g_total_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            running_results['g_epoch_adv_loss'] += adversarial_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            running_results['g_epoch_content_loss'] += content_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            if epoch > NUM_PRETRAIN_EPOCHS:
                running_results['d_epoch_total_loss'] += d_total_loss.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['d_epoch_real_mean'] += d_real_mean.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['d_epoch_fake_mean'] += d_fake_mean.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['rt'] += rt * batch_size
                running_results[
                    'augment_probability'] += augment_probability * batch_size

            train_bar.set_description(
                desc=f'[{epoch}/{NUM_ADV_EPOCHS + NUM_PRETRAIN_EPOCHS}] '
                f'Loss_D: {running_results["d_epoch_total_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G: {running_results["g_epoch_total_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G_adv: {running_results["g_epoch_adv_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G_content: {running_results["g_epoch_content_loss"] / running_results["batch_sizes"]:.4f} '
                f'D(x): {running_results["d_epoch_real_mean"] / running_results["batch_sizes"]:.4f} '
                f'D(G(z)): {running_results["d_epoch_fake_mean"] / running_results["batch_sizes"]:.4f} '
                f'rt: {running_results["rt"] / running_results["batch_sizes"]:.4f} '
                f'augment_probability: {running_results["augment_probability"] / running_results["batch_sizes"]:.4f}'
            )

        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            torch.cuda.empty_cache()
            gc.collect()
            g_net.eval()
            # ...
            images_path = results_folder / Path(f'training_images_results')
            images_path.mkdir(exist_ok=True)

            with torch.no_grad():
                val_bar = tqdm(val_loader, ncols=160)
                val_results = {
                    'epoch_mse': 0,
                    'epoch_ssim': 0,
                    'epoch_psnr': 0,
                    'epoch_avg_psnr': 0,
                    'epoch_avg_ssim': 0,
                    'epoch_lpips': 0,
                    'epoch_avg_lpips': 0,
                    'epoch_fid': 0,
                    'batch_sizes': 0
                }
                val_images = torch.empty((0, 0))
                epoch_validation_sr_dataset = None
                for lr, val_hr_restore, hr in val_bar:
                    batch_size = lr.size(0)
                    val_results['batch_sizes'] += batch_size
                    hr = hr.to(device=device)
                    lr = lr.to(device=device)

                    sr = g_net(lr)
                    sr = torch.clamp(sr, 0., 1.)
                    if not epoch_validation_sr_dataset:
                        epoch_validation_sr_dataset = SingleTensorDataset(
                            (sr.cpu() * 255).to(torch.uint8))

                    else:
                        epoch_validation_sr_dataset = ConcatDataset(
                            (epoch_validation_sr_dataset,
                             SingleTensorDataset(
                                 (sr.cpu() * 255).to(torch.uint8))))

                    batch_mse = ((sr - hr)**2).data.mean()  # Pixel-wise MSE
                    val_results['epoch_mse'] += batch_mse * batch_size
                    batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                    val_results['epoch_ssim'] += batch_ssim * batch_size
                    val_results['epoch_avg_ssim'] = val_results[
                        'epoch_ssim'] / val_results['batch_sizes']
                    val_results['epoch_psnr'] += 20 * log10(
                        hr.max() / (batch_mse / batch_size)) * batch_size
                    val_results['epoch_avg_psnr'] = val_results[
                        'epoch_psnr'] / val_results['batch_sizes']
                    val_results['epoch_lpips'] += torch.mean(
                        lpips_metric(hr * 2 - 1, sr * 2 - 1)).to(
                            'cpu', non_blocking=True).detach() * batch_size
                    val_results['epoch_avg_lpips'] = val_results[
                        'epoch_lpips'] / val_results['batch_sizes']

                    val_bar.set_description(
                        desc=
                        f"[converting LR images to SR images] PSNR: {val_results['epoch_avg_psnr']:4f} dB "
                        f"SSIM: {val_results['epoch_avg_ssim']:4f} "
                        f"LPIPS: {val_results['epoch_avg_lpips']:.4f} ")
                    if val_images.size(0) * val_images.size(
                            1) < NUM_LOGGED_VALIDATION_IMAGES * 3:
                        if val_images.size(0) == 0:
                            val_images = torch.hstack(
                                (display_transform(CENTER_CROP_SIZE)
                                 (val_hr_restore).unsqueeze(0).transpose(0, 1),
                                 display_transform(CENTER_CROP_SIZE)(
                                     hr.data.cpu()).unsqueeze(0).transpose(
                                         0, 1),
                                 display_transform(CENTER_CROP_SIZE)(
                                     sr.data.cpu()).unsqueeze(0).transpose(
                                         0, 1)))
                        else:
                            val_images = torch.cat((
                                val_images,
                                torch.hstack(
                                    (display_transform(CENTER_CROP_SIZE)(
                                        val_hr_restore).unsqueeze(0).transpose(
                                            0, 1),
                                     display_transform(CENTER_CROP_SIZE)(
                                         hr.data.cpu()).unsqueeze(0).transpose(
                                             0, 1),
                                     display_transform(CENTER_CROP_SIZE)(
                                         sr.data.cpu()).unsqueeze(0).transpose(
                                             0, 1)))))
                val_results['epoch_fid'] = calculate_metrics(
                    epoch_validation_sr_dataset,
                    epoch_validation_hr_dataset,
                    cuda=True,
                    fid=True,
                    verbose=True
                )['frechet_inception_distance']  # Set batch_size=1 if you get memory error (inside calculate metric function)

                val_images = val_images.view(
                    (NUM_LOGGED_VALIDATION_IMAGES // 4, -1, 3,
                     CENTER_CROP_SIZE, CENTER_CROP_SIZE))
                val_save_bar = tqdm(val_images,
                                    desc='[saving validation results]',
                                    ncols=160)

                for index, image_batch in enumerate(val_save_bar, start=1):
                    image_grid = utils.make_grid(image_batch,
                                                 nrow=3,
                                                 padding=5)
                    writer.add_image(
                        f'progress{image_percentage:.1f}_index_{index}.png',
                        image_grid)

        # save loss / scores / psnr /ssim
        results['d_total_loss'].append(running_results['d_epoch_total_loss'] /
                                       running_results['batch_sizes'])
        results['g_total_loss'].append(running_results['g_epoch_total_loss'] /
                                       running_results['batch_sizes'])
        results['g_adv_loss'].append(running_results['g_epoch_adv_loss'] /
                                     running_results['batch_sizes'])
        results['g_content_loss'].append(
            running_results['g_epoch_content_loss'] /
            running_results['batch_sizes'])
        results['d_real_mean'].append(running_results['d_epoch_real_mean'] /
                                      running_results['batch_sizes'])
        results['d_fake_mean'].append(running_results['d_epoch_fake_mean'] /
                                      running_results['batch_sizes'])
        results['rt'].append(running_results['rt'] /
                             running_results['batch_sizes'])
        results['augment_probability'].append(
            running_results['augment_probability'] /
            running_results['batch_sizes'])
        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            results['psnr'].append(val_results['epoch_avg_psnr'])
            results['ssim'].append(val_results['epoch_avg_ssim'])
            results['lpips'].append(val_results['epoch_avg_lpips'])
            results['fid'].append(val_results['epoch_fid'])

        for metric, metric_values in results.items():
            if epoch == 1 or epoch == (
                    NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1 or \
                    metric not in ["psnr", "ssim", "lpips", "fid"]:
                writer.add_scalar(metric, metric_values[-1],
                                  int(image_percentage * num_images * 0.01))

        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            # save model parameters
            models_path = results_folder / "saved_models"
            models_path.mkdir(exist_ok=True)
            torch.save(
                {
                    'progress': image_percentage,
                    'g_net': g_net.state_dict(),
                    'd_net': g_net.state_dict(),
                    # 'g_optimizer': g_optimizer.state_dict(), Uncomment this if you want resume training
                    # 'd_optimizer': d_optimizer.state_dict(),
                },
                str(models_path / f'progress_{image_percentage:.1f}.tar'))
Exemplo n.º 11
0
def train(args):
    # set up dataset loader
    os.makedirs(args.dir_dataset, exist_ok=True)
    ds_transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    ds_instance = torchvision.datasets.CIFAR10(args.dir_dataset, train=True, download=True, transform=ds_transform)
    loader = torch.utils.data.DataLoader(
        ds_instance, batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=4, pin_memory=True
    )
    loader_iter = iter(loader)

    # reinterpret command line inputs
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_classes = 10 if args.conditional else 0  # unconditional
    leading_metric, last_best_metric, metric_greater_cmp = {
        'ISC': (torch_fidelity.KEY_METRIC_ISC_MEAN, 0.0, float.__gt__),
        'FID': (torch_fidelity.KEY_METRIC_FID, float('inf'), float.__lt__),
        'KID': (torch_fidelity.KEY_METRIC_KID_MEAN, float('inf'), float.__lt__),
        'PPL': (torch_fidelity.KEY_METRIC_PPL_MEAN, float('inf'), float.__lt__),
    }[args.leading_metric]

    # create Generator and Discriminator models
    G = Generator(args.z_size).to(device).train()
    D = Discriminator(not args.disable_sn).to(device).train()
    
    # initialize persistent noise for observed samples
    z_vis = torch.randn(64, args.z_size, device=device)

    # prepare optimizer and learning rate schedulers (linear decay)
    optim_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(0.0, 0.9))
    optim_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.0, 0.9))
    scheduler_G = torch.optim.lr_scheduler.LambdaLR(optim_G, lambda step: 1. - step / args.num_total_steps)
    scheduler_D = torch.optim.lr_scheduler.LambdaLR(optim_D, lambda step: 1. - step / args.num_total_steps)

    # initialize logging
    tb = tensorboard.SummaryWriter(log_dir=args.dir_logs)
    pbar = tqdm.tqdm(total=args.num_total_steps, desc='Training', unit='batch')
    os.makedirs(args.dir_logs, exist_ok=True)

    for step in range(args.num_total_steps):
        # read next batch
        try:
            real_img, real_label = next(loader_iter)
        except StopIteration:
            loader_iter = iter(loader)
            real_img, real_label = next(loader_iter)
        real_img = real_img.to(device)
        real_label = real_label.to(device)

        # update Generator
        G.requires_grad_(True)
        D.requires_grad_(False)
        z = torch.randn(args.batch_size, args.z_size, device=device)
        optim_D.zero_grad()
        optim_G.zero_grad()
        fake = G(z)
        loss_G = hinge_loss_gen(D(fake))
        loss_G.backward()
        optim_G.step()

        # update Discriminator
        G.requires_grad_(False)
        D.requires_grad_(True)
        for d_iter in range(args.num_dis_updates):
            z = torch.randn(args.batch_size, args.z_size, device=device)
            optim_D.zero_grad()
            optim_G.zero_grad()
            fake = G(z)
            loss_D = hinge_loss_dis(D(fake), D(real_img))
            loss_D.backward()
            optim_D.step()

        # log
        if (step + 1) % 10 == 0:
            step_info = {'loss_G': loss_G.cpu().item(), 'loss_D': loss_D.cpu().item()}
            pbar.set_postfix(step_info)
            for k, v in step_info.items():
                tb.add_scalar(f'loss/{k}', v, global_step=step)
            tb.add_scalar(f'LR/lr', scheduler_G.get_last_lr()[0], global_step=step)
        pbar.update(1)

        # decay LR
        scheduler_G.step()
        scheduler_D.step()

        # check if it is validation time
        next_step = step + 1
        if next_step % args.num_epoch_steps != 0:
            continue
        pbar.close()
        G.eval()
        print('Evaluating the generator...')

        # compute and log generative metrics
        metrics = torch_fidelity.calculate_metrics(
            input1=torch_fidelity.GenerativeModelModuleWrapper(G, args.z_size, args.z_type, num_classes),
            input1_model_num_samples=args.num_samples_for_metrics,
            input2='cifar10-train',
            isc=True,
            fid=True,
            kid=True,
            ppl=True,
            ppl_epsilon=1e-2,
            ppl_sample_similarity_resize=64,
        )
        
        # log metrics
        for k, v in metrics.items():
            tb.add_scalar(f'metrics/{k}', v, global_step=next_step)

        # log observed images
        samples_vis = G(z_vis).detach().cpu()
        samples_vis = torchvision.utils.make_grid(samples_vis).permute(1, 2, 0).numpy()
        tb.add_image('observations', samples_vis, global_step=next_step, dataformats='HWC')
        samples_vis = PIL.Image.fromarray(samples_vis)
        samples_vis.save(os.path.join(args.dir_logs, f'{next_step:06d}.png'))

        # save the generator if it improved
        if metric_greater_cmp(metrics[leading_metric], last_best_metric):
            print(f'Leading metric {args.leading_metric} improved from {last_best_metric} to {metrics[leading_metric]}')

            last_best_metric = metrics[leading_metric]

            dummy_input = torch.zeros(1, args.z_size, device=device)
            torch.jit.save(torch.jit.trace(G, (dummy_input,)), os.path.join(args.dir_logs, 'generator.pth'))
            torch.onnx.export(G, dummy_input, os.path.join(args.dir_logs, 'generator.onnx'),
                opset_version=11, input_names=['z'], output_names=['rgb'],
                dynamic_axes={'z': {0: 'batch'}, 'rgb': {0: 'batch'}},
            )

        # resume training
        if next_step <= args.num_total_steps:
            pbar = tqdm.tqdm(total=args.num_total_steps, initial=next_step, desc='Training', unit='batch')
            G.train()

    tb.close()
    print(f'Training finished; the model with best {args.leading_metric} value ({last_best_metric}) is saved as '
          f'{args.dir_logs}/generator.onnx and {args.dir_logs}/generator.pth')
Exemplo n.º 12
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--model", required=True)
    parser.add_argument("--name", required=True, type=str)
    args = parser.parse_args()

    saved_model = torch.load(args.model, map_location=device)
    model_name = args.name

    g_net = Generator(n_residual_blocks=NUM_RESIDUAL_BLOCKS, upsample_factor=UPSCALE_FACTOR)
    lpips_metric = lpips.LPIPS(net='alex')

    g_net.to(device=device)
    lpips_metric.to(device=device)

    g_net.load_state_dict(saved_model['g_net'])

    test_folder = Path("test_results")
    test_folder.mkdir(exist_ok=True)

    results_folder = test_folder / Path(f"{model_name}")
    results_folder.mkdir(exist_ok=True)

    test_set = ValDatasetFromFolder('data/ffhq/images512x512/test_set', upscale_factor=UPSCALE_FACTOR)
    test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=BATCH_SIZE, shuffle=False,
                             pin_memory=True)

    test_hr_dataset = HrValDatasetFromFolder('data/ffhq/images512x512/test_set')

    g_net.eval()
    images_path = results_folder / Path(f'test_images_results')
    images_path.mkdir(exist_ok=True)

    with torch.no_grad():
        test_bar = tqdm(test_loader, ncols=160)
        test_results = {'psnr': 0, 'ssim': 0, 'lpips': 0, 'fid': 0}
        accumulated_results = {'accumulated_mse': 0, 'accumulated_ssim': 0, 'accumulated_psnr': 0,
                               'accumulated_lpips': 0, 'batch_sizes': 0}
        test_images = torch.empty((0, 0))
        test_sr_dataset = None
        for lr, test_hr_restore, hr in test_bar:
            batch_size = lr.size(0)
            accumulated_results['batch_sizes'] += batch_size
            hr = hr.to(device=device)
            lr = lr.to(device=device)

            sr = g_net(lr)
            sr = torch.clamp(sr, 0., 1.)
            if not test_sr_dataset:
                test_sr_dataset = SingleTensorDataset((sr.cpu() * 255).to(torch.uint8))

            else:
                test_sr_dataset = ConcatDataset(
                    (test_sr_dataset, SingleTensorDataset((sr.cpu() * 255).to(torch.uint8))))

            batch_mse = ((sr - hr) ** 2).data.mean()  # Pixel-wise MSE
            accumulated_results['accumulated_mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            accumulated_results['accumulated_ssim'] += batch_ssim * batch_size
            test_results['ssim'] = accumulated_results['accumulated_ssim'] / accumulated_results['batch_sizes']
            accumulated_results['accumulated_psnr'] += 20 * log10(
                hr.max() / (batch_mse / batch_size)) * batch_size
            test_results['psnr'] = accumulated_results['accumulated_psnr'] / accumulated_results['batch_sizes']
            accumulated_results['accumulated_lpips'] += torch.mean(lpips_metric(hr * 2 - 1, sr * 2 - 1)).to(
                'cpu', non_blocking=True).detach() * batch_size
            test_results['lpips'] = accumulated_results['accumulated_lpips'] / accumulated_results['batch_sizes']

            test_bar.set_description(
                desc=f"[converting LR images to SR images] PSNR: {test_results['psnr']:4f} dB "
                     f"SSIM: {test_results['ssim']:4f} "
                     f"LPIPS: {test_results['lpips']:.4f} ")
            if test_images.size(0) * test_images.size(1) < NUM_LOGGED_TEST_IMAGES * 3:

                if test_images.size(0) == 0:
                    test_images = torch.hstack(
                        (test_display_transform()(test_hr_restore).unsqueeze(0).transpose(0, 1),
                         test_display_transform()(hr.data.cpu()).unsqueeze(0).transpose(0, 1),
                         test_display_transform()(sr.data.cpu()).unsqueeze(0).transpose(0, 1)))
                else:
                    test_images = torch.cat((test_images,
                                             torch.hstack(
                                                 (test_display_transform()(test_hr_restore).unsqueeze(
                                                     0).transpose(0, 1),
                                                  test_display_transform()(hr.data.cpu()).unsqueeze(
                                                      0).transpose(0, 1),
                                                  test_display_transform()(sr.data.cpu()).unsqueeze(
                                                      0).transpose(0, 1)))))
        test_results['fid'] = calculate_metrics(test_sr_dataset, test_hr_dataset,
                                                cuda=True, fid=True, verbose=True)['frechet_inception_distance']

        test_images = test_images.view((NUM_LOGGED_TEST_IMAGES, -1, 3, 512, 512))
        test_save_bar = tqdm(test_images, desc='[saving test results]', ncols=160)

        for index, image_batch in enumerate(test_save_bar, start=1):
            image_grid = utils.make_grid(image_batch, nrow=3, padding=5)
            utils.save_image(image_grid, str(images_path / f"{index}.png"), padding=5)

    data_frame = pd.DataFrame(data=test_results, index=[model_name])
    data_frame.to_csv(str(test_folder / f"global_results.csv"), mode='a', index_label="Name")