Exemplo n.º 1
0
    def evaluate(self, verbose=True):
        """
        Evaluate the FID.
        Arguments:
            verbose (bool): Write progress to stdout.
                Default value is True.
        Returns:
            fid (float): Metric value.
        """
        utils.unwrap_module(self.G).set_truncation(
            truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff)
        self.G.eval()
        features = []

        if verbose:
            progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
            progress.write('FID: Gathering statistics for fakes...', step=False)

        remaining = self.num_samples
        for i in range(0, self.num_samples, self.batch_size):

            latents, latent_labels = self.prior_generator(
                batch_size=min(self.batch_size, remaining))
            if latent_labels is not None and self.labels:
                latent_labels = self.labels[i].to(self.device)
                length = min(len(latents), len(latent_labels))
                latents, latent_labels = latents[:length], latent_labels[:length]

            with torch.no_grad():
                fakes = self.G(latents, labels=latent_labels)

            with torch.no_grad():
                batch_features = self.fid_model(fakes)
            batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
            features.append(batch_features.cpu())

            remaining -= len(latents)
            progress.step()

        if verbose:
            progress.write('FID: Statistics for fakes gathered!', step=False)
            progress.close()

        features = torch.cat(features, dim=0).numpy()

        mu_fake = np.mean(features, axis=0)
        sigma_fake = np.cov(features, rowvar=False)

        m = np.square(mu_fake - self.mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False)
        dist = m + np.trace(sigma_fake + self.sigma_real - 2*s)
        return float(np.real(dist))
Exemplo n.º 2
0
def project_images(G, images, name_prefix, args):

    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    if len(args.gpu) > 1:
        warnings.warn(
            'Multi GPU is not available for projection. ' + \
            'Using device {}'.format(device)
        )
    G = utils.unwrap_module(G).to(device)

    lpips_model = stylegan2.external_models.lpips.LPIPS_VGG16(
        pixel_min=args.pixel_min, pixel_max=args.pixel_max)

    proj = stylegan2.project.Projector(G=G,
                                       dlatent_avg_samples=10000,
                                       dlatent_avg_label=args.label,
                                       dlatent_device=device,
                                       dlatent_batch_size=1024,
                                       lpips_model=lpips_model,
                                       lpips_size=256)

    for i in range(0, len(images), args.batch_size):
        target = images[i:i + args.batch_size]
        proj.start(target=target,
                   num_steps=args.num_steps + 1,
                   initial_learning_rate=args.initial_learning_rate,
                   initial_noise_factor=args.initial_noise_factor,
                   lr_rampdown_length=args.lr_rampdown_length,
                   lr_rampup_length=args.lr_rampup_length,
                   noise_ramp_length=args.noise_ramp_length,
                   regularize_noise_weight=args.regularize_noise_weight,
                   verbose=True,
                   verbose_prefix='Projecting image(s) {}/{}'.format(
                       i * args.batch_size + len(target), len(images)))
        snapshot_steps = set(args.num_steps - np.linspace(
            0, args.num_steps, args.num_snapshots, endpoint=False, dtype=int))
        for k, image in enumerate(
                utils.tensor_to_PIL(target,
                                    pixel_min=args.pixel_min,
                                    pixel_max=args.pixel_max)):
            image.save(
                os.path.join(args.output, name_prefix[i + k] + 'target.png'))
        for j in range(args.num_steps):
            proj.step()
            if j in snapshot_steps:
                generated = utils.tensor_to_PIL(proj.generate(),
                                                pixel_min=args.pixel_min,
                                                pixel_max=args.pixel_max)
                for k, image in enumerate(generated):
                    torch.save(
                        proj.get_dlatent(),
                        os.path.join(args.output,
                                     name_prefix[i + k] + 'step%04d.pt' % j))
                    image.save(
                        os.path.join(args.output,
                                     name_prefix[i + k] + 'step%04d.png' % j))
Exemplo n.º 3
0
 def __init__(self,
              G,
              prior_generator,
              device=None,
              num_samples=50000,
              epsilon=1e-4,
              use_dlatent=True,
              full_sampling=False,
              crop=None,
              lpips_model=None,
              lpips_size=None):
     device_ids = []
     if isinstance(G, torch.nn.DataParallel):
         device_ids = G.device_ids
     G = utils.unwrap_module(G)
     assert isinstance(G, models.Generator)
     assert isinstance(prior_generator, utils.PriorGenerator)
     if device is None:
         device = next(G.parameters()).device
     else:
         device = torch.device(device)
     assert torch.device(prior_generator.device) == device, \
         'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
         'is not the same as the specified (or infered from the model)' + \
         'device ({}) for the PPL evaluation.'.format(device)
     G.eval().to(device)
     self.G_mapping = G.G_mapping
     self.G_synthesis = G.G_synthesis
     if device_ids:
         self.G_mapping = torch.nn.DataParallel(self.G_mapping,
                                                device_ids=device_ids)
         self.G_synthesis = torch.nn.DataParallel(self.G_synthesis,
                                                  device_ids=device_ids)
     self.prior_generator = prior_generator
     self.device = device
     self.num_samples = num_samples
     self.epsilon = epsilon
     self.use_dlatent = use_dlatent
     self.full_sampling = full_sampling
     self.crop = crop
     self.batch_size = self.prior_generator.batch_size
     if lpips_model is None:
         warnings.warn(
             'Using default LPIPS distance metric based on VGG 16. ' + \
             'This metric will only work on image data where values are in ' + \
             'the range [-1, 1], please specify an lpips module if you want ' + \
             'to use other kinds of data formats.'
         )
         lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
         if device_ids:
             lpips_model = torch.nn.DataParallel(lpips_model,
                                                 device_ids=device_ids)
         lpips_size = lpips_size or 256
     self.lpips_model = lpips_model.eval().to(device)
     self.lpips_size = lpips_size
Exemplo n.º 4
0
    def evaluate(self, verbose=True):
        """
        Evaluate the PPL.
        Arguments:
            verbose (bool): Write progress to stdout.
                Default value is True.
        Returns:
            ppl (float): Metric value.
        """
        distances = []
        batch_size = self.batch_size
        if self.full_sampling:
            batch_size = 2 * batch_size

        if verbose:
            progress = utils.ProgressWriter(
                np.ceil(self.num_samples / self.batch_size))
            progress.write('PPL: Evaluating metric...', step=False)

        for _ in range(0, self.num_samples, self.batch_size):
            utils.unwrap_module(self.G_synthesis).static_noise()

            latents, latent_labels = self.prior_generator(
                batch_size=batch_size)
            if latent_labels is not None and self.full_sampling:
                # Labels should be the same for the first and second half of latents
                latent_labels = latent_labels.view(2, -1)[0].repeat(2)

            if self.use_dlatent:
                with torch.no_grad():
                    dlatents = self.G_mapping(latents=latents,
                                              labels=latent_labels)
                dlatents = self.prep_latents(dlatents)
            else:
                latents = self.prep_latents(latents)
                with torch.no_grad():
                    dlatents = self.G_mapping(latents=latents,
                                              labels=latent_labels)

            dlatents = dlatents.unsqueeze(1).repeat(
                1, len(utils.unwrap_module(self.G_synthesis)), 1)

            with torch.no_grad():
                output = self.G_synthesis(dlatents)

            output = self.crop_data(output)
            output = self._scale_for_lpips(output)

            output_a, output_b = output[:self.batch_size], output[self.
                                                                  batch_size:]

            with torch.no_grad():
                dist = self.lpips_model(output_a, output_b)

            distances.append(dist.cpu() * (1 / self.epsilon**2))

            if verbose:
                progress.step()

        if verbose:
            progress.write('PPL: Evaluated!', step=False)
            progress.close()

        distances = torch.cat(distances, dim=0).numpy()
        lo = np.percentile(distances, 1, interpolation='lower')
        hi = np.percentile(distances, 99, interpolation='higher')
        filtered_distances = np.extract(
            np.logical_and(lo <= distances, distances <= hi), distances)
        return float(np.mean(filtered_distances))
Exemplo n.º 5
0
    def __init__(self,
                 G,
                 prior_generator,
                 dataset,
                 device=None,
                 num_samples=50000,
                 fid_model=None,
                 fid_size=None,
                 truncation_psi=None,
                 truncation_cutoff=None,
                 reals_batch_size=None,
                 reals_data_workers=0,
                 verbose=True):
        device_ids = []
        if isinstance(G, torch.nn.DataParallel):
            device_ids = G.device_ids
        G = utils.unwrap_module(G)
        assert isinstance(G, models.Generator)
        assert isinstance(prior_generator, utils.PriorGenerator)
        if device is None:
            device = next(G.parameters()).device
        else:
            device = torch.device(device)
        assert torch.device(prior_generator.device) == device, \
            'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
            'is not the same as the specified (or infered from the model)' + \
            'device ({}) for the PPL evaluation.'.format(device)
        G.eval().to(device)
        if device_ids:
            G = torch.nn.DataParallel(G, device_ids=device_ids)
        self.G = G
        self.prior_generator = prior_generator
        self.device = device
        self.num_samples = num_samples
        self.batch_size = self.prior_generator.batch_size
        if fid_model is None:
            warnings.warn(
                'Using default fid model metric based on Inception V3. ' + \
                'This metric will only work on image data where values are in ' + \
                'the range [-1, 1], please specify another module if you want ' + \
                'to use other kinds of data formats.'
            )
            fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1)
            if device_ids:
                fid_model = torch.nn.DataParallel(fid_model, device_ids)
        self.fid_model = fid_model.eval().to(device)
        self.fid_size = fid_size

        dataset = _TruncatedDataset(dataset, self.num_samples)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=reals_batch_size or self.batch_size,
            num_workers=reals_data_workers
        )
        features = []
        self.labels = []

        if verbose:
            progress = utils.ProgressWriter(
                np.ceil(self.num_samples / (reals_batch_size or self.batch_size)))
            progress.write('FID: Gathering statistics for reals...', step=False)

        for batch in dataloader:
            data = batch
            if isinstance(batch, (tuple, list)):
                data = batch[0]
                if len(batch) > 1:
                    self.labels.append(batch[1])
            data = self._scale_for_fid(data).to(self.device)
            with torch.no_grad():
                batch_features = self.fid_model(data)
            batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
            features.append(batch_features.cpu())
            progress.step()

        if verbose:
            progress.write('FID: Statistics for reals gathered!', step=False)
            progress.close()

        features = torch.cat(features, dim=0).numpy()

        self.mu_real = np.mean(features, axis=0)
        self.sigma_real = np.cov(features, rowvar=False)
        self.truncation_psi = truncation_psi
        self.truncation_cutoff = truncation_cutoff
Exemplo n.º 6
0
def project_images(G, images, name_prefix, args):

    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    if len(args.gpu) > 1:
        warnings.warn(
            'Multi GPU is not available for projection. ' + \
            'Using device {}'.format(device)
        )
    G = utils.unwrap_module(G).to(device)

    lpips_model = stylegan2.external_models.lpips.LPIPS_VGG16(
        pixel_min=args.pixel_min, pixel_max=args.pixel_max)

    proj = stylegan2.project.Projector(
        G=G,
        dlatent_avg_samples=10000,
        dlatent_avg_label=args.label,
        dlatent_device=device,
        dlatent_batch_size=1024,
        lpips_model=lpips_model,
        lpips_size=256
    )

    for i in range(0, len(images), args.batch_size):
        target = images[i: i + args.batch_size]
        proj.start(
            target=target,
            num_steps=args.num_steps,
            initial_learning_rate=args.initial_learning_rate,
            initial_noise_factor=args.initial_noise_factor,
            lr_rampdown_length=args.lr_rampdown_length,
            lr_rampup_length=args.lr_rampup_length,
            noise_ramp_length=args.noise_ramp_length,
            regularize_noise_weight=args.regularize_noise_weight,
            verbose=True,
            verbose_prefix='Projecting image(s) {}/{}'.format(
                i * args.batch_size + len(target), len(images)),
            noise_layers=5
        )
        # snapshot_steps = set(
        #     args.num_steps - np.linspace(
        #         0, args.num_steps, args.num_snapshots, endpoint=False, dtype=int))
        for k, image in enumerate(utils.tensor_to_PIL(target, pixel_min=args.pixel_min, pixel_max=args.pixel_max)):
            os.makedirs(os.path.join(args.output, 'target+BEST'), exist_ok=True)
            image.save(os.path.join(args.output, 'target+BEST', name_prefix[i + k] + 'target.png'))
        for j in range(args.num_steps):
            current_image, loss_dict_step, best_output = proj.step()
            # if j in snapshot_steps:
            #     generated = utils.tensor_to_PIL(
            #         proj.generate(), pixel_min=args.pixel_min, pixel_max=args.pixel_max)
            #     for k, image in enumerate(generated):
            #         image.save(os.path.join(
            #             args.output, name_prefix[i + k] + 'step%04d.png' % (j + 1)))
            current_image = utils.tensor_to_PIL(current_image, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
            best_output = utils.tensor_to_PIL(best_output, pixel_min=args.pixel_min, pixel_max=args.pixel_min)
            for j % SAVE_PER == 0 and j != 0 or j == args.num_steps -1:
                for k, (image, best_image) in enumerate(zip(current_image, best_output)):
                    L2, GEOCROSS = loss_dict_stepp['L2'], loss_dict_stepp['GEOCROSS']
                    save_name = f'{name_prefix[i + k]}-loss-{L2:.2f}+GEOCROSS-{GEOCROSS:.2f}-{j}.png'
                    if j != args.num_steps - 1:
                        image.save(os.path.join(args.output, save_name))
                    else:
                        save_name = f'{name_prefix[i + k]}-BEST-{j}.png'
                        best_image.save(os.path.join(args.output, 'target+BEST', save_name))