def get_style_attribute_pairs():  # this function is written with horovod to accelerate the extraction (by n_gpu times)
    import horovod.torch as hvd
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())
    torch.manual_seed(hvd.rank() * 999 + 1)
    if hvd.rank() == 0:
        print(' * Extracting style-attribute pairs...')
    # build and load the pre-trained attribute predictor on CelebA-HQ
    predictor = model.get_pretrained('attribute-predictor').to(device)
    # build and load the pre-trained anycost generator
    generator = model.get_pretrained('generator', config).to(device)

    predictor.eval()
    generator.eval()

    # randomly generate images and feed them to the predictor
    # configs from https://github.com/genforce/interfacegan
    randomized_noise = False
    truncation_psi = 0.7
    batch_size = 16
    n_batch = 500000 // (batch_size * hvd.size())

    styles = []
    attributes = []

    mean_style = generator.mean_style(100000).view(1, 1, -1)
    assert space in ['w', 'w+', 'z']
    for _ in tqdm(range(n_batch), disable=hvd.rank() != 0):
        if space in ['w', 'z']:
            z = torch.randn(batch_size, 1, generator.style_dim, device=device)
        else:
            z = torch.randn(batch_size, generator.n_style, generator.style_dim, device=device)
        images, w = generator(z,
                              return_styles=True,
                              truncation=truncation_psi,
                              truncation_style=mean_style,
                              input_is_style=False,
                              randomize_noise=randomized_noise)
        images = F.interpolate(images.clamp(-1, 1), size=256, mode='bilinear', align_corners=True)
        attr = predictor(images)
        # move to cpu to save memory
        if space == 'w+':
            styles.append(w.to('cpu'))
        elif space == 'w':
            styles.append(w.mean(1, keepdim=True).to('cpu'))  # originally duplicated
        else:
            styles.append(z.to('cpu'))
        attributes.append(attr.to('cpu'))

    styles = torch.cat(styles, dim=0)
    attributes = torch.cat(attributes, dim=0)

    styles = hvd.allgather(styles, name='styles')
    attributes = hvd.allgather(attributes, name='attributes')
    if hvd.rank() == 0:
        print(styles.shape, attributes.shape)
        torch.save(attributes, 'attributes_{}.pt'.format(config))
        torch.save(styles, 'styles_{}.pt'.format(config))
Beispiel #2
0
    def load_assets(self):
        self.anycost_channel = 1.0
        self.anycost_resolution = 1024

        # build the generator
        self.generator = model.get_pretrained('generator', config).to(device)
        self.generator.eval()
        self.mean_latent = self.generator.mean_style(10000)

        # select only a subset of the directions to use
        '''
        possible keys:
        ['00_5_o_Clock_Shadow', '01_Arched_Eyebrows', '02_Attractive', '03_Bags_Under_Eyes', '04_Bald', '05_Bangs',
         '06_Big_Lips', '07_Big_Nose', '08_Black_Hair', '09_Blond_Hair', '10_Blurry', '11_Brown_Hair', '12_Bushy_Eyebrows',
         '13_Chubby', '14_Double_Chin', '15_Eyeglasses', '16_Goatee', '17_Gray_Hair', '18_Heavy_Makeup', '19_High_Cheekbones',
         '20_Male', '21_Mouth_Slightly_Open', '22_Mustache', '23_Narrow_Eyes', '24_No_Beard', '25_Oval_Face', '26_Pale_Skin',
         '27_Pointy_Nose', '28_Receding_Hairline', '29_Rosy_Cheeks', '30_Sideburns', '31_Smiling', '32_Straight_Hair',
         '33_Wavy_Hair', '34_Wearing_Earrings', '35_Wearing_Hat', '36_Wearing_Lipstick', '37_Wearing_Necklace',
         '38_Wearing_Necktie', '39_Young']
        '''

        direction_map = {
            'smiling': '31_Smiling',
            'young': '39_Young',
            'wavy hair': '33_Wavy_Hair',
            'gray hair': '17_Gray_Hair',
            'blonde hair': '09_Blond_Hair',
            'eyeglass': '15_Eyeglasses',
            'mustache': '22_Mustache',
        }

        boundaries = model.get_pretrained('boundary', config)
        self.direction_dict = dict()
        for k, v in direction_map.items():
            self.direction_dict[k] = boundaries[v].view(1, 1, -1)

        # 3. prepare the latent code and original images
        file_names = sorted(os.listdir(os.path.join(assets_dir, 'input_images')))
        self.file_names = [f for f in file_names if f.endswith('.png') or f.endswith('.jpg')]
        self.latent_code_list = []
        self.org_image_list = []

        for fname in self.file_names:
            org_image = np.asarray(Image.open(os.path.join(assets_dir, 'input_images', fname)).convert('RGB'))
            latent_code = torch.from_numpy(
                np.load(os.path.join(assets_dir, 'projected_latents',
                                     fname.replace('.jpg', '.npy').replace('.png', '.npy'))))
            self.org_image_list.append(org_image)
            self.latent_code_list.append(latent_code.view(1, -1, 512))

        # set up the initial display
        self.sample_idx = 0
        self.org_latent_code = self.latent_code_list[self.sample_idx]

        # input kwargs for the generator
        self.input_kwargs = {'styles': self.org_latent_code, 'noise': None, 'randomize_noise': False,
                             'input_is_style': True}
def compute_attribute_consistency(g, sub_g, n_sample, batch_size):
    attr_pred = model.get_pretrained('attribute-predictor').to(device)
    attr_pred.eval()

    n_batch = math.ceil(n_sample * 1. / batch_size / hvd.size())

    accs = 0
    mean_style = g.mean_style(10000)
    with torch.no_grad():
        for _ in tqdm(range(n_batch), disable=hvd.rank() != 0):
            noise = g.make_noise()

            latent = torch.randn(args.batch_size, 1, 512, device=device)
            kwargs = {
                'styles': latent,
                'truncation': 0.5,
                'truncation_style': mean_style,
                'noise': noise
            }
            img = g(**kwargs)[0].clamp(min=-1., max=1.)
            sub_img = sub_g(**kwargs)[0].clamp(min=-1., max=1.)
            img = adaptive_resize(img, 256)
            sub_img = adaptive_resize(sub_img, 256)

            attr = attr_pred(img).view(-1, 40, 2)
            sub_attr = attr_pred(sub_img).view(-1, 40, 2)

            attr = torch.argmax(attr, dim=2)
            sub_attr = torch.argmax(sub_attr, dim=2)
            this_acc = (attr == sub_attr).float().mean(0)
            accs = accs + this_acc / n_batch
    accs = hvd.allreduce(accs)
    return accs
Beispiel #4
0
    parser.add_argument("--config",
                        type=str,
                        help='config name of the pretrained generator')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--n_sample', type=int, default=50000)
    parser.add_argument('--inception', type=str, default=None, required=True)

    parser.add_argument('--channel_ratio', type=float, default=None)
    parser.add_argument('--target_res', type=int, default=None)

    args = parser.parse_args()

    hvd.init()
    torch.cuda.set_device(hvd.local_rank())

    generator = model.get_pretrained('generator', args.config).to(device)
    generator.eval()

    # set sub-generator
    if args.channel_ratio:
        from model.dynamic_channel import set_uniform_channel_ratio, CHANNEL_CONFIGS

        assert args.channel_ratio in CHANNEL_CONFIGS
        set_uniform_channel_ratio(generator, args.channel_ratio)

    if args.target_res is not None:
        generator.target_res = args.target_res

    # compute the flops of the generator (is possible)
    if hvd.rank() == 0:
        try:
Beispiel #5
0
                        default=1e-4,
                        help="epsilon for numerical stability")
    parser.add_argument("--crop",
                        action="store_true",
                        help="apply center crop to the images")
    parser.add_argument("--sampling",
                        default="end",
                        choices=["end", "full"],
                        help="set endpoint sampling method")

    args = parser.parse_args()

    hvd.init()
    torch.cuda.set_device(hvd.local_rank())

    generator = model.get_pretrained('generator', args.config).to(device)
    generator.eval()

    if args.channel_ratio:
        from model.dynamic_channel import set_uniform_channel_ratio
        set_uniform_channel_ratio(generator, args.channel_ratio)

    if args.target_res is not None:
        generator.target_res = args.target_res

    ppl = compute_ppl(generator,
                      n_sample=args.n_sample,
                      batch_size=args.batch_size,
                      space=args.space,
                      sampling=args.sampling,
                      eps=args.eps,