def main(args):
    root_dir = Path(args.autoencoder_checkpoint).parent.parent
    output_dir = root_dir / args.output_dir
    output_dir.mkdir(exist_ok=True, parents=True)

    config = load_config(args.autoencoder_checkpoint, None)
    config['batch_size'] = 1
    autoencoder = get_autoencoder(config).to(args.device)
    autoencoder = load_weights(autoencoder,
                               args.autoencoder_checkpoint,
                               key='autoencoder')

    input_image = Path(args.image)
    data_loader = build_data_loader(input_image,
                                    config,
                                    config['absolute'],
                                    shuffle_off=True,
                                    dataset_class=DemoDataset)

    image = next(iter(data_loader))
    image = {k: v.to(args.device) for k, v in image.items()}

    reconstructed = Image.fromarray(
        make_image(autoencoder(image['input_image'])[0].squeeze(0)))

    output_name = Path(
        args.output_dir
    ) / f"reconstructed_{input_image.stem}_stylegan_{config['stylegan_variant']}_{'w_only' if config['w_only'] else 'w_plus'}.png"
    reconstructed.save(output_name)
예제 #2
0
def embed_images(
        args: argparse.Namespace, config: dict, dataset: Path
) -> Tuple[numpy.ndarray, List[numpy.ndarray], numpy.ndarray]:
    data_loader = build_data_loader(dataset,
                                    config,
                                    config['absolute'],
                                    shuffle_off=True)
    if args.num_samples is not None:
        random.seed(args.seed if args.seed != 'none' else None)
        data_loader.dataset.image_data = random.sample(
            data_loader.dataset.image_data, args.num_samples)

    autoencoder = get_autoencoder(config).to(args.device)
    autoencoder = load_weights(autoencoder,
                               args.autoencoder_checkpoint,
                               key='autoencoder')

    latent_codes = []
    noises = None
    image_names = []

    for idx, batch in enumerate(tqdm(data_loader)):
        if isinstance(batch, dict):
            batch = batch['input_image']
        batch = batch.to(args.device)
        with torch.no_grad():
            latents: Latents = autoencoder.encode(batch)
        latents = latents.numpy()
        latent_codes.append(latents.latent)
        if noises is None:
            noises = [[noise] for noise in latents.noise]
        else:
            for noise, latent_noise in zip(noises, latents.noise):
                noise.append(latent_noise)

        for batch_idx in range(len(batch)):
            image_names.append(
                data_loader.dataset.image_data[idx * config['batch_size'] +
                                               batch_idx])

    latent_codes = numpy.concatenate(latent_codes, axis=0)
    noises = [numpy.concatenate(noise, axis=0) for noise in noises]

    return latent_codes, noises, numpy.array(image_names)
def evaluate_denoising(args):
    config = load_config(args.model_checkpoint, None)
    args.test_dataset = Path(args.test_dataset)

    assert config['denoising'] is True or config['black_and_white_denoising'] is True, "you are supplying a train run that has not been trained for denoising! Aborting"

    autoencoder = get_autoencoder(config).to(args.device)
    autoencoder = load_weights(autoencoder, args.model_checkpoint, key='autoencoder', strict=True)

    config['batch_size'] = 1
    data_loader = build_data_loader(args.test_dataset, config, config['absolute'], shuffle_off=True, dataset_class=DenoisingEvaluationDataset)

    metrics = defaultdict(list)
    psnr_ssim_evaluator = PSNRSSIMEvaluator()

    train_run_root_dir = Path(args.model_checkpoint).parent.parent
    evaluation_root = train_run_root_dir / 'evaluation' / f"denoise_{args.dataset_name}"
    evaluation_root.mkdir(parents=True, exist_ok=True)

    for i, batch in tenumerate(data_loader, leave=False):
        batch = {k: v.to(args.device) for k, v in batch.items()}
        with torch.no_grad():
            denoised = autoencoder(batch['noisy'])

        noisy = clamp_and_unnormalize(batch['noisy'])
        original = clamp_and_unnormalize(batch['original'])
        denoised = clamp_and_unnormalize(denoised)

        if args.save:
            save_dir = evaluation_root / "qualitative" / args.test_dataset.stem
            save_dir.mkdir(exist_ok=True, parents=True)
            save_images([original[0], noisy[0], denoised[0]], save_dir, i)

        psnr, ssim = psnr_ssim_evaluator.psnr_and_ssim(denoised, original)

        metrics['psnr'].append(float(psnr.cpu().numpy()))
        metrics['ssim'].append(float(ssim.cpu().numpy()))

    metrics = {k: statistics.mean(v) for k, v in metrics.items()}

    evaluation_file = evaluation_root / f'denoising_{args.test_dataset.stem}.json'
    with evaluation_file.open('w') as f:
        json.dump(metrics, f, indent='\t')
예제 #4
0
def evaluate_checkpoint(checkpoint: str, dataset: dict, args: argparse.Namespace):
    checkpoint = Path(checkpoint)
    train_run_root_dir = checkpoint.parent.parent
    evaluation_root = train_run_root_dir / 'evaluation'
    evaluation_root.mkdir(exist_ok=True)

    dataset_name = dataset.pop('name')
    to_evaluate = has_not_been_evaluated(checkpoint.name, dataset_name, evaluation_root)
    if not args.fid:
        to_evaluate['fid'] = False
    if not args.reconstruction:
        to_evaluate['reconstruction'] = False

    if not any(to_evaluate.values()):
        # there is nothing to evaluate
        return

    config = load_config(checkpoint, None)

    dataset = {k: Path(v) for k, v in dataset.items()}

    autoencoder = get_autoencoder(config).to('cuda')
    autoencoder = load_weights(autoencoder, checkpoint, key='autoencoder', strict=True)

    config['batch_size'] = 1

    dataset_class = get_dataset_class(argparse.Namespace(**config))
    data_loaders = {
        key: build_data_loader(value, config, config['absolute'], shuffle_off=True, dataset_class=dataset_class)
        for key, value in dataset.items()
    }

    if to_evaluate['fid']:
        fid_result = evaluate_fid(autoencoder, data_loaders, dataset)
        save_eval_result(fid_result, "fid", evaluation_root, dataset_name, checkpoint.name)

    if to_evaluate['reconstruction']:
        reconstruction_result = evaluate_reconstruction(autoencoder, data_loaders)
        save_eval_result(reconstruction_result, "reconstruction", evaluation_root, dataset_name, checkpoint.name)

    del autoencoder
    torch.cuda.empty_cache()
예제 #5
0
def main(args):
    embedding_dir = Path(args.embedding_file).parent
    embedded_data = numpy.load(args.embedding_file, mmap_mode='r')

    checkpoint_for_embedding = embedding_dir.parent / 'checkpoints' / f"{Path(args.embedding_file).stem.split('_')[-3]}.pt"

    config = load_config(checkpoint_for_embedding, None)
    autoencoder = get_autoencoder(config).to(args.device)
    autoencoder = load_weights(autoencoder,
                               checkpoint_for_embedding,
                               key='autoencoder',
                               strict=True)

    num_images = len(embedded_data['image_names'])

    interpolation_dir = embedding_dir / 'interpolations'
    interpolation_dir.mkdir(parents=True, exist_ok=True)

    is_w_plus = not config['w_only']

    for _ in range(args.num_images):
        start_image_idx, end_image_idx = random.sample(list(range(num_images)),
                                                       k=2)

        start_latent, start_noises = load_embeddings(embedded_data,
                                                     start_image_idx)
        end_latent, end_noises = load_embeddings(embedded_data, end_image_idx)

        for steps in args.steps:
            result = make_interpolation_image(steps, args.device, autoencoder,
                                              is_w_plus, start_latent,
                                              end_latent, start_noises,
                                              end_noises)
            result.save(
                str(interpolation_dir /
                    f"{start_image_idx}_{end_image_idx}_all_{steps}_steps.png")
            )
예제 #6
0
def main(args: argparse.Namespace):
    dest_dir = Path(args.model_checkpoint).parent.parent / 'evaluation'
    dest_dir.mkdir(parents=True, exist_ok=True)

    config = load_config(args.model_checkpoint, None)
    dataset = Path(args.dataset)

    config['batch_size'] = args.batch_size
    data_loader = build_data_loader(dataset,
                                    config,
                                    config['absolute'],
                                    shuffle_off=True)
    fid_calculator = FID(args.num_samples, device=args.device)

    autoencoder = get_autoencoder(config).to(args.device)
    autoencoder = load_weights(autoencoder,
                               args.model_checkpoint,
                               key='autoencoder')

    fid_score = fid_calculator(autoencoder, data_loader, args.dataset)

    save_fid_score(fid_score, dest_dir, args.dataset_name)

    print(f"FID Score for {args.dataset_name} is {fid_score}.")
예제 #7
0
def main(args):
    checkpoint_path = Path(args.model_checkpoint)

    config = load_config(checkpoint_path, None)

    autoencoder = get_autoencoder(config).to(args.device)
    load_weights(autoencoder, checkpoint_path, key='autoencoder', strict=True)

    config['batch_size'] = 1
    if args.generate:
        data_loader = build_latent_and_noise_generator(autoencoder, config)
    else:
        data_loader = build_data_loader(args.images,
                                        config,
                                        args.absolute,
                                        shuffle_off=True)

    noise_dest_dir = checkpoint_path.parent.parent / "noise_maps"
    noise_dest_dir.mkdir(parents=True, exist_ok=True)

    num_images = 0
    for idx, batch in enumerate(tqdm(data_loader, total=args.num_images)):
        batch = batch.to(args.device)

        if args.generate:
            latents = batch
            image_names = [Path(f"generate_{idx}.png")]
        else:
            with torch.no_grad():
                latents: Latents = autoencoder.encode(batch)

            image_names = [
                Path(
                    data_loader.dataset.image_data[idx * config['batch_size'] +
                                                   batch_idx])
                for batch_idx in range(len(batch))
            ]

        if args.shift_noise:
            noise_shifted_tensors = render_with_shifted_noise(
                autoencoder, latents, args.rounds)

        images = []
        for noise_tensors in latents.noise:
            noise_images = make_image(noise_tensors,
                                      normalize_func=noise_normalize)
            images.append([
                Image.fromarray(im).resize(
                    (config['image_size'], config['image_size']),
                    Image.NEAREST) for im in noise_images
            ])

        for batch_idx, (image,
                        orig_file_name) in enumerate(zip(batch, image_names)):
            full_image = Image.new(
                'RGB',
                ((len(images) + 1) * config['image_size'], config['image_size']
                 if not args.shift_noise else config['image_size'] *
                 (args.rounds + 1)))
            if not args.generate:
                full_image.paste(Image.fromarray(make_image(image)), (0, 0))
            for i, noise_images in enumerate(images):
                full_image.paste(noise_images[batch_idx],
                                 ((i + 1) * config['image_size'], 0))

            if args.shift_noise:
                for i, shifted_images in enumerate(noise_shifted_tensors):
                    for j, shifted_image in enumerate(shifted_images):
                        full_image.paste(shifted_image,
                                         (j * config['image_size'],
                                          (i + 1) * config['image_size']))

            full_image.save(noise_dest_dir /
                            f"{orig_file_name.stem}_noise.png")

        num_images += len(image_names)
        if num_images >= args.num_images:
            break
def main(args, rank, world_size):
    config = load_yaml_config(args.config)
    config = merge_config_and_args(config, args)

    dataset_class = get_dataset_class(args)
    train_data_loader = build_data_loader(args.images,
                                          config,
                                          args.absolute,
                                          dataset_class=dataset_class)

    autoencoder = get_autoencoder(config, init_ckpt=args.stylegan_checkpoint)
    discriminator = None
    if args.use_discriminator:
        discriminator = get_discriminator(config)

    if args.disable_update_for == 'latent':
        assert args.autoencoder is not None, "if you want to only train noise, we need an autoencoder checkoint!"
        print(
            f"Loading encoder weights from {args.autoencoder} for noise-only training."
        )
        load_weights(autoencoder.encoder,
                     args.autoencoder,
                     key='encoder',
                     strict=False)
    elif args.autoencoder is not None:
        print(f"Loading all weights from {args.autoencoder}.")
        load_weights(autoencoder, args.autoencoder, key='autoencoder')

    optimizer_opts = {
        'betas': (config['beta1'], config['beta2']),
        'weight_decay': config['weight_decay'],
        'lr': float(config['lr']),
    }

    if args.disable_update_for != 'none':
        if float(config['lr_to_noise']) != float(config['lr']):
            print(
                "Warning: updates for some parts of the networks are disabled. "
                f"Therefore 'lr_to_noise'={config['lr_to_noise']} is ignored.")
        optimizer = GradientClipAdam(autoencoder.trainable_parameters(),
                                     **optimizer_opts)
    else:
        main_param_group, noise_param_group = autoencoder.trainable_parameters(
            as_groups=(["to_noise", "intermediate_to_noise"], ))
        noise_param_group['lr'] = float(config['lr_to_noise'])
        optimizer = GradientClipAdam([main_param_group, noise_param_group],
                                     **optimizer_opts)

    if world_size > 1:
        distributed = functools.partial(DDP,
                                        device_ids=[rank],
                                        find_unused_parameters=True,
                                        broadcast_buffers=False,
                                        output_device=rank)
        autoencoder = distributed(autoencoder.to('cuda'))
        if discriminator is not None:
            discriminator = distributed(discriminator.to('cuda'))
    else:
        autoencoder = autoencoder.to('cuda')
        if discriminator is not None:
            discriminator = discriminator.to('cuda')

    if discriminator is not None:
        discriminator_optimizer = GradientClipAdam(discriminator.parameters(),
                                                   **optimizer_opts)
        updater = AutoencoderDiscriminatorUpdater(
            iterators={'images': train_data_loader},
            networks={
                'autoencoder': autoencoder,
                'discriminator': discriminator
            },
            optimizers={
                'main': optimizer,
                'discriminator': discriminator_optimizer
            },
            device='cuda',
            copy_to_device=world_size == 1,
            disable_update_for=args.disable_update_for,
        )
    else:
        updater = AutoencoderUpdater(
            iterators={'images': train_data_loader},
            networks={'autoencoder': autoencoder},
            optimizers={'main': optimizer},
            device='cuda',
            copy_to_device=world_size == 1,
            disable_update_for=args.disable_update_for,
        )

    trainer = DistributedTrainer(updater,
                                 stop_trigger=get_trigger(
                                     (config['max_iter'], 'iteration')))

    logger = WandBLogger(
        args.log_dir,
        args,
        config,
        os.path.dirname(os.path.realpath(__file__)),
        trigger=get_trigger((config['log_iter'], 'iteration')),
        master=rank == 0,
        project_name="One Model to Generate them All",
        run_name=args.log_name,
    )

    if args.val_images is not None:
        val_data_loader = build_data_loader(args.val_images,
                                            config,
                                            args.absolute,
                                            shuffle_off=True,
                                            dataset_class=dataset_class)

        evaluator = Evaluator(val_data_loader,
                              logger,
                              AutoEncoderEvalFunc(autoencoder, rank),
                              rank,
                              trigger=get_trigger((1, 'epoch')))
        trainer.extend(evaluator)

    fid_extension = FIDScore(
        autoencoder
        if not isinstance(autoencoder, DDP) else autoencoder.module,
        val_data_loader if args.val_images is not None else train_data_loader,
        dataset_path=args.val_images
        if args.val_images is not None else args.images,
        trigger=(1, 'epoch'))
    trainer.extend(fid_extension)

    if rank == 0:
        snapshot_autoencoder = autoencoder if not isinstance(
            autoencoder, DDP) else autoencoder.module
        snapshotter = Snapshotter(
            {
                'autoencoder': snapshot_autoencoder,
                'encoder': snapshot_autoencoder.encoder,
                'decoder': snapshot_autoencoder.decoder,
                'optimizer': optimizer,
            },
            args.log_dir,
            trigger=get_trigger((config['snapshot_save_iter'], 'iteration')))
        trainer.extend(snapshotter)

        plot_images = []
        if args.val_images is not None:

            def fill_plot_images(data_loader):
                image_list = []
                num_images = 0
                for batch in data_loader:
                    for image in batch['input_image']:
                        image_list.append(image)
                        num_images += 1
                        if num_images > config['display_size']:
                            return image_list
                raise RuntimeError(
                    f"Could not gather enough plot images for display size {config['display_size']}."
                )

            plot_images = fill_plot_images(val_data_loader)
        else:
            for i in range(config['display_size']):
                if hasattr(train_data_loader.sampler, 'set_epoch'):
                    train_data_loader.sampler.set_epoch(i)
                plot_images.append(
                    next(iter(train_data_loader))['input_image'][0])
        image_plotter = ImagePlotter(plot_images, [autoencoder],
                                     args.log_dir,
                                     trigger=get_trigger(
                                         (config['image_save_iter'],
                                          'iteration')),
                                     plot_to_logger=True)
        trainer.extend(image_plotter)

    schedulers = {
        "encoder": CosineAnnealingLR(optimizer,
                                     config["max_iter"],
                                     eta_min=1e-8)
    }
    lr_scheduler = LRScheduler(schedulers,
                               trigger=get_trigger((1, 'iteration')))
    trainer.extend(lr_scheduler)

    trainer.extend(logger)

    synchronize()
    trainer.train()