class StyleGANInverter(object): """Defines the class for StyleGAN inversion. Even having the encoder, the output latent code is not good enough to recover the target image satisfyingly. To this end, this class optimize the latent code based on gradient descent algorithm. In the optimization process, following loss functions will be considered: (1) Pixel-wise reconstruction loss. (required) (2) Perceptual loss. (optional, but recommended) (3) Regularization loss from encoder. (optional, but recommended for in-domain inversion) NOTE: The encoder can be missing for inversion, in which case the latent code will be randomly initialized and the regularization loss will be ignored. """ def __init__(self, model_name, learning_rate=1e-2, iteration=100, reconstruction_loss_weight=1.0, perceptual_loss_weight=5e-5, regularization_loss_weight=2.0, logger=None): """Initializes the inverter. NOTE: Only Adam optimizer is supported in the optimization process. Args: model_name: Name of the model on which the inverted is based. The model should be first registered in `models/model_settings.py`. logger: Logger to record the log message. learning_rate: Learning rate for optimization. (default: 1e-2) iteration: Number of iterations for optimization. (default: 100) reconstruction_loss_weight: Weight for reconstruction loss. Should always be a positive number. (default: 1.0) perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual loss. (default: 5e-5) regularization_loss_weight: Weight for regularization loss from encoder. This is essential for in-domain inversion. However, this loss will automatically ignored if the generative model does not include a valid encoder. 0 disables regularization loss. (default: 2.0) """ self.logger = logger self.model_name = model_name self.gan_type = 'stylegan' self.G = StyleGANGenerator(self.model_name, self.logger) self.E = StyleGANEncoder(self.model_name, self.logger) self.F = PerceptualModel(min_val=self.G.min_val, max_val=self.G.max_val) self.encode_dim = [self.G.num_layers, self.G.w_space_dim] self.run_device = self.G.run_device assert list(self.encode_dim) == list(self.E.encode_dim) assert self.G.gan_type == self.gan_type assert self.E.gan_type == self.gan_type self.learning_rate = learning_rate self.iteration = iteration self.loss_pix_weight = reconstruction_loss_weight self.loss_feat_weight = perceptual_loss_weight self.loss_reg_weight = regularization_loss_weight assert self.loss_pix_weight > 0 def preprocess(self, image): """Preprocesses a single image. This function assumes the input numpy array is with shape [height, width, channel], channel order `RGB`, and pixel range [0, 255]. The returned image is with shape [channel, new_height, new_width], where `new_height` and `new_width` are specified by the given generative model. The channel order of returned image is also specified by the generative model. The pixel range is shifted to [min_val, max_val], where `min_val` and `max_val` are also specified by the generative model. """ if not isinstance(image, np.ndarray): raise ValueError( f'Input image should be with type `numpy.ndarray`!') if image.dtype != np.uint8: raise ValueError( f'Input image should be with dtype `numpy.uint8`!') if image.ndim != 3 or image.shape[2] not in [1, 3]: raise ValueError( f'Input should be with shape [height, width, channel], ' f'where channel equals to 1 or 3!\n' f'But {image.shape} is received!') if image.shape[2] == 1 and self.G.image_channels == 3: image = np.tile(image, (1, 1, 3)) if image.shape[2] != self.G.image_channels: raise ValueError( f'Number of channels of input image, which is ' f'{image.shape[2]}, is not supported by the current ' f'inverter, which requires {self.G.image_channels} ' f'channels!') if self.G.image_channels == 3 and self.G.channel_order == 'BGR': image = image[:, :, ::-1] if image.shape[1:3] != [self.G.resolution, self.G.resolution]: image = cv2.resize(image, (self.G.resolution, self.G.resolution)) image = image.astype(np.float32) image = image / 255.0 * (self.G.max_val - self.G.min_val) + self.G.min_val image = image.astype(np.float32).transpose(2, 0, 1) return image def get_init_code(self, image): """Gets initial latent codes as the start point for optimization. The input image is assumed to have already been preprocessed, meaning to have shape [self.G.image_channels, self.G.resolution, self.G.resolution], channel order `self.G.channel_order`, and pixel range [self.G.min_val, self.G.max_val]. """ x = image[np.newaxis] x = self.G.to_tensor(x.astype(np.float32)) z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim)) return z.astype(np.float32) def invert(self, image, num_viz=0): """Inverts the given image to a latent code. Basically, this function is based on gradient descent algorithm. Args: image: Target image to invert, which is assumed to have already been preprocessed. num_viz: Number of intermediate outputs to visualize. (default: 0) Returns: A two-element tuple. First one is the inverted code. Second one is a list of intermediate results, where first image is the input image, second one is the reconstructed result from the initial latent code, remainings are from the optimization process every `self.iteration // num_viz` steps. """ x = image[np.newaxis] x = self.G.to_tensor(x.astype(np.float32)) x.requires_grad = False init_z = self.get_init_code(image) z = torch.Tensor(init_z).to(self.run_device) z.requires_grad = True optimizer = torch.optim.Adam([z], lr=self.learning_rate) viz_results = [] viz_results.append(self.G.postprocess(_get_tensor_value(x))[0]) x_init_inv = self.G.net.synthesis(z) viz_results.append( self.G.postprocess(_get_tensor_value(x_init_inv))[0]) pbar = tqdm(range(1, self.iteration + 1), leave=True) for step in pbar: loss = 0.0 # Reconstruction loss. x_rec = self.G.net.synthesis(z) loss_pix = torch.mean((x - x_rec)**2) loss = loss + loss_pix * self.loss_pix_weight log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}' # Perceptual loss. if self.loss_feat_weight: x_feat = self.F.net(x) x_rec_feat = self.F.net(x_rec) loss_feat = torch.mean((x_feat - x_rec_feat)**2) loss = loss + loss_feat * self.loss_feat_weight log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}' # Regularization loss. if self.loss_reg_weight: z_rec = self.E.net(x_rec).view(1, *self.encode_dim) loss_reg = torch.mean((z - z_rec)**2) loss = loss + loss_reg * self.loss_reg_weight log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}' log_message += f', loss: {_get_tensor_value(loss):.3f}' pbar.set_description_str(log_message) if self.logger: self.logger.debug(f'Step: {step:05d}, ' f'lr: {self.learning_rate:.2e}, ' f'{log_message}') # Do optimization. optimizer.zero_grad() loss.backward() optimizer.step() if num_viz > 0 and step % (self.iteration // num_viz) == 0: viz_results.append( self.G.postprocess(_get_tensor_value(x_rec))[0]) return _get_tensor_value(z), viz_results def easy_invert(self, image, num_viz=0): """Wraps functions `preprocess()` and `invert()` together.""" return self.invert(self.preprocess(image), num_viz) def diffuse(self, target, context, center_x, center_y, crop_x, crop_y, num_viz=0): """Diffuses the target image to a context image. Basically, this function is a motified version of `self.invert()`. More concretely, the encoder regularizer is removed from the objectives and the reconstruction loss is computed from the masked region. Args: target: Target image (foreground). context: Context image (background). center_x: The x-coordinate of the crop center. center_y: The y-coordinate of the crop center. crop_x: The crop size along the x-axis. crop_y: The crop size along the y-axis. num_viz: Number of intermediate outputs to visualize. (default: 0) Returns: A two-element tuple. First one is the inverted code. Second one is a list of intermediate results, where first image is the direct copy-paste image, second one is the reconstructed result from the initial latent code, remainings are from the optimization process every `self.iteration // num_viz` steps. """ image_shape = (self.G.image_channels, self.G.resolution, self.G.resolution) mask = np.zeros((1, *image_shape), dtype=np.float32) xx = center_x - crop_x // 2 yy = center_y - crop_y // 2 mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0 target = target[np.newaxis] context = context[np.newaxis] x = target * mask + context * (1 - mask) x = self.G.to_tensor(x.astype(np.float32)) x.requires_grad = False mask = self.G.to_tensor(mask.astype(np.float32)) mask.requires_grad = False init_z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim)) init_z = init_z.astype(np.float32) z = torch.Tensor(init_z).to(self.run_device) z.requires_grad = True optimizer = torch.optim.Adam([z], lr=self.learning_rate) viz_results = [] viz_results.append(self.G.postprocess(_get_tensor_value(x))[0]) x_init_inv = self.G.net.synthesis(z) viz_results.append( self.G.postprocess(_get_tensor_value(x_init_inv))[0]) pbar = tqdm(range(1, self.iteration + 1), leave=True) for step in pbar: loss = 0.0 # Reconstruction loss. x_rec = self.G.net.synthesis(z) loss_pix = torch.mean(((x - x_rec) * mask)**2) loss = loss + loss_pix * self.loss_pix_weight log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}' # Perceptual loss. if self.loss_feat_weight: x_feat = self.F.net(x * mask) x_rec_feat = self.F.net(x_rec * mask) loss_feat = torch.mean((x_feat - x_rec_feat)**2) loss = loss + loss_feat * self.loss_feat_weight log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}' log_message += f', loss: {_get_tensor_value(loss):.3f}' pbar.set_description_str(log_message) if self.logger: self.logger.debug(f'Step: {step:05d}, ' f'lr: {self.learning_rate:.2e}, ' f'{log_message}') # Do optimization. optimizer.zero_grad() loss.backward() optimizer.step() if num_viz > 0 and step % (self.iteration // num_viz) == 0: viz_results.append( self.G.postprocess(_get_tensor_value(x_rec))[0]) return _get_tensor_value(z), viz_results def easy_diffuse(self, target, context, *args, **kwargs): """Wraps functions `preprocess()` and `diffuse()` together.""" return self.diffuse(self.preprocess(target), self.preprocess(context), *args, **kwargs)
def training_loop(config, dataset_args={}, E_lr_args=EasyDict(), D_lr_args=EasyDict(), opt_args=EasyDict(), E_loss_args=EasyDict(), D_loss_args=EasyDict(), logger=None, writer=None, image_snapshot_ticks=50, max_epoch=50): # parse loss_pix_weight = E_loss_args.loss_pix_weight loss_feat_weight = E_loss_args.loss_feat_weight loss_adv_weight = E_loss_args.loss_adv_weight loss_real_weight = D_loss_args.loss_real_weight loss_fake_weight = D_loss_args.loss_fake_weight loss_gp_weight = D_loss_args.loss_gp_weight loss_ep_weight = D_loss_args.loss_ep_weight E_learning_rate = E_lr_args.learning_rate D_learning_rate = D_lr_args.learning_rate # construct dataloader train_dataset = CelebaHQ(dataset_args, train=True) val_dataset = CelebaHQ(dataset_args, train=False) train_dataloader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=config.test_batch_size, shuffle=False) # construct model G = StyleGANGenerator(config.model_name, logger, gpu_ids=config.gpu_ids) E = StyleGANEncoder(config.model_name, logger, gpu_ids=config.gpu_ids) F = PerceptualModel(min_val=G.min_val, max_val=G.max_val, gpu_ids=config.gpu_ids) D = StyleGANDiscriminator(config.model_name, logger, gpu_ids=config.gpu_ids) G.net.synthesis.eval() E.net.train() F.net.eval() D.net.train() encode_dim = [G.num_layers, G.w_space_dim] # optimizer optimizer_E = torch.optim.Adam(E.net.parameters(), lr=E_learning_rate, **opt_args) optimizer_D = torch.optim.Adam(D.net.parameters(), lr=D_learning_rate, **opt_args) lr_scheduler_E = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer_E, gamma=E_lr_args.decay_rate) lr_scheduler_D = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer_D, gamma=D_lr_args.decay_rate) global_step = 0 for epoch in range(max_epoch): E_loss_rec = 0. E_loss_adv = 0. E_loss_feat = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. learning_rate = lr_scheduler_E.get_lr()[0] for step, items in enumerate(train_dataloader): E.net.train() x = items x = x.float().cuda() batch_size = x.shape[0] z = E.net(x).view(batch_size, *encode_dim) x_rec = G.net.synthesis(z) # =============================== # optimizing D # =============================== x_real = D.net(x) x_fake = D.net(x_rec.detach()) loss_real = GAN_loss(x_real, real=True) loss_fake = GAN_loss(x_fake, real=False) # gradient div loss_gp = div_loss_(D, x, x_rec.detach(), cuda=config.cuda) # loss_gp = div_loss(D, x, x_real) D_loss_real += loss_real.item() D_loss_fake += loss_fake.item() D_loss_grad += loss_gp.item() log_message = f'D-[real:{loss_real.cpu().detach().numpy():.3f}, ' \ f'fake:{loss_fake.cpu().detach().numpy():.3f}, ' \ f'gp:{loss_gp.cpu().detach().numpy():.3f}]' D_loss = loss_real_weight * loss_real + loss_fake_weight * loss_fake + loss_gp_weight * loss_gp optimizer_D.zero_grad() D_loss.backward() optimizer_D.step() # =============================== # optimizing G # =============================== # Reconstruction loss. loss_pix = torch.mean((x - x_rec)**2) E_loss_rec += loss_pix.item() log_message += f', G-[pix:{loss_pix.cpu().detach().numpy():.3f}' # Perceptual loss. loss_feat = 0. if loss_feat_weight: x_feat = F.net(x) x_rec_feat = F.net(x_rec) loss_feat = torch.mean((x_feat - x_rec_feat)**2) E_loss_feat += loss_feat.item() log_message += f', feat:{loss_feat.cpu().detach().numpy():.3f}' # adversarial loss. loss_adv = 0. if loss_adv_weight: x_adv = D.net(x_rec) loss_adv = GAN_loss(x_adv, real=True) E_loss_adv += loss_adv.item() log_message += f', adv:{loss_adv.cpu().detach().numpy():.3f}]' E_loss = loss_pix_weight * loss_pix + loss_feat_weight * loss_feat + loss_adv_weight * loss_adv log_message += f', loss:{E_loss.cpu().detach().numpy():.3f}' optimizer_E.zero_grad() E_loss.backward() optimizer_E.step() # pbar.set_description_str(log_message) if logger: logger.debug(f'Epoch:{epoch:03d}, ' f'Step:{step:04d}, ' f'lr:{learning_rate:.2e}, ' f'{log_message}') if writer: writer.add_scalar('D/loss_real', loss_real.item(), global_step=global_step) writer.add_scalar('D/loss_fake', loss_fake.item(), global_step=global_step) writer.add_scalar('D/loss_gp', loss_gp.item(), global_step=global_step) writer.add_scalar('D/loss', D_loss.item(), global_step=global_step) writer.add_scalar('E/loss_pix', loss_pix.item(), global_step=global_step) writer.add_scalar('E/loss_feat', loss_feat.item(), global_step=global_step) writer.add_scalar('E/loss_adv', loss_adv.item(), global_step=global_step) writer.add_scalar('E/loss', E_loss.item(), global_step=global_step) if step % image_snapshot_ticks == 0: E.net.eval() for val_step, val_items in enumerate(val_dataloader): x_val = val_items x_val = x_val.float().cuda() batch_size_val = x_val.shape[0] x_train = x[:batch_size_val, :, :, :] z_train = E.net(x_train).view(batch_size_val, *encode_dim) x_rec_train = G.net.synthesis(z_train) z_val = E.net(x_val).view(batch_size_val, *encode_dim) x_rec_val = G.net.synthesis(z_val) x_all = torch.cat([x_val, x_rec_val, x_train, x_rec_train], dim=0) if val_step > config.test_save_step: break save_filename = f'epoch_{epoch:03d}_step_{step:04d}_test_{val_step:04d}.png' save_filepath = os.path.join(config.save_images, save_filename) tvutils.save_image(x_all, filename=save_filepath, nrow=config.test_batch_size, normalize=True, scale_each=True) global_step += 1 if (global_step + 1) % E_lr_args.decay_step == 0: lr_scheduler_E.step() if (global_step + 1) % D_lr_args.decay_step == 0: lr_scheduler_D.step() D_loss_real /= train_dataloader.__len__() D_loss_fake /= train_dataloader.__len__() D_loss_grad /= train_dataloader.__len__() E_loss_rec /= train_dataloader.__len__() E_loss_adv /= train_dataloader.__len__() E_loss_feat /= train_dataloader.__len__() log_message_ep = f'D-[real:{D_loss_real:.3f}, fake:{D_loss_fake:.3f}, gp:{D_loss_grad:.3f}], ' \ f'G-[pix:{E_loss_rec:.3f}, feat:{E_loss_feat:.3f}, adv:{E_loss_adv:.3f}]' if logger: logger.debug(f'Epoch: {epoch:03d}, ' f'lr: {learning_rate:.2e}, ' f'{log_message_ep}') save_filename = f'styleganinv_encoder_epoch_{epoch:03d}' save_filepath = os.path.join(config.save_models, save_filename) torch.save(E.net.module.state_dict(), save_filepath)