コード例 #1
0
ファイル: inverter.py プロジェクト: NivC/TediGAN
    def __init__(self,
                 model_name,
                 mode='man',
                 learning_rate=1e-2,
                 iteration=100,
                 reconstruction_loss_weight=1.0,
                 perceptual_loss_weight=5e-5,
                 regularization_loss_weight=2.0,
                 clip_loss_weight=None,
                 description=None,
                 logger=None):
        """Initializes the inverter.

    NOTE: Only Adam optimizer is supported in the optimization process.

    Args:
      model_name: Name of the model on which the inverted is based. The model
        should be first registered in `models/model_settings.py`.
      logger: Logger to record the log message.
      learning_rate: Learning rate for optimization. (default: 1e-2)
      iteration: Number of iterations for optimization. (default: 100)
      reconstruction_loss_weight: Weight for reconstruction loss. Should always
        be a positive number. (default: 1.0)
      perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual
        loss. (default: 5e-5)
      regularization_loss_weight: Weight for regularization loss from encoder.
        This is essential for in-domain inversion. However, this loss will
        automatically ignored if the generative model does not include a valid
        encoder. 0 disables regularization loss. (default: 2.0)
      clip_loss_weight: weight for CLIP loss.
    """

        if clip_loss_weight:
            self.text_inputs = torch.cat([clip.tokenize(description)]).cuda()
            self.clip_loss = CLIPLoss()

        self.mode = mode
        self.logger = logger
        self.model_name = model_name
        self.gan_type = 'stylegan'

        self.G = StyleGANGenerator(self.model_name, self.logger)
        self.E = StyleGANEncoder(self.model_name, self.logger)
        self.F = PerceptualModel(min_val=self.G.min_val,
                                 max_val=self.G.max_val)
        self.encode_dim = [self.G.num_layers, self.G.w_space_dim]
        self.run_device = self.G.run_device
        assert list(self.encode_dim) == list(self.E.encode_dim)

        assert self.G.gan_type == self.gan_type
        assert self.E.gan_type == self.gan_type

        self.learning_rate = learning_rate
        self.iteration = iteration
        self.loss_pix_weight = reconstruction_loss_weight
        self.loss_feat_weight = perceptual_loss_weight
        self.loss_reg_weight = regularization_loss_weight
        self.loss_clip_weight = clip_loss_weight
        assert self.loss_pix_weight > 0
コード例 #2
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(config.OUTPUT_PATH, logger_name='generate_data')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[config.MODEL_NAME]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(config.MODEL_NAME, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(config.MODEL_NAME, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(latent_codes, **kwargs)
    else:
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(args.num, **kwargs)
        # latent_code 是否保存
    total_num = latent_codes.shape[0]

    logger.info(f'Generating {total_num} samples.')
    results = defaultdict(list)
    pbar = tqdm(total=total_num, leave=False)
    for latent_codes_batch in model.get_batch_inputs(latent_codes):
        if gan_type == 'pggan':
            outputs = model.easy_synthesize(latent_codes_batch)
        elif gan_type == 'stylegan':
            outputs = model.easy_synthesize(latent_codes_batch,
                                            **kwargs,
                                            generate_style=args.generate_style,
                                            generate_image=args.generate_image)
        for key, val in outputs.items():
            # what is outputs' key?
            if key == 'image':
                for image in val:
                    save_path = os.path.join(config.IMAGE_PATH,
                                             f'{pbar.n:06d}.jpg')
                    cv2.imwrite(save_path, image[:, :, ::-1])
                    pbar.update(1)
            else:
                results[key].append(val)
        if 'image' not in outputs:
            pbar.update(latent_codes_batch.shape[0])
        if pbar.n % 1000 == 0 or pbar.n == total_num:
            logger.debug(f'  Finish {pbar.n:6d} samples.')
    pbar.close()

    logger.info(f'Saving results.')
    for key, val in results.items():
        # key: z, w, wp, f'style{i:02d}', result
        save_path = os.path.join(config.OUTPUT_PATH, f'{key}.npy')
        np.save(save_path, np.concatenate(val, axis=0))
コード例 #3
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(args.output_dir, logger_name='generate_data')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[args.model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(args.model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(args.model_name, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing boundary.')
    if not os.path.isfile(args.boundary_path):
        raise ValueError(f'Boundary `{args.boundary_path}` does not exist!')
    boundary = np.load(args.boundary_path)
    np.save(os.path.join(args.output_dir, 'boundary.npy'), boundary)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.input_latent_codes_path):
        logger.info(
            f'  Load latent codes from `{args.input_latent_codes_path}`.')
        latent_codes = np.load(args.input_latent_codes_path)
        latent_codes = model.preprocess(latent_codes, **kwargs)
    else:
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(args.num, **kwargs)
    np.save(os.path.join(args.output_dir, 'latent_codes.npy'), latent_codes)
    total_num = latent_codes.shape[0]

    logger.info(f'Editing {total_num} samples.')
    for sample_id in tqdm(range(total_num), leave=False):
        interpolations = linear_interpolate(latent_codes[sample_id:sample_id +
                                                         1],
                                            boundary,
                                            start_distance=args.start_distance,
                                            end_distance=args.end_distance,
                                            steps=args.steps)
        interpolation_id = 0
        for interpolations_batch in model.get_batch_inputs(interpolations):
            for interpolation in interpolations_batch:
                if gan_type == 'pggan':
                    outputs = model.easy_synthesize(np.array([interpolation]))
                elif gan_type == 'stylegan':
                    outputs = model.easy_synthesize(np.array([interpolation]),
                                                    **kwargs)
                for image in outputs['image']:
                    save_path = os.path.join(
                        args.output_dir,
                        f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                    cv2.imwrite(save_path, image[:, :, ::-1])
                    interpolation_id += 1
        assert interpolation_id == args.steps
        logger.debug(f'  Finished sample {sample_id:3d}.')
    logger.info(f'Successfully edited {total_num} samples.')
コード例 #4
0
ファイル: demo_app.py プロジェクト: XenonLamb/interfacegan
def load_model(model_name, latent_space_type=None, logger=None):
    model_load_state = st.text('Loading GAN model...')
    gan_type = MODEL_POOL[model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(model_name, logger)
        kwargs = {'latent_space_type': latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')
    print('loading', model_name)
    model_load_state.empty()
    return model, kwargs
コード例 #5
0
def main():
  """Main function."""
  args = parse_args()
  logger = setup_logger(logger_name='latent_train')
  
  logger.info(f'Initializing generator.')
  gan_type = MODEL_POOL[args.model_name]['gan_type']
  if gan_type == 'pggan':
    model = PGGANGenerator(args.model_name, logger)
    kwargs = {}
  elif gan_type == 'stylegan':
    model = StyleGANGenerator(args.model_name, logger)
    kwargs = {'latent_space_type': args.latent_space_type}
  else:
    raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

  write_images(model, args.num_train, "./train", gan_kwargs=kwargs)
  write_images(model, args.num_test, "./test", gan_kwargs=kwargs)
  write_images(model, args.num_val, "./validation", gan_kwargs=kwargs)
コード例 #6
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(logger_name='latent_train')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[args.model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(args.model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(args.model_name, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing VGG.')
    # Load the pretrained model from pytorch
    #stock_vgg = models.vgg16()
    stock_vgg = models.vgg16_bn()
    #stock_vgg = vgg_face_dag.Vgg_face_dag()
    vgg = VGG9(stock_vgg)

    start_epoch = 0
    start_epoch, latest = find_latest_epoch_and_checkpoint("./checkpoints/")
    if latest:
        print(f"restoring from checkpoint: {latest} (epoch: {start_epoch})")
        vgg.load_state_dict(torch.load(latest))
    else:
        print("No checkpoint found! exiting")
        return 1

    if args.use_gpu:
        vgg.cuda()

    if args.codes_file:
        visualize_latent_codes(model, args.codes_file, **kwargs)
    else:
        print("Visualizing model...\n")
        visualize_model(vgg, model, 999, num_examples=4, **kwargs)
コード例 #7
0
ファイル: processor.py プロジェクト: countofkrakow/SpaceFace
def setup():
    global logger
    global model
    global boundaries
    global cfg
    global kwargs
    cfg = {
        'output_dir': 'results',
        'model_name': 'stylegan_ffhq',
    }
    logger = setup_logger(logger_name='generate_data')
    model = StyleGANGenerator(cfg['model_name'], logger)
    kwargs = {'latent_space_type': 'Wp'}
    boundaries = {
        'age': 'interfacegan/boundaries/stylegan_ffhq_age_w_boundary.npy',
        'gender':
        'interfacegan/boundaries/stylegan_ffhq_gender_w_boundary.npy',
        'pose': 'interfacegan/boundaries/stylegan_ffhq_pose_w_boundary.npy',
        'smile': 'interfacegan/boundaries/stylegan_ffhq_smile_w_boundary.npy',
        'glasses':
        'interfacegan/boundaries/stylegan_ffhq_eyeglasses_w_boundary.npy',
        'beauty': 'interfacegan/boundaries/beauty_boundary.npy'
    }
コード例 #8
0
class StyleGANInverter(object):
    """Defines the class for StyleGAN inversion.

  Even having the encoder, the output latent code is not good enough to recover
  the target image satisfyingly. To this end, this class optimize the latent
  code based on gradient descent algorithm. In the optimization process,
  following loss functions will be considered:

  (1) Pixel-wise reconstruction loss. (required)
  (2) Perceptual loss. (optional, but recommended)
  (3) Regularization loss from encoder. (optional, but recommended for in-domain
      inversion)

  NOTE: The encoder can be missing for inversion, in which case the latent code
  will be randomly initialized and the regularization loss will be ignored.
  """
    def __init__(self,
                 model_name,
                 learning_rate=1e-2,
                 iteration=100,
                 reconstruction_loss_weight=1.0,
                 perceptual_loss_weight=5e-5,
                 regularization_loss_weight=2.0,
                 logger=None):
        """Initializes the inverter.

    NOTE: Only Adam optimizer is supported in the optimization process.

    Args:
      model_name: Name of the model on which the inverted is based. The model
        should be first registered in `models/model_settings.py`.
      logger: Logger to record the log message.
      learning_rate: Learning rate for optimization. (default: 1e-2)
      iteration: Number of iterations for optimization. (default: 100)
      reconstruction_loss_weight: Weight for reconstruction loss. Should always
        be a positive number. (default: 1.0)
      perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual
        loss. (default: 5e-5)
      regularization_loss_weight: Weight for regularization loss from encoder.
        This is essential for in-domain inversion. However, this loss will
        automatically ignored if the generative model does not include a valid
        encoder. 0 disables regularization loss. (default: 2.0)
    """
        self.logger = logger
        self.model_name = model_name
        self.gan_type = 'stylegan'

        self.G = StyleGANGenerator(self.model_name, self.logger)
        self.E = StyleGANEncoder(self.model_name, self.logger)
        self.F = PerceptualModel(min_val=self.G.min_val,
                                 max_val=self.G.max_val)
        self.encode_dim = [self.G.num_layers, self.G.w_space_dim]
        self.run_device = self.G.run_device
        assert list(self.encode_dim) == list(self.E.encode_dim)

        assert self.G.gan_type == self.gan_type
        assert self.E.gan_type == self.gan_type

        self.learning_rate = learning_rate
        self.iteration = iteration
        self.loss_pix_weight = reconstruction_loss_weight
        self.loss_feat_weight = perceptual_loss_weight
        self.loss_reg_weight = regularization_loss_weight
        assert self.loss_pix_weight > 0

    def preprocess(self, image):
        """Preprocesses a single image.

    This function assumes the input numpy array is with shape [height, width,
    channel], channel order `RGB`, and pixel range [0, 255].

    The returned image is with shape [channel, new_height, new_width], where
    `new_height` and `new_width` are specified by the given generative model.
    The channel order of returned image is also specified by the generative
    model. The pixel range is shifted to [min_val, max_val], where `min_val` and
    `max_val` are also specified by the generative model.
    """
        if not isinstance(image, np.ndarray):
            raise ValueError(
                f'Input image should be with type `numpy.ndarray`!')
        if image.dtype != np.uint8:
            raise ValueError(
                f'Input image should be with dtype `numpy.uint8`!')

        if image.ndim != 3 or image.shape[2] not in [1, 3]:
            raise ValueError(
                f'Input should be with shape [height, width, channel], '
                f'where channel equals to 1 or 3!\n'
                f'But {image.shape} is received!')
        if image.shape[2] == 1 and self.G.image_channels == 3:
            image = np.tile(image, (1, 1, 3))
        if image.shape[2] != self.G.image_channels:
            raise ValueError(
                f'Number of channels of input image, which is '
                f'{image.shape[2]}, is not supported by the current '
                f'inverter, which requires {self.G.image_channels} '
                f'channels!')

        if self.G.image_channels == 3 and self.G.channel_order == 'BGR':
            image = image[:, :, ::-1]
        if image.shape[1:3] != [self.G.resolution, self.G.resolution]:
            image = cv2.resize(image, (self.G.resolution, self.G.resolution))
        image = image.astype(np.float32)
        image = image / 255.0 * (self.G.max_val -
                                 self.G.min_val) + self.G.min_val
        image = image.astype(np.float32).transpose(2, 0, 1)

        return image

    def get_init_code(self, image):
        """Gets initial latent codes as the start point for optimization.

    The input image is assumed to have already been preprocessed, meaning to
    have shape [self.G.image_channels, self.G.resolution, self.G.resolution],
    channel order `self.G.channel_order`, and pixel range [self.G.min_val,
    self.G.max_val].
    """
        x = image[np.newaxis]
        x = self.G.to_tensor(x.astype(np.float32))
        z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim))
        return z.astype(np.float32)

    def invert(self, image, num_viz=0):
        """Inverts the given image to a latent code.

    Basically, this function is based on gradient descent algorithm.

    Args:
      image: Target image to invert, which is assumed to have already been
        preprocessed.
      num_viz: Number of intermediate outputs to visualize. (default: 0)

    Returns:
      A two-element tuple. First one is the inverted code. Second one is a list
        of intermediate results, where first image is the input image, second
        one is the reconstructed result from the initial latent code, remainings
        are from the optimization process every `self.iteration // num_viz`
        steps.
    """
        x = image[np.newaxis]
        x = self.G.to_tensor(x.astype(np.float32))
        x.requires_grad = False
        init_z = self.get_init_code(image)
        z = torch.Tensor(init_z).to(self.run_device)
        z.requires_grad = True

        optimizer = torch.optim.Adam([z], lr=self.learning_rate)

        viz_results = []
        viz_results.append(self.G.postprocess(_get_tensor_value(x))[0])
        x_init_inv = self.G.net.synthesis(z)
        viz_results.append(
            self.G.postprocess(_get_tensor_value(x_init_inv))[0])
        pbar = tqdm(range(1, self.iteration + 1), leave=True)
        for step in pbar:
            loss = 0.0

            # Reconstruction loss.
            x_rec = self.G.net.synthesis(z)
            loss_pix = torch.mean((x - x_rec)**2)
            loss = loss + loss_pix * self.loss_pix_weight
            log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}'

            # Perceptual loss.
            if self.loss_feat_weight:
                x_feat = self.F.net(x)
                x_rec_feat = self.F.net(x_rec)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                loss = loss + loss_feat * self.loss_feat_weight
                log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}'

            # Regularization loss.
            if self.loss_reg_weight:
                z_rec = self.E.net(x_rec).view(1, *self.encode_dim)
                loss_reg = torch.mean((z - z_rec)**2)
                loss = loss + loss_reg * self.loss_reg_weight
                log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}'

            log_message += f', loss: {_get_tensor_value(loss):.3f}'
            pbar.set_description_str(log_message)
            if self.logger:
                self.logger.debug(f'Step: {step:05d}, '
                                  f'lr: {self.learning_rate:.2e}, '
                                  f'{log_message}')

            # Do optimization.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if num_viz > 0 and step % (self.iteration // num_viz) == 0:
                viz_results.append(
                    self.G.postprocess(_get_tensor_value(x_rec))[0])

        return _get_tensor_value(z), viz_results

    def easy_invert(self, image, num_viz=0):
        """Wraps functions `preprocess()` and `invert()` together."""
        return self.invert(self.preprocess(image), num_viz)

    def diffuse(self,
                target,
                context,
                center_x,
                center_y,
                crop_x,
                crop_y,
                num_viz=0):
        """Diffuses the target image to a context image.

    Basically, this function is a motified version of `self.invert()`. More
    concretely, the encoder regularizer is removed from the objectives and the
    reconstruction loss is computed from the masked region.

    Args:
      target: Target image (foreground).
      context: Context image (background).
      center_x: The x-coordinate of the crop center.
      center_y: The y-coordinate of the crop center.
      crop_x: The crop size along the x-axis.
      crop_y: The crop size along the y-axis.
      num_viz: Number of intermediate outputs to visualize. (default: 0)

    Returns:
      A two-element tuple. First one is the inverted code. Second one is a list
        of intermediate results, where first image is the direct copy-paste
        image, second one is the reconstructed result from the initial latent
        code, remainings are from the optimization process every
        `self.iteration // num_viz` steps.
    """
        image_shape = (self.G.image_channels, self.G.resolution,
                       self.G.resolution)
        mask = np.zeros((1, *image_shape), dtype=np.float32)
        xx = center_x - crop_x // 2
        yy = center_y - crop_y // 2
        mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0

        target = target[np.newaxis]
        context = context[np.newaxis]
        x = target * mask + context * (1 - mask)
        x = self.G.to_tensor(x.astype(np.float32))
        x.requires_grad = False
        mask = self.G.to_tensor(mask.astype(np.float32))
        mask.requires_grad = False

        init_z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim))
        init_z = init_z.astype(np.float32)
        z = torch.Tensor(init_z).to(self.run_device)
        z.requires_grad = True

        optimizer = torch.optim.Adam([z], lr=self.learning_rate)

        viz_results = []
        viz_results.append(self.G.postprocess(_get_tensor_value(x))[0])
        x_init_inv = self.G.net.synthesis(z)
        viz_results.append(
            self.G.postprocess(_get_tensor_value(x_init_inv))[0])
        pbar = tqdm(range(1, self.iteration + 1), leave=True)
        for step in pbar:
            loss = 0.0

            # Reconstruction loss.
            x_rec = self.G.net.synthesis(z)
            loss_pix = torch.mean(((x - x_rec) * mask)**2)
            loss = loss + loss_pix * self.loss_pix_weight
            log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}'

            # Perceptual loss.
            if self.loss_feat_weight:
                x_feat = self.F.net(x * mask)
                x_rec_feat = self.F.net(x_rec * mask)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                loss = loss + loss_feat * self.loss_feat_weight
                log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}'

            log_message += f', loss: {_get_tensor_value(loss):.3f}'
            pbar.set_description_str(log_message)
            if self.logger:
                self.logger.debug(f'Step: {step:05d}, '
                                  f'lr: {self.learning_rate:.2e}, '
                                  f'{log_message}')

            # Do optimization.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if num_viz > 0 and step % (self.iteration // num_viz) == 0:
                viz_results.append(
                    self.G.postprocess(_get_tensor_value(x_rec))[0])

        return _get_tensor_value(z), viz_results

    def easy_diffuse(self, target, context, *args, **kwargs):
        """Wraps functions `preprocess()` and `diffuse()` together."""
        return self.diffuse(self.preprocess(target), self.preprocess(context),
                            *args, **kwargs)
コード例 #9
0
def main():
    """Main function."""
    args = parse_args()

    # append task to the output dir path
    args.output_dir = os.path.join(args.output_dir, args.task)
    # create a directory if the output path does not exist
    #if not os.path.exists(args.output_dir):
    #    os.mkdir(args.output_dir)

    logger = setup_logger(args.output_dir, logger_name='generate_data')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[args.model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(args.model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(args.model_name, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing boundary.')
    args.boundary_path = process_bound_path(gan_type, args)
    if not os.path.isfile(args.boundary_path):
        raise ValueError(f'Boundary `{args.boundary_path}` does not exist!')
    boundary = np.load(args.boundary_path)
    np.save(os.path.join(args.output_dir, 'boundary.npy'), boundary)

    logger.info(f'Preparing latent codes.')
    if args.demo:
        demo_code(gan_type, args)

    if os.path.isfile(args.input_latent_codes_path):
        logger.info(
            f'  Load latent codes from `{args.input_latent_codes_path}`.')
        latent_codes = np.load(args.input_latent_codes_path)
        print(latent_codes.shape)
        if len(latent_codes) > 1:
            latent_codes = np.expand_dims(latent_codes[0], axis=0)
        latent_codes = model.preprocess(latent_codes, **kwargs)
    else:
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(args.num, **kwargs)
    np.save(os.path.join(args.output_dir, 'latent_codes.npy'), latent_codes)
    total_num = latent_codes.shape[0]

    logger.info(f'Editing {total_num} samples.')
    for sample_id in tqdm(range(total_num), leave=False):
        attr_index = args.attr_index
        if args.task == 'attribute':
            if args.method == 'interfacegan':
                # baseline modification from initial point
                interpolations = my_linear_interpolate(
                    latent_codes[sample_id:sample_id + 1],
                    attr_index,
                    boundary,
                    'linear',
                    steps=args.steps,
                    gan_type=gan_type,
                    step_size=args.step_size)
                interpolation_id = 0
                for interpolations_batch in model.get_batch_inputs(
                        interpolations):
                    if gan_type == 'pggan':
                        outputs = model.easy_synthesize(interpolations_batch)
                    elif gan_type == 'stylegan':
                        outputs = model.easy_synthesize(
                            interpolations_batch, **kwargs)
                    for image in outputs['image']:
                        save_path = os.path.join(
                            args.output_dir,
                            f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                        cv2.imwrite(save_path, image[:, :, ::-1])
                        interpolation_id += 1

            elif args.method == 'linear':
                # linear baseline attribute modification
                starting_latent_code = latent_codes[sample_id:sample_id +
                                                    1].reshape(1, -1)
                interpolations = my_linear_interpolate(
                    starting_latent_code,
                    attr_index,
                    boundary,
                    'static_linear',
                    steps=args.steps,
                    condition=args.condition,
                    gan_type=gan_type,
                    step_size=args.step_size)
                interpolation_id = 0
                for interpolations_batch in model.get_batch_inputs(
                        interpolations):
                    if gan_type == 'pggan':
                        outputs = model.easy_synthesize(interpolations_batch)
                    elif gan_type == 'stylegan':
                        outputs = model.easy_synthesize(
                            interpolations_batch, **kwargs)
                    for image in outputs['image']:
                        save_path = os.path.join(
                            args.output_dir,
                            f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                        cv2.imwrite(save_path, image[:, :, ::-1])
                        interpolation_id += 1

            elif args.method == 'ours':
                # attribute modification
                starting_latent_code = latent_codes[sample_id:sample_id +
                                                    1].reshape(1, -1)
                interpolations = my_linear_interpolate(
                    starting_latent_code,
                    attr_index,
                    boundary,
                    'piecewise_linear',
                    steps=args.steps,
                    condition=args.condition,
                    gan_type=gan_type,
                    step_size=args.step_size)
                interpolation_id = 0
                for interpolations_batch in model.get_batch_inputs(
                        interpolations):
                    if gan_type == 'pggan':
                        outputs = model.easy_synthesize(interpolations_batch)
                    elif gan_type == 'stylegan':
                        outputs = model.easy_synthesize(
                            interpolations_batch, **kwargs)
                    for image in outputs['image']:
                        save_path = os.path.join(
                            args.output_dir,
                            f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                        cv2.imwrite(save_path, image[:, :, ::-1])
                        interpolation_id += 1

        elif args.task == 'head_pose':
            # pose modification
            starting_latent_code = latent_codes[sample_id:sample_id +
                                                1].reshape(1, -1)
            interpolations = my_linear_interpolate(starting_latent_code,
                                                   attr_index,
                                                   boundary,
                                                   'pose_edit',
                                                   steps=args.steps,
                                                   condition=args.condition,
                                                   gan_type=gan_type,
                                                   step_size=args.step_size,
                                                   direction=args.direction)
            interpolation_id = 0
            for interpolations_batch in model.get_batch_inputs(interpolations):
                if gan_type == 'pggan':
                    outputs = model.easy_synthesize(interpolations_batch)
                elif gan_type == 'stylegan':
                    outputs = model.easy_synthesize(interpolations_batch,
                                                    **kwargs)
                for image in outputs['image']:
                    save_path = os.path.join(
                        args.output_dir,
                        f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                    cv2.imwrite(save_path, image[:, :, ::-1])
                    interpolation_id += 1

        elif args.task == 'landmark':
            # landmark modification
            starting_latent_code = latent_codes[sample_id:sample_id +
                                                1].reshape(1, -1)
            interpolations = my_linear_interpolate(starting_latent_code,
                                                   attr_index,
                                                   boundary,
                                                   'piecewise_linear',
                                                   steps=args.steps,
                                                   is_landmark=True,
                                                   condition=args.condition,
                                                   step_size=args.step_size,
                                                   direction=args.direction)
            interpolation_id = 0
            for interpolations_batch in model.get_batch_inputs(interpolations):
                if gan_type == 'pggan':
                    outputs = model.easy_synthesize(interpolations_batch)
                elif gan_type == 'stylegan':
                    outputs = model.easy_synthesize(interpolations_batch,
                                                    **kwargs)
                for image in outputs['image']:
                    save_path = os.path.join(
                        args.output_dir,
                        f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                    cv2.imwrite(save_path, image[:, :, ::-1])
                    interpolation_id += 1
コード例 #10
0
def training_loop(config,
                  dataset_args={},
                  E_lr_args=EasyDict(),
                  D_lr_args=EasyDict(),
                  opt_args=EasyDict(),
                  E_loss_args=EasyDict(),
                  D_loss_args=EasyDict(),
                  logger=None,
                  writer=None,
                  image_snapshot_ticks=50,
                  max_epoch=50):
    # parse
    loss_pix_weight = E_loss_args.loss_pix_weight
    loss_feat_weight = E_loss_args.loss_feat_weight
    loss_adv_weight = E_loss_args.loss_adv_weight
    loss_real_weight = D_loss_args.loss_real_weight
    loss_fake_weight = D_loss_args.loss_fake_weight
    loss_gp_weight = D_loss_args.loss_gp_weight
    loss_ep_weight = D_loss_args.loss_ep_weight
    E_learning_rate = E_lr_args.learning_rate
    D_learning_rate = D_lr_args.learning_rate

    # construct dataloader
    train_dataset = CelebaHQ(dataset_args, train=True)
    val_dataset = CelebaHQ(dataset_args, train=False)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.train_batch_size,
                                  shuffle=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.test_batch_size,
                                shuffle=False)

    # construct model
    G = StyleGANGenerator(config.model_name, logger, gpu_ids=config.gpu_ids)
    E = StyleGANEncoder(config.model_name, logger, gpu_ids=config.gpu_ids)
    F = PerceptualModel(min_val=G.min_val,
                        max_val=G.max_val,
                        gpu_ids=config.gpu_ids)
    D = StyleGANDiscriminator(config.model_name,
                              logger,
                              gpu_ids=config.gpu_ids)
    G.net.synthesis.eval()
    E.net.train()
    F.net.eval()
    D.net.train()
    encode_dim = [G.num_layers, G.w_space_dim]

    # optimizer
    optimizer_E = torch.optim.Adam(E.net.parameters(),
                                   lr=E_learning_rate,
                                   **opt_args)
    optimizer_D = torch.optim.Adam(D.net.parameters(),
                                   lr=D_learning_rate,
                                   **opt_args)
    lr_scheduler_E = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer_E, gamma=E_lr_args.decay_rate)
    lr_scheduler_D = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer_D, gamma=D_lr_args.decay_rate)

    global_step = 0
    for epoch in range(max_epoch):
        E_loss_rec = 0.
        E_loss_adv = 0.
        E_loss_feat = 0.
        D_loss_real = 0.
        D_loss_fake = 0.
        D_loss_grad = 0.
        learning_rate = lr_scheduler_E.get_lr()[0]
        for step, items in enumerate(train_dataloader):
            E.net.train()
            x = items
            x = x.float().cuda()
            batch_size = x.shape[0]
            z = E.net(x).view(batch_size, *encode_dim)
            x_rec = G.net.synthesis(z)

            # ===============================
            #         optimizing D
            # ===============================

            x_real = D.net(x)
            x_fake = D.net(x_rec.detach())
            loss_real = GAN_loss(x_real, real=True)
            loss_fake = GAN_loss(x_fake, real=False)
            # gradient div
            loss_gp = div_loss_(D, x, x_rec.detach(), cuda=config.cuda)
            # loss_gp = div_loss(D, x, x_real)

            D_loss_real += loss_real.item()
            D_loss_fake += loss_fake.item()
            D_loss_grad += loss_gp.item()
            log_message = f'D-[real:{loss_real.cpu().detach().numpy():.3f}, ' \
                          f'fake:{loss_fake.cpu().detach().numpy():.3f}, ' \
                          f'gp:{loss_gp.cpu().detach().numpy():.3f}]'
            D_loss = loss_real_weight * loss_real + loss_fake_weight * loss_fake + loss_gp_weight * loss_gp
            optimizer_D.zero_grad()
            D_loss.backward()
            optimizer_D.step()

            # ===============================
            #         optimizing G
            # ===============================
            # Reconstruction loss.
            loss_pix = torch.mean((x - x_rec)**2)
            E_loss_rec += loss_pix.item()
            log_message += f', G-[pix:{loss_pix.cpu().detach().numpy():.3f}'

            # Perceptual loss.
            loss_feat = 0.
            if loss_feat_weight:
                x_feat = F.net(x)
                x_rec_feat = F.net(x_rec)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                E_loss_feat += loss_feat.item()
                log_message += f', feat:{loss_feat.cpu().detach().numpy():.3f}'

            # adversarial loss.
            loss_adv = 0.
            if loss_adv_weight:
                x_adv = D.net(x_rec)
                loss_adv = GAN_loss(x_adv, real=True)
                E_loss_adv += loss_adv.item()
                log_message += f', adv:{loss_adv.cpu().detach().numpy():.3f}]'

            E_loss = loss_pix_weight * loss_pix + loss_feat_weight * loss_feat + loss_adv_weight * loss_adv
            log_message += f', loss:{E_loss.cpu().detach().numpy():.3f}'
            optimizer_E.zero_grad()
            E_loss.backward()
            optimizer_E.step()

            # pbar.set_description_str(log_message)
            if logger:
                logger.debug(f'Epoch:{epoch:03d}, '
                             f'Step:{step:04d}, '
                             f'lr:{learning_rate:.2e}, '
                             f'{log_message}')
            if writer:
                writer.add_scalar('D/loss_real',
                                  loss_real.item(),
                                  global_step=global_step)
                writer.add_scalar('D/loss_fake',
                                  loss_fake.item(),
                                  global_step=global_step)
                writer.add_scalar('D/loss_gp',
                                  loss_gp.item(),
                                  global_step=global_step)
                writer.add_scalar('D/loss',
                                  D_loss.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss_pix',
                                  loss_pix.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss_feat',
                                  loss_feat.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss_adv',
                                  loss_adv.item(),
                                  global_step=global_step)
                writer.add_scalar('E/loss',
                                  E_loss.item(),
                                  global_step=global_step)

            if step % image_snapshot_ticks == 0:
                E.net.eval()
                for val_step, val_items in enumerate(val_dataloader):
                    x_val = val_items
                    x_val = x_val.float().cuda()
                    batch_size_val = x_val.shape[0]
                    x_train = x[:batch_size_val, :, :, :]
                    z_train = E.net(x_train).view(batch_size_val, *encode_dim)
                    x_rec_train = G.net.synthesis(z_train)
                    z_val = E.net(x_val).view(batch_size_val, *encode_dim)
                    x_rec_val = G.net.synthesis(z_val)
                    x_all = torch.cat([x_val, x_rec_val, x_train, x_rec_train],
                                      dim=0)
                    if val_step > config.test_save_step:
                        break
                    save_filename = f'epoch_{epoch:03d}_step_{step:04d}_test_{val_step:04d}.png'
                    save_filepath = os.path.join(config.save_images,
                                                 save_filename)
                    tvutils.save_image(x_all,
                                       filename=save_filepath,
                                       nrow=config.test_batch_size,
                                       normalize=True,
                                       scale_each=True)

            global_step += 1
            if (global_step + 1) % E_lr_args.decay_step == 0:
                lr_scheduler_E.step()
            if (global_step + 1) % D_lr_args.decay_step == 0:
                lr_scheduler_D.step()

        D_loss_real /= train_dataloader.__len__()
        D_loss_fake /= train_dataloader.__len__()
        D_loss_grad /= train_dataloader.__len__()
        E_loss_rec /= train_dataloader.__len__()
        E_loss_adv /= train_dataloader.__len__()
        E_loss_feat /= train_dataloader.__len__()
        log_message_ep = f'D-[real:{D_loss_real:.3f}, fake:{D_loss_fake:.3f}, gp:{D_loss_grad:.3f}], ' \
                         f'G-[pix:{E_loss_rec:.3f}, feat:{E_loss_feat:.3f}, adv:{E_loss_adv:.3f}]'
        if logger:
            logger.debug(f'Epoch: {epoch:03d}, '
                         f'lr: {learning_rate:.2e}, '
                         f'{log_message_ep}')

        save_filename = f'styleganinv_encoder_epoch_{epoch:03d}'
        save_filepath = os.path.join(config.save_models, save_filename)
        torch.save(E.net.module.state_dict(), save_filepath)
コード例 #11
0
parser.add_argument(
    '--direction_type',
    default='w',
    help=
    'where direction is belong to, "mid_space" for direction belong to feature space of GAN'
)
parser.add_argument('--w_start', type=int, default=0)
parser.add_argument('--w_end', type=int, default=7)

args, other_args = parser.parse_known_args()
torch.cuda.set_device(0)
out_dir = 'static/results'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

gan = StyleGANGenerator("stylegan_anime")
gan.net.eval()

Bootstrap(app)

direction_list = []
shifts_r_list = []
manipulate_layers_list = []
for i in range(10):
    boundary_name = ATTRIBUTES_[i]['boundary']
    direction_path = os.path.join(args.boundaries_dir, f"{boundary_name}.npy")
    direction_dict = np.load(direction_path, allow_pickle=True)[()]
    direction_list.append(direction_dict["boundary"])
    # direction vector
    manipulate_layers_list.append(direction_dict["manipulate_layers"])
    # specific operation layers
コード例 #12
0
ファイル: new_demo.py プロジェクト: BERYLSHEEP/AdvStyle
def manipulate_test(attribute, output_dir, noise_path, resolution, gan_model,
                    latent_type):
    attr_list = parse_attr(attribute)
    boundary = []
    manipulate_layers = []
    shift_range = []

    for attr in attr_list:
        direction_path = os.path.join("./boundaries", f"{attr}.npy")
        direction_dict = np.load(direction_path, allow_pickle=True)[()]
        boundary.append(direction_dict["boundary"])
        # direction vector
        manipulate_layers.append(direction_dict["manipulate_layers"])
        # specific operation layers
        shift_range.append(direction_dict["shift_range"])
        '''
        recommand range
        direction_dict["shift_range"]: [-10, 10] 
        represents that the negative direction step is -10 and the positive direction step is 10
        '''

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(f"{output_dir}/{attr_list[0]}", exist_ok=True)

    step = 7
    gan = StyleGANGenerator(gan_model)
    gan.net.eval()

    num_layers = int((math.log(resolution, 2) - 1) * 2)

    if noise_path is None:
        noise_paths = []
        for attr in attr_list:
            dir_name = f"./noise/{attr}"
            noise_path = read_image_path(None, dir_name)
            noise_paths += noise_path
    else:
        if not os.path.exists(noise_path):
            raise ValueError(f"noise path is not exist: {noise_path}")
        noise_paths = [noise_path]

    with tqdm(total=len(noise_paths)) as pbar:
        for noise_path in noise_paths:
            name = os.path.basename(noise_path).split(".")[0]
            latent = np.load(noise_path)
            noise_torch = torch.from_numpy(latent).float().cuda()

            #noise_torch = torch.randn((1,512)).cuda()
            #np.save("./result/test.npy", noise_torch.detach().cpu())
            #w = gan.net.mapping(noise_torch)
            #ws = gan.net.truncation(w)
            #image = gan.net.synthesis(ws)
            #save_img(image, "./result/test.png",is_torch=True, is_map=False, trans_type=None)

            if latent_type == "ws":
                ws = noise_torch
            elif latent_type == "z":
                w = gan.net.mapping(noise_torch)
                ws = gan.net.truncation(w)
            if latent_type == "w":
                ws = gan.net.truncation(noise_torch)

            output_images = []
            #bdary = torch.from_numpy(boundary[0].T).cuda()
            #w = w - torch.matmul(w, bdary)*torch.from_numpy(boundary[0]).cuda()
            wp_np = torch_to_numpy(ws)
            shift_range = np.array(shift_range)

            wp_mani = manipulation(latent_codes=wp_np,
                                   boundary=boundary,
                                   start_distance=shift_range[:, 0],
                                   end_distance=shift_range[:, 1],
                                   steps=step,
                                   layerwise_manipulation=True,
                                   num_layers=num_layers,
                                   manipulation_layers=manipulate_layers,
                                   is_code_layerwise=True,
                                   is_boundary_layerwise=False)
            '''
            When generating one image,
            please set step to 1,
            set end_distance to x,
            where shift_range[:,0] <= x <= shift_range[:,1] is recommended,
            set start_distance randomly.

            when generating multi images(multi steps),
            please set end_distance to shift_range[:,1],
            set start_distance to shift_range[:,0]
            
            wp_np shape: [batch_size, steps, num_layers, *code_shape]
            '''
            for step_idx in range(step):
                test_torch = torch.from_numpy(wp_mani[:, step_idx, :, :])
                test_torch = test_torch.type(torch.FloatTensor).cuda()
                images = gan.net.synthesis(test_torch)

                save_img(
                    images,
                    f"./{output_dir}/{attr_list[0]}/{name}_{step_idx}.png",
                    is_torch=True,
                    is_map=False,
                    trans_type=None)

            pbar.update(1)
コード例 #13
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(logger_name='latent_train')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[args.model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(args.model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(args.model_name, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    data_transforms = {
        TRAIN:
        transforms.Compose([
            # Data augmentation is a good practice for the train set
            # Here, we randomly crop the image to 224x224 and
            # randomly flip it horizontally.
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]),
        VAL:
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ]),
        TEST:
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
    }

    image_datasets = {
        TRAIN: ImageDataset(
            args.train_dir,
            transform=data_transforms[TRAIN],
        ),
        VAL: ImageDataset(
            args.val_dir,
            transform=data_transforms[VAL],
        ),
        TEST: ImageDataset(
            args.test_dir,
            transform=data_transforms[TEST],
        ),
    }

    dataloaders = {
        x: torch.utils.data.DataLoader(image_datasets[x],
                                       batch_size=args.batch_size,
                                       shuffle=True)
        for x in [TRAIN, VAL, TEST]
    }

    dataset_sizes = {x: len(image_datasets[x]) for x in [TRAIN, VAL, TEST]}

    logger.info(f'Preparing VGG.')
    # Load the pretrained model from pytorch
    #stock_vgg = models.vgg16()
    stock_vgg = models.vgg16()
    #stock_vgg = vgg_face_dag.Vgg_face_dag()
    if os.path.isfile(args.pretrained_vgg_path):
        logger.info(f'  Load vgg-16 state from `{args.pretrained_vgg_path}`.')
        stock_vgg.load_state_dict(torch.load(args.pretrained_vgg_path))
    else:
        raise NotImplementedError(
            f'  VGG16 initialized randomly. Is this really what you want?')
    import pdb
    pdb.set_trace()
    vgg = VGG9(stock_vgg)

    start_epoch = 0
    if args.resume:
        start_epoch, latest = find_latest_epoch_and_checkpoint(
            "./checkpoints/")
        if latest:
            print(
                f"restoring from checkpoint: {latest} (epoch: {start_epoch})")
            vgg.load_state_dict(torch.load(latest))

    params_to_update = []
    for name, param in vgg.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)

    if args.use_gpu:
        vgg.cuda()

    criterion = nn.MSELoss()
    #optimizer_ft = optim.Adam(params_to_update, lr=0.1, betas=[0.5, 0.999])
    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    print("Visualizing model...\n")
    visualize_model(vgg, model, 999, num_examples=4, **kwargs)

    print("Test before training...\n")
    eval_model(vgg,
               dataloaders,
               dataset_sizes,
               criterion,
               use_gpu=args.use_gpu)

    print("Training model...\n")
    vgg = train_model(vgg,
                      model,
                      dataloaders,
                      dataset_sizes,
                      criterion,
                      optimizer_ft,
                      num_epochs=args.epochs,
                      use_gpu=args.use_gpu,
                      gan_kwargs=kwargs,
                      start_epoch=start_epoch)
    print("Test after training...\n")
    eval_model(vgg,
               dataloaders,
               dataset_sizes,
               criterion,
               use_gpu=args.use_gpu)
コード例 #14
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(logger_name='latent_train')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[args.model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(args.model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(args.model_name, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing VGG.')
    stock_vgg = models.vgg16()
    if os.path.isfile(args.pretrained_vgg_path):
        logger.info(f'  Load vgg-16 state from `{args.pretrained_vgg_path}`.')
        stock_vgg.load_state_dict(torch.load(args.pretrained_vgg_path))
    else:
        raise NotImplementedError(
            f'  VGG16 initialized randomly. Is this really what you want?')
    vgg = VGG9(stock_vgg)

    if args.use_gpu:
        vgg.cuda()
    """
  R = your real image
  Gen(latent) - a generated image from some latent vector using pre-trained generator
  VGG16 - a pre-trained model for perceptual loss (9th layer in my implementation, but 5 also can be used)

  R_features = VGG16(R)
  G_features = VGG16(Gen(latent))

  loss = mse(R_features, G_features)
  ** only change latent**
  """

    gan_model = model  # haha.
    vgg.train(False)
    vgg.eval()

    def normalize_image_to_arr(input_image):
        arr = np.array(input_image).astype(np.float32)

        # resize
        arr = cv2.resize(arr[:, :, ::-1], (224, 224))

        # normalize for VGG
        arr[:, :, 0] -= 103.939
        arr[:, :, 1] -= 116.779
        arr[:, :, 2] -= 123.68
        arr /= 255.0

        return arr

    def image_to_tensor(image):
        arr = normalize_image_to_arr(image)
        arr = arr.transpose((2, 0, 1)).astype(np.float32)
        tensor = torch.from_numpy(arr).float()
        tensor = tensor.unsqueeze(0)
        return tensor

    def image_from_gan(latent_code_tensor):
        latent_code = latent_code_tensor.cpu().detach().numpy()
        gan_outputs = gan_model.easy_synthesize(latent_code, **kwargs)
        gan_output_image = None
        for image in gan_outputs['image']:
            gan_output_image = image
            break
        return gan_output_image

    image = Image.open(args.input_image).resize((224, 224))
    tensor_r = image_to_tensor(np.array(image))
    tensor_r = tensor_r.cuda()

    #numpy_codes = gan_model.easy_sample(1, **kwargs).astype(np.float32)
    #latent_tensor = torch.from_numpy(numpy_codes).float()
    #latent_code = Variable(latent_tensor.cuda(), requires_grad=True)

    weights = torch.randn(1, 512, device='cuda', requires_grad=True)
    criterion = nn.MSELoss()
    optimizer = optim.SGD([weights], lr=1, momentum=0.1)

    pbar = tqdm(range(args.epochs))
    for i in pbar:
        features_r = vgg(tensor_r)
        features_g = vgg(image_to_tensor(image_from_gan(weights)).cuda())

        left = features_g * weights
        right = features_r * weights

        loss = criterion(left, right)
        #print(f"loss was: {loss}")
        pbar.set_description(f"loss: {loss}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del features_g, features_r
        torch.cuda.empty_cache()
    np.save("codes.npy", weights.cpu().detach().numpy())
コード例 #15
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(args.output_dir,
                          logger_name='generate_data',
                          immutable=False)

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[args.model_name]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(args.model_name, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(args.model_name, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.latent_codes_path):
        logger.info(f'  Load latent codes from `{args.latent_codes_path}`.')
        latent_codes = np.load(args.latent_codes_path)
        latent_codes = model.preprocess(latent_codes, **kwargs)
    else:
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.easy_sample(args.num, **kwargs)
        # The orginal code of interfaceGAN does not have this line.
        #if gan_type == 'stylegan':
        #    latent_codes = model.preprocess(latent_codes, **kwargs)
    total_num = latent_codes.shape[0]

    logger.info(f'Generating {total_num} samples.')
    results = defaultdict(list)
    pbar = tqdm(total=total_num, leave=False)
    """Pretrained attribute classifier
  
  Replace the classifier 'attr_clf' with your own task models.
  """
    attr_num = 16
    attr_clf = YOUR_TASK_MODEL()
    logger.info(f'Classifier loaded.')

    attr_clf.cuda()
    attr_clf.eval()
    attr_clf.requires_grad = False

    downscale = nn.Upsample(size=224, mode='bilinear')
    if os.path.exists(args.key_file_path):
        raise ValueError(f'{args.key_file_path} has existed.')
    else:
        data_index = [[[], []] for _ in range(int(attr_num))]
    image_cnt = args.start_index

    for latent_codes_batch in model.get_batch_inputs(latent_codes):
        if gan_type == 'pggan':
            outputs = model.easy_synthesize(latent_codes_batch)
        elif gan_type == 'stylegan':
            outputs = model.easy_synthesize(latent_codes_batch,
                                            **kwargs,
                                            generate_style=args.generate_style,
                                            generate_image=args.generate_image)

        with torch.no_grad():
            val = outputs['image']
            # Within a batch, some of images will be saved, while some of them won't.
            kept_indices = []
            img_tensor = torch.tensor(val.transpose(0, 3, 1, 2)).cuda().float()
            img_tensor = downscale(img_tensor) / 255.
            # predict = torch.sigmoid(logits)
            preds = attr_clf.predict(img_tensor)
            for iid, image in enumerate(val):
                pbar.update(1)
                is_save = False
                for ind in range(attr_num):
                    if preds[iid][ind] >= args.threshold and len(
                            data_index[ind][0]) < args.sample_per_category:
                        is_save = True
                        data_index[ind][0].append(image_cnt)
                    elif (1 - preds[iid][ind]) >= args.threshold and len(
                            data_index[ind][1]) < args.sample_per_category:
                        is_save = True
                        data_index[ind][1].append(image_cnt)

                if is_save:
                    kept_indices.append(iid)
                    save_path = os.path.join(args.output_dir,
                                             f'{image_cnt:08d}.jpg')
                    if not args.no_generated_imgs:
                        cv2.imwrite(save_path, image[:, :, ::-1])
                    results['soft_labels'].append(preds[iid].reshape(1, -1))
                    image_cnt += 1

            for key, val in outputs.items():
                if key != 'image' and len(kept_indices) > 0:
                    val = val[kept_indices]
                    results[key].append(val)
            if 'image' not in outputs:
                pbar.update(latent_codes_batch.shape[0])
            if pbar.n % 1000 == 0 or pbar.n == total_num:
                print('iter: ', pbar.n)
                for ind in range(attr_num):
                    print("attr_index: {}, pos: {}, neg{}".format(
                        ind, len(data_index[ind][0]), len(data_index[ind][1])))
                # save data_index
                with open(args.key_file_path, 'wb') as f:
                    pickle.dump(data_index, f)
                logger.debug(f'  Finish {pbar.n:6d} samples.')
                logger.info(f'Saving results.')
                for key, val in results.items():
                    save_path = os.path.join(args.output_dir, f'{key}.npy')
                    np.save(save_path, np.concatenate(val, axis=0))
    pbar.close()
コード例 #16
0
def main():
    """Main function."""
    args = parse_args()
    logger = setup_logger(config.OUTPUT_PATH, logger_name='generate_data')

    logger.info(f'Initializing generator.')
    gan_type = MODEL_POOL[config.MODEL_NAME]['gan_type']
    if gan_type == 'pggan':
        model = PGGANGenerator(config.MODEL_NAME, logger)
        kwargs = {}
    elif gan_type == 'stylegan':
        model = StyleGANGenerator(config.MODEL_NAME, logger)
        kwargs = {'latent_space_type': args.latent_space_type}
    else:
        raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')

    logger.info(f'Preparing boundary.')
    if not os.path.isfile(config.BOUNDARY_PATH):
        raise ValueError(f'Boundary `{config.BOUNDARY_PATH}` does not exist!')
    boundary = np.load(config.BOUNDARY_PATH)
    #np.save(os.path.join(config.OUTPUT_PATH, 'boundary.npy'), boundary)

    logger.info(f'Preparing latent codes.')
    if os.path.isfile(args.input_latent_codes_path):
        logger.info(
            f'  Load latent codes from `{args.input_latent_codes_path}`.')
        latent_codes = np.load(args.input_latent_codes_path)
        latent_codes = model.preprocess(latent_codes, **kwargs)
    else:
        logger.info(f'  Sample latent codes randomly.')
        latent_codes = model.complicate_sample(args.num, **kwargs)
    np.save(os.path.join(config.OUTPUT_PATH, 'latent_codes.npy'), latent_codes)
    total_num = latent_codes.shape[0]

    logger.info(f'Editing {total_num} samples.')

    for sample_id in tqdm(range(total_num), leave=False):
        interpolations = linear_interpolate(latent_codes[sample_id:sample_id +
                                                         1],
                                            boundary,
                                            start_distance=args.start_distance,
                                            end_distance=args.end_distance,
                                            steps=args.steps)
        interpolation_id = 0
        canvas = PIL.Image.new(
            'RGB', (config.RESOLUTION * args.steps, config.RESOLUTION),
            'white')
        for interpolations_batch in model.get_batch_inputs(interpolations):
            if gan_type == 'pggan':
                outputs = model.easy_synthesize(interpolations_batch)
            elif gan_type == 'stylegan':
                outputs = model.easy_synthesize(interpolations_batch, **kwargs)
            for image in outputs['image']:
                # save_path = os.path.join(config.IMAGE_PATH,
                #                          f'{sample_id:03d}_{interpolation_id:03d}.jpg')
                # cv2.imwrite(save_path, image[:, :, ::-1])
                canvas.paste(
                    transforms.ToPILImage(mode='RGB')(image),
                    (config.RESOLUTION * interpolation_id, 0))
                interpolation_id += 1
        save_path = os.path.join(config.OUTPUT_PATH, f'{sample_id:03d}.jpg')
        canvas.save(save_path)
        assert interpolation_id == args.steps
        logger.debug(f'  Finished sample {sample_id:3d}.')
    logger.info(f'Successfully edited {total_num} samples.')