def __init__(self, cfg, generator, gt_loader): # Factorize weights. self.generator = generator self.gan_type = parse_gan_type(self.generator) self.layers, self.boundaries, self.values = factorize_weight( self.generator, cfg.layer_idx) # Set random seed. np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) # Prepare codes. codes = torch.randn(cfg.num_samples, self.generator.z_space_dim).cuda() if self.gan_type == 'pggan': codes = self.generator.layer0.pixel_norm(codes) elif self.gan_type in ['stylegan', 'stylegan2']: codes = self.generator.mapping(codes)['w'] codes = self.generator.truncation(codes, trunc_psi=cfg.trunc_psi, trunc_layers=cfg.trunc_layers) self.codes = codes.detach().cpu().numpy() # Generate visualization pages. self.distances = np.linspace(cfg.start_distance, cfg.end_distance, cfg.step) self.num_sam = cfg.num_samples self.num_sem = cfg.num_semantics self.gt_loader = gt_loader
def factorize_weight(generator, layer_idx='all'): """Factorizes the generator weight to get semantics boundaries. Args: generator: Generator to factorize. layer_idx: Indices of layers to interpret, especially for StyleGAN and StyleGAN2. (default: `all`) Returns: A tuple of (layers_to_interpret, semantic_boundaries, eigen_values). Raises: ValueError: If the generator type is not supported. """ # Get GAN type. gan_type = parse_gan_type(generator) # Get layers. if gan_type == 'pggan': layers = [0] elif gan_type in ['stylegan', 'stylegan2']: if layer_idx == 'all': layers = list(range(generator.num_layers)) else: layers = parse_indices(layer_idx, min_val=0, max_val=generator.num_layers - 1) # Factorize semantics from weight. weights = [] for idx in layers: layer_name = f'layer{idx}' if gan_type == 'stylegan2' and idx == generator.num_layers - 1: layer_name = f'output{idx // 2}' if gan_type == 'pggan': weight = generator.__getattr__(layer_name).weight weight = weight.flip(2, 3).permute(1, 0, 2, 3).flatten(1) elif gan_type in ['stylegan', 'stylegan2']: weight = generator.synthesis.__getattr__(layer_name).style.weight.T weights.append(weight.cpu().detach().numpy()) weight = np.concatenate(weights, axis=1).astype(np.float32) weight = weight / np.linalg.norm(weight, axis=0, keepdims=True) eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T)) return layers, eigen_vectors.T, eigen_values
def get_weights(generator, layer_idx='all', apply_norm=True): """Obtains weight matrix from specified generator and layer selection. Adapted from `factorize_weights` Args: generator: Generator to get. layer_idx: Indices of layers to interpret, especially for StyleGAN and StyleGAN2. (default: `all`) Returns: A weight matrix. Raises: ValueError: If the generator type is not supported. """ # Get GAN type. gan_type = parse_gan_type(generator) # Get layers. if gan_type in ['stylegan', 'stylegan2']: if layer_idx == 'all': layers = list(range(generator.num_layers)) else: layers = parse_indices(layer_idx, min_val=0, max_val=generator.num_layers - 1) # Factorize semantics from weight. weights = [] for idx in layers: layer_name = f'layer{idx}' if gan_type == 'stylegan2' and idx == generator.num_layers - 1: layer_name = f'output{idx // 2}' if gan_type in ['stylegan', 'stylegan2']: weight = generator.synthesis.__getattr__(layer_name).style.weight.T weights.append(weight.cpu().detach().numpy()) weight = np.concatenate(weights, axis=1).astype(np.float32) if apply_norm: weight = weight / np.linalg.norm(weight, axis=0, keepdims=True) # Q: is normalizing the weight values here necessary? return weight
def main(): """Main function (loop for StreamLit).""" st.title('Closed-Form Factorization of Latent Semantics in GANs') st.sidebar.title('Options') reset = st.sidebar.button('Reset') model_name = st.sidebar.selectbox( 'Model to Interpret', ['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256', 'pggan_celebahq1024']) model = get_model(model_name) gan_type = parse_gan_type(model) layer_idx = st.sidebar.selectbox( 'Layers to Interpret', ['all', '0-1', '2-5', '6-13']) layers, boundaries, eigen_values = factorize_model(model, layer_idx) num_semantics = st.sidebar.number_input( 'Number of semantics', value=10, min_value=0, max_value=None, step=1) steps = {sem_idx: 0 for sem_idx in range(num_semantics)} if gan_type == 'pggan': max_step = 5.0 elif gan_type == 'stylegan': max_step = 2.0 elif gan_type == 'stylegan2': max_step = 15.0 for sem_idx in steps: eigen_value = eigen_values[sem_idx] steps[sem_idx] = st.sidebar.slider( f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})', value=0.0, min_value=-max_step, max_value=max_step, step=0.04 * max_step if not reset else 0.0) image_placeholder = st.empty() button_placeholder = st.empty() try: base_codes = np.load(f'latent_codes/{model_name}_latents.npy') except FileNotFoundError: base_codes = sample(model, gan_type) state = SessionState.get(model_name=model_name, code_idx=0, codes=base_codes[0:1]) if state.model_name != model_name: state.model_name = model_name state.code_idx = 0 state.codes = base_codes[0:1] if button_placeholder.button('Random', key=0): state.code_idx += 1 if state.code_idx < base_codes.shape[0]: state.codes = base_codes[state.code_idx][np.newaxis] else: state.codes = sample(model, gan_type) code = state.codes.copy() for sem_idx, step in steps.items(): if gan_type == 'pggan': code += boundaries[sem_idx:sem_idx + 1] * step elif gan_type in ['stylegan', 'stylegan2']: code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step image = synthesize(model, gan_type, code) image_placeholder.image(image / 255.0)
def main(): """Main function.""" args = parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id os.makedirs(args.save_dir, exist_ok=True) # Factorize weights. generator = load_generator(args.model_name) gan_type = parse_gan_type(generator) layers, boundaries, values = factorize_weight(generator, args.layer_idx) # Set random seed. np.random.seed(args.seed) torch.manual_seed(args.seed) # Prepare codes. codes = torch.randn(args.num_samples, generator.z_space_dim).cuda() if gan_type == 'pggan': codes = generator.layer0.pixel_norm(codes) elif gan_type in ['stylegan', 'stylegan2']: codes = generator.mapping(codes)['w'] codes = generator.truncation(codes, trunc_psi=args.trunc_psi, trunc_layers=args.trunc_layers) codes = codes.detach().cpu().numpy() # Generate visualization pages. distances = np.linspace(args.start_distance, args.end_distance, args.step) num_sam = args.num_samples num_sem = args.num_semantics vizer_1 = HtmlPageVisualizer(num_rows=num_sem * (num_sam + 1), num_cols=args.step + 1, viz_size=args.viz_size) vizer_2 = HtmlPageVisualizer(num_rows=num_sam * (num_sem + 1), num_cols=args.step + 1, viz_size=args.viz_size) headers = [''] + [f'Distance {d:.2f}' for d in distances] vizer_1.set_headers(headers) vizer_2.set_headers(headers) for sem_id in range(num_sem): value = values[sem_id] vizer_1.set_cell(sem_id * (num_sam + 1), 0, text=f'Semantic {sem_id:03d}<br>({value:.3f})', highlight=True) for sam_id in range(num_sam): vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, 0, text=f'Sample {sam_id:03d}') for sam_id in range(num_sam): vizer_2.set_cell(sam_id * (num_sem + 1), 0, text=f'Sample {sam_id:03d}', highlight=True) for sem_id in range(num_sem): value = values[sem_id] vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, 0, text=f'Semantic {sem_id:03d}<br>({value:.3f})') for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False): code = codes[sam_id:sam_id + 1] for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False): boundary = boundaries[sem_id:sem_id + 1] for col_id, d in enumerate(distances, start=1): temp_code = code.copy() if gan_type == 'pggan': temp_code += boundary * d image = generator(to_tensor(temp_code))['image'] elif gan_type in ['stylegan', 'stylegan2']: temp_code[:, layers, :] += boundary * d image = generator.synthesis(to_tensor(temp_code))['image'] image = postprocess(image)[0] vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, col_id, image=image) vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, col_id, image=image) prefix = (f'{args.model_name}_' f'N{num_sam}_K{num_sem}_L{args.layer_idx}_seed{args.seed}') vizer_1.save(os.path.join(args.save_dir, f'{prefix}_sample_first.html')) vizer_2.save(os.path.join(args.save_dir, f'{prefix}_semantic_first.html'))