示例#1
0
文件: inverter.py 项目: NivC/TediGAN
    def __init__(self,
                 model_name,
                 mode='man',
                 learning_rate=1e-2,
                 iteration=100,
                 reconstruction_loss_weight=1.0,
                 perceptual_loss_weight=5e-5,
                 regularization_loss_weight=2.0,
                 clip_loss_weight=None,
                 description=None,
                 logger=None):
        """Initializes the inverter.

    NOTE: Only Adam optimizer is supported in the optimization process.

    Args:
      model_name: Name of the model on which the inverted is based. The model
        should be first registered in `models/model_settings.py`.
      logger: Logger to record the log message.
      learning_rate: Learning rate for optimization. (default: 1e-2)
      iteration: Number of iterations for optimization. (default: 100)
      reconstruction_loss_weight: Weight for reconstruction loss. Should always
        be a positive number. (default: 1.0)
      perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual
        loss. (default: 5e-5)
      regularization_loss_weight: Weight for regularization loss from encoder.
        This is essential for in-domain inversion. However, this loss will
        automatically ignored if the generative model does not include a valid
        encoder. 0 disables regularization loss. (default: 2.0)
      clip_loss_weight: weight for CLIP loss.
    """

        if clip_loss_weight:
            self.text_inputs = torch.cat([clip.tokenize(description)]).cuda()
            self.clip_loss = CLIPLoss()

        self.mode = mode
        self.logger = logger
        self.model_name = model_name
        self.gan_type = 'stylegan'

        self.G = StyleGANGenerator(self.model_name, self.logger)
        self.E = StyleGANEncoder(self.model_name, self.logger)
        self.F = PerceptualModel(min_val=self.G.min_val,
                                 max_val=self.G.max_val)
        self.encode_dim = [self.G.num_layers, self.G.w_space_dim]
        self.run_device = self.G.run_device
        assert list(self.encode_dim) == list(self.E.encode_dim)

        assert self.G.gan_type == self.gan_type
        assert self.E.gan_type == self.gan_type

        self.learning_rate = learning_rate
        self.iteration = iteration
        self.loss_pix_weight = reconstruction_loss_weight
        self.loss_feat_weight = perceptual_loss_weight
        self.loss_reg_weight = regularization_loss_weight
        self.loss_clip_weight = clip_loss_weight
        assert self.loss_pix_weight > 0
class StyleGANInverter(object):
    """Defines the class for StyleGAN inversion.

  Even having the encoder, the output latent code is not good enough to recover
  the target image satisfyingly. To this end, this class optimize the latent
  code based on gradient descent algorithm. In the optimization process,
  following loss functions will be considered:

  (1) Pixel-wise reconstruction loss. (required)
  (2) Perceptual loss. (optional, but recommended)
  (3) Regularization loss from encoder. (optional, but recommended for in-domain
      inversion)

  NOTE: The encoder can be missing for inversion, in which case the latent code
  will be randomly initialized and the regularization loss will be ignored.
  """
    def __init__(self,
                 model_name,
                 learning_rate=1e-2,
                 iteration=100,
                 reconstruction_loss_weight=1.0,
                 perceptual_loss_weight=5e-5,
                 regularization_loss_weight=2.0,
                 logger=None):
        """Initializes the inverter.

    NOTE: Only Adam optimizer is supported in the optimization process.

    Args:
      model_name: Name of the model on which the inverted is based. The model
        should be first registered in `models/model_settings.py`.
      logger: Logger to record the log message.
      learning_rate: Learning rate for optimization. (default: 1e-2)
      iteration: Number of iterations for optimization. (default: 100)
      reconstruction_loss_weight: Weight for reconstruction loss. Should always
        be a positive number. (default: 1.0)
      perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual
        loss. (default: 5e-5)
      regularization_loss_weight: Weight for regularization loss from encoder.
        This is essential for in-domain inversion. However, this loss will
        automatically ignored if the generative model does not include a valid
        encoder. 0 disables regularization loss. (default: 2.0)
    """
        self.logger = logger
        self.model_name = model_name
        self.gan_type = 'stylegan'

        self.G = StyleGANGenerator(self.model_name, self.logger)
        self.E = StyleGANEncoder(self.model_name, self.logger)
        self.F = PerceptualModel(min_val=self.G.min_val,
                                 max_val=self.G.max_val)
        self.encode_dim = [self.G.num_layers, self.G.w_space_dim]
        self.run_device = self.G.run_device
        assert list(self.encode_dim) == list(self.E.encode_dim)

        assert self.G.gan_type == self.gan_type
        assert self.E.gan_type == self.gan_type

        self.learning_rate = learning_rate
        self.iteration = iteration
        self.loss_pix_weight = reconstruction_loss_weight
        self.loss_feat_weight = perceptual_loss_weight
        self.loss_reg_weight = regularization_loss_weight
        assert self.loss_pix_weight > 0

    def preprocess(self, image):
        """Preprocesses a single image.

    This function assumes the input numpy array is with shape [height, width,
    channel], channel order `RGB`, and pixel range [0, 255].

    The returned image is with shape [channel, new_height, new_width], where
    `new_height` and `new_width` are specified by the given generative model.
    The channel order of returned image is also specified by the generative
    model. The pixel range is shifted to [min_val, max_val], where `min_val` and
    `max_val` are also specified by the generative model.
    """
        if not isinstance(image, np.ndarray):
            raise ValueError(
                f'Input image should be with type `numpy.ndarray`!')
        if image.dtype != np.uint8:
            raise ValueError(
                f'Input image should be with dtype `numpy.uint8`!')

        if image.ndim != 3 or image.shape[2] not in [1, 3]:
            raise ValueError(
                f'Input should be with shape [height, width, channel], '
                f'where channel equals to 1 or 3!\n'
                f'But {image.shape} is received!')
        if image.shape[2] == 1 and self.G.image_channels == 3:
            image = np.tile(image, (1, 1, 3))
        if image.shape[2] != self.G.image_channels:
            raise ValueError(
                f'Number of channels of input image, which is '
                f'{image.shape[2]}, is not supported by the current '
                f'inverter, which requires {self.G.image_channels} '
                f'channels!')

        if self.G.image_channels == 3 and self.G.channel_order == 'BGR':
            image = image[:, :, ::-1]
        if image.shape[1:3] != [self.G.resolution, self.G.resolution]:
            image = cv2.resize(image, (self.G.resolution, self.G.resolution))
        image = image.astype(np.float32)
        image = image / 255.0 * (self.G.max_val -
                                 self.G.min_val) + self.G.min_val
        image = image.astype(np.float32).transpose(2, 0, 1)

        return image

    def get_init_code(self, image):
        """Gets initial latent codes as the start point for optimization.

    The input image is assumed to have already been preprocessed, meaning to
    have shape [self.G.image_channels, self.G.resolution, self.G.resolution],
    channel order `self.G.channel_order`, and pixel range [self.G.min_val,
    self.G.max_val].
    """
        x = image[np.newaxis]
        x = self.G.to_tensor(x.astype(np.float32))
        z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim))
        return z.astype(np.float32)

    def invert(self, image, num_viz=0):
        """Inverts the given image to a latent code.

    Basically, this function is based on gradient descent algorithm.

    Args:
      image: Target image to invert, which is assumed to have already been
        preprocessed.
      num_viz: Number of intermediate outputs to visualize. (default: 0)

    Returns:
      A two-element tuple. First one is the inverted code. Second one is a list
        of intermediate results, where first image is the input image, second
        one is the reconstructed result from the initial latent code, remainings
        are from the optimization process every `self.iteration // num_viz`
        steps.
    """
        x = image[np.newaxis]
        x = self.G.to_tensor(x.astype(np.float32))
        x.requires_grad = False
        init_z = self.get_init_code(image)
        z = torch.Tensor(init_z).to(self.run_device)
        z.requires_grad = True

        optimizer = torch.optim.Adam([z], lr=self.learning_rate)

        viz_results = []
        viz_results.append(self.G.postprocess(_get_tensor_value(x))[0])
        x_init_inv = self.G.net.synthesis(z)
        viz_results.append(
            self.G.postprocess(_get_tensor_value(x_init_inv))[0])
        pbar = tqdm(range(1, self.iteration + 1), leave=True)
        for step in pbar:
            loss = 0.0

            # Reconstruction loss.
            x_rec = self.G.net.synthesis(z)
            loss_pix = torch.mean((x - x_rec)**2)
            loss = loss + loss_pix * self.loss_pix_weight
            log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}'

            # Perceptual loss.
            if self.loss_feat_weight:
                x_feat = self.F.net(x)
                x_rec_feat = self.F.net(x_rec)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                loss = loss + loss_feat * self.loss_feat_weight
                log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}'

            # Regularization loss.
            if self.loss_reg_weight:
                z_rec = self.E.net(x_rec).view(1, *self.encode_dim)
                loss_reg = torch.mean((z - z_rec)**2)
                loss = loss + loss_reg * self.loss_reg_weight
                log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}'

            log_message += f', loss: {_get_tensor_value(loss):.3f}'
            pbar.set_description_str(log_message)
            if self.logger:
                self.logger.debug(f'Step: {step:05d}, '
                                  f'lr: {self.learning_rate:.2e}, '
                                  f'{log_message}')

            # Do optimization.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if num_viz > 0 and step % (self.iteration // num_viz) == 0:
                viz_results.append(
                    self.G.postprocess(_get_tensor_value(x_rec))[0])

        return _get_tensor_value(z), viz_results

    def easy_invert(self, image, num_viz=0):
        """Wraps functions `preprocess()` and `invert()` together."""
        return self.invert(self.preprocess(image), num_viz)

    def diffuse(self,
                target,
                context,
                center_x,
                center_y,
                crop_x,
                crop_y,
                num_viz=0):
        """Diffuses the target image to a context image.

    Basically, this function is a motified version of `self.invert()`. More
    concretely, the encoder regularizer is removed from the objectives and the
    reconstruction loss is computed from the masked region.

    Args:
      target: Target image (foreground).
      context: Context image (background).
      center_x: The x-coordinate of the crop center.
      center_y: The y-coordinate of the crop center.
      crop_x: The crop size along the x-axis.
      crop_y: The crop size along the y-axis.
      num_viz: Number of intermediate outputs to visualize. (default: 0)

    Returns:
      A two-element tuple. First one is the inverted code. Second one is a list
        of intermediate results, where first image is the direct copy-paste
        image, second one is the reconstructed result from the initial latent
        code, remainings are from the optimization process every
        `self.iteration // num_viz` steps.
    """
        image_shape = (self.G.image_channels, self.G.resolution,
                       self.G.resolution)
        mask = np.zeros((1, *image_shape), dtype=np.float32)
        xx = center_x - crop_x // 2
        yy = center_y - crop_y // 2
        mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0

        target = target[np.newaxis]
        context = context[np.newaxis]
        x = target * mask + context * (1 - mask)
        x = self.G.to_tensor(x.astype(np.float32))
        x.requires_grad = False
        mask = self.G.to_tensor(mask.astype(np.float32))
        mask.requires_grad = False

        init_z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim))
        init_z = init_z.astype(np.float32)
        z = torch.Tensor(init_z).to(self.run_device)
        z.requires_grad = True

        optimizer = torch.optim.Adam([z], lr=self.learning_rate)

        viz_results = []
        viz_results.append(self.G.postprocess(_get_tensor_value(x))[0])
        x_init_inv = self.G.net.synthesis(z)
        viz_results.append(
            self.G.postprocess(_get_tensor_value(x_init_inv))[0])
        pbar = tqdm(range(1, self.iteration + 1), leave=True)
        for step in pbar:
            loss = 0.0

            # Reconstruction loss.
            x_rec = self.G.net.synthesis(z)
            loss_pix = torch.mean(((x - x_rec) * mask)**2)
            loss = loss + loss_pix * self.loss_pix_weight
            log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}'

            # Perceptual loss.
            if self.loss_feat_weight:
                x_feat = self.F.net(x * mask)
                x_rec_feat = self.F.net(x_rec * mask)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                loss = loss + loss_feat * self.loss_feat_weight
                log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}'

            log_message += f', loss: {_get_tensor_value(loss):.3f}'
            pbar.set_description_str(log_message)
            if self.logger:
                self.logger.debug(f'Step: {step:05d}, '
                                  f'lr: {self.learning_rate:.2e}, '
                                  f'{log_message}')

            # Do optimization.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if num_viz > 0 and step % (self.iteration // num_viz) == 0:
                viz_results.append(
                    self.G.postprocess(_get_tensor_value(x_rec))[0])

        return _get_tensor_value(z), viz_results

    def easy_diffuse(self, target, context, *args, **kwargs):
        """Wraps functions `preprocess()` and `diffuse()` together."""
        return self.diffuse(self.preprocess(target), self.preprocess(context),
                            *args, **kwargs)
示例#3
0
def training_loop(config,
                  dataset_args={},
                  E_lr_args=EasyDict(),
                  D_lr_args=EasyDict(),
                  opt_args=EasyDict(),
                  E_loss_args=EasyDict(),
                  D_loss_args=EasyDict(),
                  logger=None,
                  writer=None,
                  image_snapshot_ticks=50,
                  max_epoch=50):
    # parse
    loss_pix_weight = E_loss_args.loss_pix_weight
    loss_feat_weight = E_loss_args.loss_feat_weight
    loss_adv_weight = E_loss_args.loss_adv_weight
    loss_real_weight = D_loss_args.loss_real_weight
    loss_fake_weight = D_loss_args.loss_fake_weight
    loss_gp_weight = D_loss_args.loss_gp_weight
    loss_ep_weight = D_loss_args.loss_ep_weight
    E_learning_rate = E_lr_args.learning_rate
    D_learning_rate = D_lr_args.learning_rate

    # construct dataloader
    train_dataset = CelebaHQ(dataset_args, train=True)
    val_dataset = CelebaHQ(dataset_args, train=False)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.train_batch_size,
                                  shuffle=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.test_batch_size,
                                shuffle=False)

    # construct model
    G = StyleGANGenerator(config.model_name, logger, gpu_ids=config.gpu_ids)
    E = StyleGANEncoder(config.model_name, logger, gpu_ids=config.gpu_ids)
    F = PerceptualModel(min_val=G.min_val,
                        max_val=G.max_val,
                        gpu_ids=config.gpu_ids)
    D = StyleGANDiscriminator(config.model_name,
                              logger,
                              gpu_ids=config.gpu_ids)
    G.net.synthesis.eval()
    E.net.train()
    F.net.eval()
    D.net.train()
    encode_dim = [G.num_layers, G.w_space_dim]

    # optimizer
    optimizer_E = torch.optim.Adam(E.net.parameters(),
                                   lr=E_learning_rate,
                                   **opt_args)
    optimizer_D = torch.optim.Adam(D.net.parameters(),
                                   lr=D_learning_rate,
                                   **opt_args)
    lr_scheduler_E = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer_E, gamma=E_lr_args.decay_rate)
    lr_scheduler_D = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer_D, gamma=D_lr_args.decay_rate)

    global_step = 0
    for epoch in range(max_epoch):
        E_loss_rec = 0.
        E_loss_adv = 0.
        E_loss_feat = 0.
        D_loss_real = 0.
        D_loss_fake = 0.
        D_loss_grad = 0.
        learning_rate = lr_scheduler_E.get_lr()[0]
        for step, items in enumerate(train_dataloader):
            E.net.train()
            x = items
            x = x.float().cuda()
            batch_size = x.shape[0]
            z = E.net(x).view(batch_size, *encode_dim)
            x_rec = G.net.synthesis(z)

            # ===============================
            #         optimizing D
            # ===============================

            x_real = D.net(x)
            x_fake = D.net(x_rec.detach())
            loss_real = GAN_loss(x_real, real=True)
            loss_fake = GAN_loss(x_fake, real=False)
            # gradient div
            loss_gp = div_loss_(D, x, x_rec.detach(), cuda=config.cuda)
            # loss_gp = div_loss(D, x, x_real)

            D_loss_real += loss_real.item()
            D_loss_fake += loss_fake.item()
            D_loss_grad += loss_gp.item()
            log_message = f'D-[real:{loss_real.cpu().detach().numpy():.3f}, ' \
                          f'fake:{loss_fake.cpu().detach().numpy():.3f}, ' \
                          f'gp:{loss_gp.cpu().detach().numpy():.3f}]'
            D_loss = loss_real_weight * loss_real + loss_fake_weight * loss_fake + loss_gp_weight * loss_gp
            optimizer_D.zero_grad()
            D_loss.backward()
            optimizer_D.step()

            # ===============================
            #         optimizing G
            # ===============================
            # Reconstruction loss.
            loss_pix = torch.mean((x - x_rec)**2)
            E_loss_rec += loss_pix.item()
            log_message += f', G-[pix:{loss_pix.cpu().detach().numpy():.3f}'

            # Perceptual loss.
            loss_feat = 0.
            if loss_feat_weight:
                x_feat = F.net(x)
                x_rec_feat = F.net(x_rec)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                E_loss_feat += loss_feat.item()
                log_message += f', feat:{loss_feat.cpu().detach().numpy():.3f}'

            # adversarial loss.
            loss_adv = 0.
            if loss_adv_weight:
                x_adv = D.net(x_rec)
                loss_adv = GAN_loss(x_adv, real=True)
                E_loss_adv += loss_adv.item()
                log_message += f', adv:{loss_adv.cpu().detach().numpy():.3f}]'

            E_loss = loss_pix_weight * loss_pix + loss_feat_weight * loss_feat + loss_adv_weight * loss_adv
            log_message += f', loss:{E_loss.cpu().detach().numpy():.3f}'
            optimizer_E.zero_grad()
            E_loss.backward()
            optimizer_E.step()

            # pbar.set_description_str(log_message)
            if logger:
                logger.debug(f'Epoch:{epoch:03d}, '
                             f'Step:{step:04d}, '
                             f'lr:{learning_rate:.2e}, '
                             f'{log_message}')
            if writer:
                writer.add_scalar('D/loss_real',
                                  loss_real.item(),
                                  global_step=global_step)
                writer.add_scalar('D/loss_fake',
                                  loss_fake.item(),
                                  global_step=global_step)
                writer.add_scalar('D/loss_gp',
                                  loss_gp.item(),
                                  global_step=global_step)
                writer.add_scalar('D/loss',
                                  D_loss.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss_pix',
                                  loss_pix.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss_feat',
                                  loss_feat.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss_adv',
                                  loss_adv.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss',
                                  E_loss.item(),
                                  global_step=global_step)

            if step % image_snapshot_ticks == 0:
                E.net.eval()
                for val_step, val_items in enumerate(val_dataloader):
                    x_val = val_items
                    x_val = x_val.float().cuda()
                    batch_size_val = x_val.shape[0]
                    x_train = x[:batch_size_val, :, :, :]
                    z_train = E.net(x_train).view(batch_size_val, *encode_dim)
                    x_rec_train = G.net.synthesis(z_train)
                    z_val = E.net(x_val).view(batch_size_val, *encode_dim)
                    x_rec_val = G.net.synthesis(z_val)
                    x_all = torch.cat([x_val, x_rec_val, x_train, x_rec_train],
                                      dim=0)
                    if val_step > config.test_save_step:
                        break
                    save_filename = f'epoch_{epoch:03d}_step_{step:04d}_test_{val_step:04d}.png'
                    save_filepath = os.path.join(config.save_images,
                                                 save_filename)
                    tvutils.save_image(x_all,
                                       filename=save_filepath,
                                       nrow=config.test_batch_size,
                                       normalize=True,
                                       scale_each=True)

            global_step += 1
            if (global_step + 1) % E_lr_args.decay_step == 0:
                lr_scheduler_E.step()
            if (global_step + 1) % D_lr_args.decay_step == 0:
                lr_scheduler_D.step()

        D_loss_real /= train_dataloader.__len__()
        D_loss_fake /= train_dataloader.__len__()
        D_loss_grad /= train_dataloader.__len__()
        E_loss_rec /= train_dataloader.__len__()
        E_loss_adv /= train_dataloader.__len__()
        E_loss_feat /= train_dataloader.__len__()
        log_message_ep = f'D-[real:{D_loss_real:.3f}, fake:{D_loss_fake:.3f}, gp:{D_loss_grad:.3f}], ' \
                         f'G-[pix:{E_loss_rec:.3f}, feat:{E_loss_feat:.3f}, adv:{E_loss_adv:.3f}]'
        if logger:
            logger.debug(f'Epoch: {epoch:03d}, '
                         f'lr: {learning_rate:.2e}, '
                         f'{log_message_ep}')

        save_filename = f'styleganinv_encoder_epoch_{epoch:03d}'
        save_filepath = os.path.join(config.save_models, save_filename)
        torch.save(E.net.module.state_dict(), save_filepath)