예제 #1
0
def render_with_shifted_noise(autoencoder: StyleganAutoencoder,
                              latents: Latents,
                              shifting_rounds: int) -> List[List[Image.Image]]:
    if shifting_rounds == 1:
        shift_factor = torch.tensor([random.random() * 4 - 2])
    else:
        shift_factor = torch.tensor(numpy.linspace(-2, 2, num=shifting_rounds))

    def generate(latents: Latents) -> torch.Tensor:
        with torch.no_grad():
            generated, _ = autoencoder.decoder(
                [latents.latent],
                input_is_latent=autoencoder.is_wplus(latents),
                noise=latents.noise)
        return generated

    shifted_images = [[Image.fromarray(make_image(generate(latents)[0]))]
                      for _ in range(shifting_rounds)]

    for the_round in trange(shifting_rounds, leave=False):
        for i in range(len(latents.noise)):
            noise_copy = latents.noise[i].clone()
            latents.noise[i] = latents.noise[i] * shift_factor[the_round]
            generated_image = generate(latents)
            generated_image = Image.fromarray(make_image(generated_image[0]))
            shifted_images[the_round].append(generated_image)
            latents.noise[i] = noise_copy

    return shifted_images
예제 #2
0
def make_interpolation_image(steps: int, device: torch.device,
                             autoencoder: nn.Module, is_w_plus: bool,
                             start_latent: torch.Tensor,
                             end_latent: torch.Tensor,
                             start_noises: List[torch.Tensor],
                             end_noises: List[torch.Tensor]):
    all_interpolation_images = []
    for interpolation_strategy in ['all', 'latent', 'noise']:
        interpolation_images = []

        start_image, _ = autoencoder.decoder(
            [start_latent.to(device)],
            input_is_latent=is_w_plus,
            noise=[n.to(device) for n in start_noises])
        interpolation_images.append(make_image(start_image.squeeze(0)))

        for i in trange(steps + 1):
            step_fraction = i / steps
            if interpolation_strategy in ['latent', 'all']:
                latent = interpolate(start_latent, end_latent, step_fraction)
            else:
                latent = start_latent
            latent = latent.to(device)

            if interpolation_strategy in ['noise', 'all']:
                noises = [
                    interpolate(start_noise, end_noise, step_fraction)
                    for start_noise, end_noise in zip(start_noises, end_noises)
                ]
            else:
                noises = autoencoder.decoder.make_noise()
            noises = [noise.to(device) for noise in noises]

            image, _ = autoencoder.decoder([latent],
                                           input_is_latent=is_w_plus,
                                           noise=noises)
            image = make_image(image.squeeze(0))
            interpolation_images.append(image)

        end_image, _ = autoencoder.decoder(
            [end_latent.to(device)],
            input_is_latent=is_w_plus,
            noise=[n.to(device) for n in end_noises])
        interpolation_images.append(make_image(end_image.squeeze(0)))

        all_images = numpy.concatenate(interpolation_images, axis=1)
        image = Image.fromarray(all_images)
        all_interpolation_images.append(image)

    dest_image = Image.new("RGB", (all_interpolation_images[0].width,
                                   all_interpolation_images[0].height * 3))
    for i, image in enumerate(all_interpolation_images):
        dest_image.paste(image, (0, i * image.height))

    return dest_image
def main(args):
    root_dir = Path(args.autoencoder_checkpoint).parent.parent
    output_dir = root_dir / args.output_dir
    output_dir.mkdir(exist_ok=True, parents=True)

    config = load_config(args.autoencoder_checkpoint, None)
    config['batch_size'] = 1
    autoencoder = get_autoencoder(config).to(args.device)
    autoencoder = load_weights(autoencoder,
                               args.autoencoder_checkpoint,
                               key='autoencoder')

    input_image = Path(args.image)
    data_loader = build_data_loader(input_image,
                                    config,
                                    config['absolute'],
                                    shuffle_off=True,
                                    dataset_class=DemoDataset)

    image = next(iter(data_loader))
    image = {k: v.to(args.device) for k, v in image.items()}

    reconstructed = Image.fromarray(
        make_image(autoencoder(image['input_image'])[0].squeeze(0)))

    output_name = Path(
        args.output_dir
    ) / f"reconstructed_{input_image.stem}_stylegan_{config['stylegan_variant']}_{'w_only' if config['w_only'] else 'w_plus'}.png"
    reconstructed.save(output_name)
예제 #4
0
def render_color_grid(autoencoder: StyleganAutoencoder, latents: Latents,
                      indices: List[int], grid_size: int,
                      bounds: List[int]) -> List[List[torch.Tensor]]:
    def generate(latents: Latents) -> torch.Tensor:
        with torch.no_grad():
            generated, _ = autoencoder.decoder(
                [latents.latent],
                input_is_latent=autoencoder.is_wplus(latents),
                noise=latents.noise)
        return generated

    assert len(
        indices
    ) == 2, "Render Color grid only supports the rendering of two indices at once!"
    assert len(
        bounds
    ) == 2, "Render Color grid only supports the rendering with min and max bound"

    shift_factor = numpy.linspace(bounds[0], bounds[1], num=grid_size)
    x_shifts, y_shifts = map(
        numpy.squeeze, numpy.meshgrid(shift_factor, shift_factor, sparse=True))

    x_noise_map = latents.noise[indices[0]].clone()
    y_noise_map = latents.noise[indices[1]].clone()

    grid = []
    for y_shift in tqdm(y_shifts, leave=False):
        latents.noise[indices[1]] = y_noise_map.clone() * y_shift
        x_images = []
        for x_shift in tqdm(x_shifts, leave=False):
            latents.noise[indices[0]] = x_noise_map.clone() * x_shift
            generated_image = generate(latents)
            generated_image = Image.fromarray(make_image(generated_image[0]))
            x_images.append(generated_image)
        grid.append(x_images)

    return grid
예제 #5
0
def main(args):
    checkpoint_path = Path(args.model_checkpoint)

    config = load_config(checkpoint_path, None)

    autoencoder = get_autoencoder(config).to(args.device)
    load_weights(autoencoder, checkpoint_path, key='autoencoder', strict=True)

    config['batch_size'] = 1
    if args.generate:
        data_loader = build_latent_and_noise_generator(autoencoder, config)
    else:
        data_loader = build_data_loader(args.images,
                                        config,
                                        args.absolute,
                                        shuffle_off=True)

    noise_dest_dir = checkpoint_path.parent.parent / "noise_maps"
    noise_dest_dir.mkdir(parents=True, exist_ok=True)

    num_images = 0
    for idx, batch in enumerate(tqdm(data_loader, total=args.num_images)):
        batch = batch.to(args.device)

        if args.generate:
            latents = batch
            image_names = [Path(f"generate_{idx}.png")]
        else:
            with torch.no_grad():
                latents: Latents = autoencoder.encode(batch)

            image_names = [
                Path(
                    data_loader.dataset.image_data[idx * config['batch_size'] +
                                                   batch_idx])
                for batch_idx in range(len(batch))
            ]

        if args.shift_noise:
            noise_shifted_tensors = render_with_shifted_noise(
                autoencoder, latents, args.rounds)

        images = []
        for noise_tensors in latents.noise:
            noise_images = make_image(noise_tensors,
                                      normalize_func=noise_normalize)
            images.append([
                Image.fromarray(im).resize(
                    (config['image_size'], config['image_size']),
                    Image.NEAREST) for im in noise_images
            ])

        for batch_idx, (image,
                        orig_file_name) in enumerate(zip(batch, image_names)):
            full_image = Image.new(
                'RGB',
                ((len(images) + 1) * config['image_size'], config['image_size']
                 if not args.shift_noise else config['image_size'] *
                 (args.rounds + 1)))
            if not args.generate:
                full_image.paste(Image.fromarray(make_image(image)), (0, 0))
            for i, noise_images in enumerate(images):
                full_image.paste(noise_images[batch_idx],
                                 ((i + 1) * config['image_size'], 0))

            if args.shift_noise:
                for i, shifted_images in enumerate(noise_shifted_tensors):
                    for j, shifted_image in enumerate(shifted_images):
                        full_image.paste(shifted_image,
                                         (j * config['image_size'],
                                          (i + 1) * config['image_size']))

            full_image.save(noise_dest_dir /
                            f"{orig_file_name.stem}_noise.png")

        num_images += len(image_names)
        if num_images >= args.num_images:
            break