コード例 #1
0
def render_with_shifted_noise(autoencoder: StyleganAutoencoder,
                              latents: Latents,
                              shifting_rounds: int) -> List[List[Image.Image]]:
    if shifting_rounds == 1:
        shift_factor = torch.tensor([random.random() * 4 - 2])
    else:
        shift_factor = torch.tensor(numpy.linspace(-2, 2, num=shifting_rounds))

    def generate(latents: Latents) -> torch.Tensor:
        with torch.no_grad():
            generated, _ = autoencoder.decoder(
                [latents.latent],
                input_is_latent=autoencoder.is_wplus(latents),
                noise=latents.noise)
        return generated

    shifted_images = [[Image.fromarray(make_image(generate(latents)[0]))]
                      for _ in range(shifting_rounds)]

    for the_round in trange(shifting_rounds, leave=False):
        for i in range(len(latents.noise)):
            noise_copy = latents.noise[i].clone()
            latents.noise[i] = latents.noise[i] * shift_factor[the_round]
            generated_image = generate(latents)
            generated_image = Image.fromarray(make_image(generated_image[0]))
            shifted_images[the_round].append(generated_image)
            latents.noise[i] = noise_copy

    return shifted_images
コード例 #2
0
    def do_style_transfer(self, content_latent: Latents, style_latent: Latents, layer_id: int) -> Tuple[torch.Tensor, Optional[LatentPaths]]:
        latent = torch.cat([content_latent.latent[:, :layer_id, :], style_latent.latent[:, layer_id:, :]], dim=1).detach().clone()
        latent = latent.to(self.projector.device)
        # noise = content_latent.noise[:layer_id] + style_latent.noise[layer_id:]
        noise = [n.detach().clone().to(self.projector.device) for n in content_latent.noise]
        latent_and_noise = Latents(latent, noise)

        path = None
        if self.args.post_optimize:
            path, latent_and_noise = self.post_noise_optimize(content_latent, latent_and_noise)

        latent_and_noise = latent_and_noise.to(self.projector.device)
        return self.projector.generate(latent_and_noise)[0], path
コード例 #3
0
    def test_to(self, device):
        def check_device(latents, dev):
            assert dev in str(latents.latent.device)
            for noise in latents.noise:
                assert dev in str(noise.device)

        latent = torch.ones((1, 14, 512))
        noises = [torch.ones((1, 1, 4, 4)) for _ in range(7)]

        latents = Latents(latent, noises)
        check_device(latents, 'cpu')

        latents = latents.to(device)
        check_device(latents, device)
コード例 #4
0
def build_latent_and_noise_generator(autoencoder: StyleganAutoencoder,
                                     config: Dict) -> Iterable:
    torch.random.manual_seed(1)
    while True:
        latent_code = torch.randn(config['batch_size'], config['latent_size'])
        noise = autoencoder.decoder.make_noise()
        yield Latents(latent_code, noise)
コード例 #5
0
    def post_noise_optimize(self, content_latent: Latents, transfer_latent: Latents) -> Tuple[LatentPaths, Latents]:
        content_latent = content_latent.to(self.projector.device)
        transfer_latent = transfer_latent.to(self.projector.device)

        content_image = self.projector.generate(content_latent)[0].detach()
        style_image = self.projector.generate(transfer_latent)[0].detach()
        content_mask = clamp_and_unnormalize(content_image.clone().detach())
        loss_func = noise_loss(
            {"l_mse_1": 1, "l_mse_2": 1},
            content_image,
            style_image,
            (1 - content_mask).detach()
        )

        path, latent_and_noise = optimize_noise(self.args, self.projector, transfer_latent, content_image, loss_func)

        return path, latent_and_noise
コード例 #6
0
ファイル: autoencoder.py プロジェクト: milesgray/CALAE
    def encode(self, x: torch.Tensor) -> Latents:
        with torch.set_grad_enabled(self.update_latent):
            latent_codes = self.latent_encoder(x).latent

        if self.update_noise:
            noise_codes = self.noise_encoder(x).noise
        else:
            noise_codes = self.decoder.make_noise()

        return Latents(latent=latent_codes, noise=noise_codes)
コード例 #7
0
    def project(self, latents: Latents, images: torch.Tensor, optimizer: Optimizer, num_steps: int, loss_function: Callable, lr_scheduler: _LRScheduler = None) -> Tuple[LatentPaths, Latents]:
        pbar = tqdm(range(num_steps), leave=False)
        latent_path = []
        noise_path = []

        best_latent = best_noise = best_psnr = None

        for i in pbar:
            img_gen, _ = self.generate(latents)

            batch, channel, height, width = img_gen.shape

            if height > 256:
                factor = height // 256

                img_gen = img_gen.reshape(
                    batch, channel, height // factor, factor, width // factor, factor
                )
                img_gen = img_gen.mean([3, 5])

            # # n_loss = noise_regularize(noises)
            loss, loss_dict = loss_function(img_gen, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_dict['psnr'] = self.psnr(img_gen, images).item()
            loss_dict['lr'] = optimizer.param_groups[0]["lr"]

            if lr_scheduler is not None:
                lr_scheduler.step()

            self.log.append(loss_dict)

            if best_psnr is None or best_psnr < loss_dict['psnr']:
                best_psnr = loss_dict['psnr']
                best_latent = latents.latent.detach().clone().cpu()
                best_noise = [noise.detach().clone().cpu() for noise in latents.noise]

            if i % self.debug_step == 0:
                latent_path.append(latents.latent.detach().clone().cpu())
                noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

            loss_description = "; ".join(f"{key}: {value:.6f}" for key, value in loss_dict.items())
            pbar.set_description(loss_description)

            loss_dict['iteration'] = i
            if self.abort_condition is not None and self.abort_condition(loss_dict):
                break

        latent_path.append(latents.latent.detach().clone().cpu())
        noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

        return LatentPaths(latent_path, noise_path), Latents(best_latent, best_noise)
コード例 #8
0
    def test_getitem(self, batch_size):
        latent = torch.ones((batch_size, 14, 512))
        noises = [torch.ones((batch_size, 1, 4, 4)) for _ in range(7)]

        latents = Latents(latent, noises)

        for i in range(batch_size):
            sub_latent = latents[i]
            assert sub_latent.latent.shape == (1, 14, 512)
            for noise in sub_latent.noise:
                assert noise.shape == (1, 1, 4, 4)
コード例 #9
0
    def forward(self, x: torch.Tensor) -> Latents:
        noise_codes = []
        h = x

        for i in range(len(self.resnet_blocks)):
            h = self.resnet_blocks[i](h)
            noise_codes.append(self.to_noise[i](h))
            h = self.intermediate_resnet_blocks[i](h)
            if self.stylegan_variant == 2 and i < len(self.resnet_blocks) - 1:
                noise_codes.append(self.intermediate_to_noise[i](h))

        noise_codes.reverse()

        return Latents(None, noise_codes)
コード例 #10
0
def render_color_grid(autoencoder: StyleganAutoencoder, latents: Latents,
                      indices: List[int], grid_size: int,
                      bounds: List[int]) -> List[List[torch.Tensor]]:
    def generate(latents: Latents) -> torch.Tensor:
        with torch.no_grad():
            generated, _ = autoencoder.decoder(
                [latents.latent],
                input_is_latent=autoencoder.is_wplus(latents),
                noise=latents.noise)
        return generated

    assert len(
        indices
    ) == 2, "Render Color grid only supports the rendering of two indices at once!"
    assert len(
        bounds
    ) == 2, "Render Color grid only supports the rendering with min and max bound"

    shift_factor = numpy.linspace(bounds[0], bounds[1], num=grid_size)
    x_shifts, y_shifts = map(
        numpy.squeeze, numpy.meshgrid(shift_factor, shift_factor, sparse=True))

    x_noise_map = latents.noise[indices[0]].clone()
    y_noise_map = latents.noise[indices[1]].clone()

    grid = []
    for y_shift in tqdm(y_shifts, leave=False):
        latents.noise[indices[1]] = y_noise_map.clone() * y_shift
        x_images = []
        for x_shift in tqdm(x_shifts, leave=False):
            latents.noise[indices[0]] = x_noise_map.clone() * x_shift
            generated_image = generate(latents)
            generated_image = Image.fromarray(make_image(generated_image[0]))
            x_images.append(generated_image)
        grid.append(x_images)

    return grid
コード例 #11
0
    def create_initial_latent_and_noise(self) -> Latents:
        n_mean_latent = 10000
        latent_mean, latent_std = self.get_mean_latent(n_mean_latent)

        base_noises = self.generator.make_noise()
        noises = [noise.detach().clone() for noise in base_noises]

        if self.args.no_mean_latent:
            latent_in = torch.normal(0, latent_std.item(), size=(1, self.config['latent_size']),
                                     device=self.device)
        else:
            latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(1, 1)

        if self.args.w_plus:
            latent_in = latent_in.unsqueeze(1).repeat(1, self.generator.n_latent, 1)

        return Latents(latent_in, noises)
コード例 #12
0
    def forward(self, x: torch.Tensor) -> Latents:
        latent_codes = []
        h = x

        for i in range(len(self.resnet_blocks)):
            h = self.resnet_blocks[i](h)
            latent_codes.append(self.to_latent[i](F.adaptive_avg_pool2d(
                h, (1, 1))))
            h = self.intermediate_resnet_blocks[i](h)
            latent_codes.append(self.intermediate_to_latent[i](
                F.adaptive_avg_pool2d(h, (1, 1))))

        latent_codes.reverse()
        latent_codes = torch.stack(latent_codes, dim=1)
        latent_codes = latent_codes.squeeze(3).squeeze(3)

        return Latents(latent_codes, None)
コード例 #13
0
    def forward(self, x: torch.Tensor) -> Latents:
        latent_codes = []
        noise_codes = []
        h = x

        for i in range(len(self.resnet_blocks)):
            h = self.resnet_blocks[i](h)
            noise_codes.append(self.to_noise[i](h))
            h = self.intermediate_resnet_blocks[i](h)
            if self.stylegan_variant == 2 and i < len(self.resnet_blocks) - 1:
                noise_codes.append(self.intermediate_to_noise[i](h))

        latent_codes.append(self.to_latent(F.adaptive_avg_pool2d(h, (1, 1))))

        latent_codes.reverse()
        latent_codes = latent_codes[0].squeeze(2).squeeze(2)

        noise_codes.reverse()

        return Latents(latent_codes, noise_codes)
コード例 #14
0
    def get_latents(self, content_path: Union[str, Path], style_path: Union[str, Path]) -> Tuple[Latents, Latents]:
        if is_image(content_path):
            content_latents = self.embed_image(content_path, True)
        else:
            embedded_data = torch.load(content_path)
            content_latents = Latents(embedded_data['latent'], embedded_data['noise'])

        if is_image(style_path):
            style_latents = self.embed_image(style_path, False)
        else:
            embedded_data = torch.load(style_path)
            style_latents = Latents(embedded_data['latent'], embedded_data['noise'])

        for latents in [content_latents, style_latents]:
            if len(latents.latent.shape) < 3:
                latents.latent = latents.latent.unsqueeze(0)

        return content_latents.to(self.projector.device), style_latents.to(self.projector.device)
コード例 #15
0
 def forward(self, x: torch.Tensor) -> Latents:
     resulting_latents = super().forward(x)
     latent_code = resulting_latents.latent.sum(dim=1)
     return Latents(latent_code, resulting_latents.noise)
コード例 #16
0
def main(args):
    projector = Projector(args)

    transform = projector.get_transforms()

    imgs = []
    image_names = []

    for file_name in os.listdir(args.files):
        if os.path.splitext(file_name)[-1] not in Image.EXTENSION.keys():
            continue

        image_name = os.path.join(args.files, file_name)
        img = transform(Image.open(image_name).convert('RGB'))
        image_names.append(image_name)
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(args.device)

    n_mean_latent = 10000
    latent_mean, latent_std = projector.get_mean_latent(n_mean_latent)

    for idx in trange(0, len(imgs), args.batch_size):
        images = imgs[idx:idx + args.batch_size]

        base_noises = projector.generator.make_noise()
        base_noises = [
            noise.repeat(len(images), 1, 1, 1) for noise in base_noises
        ]

        noises = [noise.detach().clone() for noise in base_noises]

        if args.no_mean_latent:
            latent_in = torch.normal(0,
                                     latent_std.item(),
                                     size=(len(images),
                                           projector.config['latent_size']),
                                     device=args.device)
        else:
            latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(
                len(images), 1)

        if args.w_plus:
            latent_in = latent_in.unsqueeze(1).repeat(
                1, projector.generator.n_latent, 1)

        # optimize latent vector
        paths, best_latent = run_image_reconstruction(
            args,
            projector,
            Latents(latent_in, noises),
            images,
            do_optimize_noise=args.optimize_noise)

        # result_file = {'noises': noises}

        img_gen, _ = projector.generator(
            [best_latent.latent.cuda()],
            input_is_latent=True,
            noise=[noise.cuda() for noise in best_latent.noise])

        img_ar = make_image(img_gen)

        destination_dir = Path(args.files) / 'projected' / args.destination
        destination_dir.mkdir(parents=True, exist_ok=True)

        path_per_image = paths.split()
        for i in range(len(images)):
            image_name = image_names[idx + i]
            image_latent = best_latent[i]
            result_file = {
                'noise': image_latent.noise,
                'latent': image_latent.latent,
            }
            image_base_name = os.path.splitext(os.path.basename(image_name))[0]
            img_name = image_base_name + '-project.png'
            pil_img = Image.fromarray(img_ar[i])
            pil_img.save(destination_dir / img_name)
            torch.save(result_file,
                       destination_dir / f'results_{image_base_name}.pth')
            if args.create_gif:
                projector.create_gif(path_per_image[i].to(args.device),
                                     image_base_name, destination_dir)
            projector.render_log(destination_dir, image_base_name)

        # cleanup
        del paths
        del best_latent
        torch.cuda.empty_cache()
        projector.reset()