def get_style_attribute_pairs(): # this function is written with horovod to accelerate the extraction (by n_gpu times) import horovod.torch as hvd hvd.init() torch.cuda.set_device(hvd.local_rank()) torch.manual_seed(hvd.rank() * 999 + 1) if hvd.rank() == 0: print(' * Extracting style-attribute pairs...') # build and load the pre-trained attribute predictor on CelebA-HQ predictor = model.get_pretrained('attribute-predictor').to(device) # build and load the pre-trained anycost generator generator = model.get_pretrained('generator', config).to(device) predictor.eval() generator.eval() # randomly generate images and feed them to the predictor # configs from https://github.com/genforce/interfacegan randomized_noise = False truncation_psi = 0.7 batch_size = 16 n_batch = 500000 // (batch_size * hvd.size()) styles = [] attributes = [] mean_style = generator.mean_style(100000).view(1, 1, -1) assert space in ['w', 'w+', 'z'] for _ in tqdm(range(n_batch), disable=hvd.rank() != 0): if space in ['w', 'z']: z = torch.randn(batch_size, 1, generator.style_dim, device=device) else: z = torch.randn(batch_size, generator.n_style, generator.style_dim, device=device) images, w = generator(z, return_styles=True, truncation=truncation_psi, truncation_style=mean_style, input_is_style=False, randomize_noise=randomized_noise) images = F.interpolate(images.clamp(-1, 1), size=256, mode='bilinear', align_corners=True) attr = predictor(images) # move to cpu to save memory if space == 'w+': styles.append(w.to('cpu')) elif space == 'w': styles.append(w.mean(1, keepdim=True).to('cpu')) # originally duplicated else: styles.append(z.to('cpu')) attributes.append(attr.to('cpu')) styles = torch.cat(styles, dim=0) attributes = torch.cat(attributes, dim=0) styles = hvd.allgather(styles, name='styles') attributes = hvd.allgather(attributes, name='attributes') if hvd.rank() == 0: print(styles.shape, attributes.shape) torch.save(attributes, 'attributes_{}.pt'.format(config)) torch.save(styles, 'styles_{}.pt'.format(config))
def load_assets(self): self.anycost_channel = 1.0 self.anycost_resolution = 1024 # build the generator self.generator = model.get_pretrained('generator', config).to(device) self.generator.eval() self.mean_latent = self.generator.mean_style(10000) # select only a subset of the directions to use ''' possible keys: ['00_5_o_Clock_Shadow', '01_Arched_Eyebrows', '02_Attractive', '03_Bags_Under_Eyes', '04_Bald', '05_Bangs', '06_Big_Lips', '07_Big_Nose', '08_Black_Hair', '09_Blond_Hair', '10_Blurry', '11_Brown_Hair', '12_Bushy_Eyebrows', '13_Chubby', '14_Double_Chin', '15_Eyeglasses', '16_Goatee', '17_Gray_Hair', '18_Heavy_Makeup', '19_High_Cheekbones', '20_Male', '21_Mouth_Slightly_Open', '22_Mustache', '23_Narrow_Eyes', '24_No_Beard', '25_Oval_Face', '26_Pale_Skin', '27_Pointy_Nose', '28_Receding_Hairline', '29_Rosy_Cheeks', '30_Sideburns', '31_Smiling', '32_Straight_Hair', '33_Wavy_Hair', '34_Wearing_Earrings', '35_Wearing_Hat', '36_Wearing_Lipstick', '37_Wearing_Necklace', '38_Wearing_Necktie', '39_Young'] ''' direction_map = { 'smiling': '31_Smiling', 'young': '39_Young', 'wavy hair': '33_Wavy_Hair', 'gray hair': '17_Gray_Hair', 'blonde hair': '09_Blond_Hair', 'eyeglass': '15_Eyeglasses', 'mustache': '22_Mustache', } boundaries = model.get_pretrained('boundary', config) self.direction_dict = dict() for k, v in direction_map.items(): self.direction_dict[k] = boundaries[v].view(1, 1, -1) # 3. prepare the latent code and original images file_names = sorted(os.listdir(os.path.join(assets_dir, 'input_images'))) self.file_names = [f for f in file_names if f.endswith('.png') or f.endswith('.jpg')] self.latent_code_list = [] self.org_image_list = [] for fname in self.file_names: org_image = np.asarray(Image.open(os.path.join(assets_dir, 'input_images', fname)).convert('RGB')) latent_code = torch.from_numpy( np.load(os.path.join(assets_dir, 'projected_latents', fname.replace('.jpg', '.npy').replace('.png', '.npy')))) self.org_image_list.append(org_image) self.latent_code_list.append(latent_code.view(1, -1, 512)) # set up the initial display self.sample_idx = 0 self.org_latent_code = self.latent_code_list[self.sample_idx] # input kwargs for the generator self.input_kwargs = {'styles': self.org_latent_code, 'noise': None, 'randomize_noise': False, 'input_is_style': True}
def compute_attribute_consistency(g, sub_g, n_sample, batch_size): attr_pred = model.get_pretrained('attribute-predictor').to(device) attr_pred.eval() n_batch = math.ceil(n_sample * 1. / batch_size / hvd.size()) accs = 0 mean_style = g.mean_style(10000) with torch.no_grad(): for _ in tqdm(range(n_batch), disable=hvd.rank() != 0): noise = g.make_noise() latent = torch.randn(args.batch_size, 1, 512, device=device) kwargs = { 'styles': latent, 'truncation': 0.5, 'truncation_style': mean_style, 'noise': noise } img = g(**kwargs)[0].clamp(min=-1., max=1.) sub_img = sub_g(**kwargs)[0].clamp(min=-1., max=1.) img = adaptive_resize(img, 256) sub_img = adaptive_resize(sub_img, 256) attr = attr_pred(img).view(-1, 40, 2) sub_attr = attr_pred(sub_img).view(-1, 40, 2) attr = torch.argmax(attr, dim=2) sub_attr = torch.argmax(sub_attr, dim=2) this_acc = (attr == sub_attr).float().mean(0) accs = accs + this_acc / n_batch accs = hvd.allreduce(accs) return accs
parser.add_argument("--config", type=str, help='config name of the pretrained generator') parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--n_sample', type=int, default=50000) parser.add_argument('--inception', type=str, default=None, required=True) parser.add_argument('--channel_ratio', type=float, default=None) parser.add_argument('--target_res', type=int, default=None) args = parser.parse_args() hvd.init() torch.cuda.set_device(hvd.local_rank()) generator = model.get_pretrained('generator', args.config).to(device) generator.eval() # set sub-generator if args.channel_ratio: from model.dynamic_channel import set_uniform_channel_ratio, CHANNEL_CONFIGS assert args.channel_ratio in CHANNEL_CONFIGS set_uniform_channel_ratio(generator, args.channel_ratio) if args.target_res is not None: generator.target_res = args.target_res # compute the flops of the generator (is possible) if hvd.rank() == 0: try:
default=1e-4, help="epsilon for numerical stability") parser.add_argument("--crop", action="store_true", help="apply center crop to the images") parser.add_argument("--sampling", default="end", choices=["end", "full"], help="set endpoint sampling method") args = parser.parse_args() hvd.init() torch.cuda.set_device(hvd.local_rank()) generator = model.get_pretrained('generator', args.config).to(device) generator.eval() if args.channel_ratio: from model.dynamic_channel import set_uniform_channel_ratio set_uniform_channel_ratio(generator, args.channel_ratio) if args.target_res is not None: generator.target_res = args.target_res ppl = compute_ppl(generator, n_sample=args.n_sample, batch_size=args.batch_size, space=args.space, sampling=args.sampling, eps=args.eps,