def __init__(self, args): super().__init__() self.device = args.device self.args = args self.initial_lr = self.args.lr self.lr = self.args.lr self.lr_rampdown_length = 0.4 # Init global step self.global_step = 0 self.step_start = 0 # Init Generator self.g = style_gan_2.PretrainedGenerator1024().eval().to(self.device) for param in self.g.parameters(): param.requires_grad = False # Define audio encoder self.audio_encoder = models.AudioExpressionNet3(args.T).to(self.device).train() # Print # parameters print("# params {} (trainable {})".format( utils.count_params(self.audio_encoder), utils.count_trainable_params(self.audio_encoder) )) # Select optimizer and loss criterion self.optim = torch.optim.Adam(self.audio_encoder.parameters(), lr=self.lr) self.lpips = PerceptualLoss(model='net-lin', net='vgg', gpu_id=args.gpu) if self.args.cont or self.args.test: path = self.args.model_path self.load(path) self.step_start = self.global_step # Mouth mask for image mouth_mask = torch.load('saves/pre-trained/tagesschau_mouth_mask_5std.pt').to(self.device) # eyes_mask = torch.load('saves/pre-trained/tagesschau_eyes_mask_3std.pt').to(self.device) self.image_mask = mouth_mask.clamp(0., 1.) # self.image_mask = (mouth_mask + eyes_mask).clamp(0., 1.) # MSE mask self.mse_mask = torch.load('saves/pre-trained/mse_mask_var+1.pt')[4:8].unsqueeze(0).to(self.device) # Set up tensorboard if not self.args.debug and not self.args.test: tb_dir = self.args.save_dir # self.writer = SummaryWriter(tb_dir) self.train_writer = utils.HparamWriter(tb_dir + 'train/') self.val_writer = utils.HparamWriter(tb_dir + 'val/') self.train_writer.log_hyperparams(self.args) print(f"Logging run to {tb_dir}") # Create save dir os.makedirs(self.args.save_dir + 'models', exist_ok=True) os.makedirs(self.args.save_dir + 'sample', exist_ok=True)
def __init__(self, g, num_steps=1000, initial_learning_rate=0.1, initial_noise_factor=0.05, verbose=True): self.num_steps = num_steps self.n_mean_latent = 10000 self.initial_lr = initial_learning_rate self.initial_noise_factor = initial_noise_factor self.lr_rampdown_length = 0.25 self.lr_rampup_length = 0.05 self.noise_ramp_length = 0.75 self.regularize_noise_weight = 1e5 self.verbose = verbose self.latent_expr = None self.lpips = None self.target_images = None self.imag_gen = None self.loss = None self.lr = None self.cur_step = None self.g_ema = g self.device = next(g.parameters()).device # Find latent stats self._info(('Finding W midpoint and stddev using %d samples...' % self.n_mean_latent)) torch.manual_seed(123) with torch.no_grad(): noise_sample = torch.randn(self.n_mean_latent, 512, device=self.device) latent_out = self.g_ema.style(noise_sample) self.latent_mean = latent_out.mean(0) self.latent_std = ((latent_out - self.latent_mean).pow(2).sum() / self.n_mean_latent)**0.5 self._info('std = {}'.format(self.latent_std)) self.latent_in = self.latent_mean.detach().clone().unsqueeze(0) self.latent_in = self.latent_in.repeat(self.g_ema.n_latent, 1) self.latent_in.requires_grad = True # Find noise inputs. self.noises = [noise.to(self.device) for noise in g.noises] # Init optimizer self.opt = torch.optim.Adam([self.latent_in] + self.noises, lr=self.initial_lr) # Init loss function self.lpips = PerceptualLoss(model='net-lin', net='vgg').to(self.device)
def __init__(self, args): super().__init__() self.device = args.device self.args = args self.initial_lr = self.args.lr self.lr = self.args.lr self.lr_rampdown_length = 0.3 self.lr_rampup_length = 0.1 # Load generator self.g = style_gan_2.PretrainedGenerator1024().eval().to(self.device) for param in self.g.parameters(): param.requires_grad = False self.latent_avg = self.g.latent_avg.repeat( 18, 1).unsqueeze(0).to(self.device) # Init global step self.global_step = 0 # Define encoder model self.e = resnetEncoder().train().to(self.device) # Print # parameters print("# params {} (trainable {})".format( utils.count_params(self.e), utils.count_trainable_params(self.e) )) # Select optimizer and loss criterion self.optim = torch.optim.Adam(self.e.parameters(), lr=self.initial_lr) self.criterion = PerceptualLoss( model='net-lin', net='vgg', gpu_id=args.gpu) # Load model and optimizer checkpoint if self.args.cont or self.args.test or self.args.run: path = self.args.model_path self.load(path) # Set up tensorboard if not self.args.debug and not self.args.test and not self.args.run: tb_dir = 'tensorboard_runs/encode_stylegan/' + \ self.args.save_dir.split('/')[-2] self.writer = SummaryWriter(tb_dir) print(f"Logging run to {tb_dir}") # Create save dir os.makedirs(self.args.save_dir + 'models', exist_ok=True)
def main(hparams): # set up perceptual loss device = 'cuda:0' percept = PerceptualLoss( model="net-lin", net="vgg", use_gpu=device.startswith("cuda") ) utils.print_hparams(hparams) # get inputs xs_dict = model_input(hparams) estimators = utils.get_estimators(hparams) utils.setup_checkpointing(hparams) measurement_losses, l2_losses, lpips_scores, z_hats = utils.load_checkpoints(hparams) x_hats_dict = {model_type : {} for model_type in hparams.model_types} x_batch_dict = {} A = utils.get_A(hparams) noise_batch = hparams.noise_std * np.random.standard_t(2, size=(hparams.batch_size, hparams.num_measurements)) for key, x in xs_dict.items(): if not hparams.not_lazy: # If lazy, first check if the image has already been # saved before by *all* estimators. If yes, then skip this image. save_paths = utils.get_save_paths(hparams, key) is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()]) if is_saved: continue x_batch_dict[key] = x if len(x_batch_dict) < hparams.batch_size: continue # Reshape input x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()] x_batch = np.concatenate(x_batch_list) # Construct noise and measurements y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams) # Construct estimates using each estimator for model_type in hparams.model_types: estimator = estimators[model_type] x_hat_batch, z_hat_batch, m_loss_batch = estimator(A, y_batch, hparams) for i, key in enumerate(x_batch_dict.keys()): x = xs_dict[key] y_train = y_batch[i] x_hat = x_hat_batch[i] # Save the estimate x_hats_dict[model_type][key] = x_hat # Compute and store measurement and l2 loss measurement_losses[model_type][key] = m_loss_batch[key] l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x) lpips_scores[model_type][key] = utils.get_lpips_score(percept, x_hat, x, hparams.image_shape) z_hats[model_type][key] = z_hat_batch[i] print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict))) # Checkpointing if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0): utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams) x_hats_dict = {model_type : {} for model_type in hparams.model_types} print('\nProcessed and saved first ', key+1, 'images\n') x_batch_dict = {} # Final checkpoint if hparams.save_images: utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams) print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))) if hparams.print_stats: for model_type in hparams.model_types: print(model_type) measurement_loss_list = list(measurement_losses[model_type].values()) l2_loss_list = list(l2_losses[model_type].values()) mean_m_loss = np.mean(measurement_loss_list) mean_l2_loss = np.mean(l2_loss_list) print('mean measurement loss = {0}'.format(mean_m_loss)) print('mean l2 loss = {0}'.format(mean_l2_loss)) if hparams.image_matrix > 0: utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams) # Warn the user that some things were not processsed if len(x_batch_dict) > 0: print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))) print('Consider rerunning lazily with a smaller batch size.')