class PULSE(torch.nn.Module): def __init__(self, cache_dir, verbose=True): super(PULSE, self).__init__() self.synthesis = G_synthesis().cuda() self.verbose = verbose cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) if self.verbose: print("Loading Synthesis Network") with open_url( "https://drive.google.com/uc?id=1OjFOETBvDE3FOd1OmTP-ujSh3hCi7rgQ", cache_dir=cache_dir, verbose=verbose) as f: self.synthesis.load_state_dict(torch.load(f)) for param in self.synthesis.parameters(): param.requires_grad = False self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) if Path("gaussian_fit.pt").exists(): self.gaussian_fit = torch.load("gaussian_fit.pt") else: if self.verbose: print("\tLoading Mapping Network") mapping = G_mapping().cuda() with open_url( "https://drive.google.com/uc?id=1eVGz3JTo7QuRUK7dqdlf5UkCcFJyGrUQ", cache_dir=cache_dir, verbose=verbose) as f: mapping.load_state_dict(torch.load(f)) if self.verbose: print("\tRunning Mapping Network") with torch.no_grad(): torch.manual_seed(0) latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) self.gaussian_fit = { "mean": latent_out.mean(0), "std": latent_out.std(0) } torch.save(self.gaussian_fit, "gaussian_fit.pt") if self.verbose: print("\tSaved \"gaussian_fit.pt\"") def forward(self, ref_im, seed, loss_str, eps, noise_type, num_trainable_noise_layers, tile_latent, bad_noise_layers, opt_name, learning_rate, steps, lr_schedule, save_intermediate, **kwargs): if seed: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True batch_size = ref_im.shape[0] # Generate latent tensor if (tile_latent): latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda') else: latent = torch.randn((batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda') # Generate list of noise tensors noise = [] # stores all of the noise tensors noise_vars = [] # stores the noise tensors that we want to optimize on for i in range(18): # dimension of the ith noise tensor res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2)) if (noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]): new_noise = torch.zeros(res, dtype=torch.float, device='cuda') new_noise.requires_grad = False elif (noise_type == 'fixed'): new_noise = torch.randn(res, dtype=torch.float, device='cuda') new_noise.requires_grad = False elif (noise_type == 'trainable'): new_noise = torch.randn(res, dtype=torch.float, device='cuda') if (i < num_trainable_noise_layers): new_noise.requires_grad = True noise_vars.append(new_noise) else: new_noise.requires_grad = False else: raise Exception("unknown noise type") noise.append(new_noise) var_list = [latent] + noise_vars opt_dict = { 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'sgdm': partial(torch.optim.SGD, momentum=0.9), 'adamax': torch.optim.Adamax } opt_func = opt_dict[opt_name] opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate) schedule_dict = { 'fixed': lambda x: 1, 'linear1cycle': lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10, 'linear1cycledrop': lambda x: (9 * (1 - np.abs(x / (0.9 * steps) - 1 / 2) * 2) + 1) / 10 if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) / (0.1 * steps) * (1 / 1000 - 1 / 10), } schedule_func = schedule_dict[lr_schedule] scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func) loss_builder = LossBuilder(ref_im, loss_str, eps).cuda() min_loss = np.inf min_l2 = np.inf best_summary = "" start_t = time.time() gen_im = None if self.verbose: print("Optimizing") for j in range(steps): opt.opt.zero_grad() # Duplicate latent in case tile_latent = True if (tile_latent): latent_in = latent.expand(-1, 18, -1) else: latent_in = latent # Apply learned linear mapping to match latent distribution to that of the mapping network latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) # Normalize image to [0,1] instead of [-1,1] gen_im = (self.synthesis(latent_in, noise) + 1) / 2 # Calculate Losses loss, loss_dict = loss_builder(latent_in, gen_im) loss_dict['TOTAL'] = loss # Save best summary for log if (loss < min_loss): min_loss = loss best_summary = f'BEST ({j+1}) | ' + ' | '.join( [f'{x}: {y:.4f}' for x, y in loss_dict.items()]) best_im = gen_im.clone() loss_l2 = loss_dict['L2'] if (loss_l2 < min_l2): min_l2 = loss_l2 # Save intermediate HR and LR images if (save_intermediate): yield (best_im.cpu().detach().clamp(0, 1), loss_builder.D(best_im).cpu().detach().clamp(0, 1)) loss.backward() opt.step() scheduler.step() total_t = time.time() - start_t current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}' if self.verbose: print(best_summary + current_info) if (min_l2 <= eps): yield (gen_im.clone().cpu().detach().clamp(0, 1), loss_builder.D(best_im).cpu().detach().clamp(0, 1)) else: print( "Could not find a face that downscales correctly within epsilon" )
class PULSE(torch.nn.Module): def __init__(self, cache_dir, verbose=True): super().__init__() self.synthesis = G_synthesis().cuda() self.verbose = verbose cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) if self.verbose: print("Loading Synthesis Network") # download synthesis.py # https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8 self.synthesis.load_state_dict(torch.load("synthesis.pt")) # with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f: # self.synthesis.load_state_dict(torch.load(f)) for param in self.synthesis.parameters(): param.requires_grad = False self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) if Path("gaussian_fit.pt").exists(): self.gaussian_fit = torch.load("gaussian_fit.pt") else: if self.verbose: print("\tLoading Mapping Network") mapping = G_mapping().cuda() # download mapping.pt # https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k mapping.load_state_dict(torch.load("mapping.pt")) # with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f: # mapping.load_state_dict(torch.load(f)) if self.verbose: print("\tRunning Mapping Network") with torch.no_grad(): torch.manual_seed(0) latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) self.gaussian_fit = { "mean": latent_out.mean(0), "std": latent_out.std(0), } torch.save(self.gaussian_fit, "gaussian_fit.pt") if self.verbose: print('\tSaved "gaussian_fit.pt"') def forward( self, ref_im, loss_str, eps, noise_type, num_trainable_noise_layers, tile_latent, bad_noise_layers, opt_name, learning_rate, steps, lr_schedule, save_intermediate, seed=0, var_list_initial_values=None, step_postprocess=None, psi=1.0, **kwargs, ): if seed: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True batch_size = ref_im.shape[0] # Generate latent tensor if tile_latent: latent = torch.randn( (batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda", ) else: latent = torch.randn( (batch_size, 18, 512), dtype=torch.float, requires_grad=True, device="cuda", ) with torch.no_grad(): latent *= psi # Generate list of noise tensors noise = [] # stores all of the noise tensors noise_vars = [] # stores the noise tensors that we want to optimize on for i in range(18): # dimension of the ith noise tensor res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2)) if noise_type == "zero" or i in [ int(layer) for layer in bad_noise_layers.split(".") ]: new_noise = torch.zeros(res, dtype=torch.float, device="cuda") new_noise.requires_grad = False elif noise_type == "fixed": new_noise = torch.randn(res, dtype=torch.float, device="cuda") new_noise.requires_grad = False elif noise_type == "trainable": new_noise = torch.randn(res, dtype=torch.float, device="cuda") if i < num_trainable_noise_layers: new_noise.requires_grad = True noise_vars.append(new_noise) else: new_noise.requires_grad = False else: raise Exception("unknown noise type") noise.append(new_noise) var_list = [latent] + noise_vars if var_list_initial_values is not None: assert len(var_list) == len(var_list_initial_values) with torch.no_grad(): for var, initial_value in zip(var_list, var_list_initial_values): var.copy_(initial_value) opt_dict = { "sgd": torch.optim.SGD, "adam": torch.optim.Adam, "sgdm": partial(torch.optim.SGD, momentum=0.9), "adamax": torch.optim.Adamax, "custom": partial(torch.optim.AdamW, betas=(0.9, 0.99)), } opt_func = opt_dict[opt_name] if step_postprocess is None: opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate) else: opt = StepPostProcessOptimizer(opt_func, var_list, step_postprocess=step_postprocess, lr=learning_rate) schedule_dict = { "fixed": lambda x: 1, "linear1cycle": lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10, "linear1cycledrop": lambda x: (9 * (1 - np.abs(x / (0.9 * steps) - 1 / 2) * 2) + 1) / 10 if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) / (0.1 * steps) * (1 / 1000 - 1 / 10), } schedule_func = schedule_dict[lr_schedule] scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func) loss_builder = LossBuilder(ref_im, loss_str, eps).cuda() min_loss = np.inf min_l2 = np.inf best_summary = "" start_t = time.time() gen_im = None if self.verbose: print("Optimizing") for j in range(steps): opt.opt.zero_grad() # Duplicate latent in case tile_latent = True if tile_latent: latent_in = latent.expand(-1, 18, -1) else: latent_in = latent # Apply learned linear mapping to match latent distribution to that of the mapping network latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) # Normalize image to [0,1] instead of [-1,1] gen_im = (self.synthesis(latent_in, noise) + 1) / 2 # Calculate Losses loss, loss_dict = loss_builder(latent_in, gen_im) loss_dict["TOTAL"] = loss # Save best summary for log if loss < min_loss: min_loss = loss best_summary = f"BEST ({j+1}) | " + " | ".join( [f"{x}: {y:.4f}" for x, y in loss_dict.items()]) best_im = gen_im.clone() loss_l2 = loss_dict["L2"] if loss_l2 < min_l2: min_l2 = loss_l2 # Save intermediate HR and LR images if save_intermediate: yield dict( final=False, min_l2=min_l2, HR=best_im.cpu().detach().clamp(0, 1), LR=loss_builder.D(best_im).cpu().detach().clamp(0, 1), ) loss.backward() opt.step() scheduler.step() total_t = time.time() - start_t current_info = f" | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}" if self.verbose: print(best_summary + current_info) yield dict( final=True, min_l2=min_l2, success=min_l2 <= eps, HR=gen_im.clone().cpu().detach().clamp(0, 1), LR=loss_builder.D(best_im).cpu().detach().clamp(0, 1), var_list=var_list, loss_dict=loss_dict, ) # if min_l2 <= eps: # yield ( # gen_im.clone().cpu().detach().clamp(0, 1), # loss_builder.D(best_im).cpu().detach().clamp(0, 1), # ) # else: # print("Could not find a face that downscales correctly within epsilon") def var_list_to_latent_and_noise( self, var_list, noise_type, num_trainable_noise_layers, bad_noise_layers, seed=0, **kwargs, ): with torch.no_grad(): if seed: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True latent = var_list[0].clone() batch_size = latent.shape[0] noise = [] # stores all of the noise tensors noise_vars = [ ] # stores the noise tensors that we want to optimize on for i in range(18): # dimension of the ith noise tensor res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2)) if noise_type == "zero" or i in [ int(layer) for layer in bad_noise_layers.split(".") ]: new_noise = torch.zeros(res, dtype=torch.float, device="cuda") new_noise.requires_grad = False elif noise_type == "fixed": new_noise = torch.randn(res, dtype=torch.float, device="cuda") new_noise.requires_grad = False elif noise_type == "trainable": new_noise = torch.randn(res, dtype=torch.float, device="cuda") new_noise.requires_grad = False if i < num_trainable_noise_layers: noise_vars.append(new_noise) else: raise Exception("unknown noise type") noise.append(new_noise) assert len(var_list) - 1 == len(noise_vars) for noise_input, noise_var in zip(var_list[1:], noise_vars): assert noise_input.shape == noise_var.shape noise_var.copy_(noise_input.clone()) return latent, noise def synthesize(self, latent, noise, tile_latent, **kwargs): with torch.no_grad(): # Duplicate latent in case tile_latent = True if tile_latent: latent_in = latent.expand(-1, 18, -1) else: latent_in = latent # Apply learned linear mapping to match latent distribution to that of the mapping network latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) # Normalize image to [0,1] instead of [-1,1] gen_im = (self.synthesis(latent_in, noise) + 1) / 2 gen_im = gen_im.cpu().detach().clamp(0, 1) return gen_im