Ejemplo n.º 1
0
    def evaluate(self, verbose=True):
        """
        Evaluate the FID.
        Arguments:
            verbose (bool): Write progress to stdout.
                Default value is True.
        Returns:
            fid (float): Metric value.
        """
        utils.unwrap_module(self.G).set_truncation(
            truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff)
        self.G.eval()
        features = []

        if verbose:
            progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
            progress.write('FID: Gathering statistics for fakes...', step=False)

        remaining = self.num_samples
        for i in range(0, self.num_samples, self.batch_size):

            latents, latent_labels = self.prior_generator(
                batch_size=min(self.batch_size, remaining))
            if latent_labels is not None and self.labels:
                latent_labels = self.labels[i].to(self.device)
                length = min(len(latents), len(latent_labels))
                latents, latent_labels = latents[:length], latent_labels[:length]

            with torch.no_grad():
                fakes = self.G(latents, labels=latent_labels)

            with torch.no_grad():
                batch_features = self.fid_model(fakes)
            batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
            features.append(batch_features.cpu())

            remaining -= len(latents)
            progress.step()

        if verbose:
            progress.write('FID: Statistics for fakes gathered!', step=False)
            progress.close()

        features = torch.cat(features, dim=0).numpy()

        mu_fake = np.mean(features, axis=0)
        sigma_fake = np.cov(features, rowvar=False)

        m = np.square(mu_fake - self.mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False)
        dist = m + np.trace(sigma_fake + self.sigma_real - 2*s)
        return float(np.real(dist))
Ejemplo n.º 2
0
def generate_images(G, args):
    latent_size, label_size = G.latent_size, G.label_size
    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    G.to(device)
    if args.truncation_psi != 1:
        G.set_truncation(truncation_psi=args.truncation_psi)
    if len(args.gpu) > 1:
        warnings.warn(
            'Noise can not be randomized based on the seed ' + \
            'when using more than 1 GPU device. Noise will ' + \
            'now be randomized from default random state.'
        )
        G.random_noise()
        G = torch.nn.DataParallel(G, device_ids=args.gpu)
    else:
        noise_reference = G.static_noise()

    def get_batch(seeds):
        latents = []
        labels = []
        if len(args.gpu) <= 1:
            noise_tensors = [[] for _ in noise_reference]
        for seed in seeds:
            rnd = np.random.RandomState(seed)
            latents.append(torch.from_numpy(rnd.randn(latent_size)))
            if len(args.gpu) <= 1:
                for i, ref in enumerate(noise_reference):
                    noise_tensors[i].append(
                        torch.from_numpy(rnd.randn(*ref.size()[1:])))
            if label_size:
                labels.append(torch.tensor([rnd.randint(0, label_size)]))
        latents = torch.stack(latents, dim=0).to(device=device,
                                                 dtype=torch.float32)
        if labels:
            labels = torch.cat(labels, dim=0).to(device=device,
                                                 dtype=torch.int64)
        else:
            labels = None
        if len(args.gpu) <= 1:
            noise_tensors = [
                torch.stack(noise, dim=0).to(device=device,
                                             dtype=torch.float32)
                for noise in noise_tensors
            ]
        else:
            noise_tensors = None
        return latents, labels, noise_tensors

    progress = utils.ProgressWriter(len(args.seeds))
    progress.write('Generating images...', step=False)

    for i in range(0, len(args.seeds), args.batch_size):
        latents, labels, noise_tensors = get_batch(args.seeds[i:i +
                                                              args.batch_size])
        if noise_tensors is not None:
            G.static_noise(noise_tensors=noise_tensors)
        with torch.no_grad():
            generated = G(latents, labels=labels)
        images = utils.tensor_to_PIL(generated,
                                     pixel_min=args.pixel_min,
                                     pixel_max=args.pixel_max)
        for seed, img in zip(args.seeds[i:i + args.batch_size], images):
            img.save(os.path.join(args.output, 'seed%04d.png' % seed))
            progress.step()

    progress.write('Done!', step=False)
    progress.close()
Ejemplo n.º 3
0
def style_mixing_example(G, args):
    assert max(args.style_layers) < len(G), \
        'Style layer indices can not be larger than ' + \
        'number of style layers ({}) of the generator.'.format(len(G))
    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    if len(args.gpu) > 1:
        warnings.warn(
            'Multi GPU is not available for style mixing example. Using device {}'
            .format(device))
    G.to(device)
    G.static_noise()
    latent_size, label_size = G.latent_size, G.label_size
    G_mapping, G_synthesis = G.G_mapping, G.G_synthesis

    all_seeds = list(set(args.row_seeds + args.col_seeds))
    all_z = torch.stack([
        torch.from_numpy(np.random.RandomState(seed).randn(latent_size))
        for seed in all_seeds
    ])
    all_z = all_z.to(device=device, dtype=torch.float32)
    if label_size:
        labels = torch.zeros(len(all_z), dtype=torch.int64, device=device)
    else:
        labels = None

    print('Generating disentangled latents...')
    #print(all_z.shape) #[-1,512]
    with torch.no_grad():
        all_w = G_mapping(latents=all_z, labels=labels)
    all_w = all_w.unsqueeze(1).repeat(1, len(G_synthesis), 1)  #[-1,18,512]

    w_avg = G.dlatent_avg  # [512]

    if args.truncation_psi != 1:
        all_w = w_avg + args.truncation_psi * (all_w - w_avg)

    w_dict = {seed: w for seed, w in zip(all_seeds, all_w)}

    all_images = []

    progress = utils.ProgressWriter(len(all_w))
    progress.write('Generating images...', step=False)

    with torch.no_grad():
        for w in all_w:
            all_images.append(G_synthesis(w.unsqueeze(0)))
            progress.step()

    progress.write('Done!', step=False)
    progress.close()

    all_images = torch.cat(all_images, dim=0)

    image_dict = {(seed, seed): image
                  for seed, image in zip(all_seeds, all_images)}

    progress = utils.ProgressWriter(len(args.row_seeds) * len(args.col_seeds))
    progress.write('Generating style-mixed images...', step=False)

    for row_seed in args.row_seeds:
        for col_seed in args.col_seeds:
            w = w_dict[row_seed].clone()
            w[args.style_layers] = w_dict[col_seed][args.style_layers]
            with torch.no_grad():
                image_dict[(row_seed,
                            col_seed)] = G_synthesis(w.unsqueeze(0)).squeeze(0)
            progress.step()

    progress.write('Done!', step=False)
    progress.close()

    progress = utils.ProgressWriter(len(image_dict))
    progress.write('Saving images...', step=False)

    for (row_seed, col_seed), image in list(image_dict.items()):
        image = utils.tensor_to_PIL(image,
                                    pixel_min=args.pixel_min,
                                    pixel_max=args.pixel_max)
        image_dict[(row_seed, col_seed)] = image
        image.save(
            os.path.join(args.output, '%d-%d.png' % (row_seed, col_seed)))
        progress.step()

    progress.write('Done!', step=False)
    progress.close()

    if args.grid:
        print('\n\nSaving style-mixed grid...')
        H, W = all_images.size()[2:]
        canvas = Image.new('RGB', (W * (len(args.col_seeds) + 1), H *
                                   (len(args.row_seeds) + 1)), 'black')
        for row_idx, row_seed in enumerate([None] + args.row_seeds):
            for col_idx, col_seed in enumerate([None] + args.col_seeds):
                if row_seed is None and col_seed is None:
                    continue
                key = (row_seed, col_seed)
                if row_seed is None:
                    key = (col_seed, col_seed)
                if col_seed is None:
                    key = (row_seed, row_seed)
                canvas.paste(image_dict[key], (W * col_idx, H * row_idx))
        canvas.save(os.path.join(args.output, 'grid.png'))
        print('Done!')
Ejemplo n.º 4
0
    def evaluate(self, verbose=True):
        """
        Evaluate the PPL.
        Arguments:
            verbose (bool): Write progress to stdout.
                Default value is True.
        Returns:
            ppl (float): Metric value.
        """
        distances = []
        batch_size = self.batch_size
        if self.full_sampling:
            batch_size = 2 * batch_size

        if verbose:
            progress = utils.ProgressWriter(
                np.ceil(self.num_samples / self.batch_size))
            progress.write('PPL: Evaluating metric...', step=False)

        for _ in range(0, self.num_samples, self.batch_size):
            utils.unwrap_module(self.G_synthesis).static_noise()

            latents, latent_labels = self.prior_generator(
                batch_size=batch_size)
            if latent_labels is not None and self.full_sampling:
                # Labels should be the same for the first and second half of latents
                latent_labels = latent_labels.view(2, -1)[0].repeat(2)

            if self.use_dlatent:
                with torch.no_grad():
                    dlatents = self.G_mapping(latents=latents,
                                              labels=latent_labels)
                dlatents = self.prep_latents(dlatents)
            else:
                latents = self.prep_latents(latents)
                with torch.no_grad():
                    dlatents = self.G_mapping(latents=latents,
                                              labels=latent_labels)

            dlatents = dlatents.unsqueeze(1).repeat(
                1, len(utils.unwrap_module(self.G_synthesis)), 1)

            with torch.no_grad():
                output = self.G_synthesis(dlatents)

            output = self.crop_data(output)
            output = self._scale_for_lpips(output)

            output_a, output_b = output[:self.batch_size], output[self.
                                                                  batch_size:]

            with torch.no_grad():
                dist = self.lpips_model(output_a, output_b)

            distances.append(dist.cpu() * (1 / self.epsilon**2))

            if verbose:
                progress.step()

        if verbose:
            progress.write('PPL: Evaluated!', step=False)
            progress.close()

        distances = torch.cat(distances, dim=0).numpy()
        lo = np.percentile(distances, 1, interpolation='lower')
        hi = np.percentile(distances, 99, interpolation='higher')
        filtered_distances = np.extract(
            np.logical_and(lo <= distances, distances <= hi), distances)
        return float(np.mean(filtered_distances))
Ejemplo n.º 5
0
    def __init__(self,
                 G,
                 prior_generator,
                 dataset,
                 device=None,
                 num_samples=50000,
                 fid_model=None,
                 fid_size=None,
                 truncation_psi=None,
                 truncation_cutoff=None,
                 reals_batch_size=None,
                 reals_data_workers=0,
                 verbose=True):
        device_ids = []
        if isinstance(G, torch.nn.DataParallel):
            device_ids = G.device_ids
        G = utils.unwrap_module(G)
        assert isinstance(G, models.Generator)
        assert isinstance(prior_generator, utils.PriorGenerator)
        if device is None:
            device = next(G.parameters()).device
        else:
            device = torch.device(device)
        assert torch.device(prior_generator.device) == device, \
            'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
            'is not the same as the specified (or infered from the model)' + \
            'device ({}) for the PPL evaluation.'.format(device)
        G.eval().to(device)
        if device_ids:
            G = torch.nn.DataParallel(G, device_ids=device_ids)
        self.G = G
        self.prior_generator = prior_generator
        self.device = device
        self.num_samples = num_samples
        self.batch_size = self.prior_generator.batch_size
        if fid_model is None:
            warnings.warn(
                'Using default fid model metric based on Inception V3. ' + \
                'This metric will only work on image data where values are in ' + \
                'the range [-1, 1], please specify another module if you want ' + \
                'to use other kinds of data formats.'
            )
            fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1)
            if device_ids:
                fid_model = torch.nn.DataParallel(fid_model, device_ids)
        self.fid_model = fid_model.eval().to(device)
        self.fid_size = fid_size

        dataset = _TruncatedDataset(dataset, self.num_samples)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=reals_batch_size or self.batch_size,
            num_workers=reals_data_workers
        )
        features = []
        self.labels = []

        if verbose:
            progress = utils.ProgressWriter(
                np.ceil(self.num_samples / (reals_batch_size or self.batch_size)))
            progress.write('FID: Gathering statistics for reals...', step=False)

        for batch in dataloader:
            data = batch
            if isinstance(batch, (tuple, list)):
                data = batch[0]
                if len(batch) > 1:
                    self.labels.append(batch[1])
            data = self._scale_for_fid(data).to(self.device)
            with torch.no_grad():
                batch_features = self.fid_model(data)
            batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
            features.append(batch_features.cpu())
            progress.step()

        if verbose:
            progress.write('FID: Statistics for reals gathered!', step=False)
            progress.close()

        features = torch.cat(features, dim=0).numpy()

        self.mu_real = np.mean(features, axis=0)
        self.sigma_real = np.cov(features, rowvar=False)
        self.truncation_psi = truncation_psi
        self.truncation_cutoff = truncation_cutoff
Ejemplo n.º 6
0
def run_labeling(G, C, tags, args):
    threshold = 0.5
    latent_size, label_size = G.latent_size, G.label_size
    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    G.to(device)

    if args.truncation_psi != 1:
      G.set_truncation(truncation_psi=args.truncation_psi)

    if len(args.gpu) > 1:
        warnings.warn(
            'Noise can not be randomized based on the seed ' + \
            'when using more than 1 GPU device. Noise will ' + \
            'now be randomized from default random state.'
        )
        G.random_noise()
        G = torch.nn.DataParallel(G, device_ids=args.gpu)
    else:
        noise_reference = G.static_noise()

    rnd = np.random.RandomState(args.seed)

    noise_tensors = None
    if len(args.gpu) <= 1:
      noise_tensors = [[] for _ in noise_reference]
      for i, ref in enumerate(noise_reference):
          noise_tensors[i].append(torch.from_numpy(rnd.randn(*ref.size()[1:])))
      noise_tensors = [
          torch.stack(noise, dim=0).to(device=device, dtype=torch.float32)
          for noise in noise_tensors
      ]
      G.static_noise(noise_tensors=noise_tensors)

    progress = utils.ProgressWriter(args.iter)
    progress.write('Generating images...', step=False)

    qlatents_data = []
    dlatents_data = []
    labels_data = []
    for i in range(0, args.iter):
        qlatents = torch.from_numpy(rnd.randn(args.batch_size, latent_size)).to(device=device, dtype=torch.float32)
        with torch.no_grad():
            generated, dlatents = G(latents=qlatents, return_dlatents=True)
            images = generated.clamp_(min=0, max=1)
            # 299 is the input size of the model
            images = F.interpolate(images, size=(299, 299), mode='bilinear')
            ort_inputs = {C.get_inputs()[0].name: images.cpu().numpy()}
            predicted_labels = C.run(None, ort_inputs)
            # transform labels to dict
            labels = transform_labels(tags, predicted_labels[0])
            # [image] = utils.tensor_to_PIL(generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
            # image.save(os.path.join(args.output, 'seed%05d-resized.png' % i))

            # store the result
            qlatents_data = qlatents_data + qlatents.detach().cpu().numpy().tolist()
            dlatents_data = dlatents_data + dlatents.detach().cpu().numpy().tolist()
            labels_data = labels_data + labels

            progress.step()

    out_path = os.path.join(args.output, 'result.pkl')
    with open(out_path, 'wb') as f:
      pickle.dump((qlatents_data, dlatents_data, labels_data), f)

    progress.write('Done!', step=False)
    progress.close()
def interpolate(G, args):
    latent_size, label_size = G.latent_size, G.label_size
    device = torch.device(args.gpu[0] if args.gpu else 'cpu')
    if device.index is not None:
        torch.cuda.set_device(device.index)
    G.to(device)
    if args.truncation_psi != 1:
        G.set_truncation(truncation_psi=args.truncation_psi)
    if len(args.gpu) > 1:
        warnings.warn(
            'Noise can not be randomized based on the seed ' + \
            'when using more than 1 GPU device. Noise will ' + \
            'now be randomized from default random state.'
        )
        G.random_noise()
        G = torch.nn.DataParallel(G, device_ids=args.gpu)
    else:
        noise_reference = G.static_noise()

    noise_tensors = None
    if noise_tensors is not None:
        G.static_noise(noise_tensors=noise_tensors)

    def gen_latent(seed):
        return torch.from_numpy(np.random.RandomState(seed).randn(latent_size))

    def interpolate_generator(seed, step):
        if len(args.gpu) <= 1:
            noise_tensors = [[] for _ in noise_reference]
            for i, ref in enumerate(noise_reference):
                noise_tensors[i].append(
                    torch.from_numpy(
                        np.random.RandomState(seed).randn(*ref.size()[1:])))
            noise_tensors = [
                torch.stack(noise, dim=0).to(device=device,
                                             dtype=torch.float32)
                for noise in noise_tensors
            ]
        else:
            noise_tensors = None

        latent1 = gen_latent(seed)
        latent2 = gen_latent(seed + 1)
        d_latents = (latent2 - latent1) / float((step - 1))
        for i in range(step):
            yield latent1 + i * d_latents, noise_tensors

    progress = utils.ProgressWriter(len(args.seeds) * args.interpolation_step)
    progress.write('Generating images...', step=False)

    if args.interpolation_step:
        fourcc_ = cv2.VideoWriter_fourcc(*'avc1')
        video_writer = cv2.VideoWriter(filename=os.path.join(
            args.output, args.animation_filename),
                                       fourcc=fourcc_,
                                       fps=args.animation_fps,
                                       apiPreference=cv2.CAP_ANY,
                                       frameSize=args.animation_frame_size)

    for seed in args.seeds:
        for i, (latent, noise_tensors) in enumerate(
                interpolate_generator(seed, args.interpolation_step)):
            latents = torch.stack([latent], dim=0).to(device=device,
                                                      dtype=torch.float32)

            if noise_tensors is not None:
                G.static_noise(noise_tensors=noise_tensors)

            with torch.no_grad():
                generated = G(latents, labels=None)
            images = utils.tensor_to_PIL(generated,
                                         pixel_min=args.pixel_min,
                                         pixel_max=args.pixel_max)
            for img in images:  # args.seeds[i: i + args.batch_size]
                img.save(os.path.join(args.output, f'{seed}_{i}.png'))
                if args.interpolation_step:
                    img = np.array(img)
                    # Convert RGB to BGR
                    img = img[:, :, ::-1].copy()
                    video_writer.write(img)

                progress.step()

    progress.write('Done!', step=False)
    progress.close()