Пример #1
0
        def get_landmark_row(lm, kp, img_list):
            displayed_imgs = [lm]

            # the np variables we'll build to feed to graph!
            batch_images = np.zeros(input_shape, np.uint8)
            batch_lms = np.zeros(input_shape, np.uint8)
            batch_kps = np.zeros([self.minibatch_per_gpu, 136], np.float32)

            for img_idx in tqdm(range(0, len(img_list), self.minibatch_per_gpu), leave=False):
                batch = img_list[img_idx:img_idx + self.minibatch_per_gpu]
                for i, image in enumerate(batch):
                    batch_images[i] = np.transpose(image, [2, 0, 1])
                    batch_lms[i] = np.transpose(lm, [2, 0, 1])
                    batch_kps[i] = kp.flatten()

                inputs = batch_images.astype(np.float32) / 255 * 2.0 - 1.0
                inputs_lm = batch_lms.astype(np.float32) / 255 * 2.0 - 1.0
                input_kp = batch_kps

                # Run encoder.
                outputs = tflib.run(manipulated_images, {x: inputs, x_lm: inputs_lm, x_kp: input_kp})
                outputs = adjust_pixel_range(outputs, min_val=0, max_val=255) # 16 x 128 x 128 x 3
                for i, _ in enumerate(batch):
                    displayed_imgs.append(outputs[i])

            return displayed_imgs
def main():
  """Main function."""
  args = parse_args()
  os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
  style_dir = args.style_dir
  style_dir_name = os.path.basename(style_dir.rstrip('/'))
  assert os.path.exists(style_dir)
  assert os.path.exists(f'{style_dir}/image_list.txt')
  assert os.path.exists(f'{style_dir}/inverted_codes.npy')
  content_dir = args.content_dir
  content_dir_name = os.path.basename(content_dir.rstrip('/'))
  assert os.path.exists(content_dir)
  assert os.path.exists(f'{content_dir}/image_list.txt')
  assert os.path.exists(f'{content_dir}/inverted_codes.npy')
  output_dir = args.output_dir or 'results/style_mixing'
  job_name = f'{style_dir_name}_STYLIZE_{content_dir_name}'
  logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

  # Load model.
  logger.info(f'Loading generator.')
  tflib.init_tf({'rnd.np_random_seed': 1000})
  with open(args.model_path, 'rb') as f:
    _, _, _, Gs = pickle.load(f)

  # Build graph.
  logger.info(f'Building graph.')
  sess = tf.get_default_session()
  num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3]
  wp = tf.placeholder(
      tf.float32, [args.batch_size, num_layers, latent_dim], name='latent_code')
  x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
  mix_layers = list(range(args.mix_layer_start_idx, num_layers))

  # Load image and codes.
  logger.info(f'Loading images and corresponding inverted latent codes.')
  style_list = []
  with open(f'{style_dir}/image_list.txt', 'r') as f:
    for line in f:
      name = os.path.splitext(os.path.basename(line.strip()))[0]
      assert os.path.exists(f'{style_dir}/{name}_ori.png')
      style_list.append(name)
  logger.info(f'Loading inverted latent codes.')
  style_codes = np.load(f'{style_dir}/inverted_codes.npy')
  assert style_codes.shape[0] == len(style_list)
  num_styles = style_codes.shape[0]
  content_list = []
  with open(f'{content_dir}/image_list.txt', 'r') as f:
    for line in f:
      name = os.path.splitext(os.path.basename(line.strip()))[0]
      assert os.path.exists(f'{content_dir}/{name}_ori.png')
      content_list.append(name)
  logger.info(f'Loading inverted latent codes.')
  content_codes = np.load(f'{content_dir}/inverted_codes.npy')
  assert content_codes.shape[0] == len(content_list)
  num_contents = content_codes.shape[0]

  # Mix styles.
  logger.info(f'Start style mixing.')
  viz_size = None if args.viz_size == 0 else args.viz_size
  visualizer = HtmlPageVisualizer(
      num_rows=num_styles + 1, num_cols=num_contents + 1, viz_size=viz_size)
  visualizer.set_headers(
      ['Style'] +
      [f'Content {i:03d}' for i in range(num_contents)]
  )
  for style_idx, style_name in enumerate(style_list):
    style_image = load_image(f'{style_dir}/{style_name}_ori.png')
    visualizer.set_cell(style_idx + 1, 0, image=style_image)
  for content_idx, content_name in enumerate(content_list):
    content_image = load_image(f'{content_dir}/{content_name}_ori.png')
    visualizer.set_cell(0, content_idx + 1, image=content_image)

  codes = mix_style(style_codes=style_codes,
                    content_codes=content_codes,
                    num_layers=num_layers,
                    mix_layers=mix_layers)
  inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32)
  for style_idx in tqdm(range(num_styles), leave=False):
    output_images = []
    for idx in range(0, num_contents, args.batch_size):
      batch = codes[style_idx, idx:idx + args.batch_size]
      inputs[0:len(batch)] = batch
      images = sess.run(x, feed_dict={wp: inputs})
      output_images.append(images[0:len(batch)])
    output_images = adjust_pixel_range(np.concatenate(output_images, axis=0))
    for content_idx, output_image in enumerate(output_images):
      visualizer.set_cell(style_idx + 1, content_idx + 1, image=output_image)

  # Save results.
  visualizer.save(f'{output_dir}/{job_name}.html')
Пример #3
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.image_list)
    image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
    output_dir = args.output_dir or f'results/inversion/{image_list_name}'
    logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    x_255 = (tf.transpose(x, [0, 2, 3, 1]) + 1) / 2 * 255
    latent_shape = Gs.components.synthesis.input_shape
    latent_shape[0] = args.batch_size
    wp = tf.get_variable(shape=latent_shape, name='latent_code')
    x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
    x_rec_255 = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) / 2 * 255
    if args.random_init:
        logger.info(f'  Use random initialization for optimization.')
        wp_rnd = tf.random.normal(shape=latent_shape, name='latent_code_init')
        setter = tf.assign(wp, wp_rnd)
    else:
        logger.info(
            f'  Use encoder output as the initialization for optimization.')
        w_enc = E.get_output_for(x, is_training=False)
        wp_enc = tf.reshape(w_enc, latent_shape)
        setter = tf.assign(wp, wp_enc)

    # Settings for optimization.
    logger.info(f'Setting configuration for optimization.')
    perceptual_model = PerceptualModel([image_size, image_size], False)
    x_feat = perceptual_model(x_255)
    x_rec_feat = perceptual_model(x_rec_255)
    loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1])
    loss_pix = tf.reduce_mean(tf.square(x - x_rec), axis=[1, 2, 3])
    if args.domain_regularizer:
        logger.info(f'  Involve encoder for optimization.')
        w_enc_new = E.get_output_for(x_rec, is_training=False)
        wp_enc_new = tf.reshape(w_enc_new, latent_shape)
        loss_enc = tf.reduce_mean(tf.square(wp - wp_enc_new), axis=[1, 2])
    else:
        logger.info(f'  Do NOT involve encoder for optimization.')
        loss_enc = 0
    loss = (loss_pix + args.loss_weight_feat * loss_feat +
            args.loss_weight_enc * loss_enc)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    train_op = optimizer.minimize(loss, var_list=[wp])
    tflib.init_uninitialized_vars()

    # Load image list.
    logger.info(f'Loading image list.')
    image_list = []
    with open(args.image_list, 'r') as f:
        for line in f:
            image_list.append(line.strip())

    # Invert images.
    logger.info(f'Start inversion.')
    save_interval = args.num_iterations // args.num_results
    headers = ['Name', 'Original Image', 'Encoder Output']
    for step in range(1, args.num_iterations + 1):
        if step == args.num_iterations or step % save_interval == 0:
            headers.append(f'Step {step:06d}')
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=len(image_list),
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    names = ['' for _ in range(args.batch_size)]
    latent_codes_enc = []
    latent_codes = []
    for img_idx in tqdm(range(0, len(image_list), args.batch_size),
                        leave=False):
        # Load inputs.
        batch = image_list[img_idx:img_idx + args.batch_size]
        for i, image_path in enumerate(batch):
            image = resize_image(load_image(image_path),
                                 (image_size, image_size))
            images[i] = np.transpose(image, [2, 0, 1])
            names[i] = os.path.splitext(os.path.basename(image_path))[0]
        inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
        # Run encoder.
        sess.run([setter], {x: inputs})
        outputs = sess.run([wp, x_rec])
        latent_codes_enc.append(outputs[0][0:len(batch)])
        outputs[1] = adjust_pixel_range(outputs[1])
        for i, _ in enumerate(batch):
            image = np.transpose(images[i], [1, 2, 0])
            save_image(f'{output_dir}/{names[i]}_ori.png', image)
            save_image(f'{output_dir}/{names[i]}_enc.png', outputs[1][i])
            visualizer.set_cell(i + img_idx, 0, text=names[i])
            visualizer.set_cell(i + img_idx, 1, image=image)
            visualizer.set_cell(i + img_idx, 2, image=outputs[1][i])
        # Optimize latent codes.
        col_idx = 3
        for step in tqdm(range(1, args.num_iterations + 1), leave=False):
            sess.run(train_op, {x: inputs})
            if step == args.num_iterations or step % save_interval == 0:
                outputs = sess.run([wp, x_rec])
                outputs[1] = adjust_pixel_range(outputs[1])
                for i, _ in enumerate(batch):
                    if step == args.num_iterations:
                        save_image(f'{output_dir}/{names[i]}_inv.png',
                                   outputs[1][i])
                    visualizer.set_cell(i + img_idx,
                                        col_idx,
                                        image=outputs[1][i])
                col_idx += 1
        latent_codes.append(outputs[0][0:len(batch)])

    # Save results.
    os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
    np.save(f'{output_dir}/encoded_codes.npy',
            np.concatenate(latent_codes_enc, axis=0))
    np.save(f'{output_dir}/inverted_codes.npy',
            np.concatenate(latent_codes, axis=0))
    visualizer.save(f'{output_dir}/inversion.html')
Пример #4
0
    def _evaluate(self, Gs, E, Inv, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(config.INCEPTION_PICKLE_DIR) # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)

        announce("Evaluating Reals")
        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
            print("loaded real mu, sigma from cache.")
        else:
            progress = 0
            for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):
                batch_stacks = data[0]
                progress += batch_stacks.shape[0]
                images = batch_stacks[:,0,:,:,:]
                landmarks = batch_stacks[:,1,:,:,:]

                # compute inception on full images!!!
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)


                # visualization
                images = images.astype(np.float32) / 255 * 2.0 - 1.0
                landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0

                if idx <= 10:
                    debug_img = np.concatenate([
                        images, # original landmarks
                        landmarks # original portraits,
                    ], axis=0)
                    debug_img = adjust_pixel_range(debug_img)
                    debug_img = fuse_images(debug_img, row=2, col=minibatch_size)
                    save_image("data_iter_{}08d.png".format(idx), debug_img)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
        
        announce("Evaluating Generator.")
        # Construct TensorFlow graph.
        result_expr = []
        print("Construct TensorFlow graph.")
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        print("Calculate statistics for fakes.")
        for begin in tqdm(range(0, self.num_images, minibatch_size), position=0, leave=True):
            end = min(begin + minibatch_size, self.num_images)
            #print("result_expr", len(result_expr)) # result_expr is a list!!!
            # results_expr[0].shape = (8, 2048) -> hat nur ein element.
            # weil: eigentlich würde man halt hier die GPUs zusammen konkattenieren.

            res_expr, fakes = tflib.run([result_expr, images])
            activations[begin:end] = np.concatenate(res_expr, axis=0)[:end-begin]

            if begin < 20:
                fakes = fakes.astype(np.float32) / 255 * 2.0 - 1.0
                debug_img = np.concatenate([
                    fakes
                ], axis=0)
                debug_img = adjust_pixel_range(debug_img)
                debug_img = fuse_images(debug_img, row=3, col=minibatch_size)
                save_image("fid_generator_iter_{}08d.png".format(end), debug_img)


        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        #print("mu_fake={}, sigma_fake={}".format(mu_fake, sigma_fake))
        
        # Calculate FID.
        print("Calculate FID (generator).")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="StyleGAN Generator Only")
        print("Distance StyleGAN", dist)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

        announce("Now evaluating encoder (appearnace)")
        print("building custom encoder graph!")
        with tf.variable_scope('fakeddddoptimizer'):

            # Build graph.
            BATCH_SIZE = self.minibatch_per_gpu
            input_shape = Inv.input_shape
            input_shape[0] = BATCH_SIZE
            latent_shape = Gs.components.synthesis.input_shape
            latent_shape[0] = BATCH_SIZE

            x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
            x_lm = tf.placeholder(tf.float32, shape=input_shape, name='some_landmark')
            x_kp = tf.placeholder(tf.float32, shape=[self.minibatch_per_gpu, 136], name='some_keypoints')

            if self.model_type == "rignet":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_lm, phase=False)
            elif self.model_type == "keypoints":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_kp, phase=False)
            else:
                w_enc = E.get_output_for(x, x_lm, phase=False)

            wp_enc = tf.reshape(w_enc, latent_shape)

            manipulated_images = Gs.components.synthesis.get_output_for(wp_enc, randomize_noise=False)
            manipulated_images = tflib.convert_images_to_uint8(manipulated_images)
            inception_codes = inception_clone.get_output_for(manipulated_images) # shape (8, 2048)

        for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):
            batch_stacks = data[0]
            images = batch_stacks[:,0,:,:,:]    # shape (8, 3, 128, 128)
            landmarks = batch_stacks[:,1,:,:,:] # shape (8, 3, 128, 128)
            images = images.astype(np.float32) / 255 * 2.0 - 1.0
            landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0
            keypoints = np.roll(data[1], shift=1, axis=0)

            begin = idx * minibatch_size
            end = min(begin + minibatch_size, self.num_images) # begin: 0; end: 8

            activations[begin:end], manip  = tflib.run([inception_codes, manipulated_images], feed_dict={x:images, x_lm:landmarks, x_kp:keypoints})
            # acivations: (5000, 2048)



            if idx < 10:
                print("saving img")
                manip = manip.astype(np.float32) / 255 * 2.0 - 1.0
                debug_img = np.concatenate([
                    images, # original landmarks
                    landmarks, # original portraits,
                    manip
                ], axis=0)
                debug_img = adjust_pixel_range(debug_img)
                debug_img = fuse_images(debug_img, row=3, col=minibatch_size)
                save_image("fid_iter_{}08d.png".format(idx), debug_img)


            if end == self.num_images:
                break

        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)
        #print("enc_mu_fake={}, enc_sigma_fake={}".format(mu_fake, sigma_fake))


        # Calculate FID.
        print("Calculate FID for encoded samples")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="Our Face-Landmark-Encoder (Apperance)")
        print("distance OUR FACE-LANDMARK-ENCODER", dist)


#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

        announce("Now evaluating encoder. (POSE)")
        print("building custom encoder graph!")
        with tf.variable_scope('fakeddddoptimizer'):

            # Build graph.
            BATCH_SIZE = self.minibatch_per_gpu
            input_shape = Inv.input_shape
            input_shape[0] = BATCH_SIZE
            latent_shape = Gs.components.synthesis.input_shape
            latent_shape[0] = BATCH_SIZE

            x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
            x_lm = tf.placeholder(tf.float32, shape=input_shape, name='some_landmark')
            x_kp = tf.placeholder(tf.float32, shape=[self.minibatch_per_gpu, 136], name='some_keypoints')

            if self.model_type == "rignet":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_lm, phase=False)
            elif self.model_type == "keypoints":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_kp, phase=False)
            else:
                w_enc = E.get_output_for(x, x_lm, phase=False)

            wp_enc = tf.reshape(w_enc, latent_shape)

            manipulated_images = Gs.components.synthesis.get_output_for(wp_enc, randomize_noise=False)
            manipulated_images = tflib.convert_images_to_uint8(manipulated_images)
            inception_codes = inception_clone.get_output_for(manipulated_images) # shape (8, 2048)

        for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):

            image_data = data[0]
            images = image_data[:,0,:,:,:]
            landmarks = np.roll(image_data[:,1,:,:,:], shift=1, axis=0)
            
            keypoints = np.roll(data[1], shift=1, axis=0)

            images = images.astype(np.float32) / 255 * 2.0 - 1.0
            landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0

            begin = idx * minibatch_size
            end = min(begin + minibatch_size, self.num_images) # begin: 0; end: 8

            activations[begin:end], manip  = tflib.run([inception_codes, manipulated_images], feed_dict={x:images, x_lm:landmarks, x_kp:keypoints})
            # acivations: (5000, 2048)



            if idx < 10:
                print("saving img")
                manip = manip.astype(np.float32) / 255 * 2.0 - 1.0
                debug_img = np.concatenate([
                    images, # original landmarks
                    landmarks, # original portraits,
                    manip
                ], axis=0)
                debug_img = adjust_pixel_range(debug_img)
                debug_img = fuse_images(debug_img, row=3, col=minibatch_size)
                save_image("fid_iter_POSE_{}08d.png".format(idx), debug_img)


            if end == self.num_images:
                break

        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)
        #print("enc_mu_fake={}, enc_sigma_fake={}".format(mu_fake, sigma_fake))


        # Calculate FID.
        print("Calculate FID for encoded samples (POSE)")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="Our_Face_Landmark_Encoder (Pose)")
        print("distance OUR FACE-LANDMARK-ENCODER (POSE)", dist)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

        announce("Now in domain inversion only encoder.")
        print("building custom in domain inversion graph!")
        with tf.variable_scope('fakedddwdoptimizer'):

            # Build graph.
            BATCH_SIZE = self.minibatch_per_gpu
            input_shape = Inv.input_shape
            input_shape[0] = BATCH_SIZE
            latent_shape = Gs.components.synthesis.input_shape
            latent_shape[0] = BATCH_SIZE

            x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')

            w_enc_1 = Inv.get_output_for(x, phase=False)
            wp_enc_1 = tf.reshape(w_enc_1, latent_shape)

            manipulated_images = Gs.components.synthesis.get_output_for(wp_enc_1, randomize_noise=False)
            manipulated_images = tflib.convert_images_to_uint8(manipulated_images)
            inception_codes = inception_clone.get_output_for(manipulated_images)

        for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):
            batch_stacks = data[0]
            images = batch_stacks[:,0,:,:,:]
            landmarks = batch_stacks[:,1,:,:,:]
            images = images.astype(np.float32) / 255 * 2.0 - 1.0
            landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0

            #print("landmarks", landmarks.shape)# (8, 3, 128, 128)
            #print("images", images.shape) # (8, 3, 128, 128)
            #print("inception_codes", inception_codes.shape) # (8, 2048)
            #print("activations", activations.shape) # (5000, 2048)
            begin = idx * minibatch_size
            end = min(begin + minibatch_size, self.num_images)
            #print("b,e", begin, end) # 0, 8; ...

            activations[begin:end]  = tflib.run(inception_codes, feed_dict={x:images})

            if end == self.num_images:
                break

        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)
        #print("enc_mu_fake={}, enc_sigma_fake={}".format(mu_fake, sigma_fake))


        # Calculate FID.
        print("Calculate FID for IN-DOMAIN-GAN-INVERSION")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="_In-Domain-Inversion_Only")
        print("distance IN-DOMAIN-GAN-INVERSION:", dist)
Пример #5
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    src_dir = args.src_dir
    src_dir_name = os.path.basename(src_dir.rstrip('/'))
    assert os.path.exists(src_dir)
    assert os.path.exists(f'{src_dir}/image_list.txt')
    assert os.path.exists(f'{src_dir}/inverted_codes.npy')
    dst_dir = args.dst_dir
    dst_dir_name = os.path.basename(dst_dir.rstrip('/'))
    assert os.path.exists(dst_dir)
    assert os.path.exists(f'{dst_dir}/image_list.txt')
    assert os.path.exists(f'{dst_dir}/inverted_codes.npy')
    output_dir = args.output_dir or 'results/interpolation'
    job_name = f'{src_dir_name}_TO_{dst_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        _, _, _, Gs = pickle.load(f)

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3]
    wp = tf.placeholder(tf.float32, [args.batch_size, num_layers, latent_dim],
                        name='latent_code')
    x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)

    # Load image and codes.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    src_list = []
    with open(f'{src_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{src_dir}/{name}_ori.png')
            src_list.append(name)
    src_codes = np.load(f'{src_dir}/inverted_codes.npy')
    assert src_codes.shape[0] == len(src_list)
    num_src = src_codes.shape[0]
    dst_list = []
    with open(f'{dst_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{dst_dir}/{name}_ori.png')
            dst_list.append(name)
    dst_codes = np.load(f'{dst_dir}/inverted_codes.npy')
    assert dst_codes.shape[0] == len(dst_list)
    num_dst = dst_codes.shape[0]

    # Interpolate images.
    logger.info(f'Start interpolation.')
    step = args.step + 2
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_src * num_dst,
                                    num_cols=step + 2,
                                    viz_size=viz_size)
    visualizer.set_headers(['Source', 'Source Inversion'] +
                           [f'Step {i:02d}' for i in range(1, step - 1)] +
                           ['Target Inversion', 'Target'])

    inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32)
    for src_idx in tqdm(range(num_src), leave=False):
        src_code = src_codes[src_idx:src_idx + 1]
        src_path = f'{src_dir}/{src_list[src_idx]}_ori.png'
        codes = interpolate(src_codes=np.repeat(src_code, num_dst, axis=0),
                            dst_codes=dst_codes,
                            step=step)
        for dst_idx in tqdm(range(num_dst), leave=False):
            dst_path = f'{dst_dir}/{dst_list[dst_idx]}_ori.png'
            output_images = []
            for idx in range(0, step, args.batch_size):
                batch = codes[dst_idx, idx:idx + args.batch_size]
                inputs[0:len(batch)] = batch
                images = sess.run(x, feed_dict={wp: inputs})
                output_images.append(images[0:len(batch)])
            output_images = adjust_pixel_range(
                np.concatenate(output_images, axis=0))

            row_idx = src_idx * num_dst + dst_idx
            visualizer.set_cell(row_idx, 0, image=load_image(src_path))
            visualizer.set_cell(row_idx, step + 1, image=load_image(dst_path))
            for s, output_image in enumerate(output_images):
                if s == 5 and row_idx == 2:
                    save_image(f'./results/interpolation/005_int.png',
                               output_image)
                visualizer.set_cell(row_idx, s + 1, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
Пример #6
0
def training_loop(
                  submit_config,
                  Encoder_args            = {},
                  E_opt_args              = {},
                  D_opt_args              = {},
                  E_loss_args             = EasyDict(),
                  D_loss_args             = {},
                  lr_args                 = EasyDict(),
                  tf_config               = {},
                  dataset_args            = EasyDict(),
                  decoder_pkl             = EasyDict(),
                  drange_data             = [0, 255],
                  drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
                  mirror_augment          = False,
                  resume_run_id           = config.ENCODER_PICKLE_DIR,     # Run ID or network pkl to resume training from, None = start from scratch.
                  resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
                  image_snapshot_ticks    = 1,        # How often to export image snapshots?
                  network_snapshot_ticks  = 4,       # How often to export network snapshots?
                  max_iters               = 150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        real_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='real_image_train')
        real_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='real_image_test')
        real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, D, Gs = misc.load_pkl(decoder_pkl.decoder_pkl)
            num_layers = Gs.components.synthesis.input_shape[1]
            E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args)
            start = 0

    E.print_layers(); Gs.print_layers(); D.print_layers()

    global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step,
                                               lr_args.decay_rate, staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('building testing graph...')
    fake_X_val = test(E, Gs, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train')
    image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test')

    summary_log = tf.summary.FileWriter(config.GDRIVE_PATH)

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    init_pascal = tf.initialize_variables(
        [global_step0],
        name='init_pascal'
    )
    sess.run(init_pascal)
    
    print('Optimization starts!!!')
    
    
    for it in range(start, max_iters):

        batch_images = sess.run(image_batch_train)
        feed_dict_1 = {real_train: batch_images}
        _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1)
        _, d_r_, d_f_, d_g_ = sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1)

        cur_nimg += submit_config.batch_size

        if it % 50 == 0:
            print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % (
                it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)
            
            
            
            
        if it % 500 == 0:
            batch_images_test = sess.run(image_batch_test)
            batch_images_test = misc.adjust_dynamic_range(batch_images_test.astype(np.float32), [0, 255], [-1., 1.])
            samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test})
            orin_recon = np.concatenate([batch_images_test, samples2], axis=0)
            orin_recon = adjust_pixel_range(orin_recon)
            orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test)
            # save image results during training, first row is original images and the second row is reconstructed images
            save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), orin_recon)

            # save image to gdrive
            img_path = os.path.join(config.GDRIVE_PATH, 'images', ('iter_%08d.png' % (cur_nimg)))
            save_image(img_path, orin_recon)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg



            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)
                
                # save network snapshot to gdrive
                pkl_drive = os.path.join(config.GDRIVE_PATH, 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl_drive)

    misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
Пример #7
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.target_list)
    target_list_name = os.path.splitext(os.path.basename(args.target_list))[0]
    assert os.path.exists(args.context_list)
    context_list_name = os.path.splitext(os.path.basename(
        args.context_list))[0]
    output_dir = args.output_dir or f'results/diffusion'
    job_name = f'{target_list_name}_TO_{context_list_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]
    crop_size = args.crop_size
    crop_x = args.center_x - crop_size // 2
    crop_y = args.center_y - crop_size // 2
    mask = np.zeros((1, image_size, image_size, 3), dtype=np.float32)
    mask[:, crop_y:crop_y + crop_size, crop_x:crop_x + crop_size, :] = 1.0

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    x_mask = (tf.transpose(x, [0, 2, 3, 1]) + 1) * mask - 1
    x_mask_255 = (x_mask + 1) / 2 * 255
    latent_shape = Gs.components.synthesis.input_shape
    latent_shape[0] = args.batch_size
    wp = tf.get_variable(shape=latent_shape, name='latent_code')
    x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
    x_rec_mask = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) * mask - 1
    x_rec_mask_255 = (x_rec_mask + 1) / 2 * 255

    w_enc = E.get_output_for(x, phase=False)
    wp_enc = tf.reshape(w_enc, latent_shape)
    setter = tf.assign(wp, wp_enc)

    # Settings for optimization.
    logger.info(f'Setting configuration for optimization.')
    perceptual_model = PerceptualModel([image_size, image_size], False)
    x_feat = perceptual_model(x_mask_255)
    x_rec_feat = perceptual_model(x_rec_mask_255)
    loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1])
    loss_pix = tf.reduce_mean(tf.square(x_mask - x_rec_mask), axis=[1, 2, 3])

    loss = loss_pix + args.loss_weight_feat * loss_feat
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    train_op = optimizer.minimize(loss, var_list=[wp])
    tflib.init_uninitialized_vars()

    # Load image list.
    logger.info(f'Loading target images and context images.')
    target_list = []
    with open(args.target_list, 'r') as f:
        for line in f:
            target_list.append(line.strip())
    num_targets = len(target_list)
    context_list = []
    with open(args.context_list, 'r') as f:
        for line in f:
            context_list.append(line.strip())
    num_contexts = len(context_list)
    num_pairs = num_targets * num_contexts

    # Invert images.
    logger.info(f'Start diffusion.')
    save_interval = args.num_iterations // args.num_results
    headers = [
        'Target Image', 'Context Image', 'Stitched Image', 'Encoder Output'
    ]
    for step in range(1, args.num_iterations + 1):
        if step == args.num_iterations or step % save_interval == 0:
            headers.append(f'Step {step:06d}')
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_pairs,
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    latent_codes_enc = []
    latent_codes = []
    for target_idx in tqdm(range(num_targets), desc='Target ID', leave=False):
        # Load target.
        target_image = resize_image(load_image(target_list[target_idx]),
                                    (image_size, image_size))
        visualizer.set_cell(target_idx * num_contexts, 0, image=target_image)
        for context_idx in tqdm(range(0, num_contexts, args.batch_size),
                                desc='Context ID',
                                leave=False):
            row_idx = target_idx * num_contexts + context_idx
            batch = context_list[context_idx:context_idx + args.batch_size]
            for i, context_image_path in enumerate(batch):
                context_image = resize_image(load_image(context_image_path),
                                             (image_size, image_size))
                visualizer.set_cell(row_idx + i, 1, image=context_image)
                context_image[crop_y:crop_y + crop_size, crop_x:crop_x +
                              crop_size] = (target_image[crop_y:crop_y +
                                                         crop_size,
                                                         crop_x:crop_x +
                                                         crop_size])
                visualizer.set_cell(row_idx + i, 2, image=context_image)
                images[i] = np.transpose(context_image, [2, 0, 1])
            inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
            # Run encoder.
            sess.run([setter], {x: inputs})
            outputs = sess.run([wp, x_rec])
            latent_codes_enc.append(outputs[0][0:len(batch)])
            outputs[1] = adjust_pixel_range(outputs[1])
            for i, _ in enumerate(batch):
                visualizer.set_cell(row_idx + i, 3, image=outputs[1][i])
            # Optimize latent codes.
            col_idx = 4
            for step in tqdm(range(1, args.num_iterations + 1), leave=False):
                sess.run(train_op, {x: inputs})
                if step == args.num_iterations or step % save_interval == 0:
                    outputs = sess.run([wp, x_rec])
                    outputs[1] = adjust_pixel_range(outputs[1])
                    for i, _ in enumerate(batch):
                        visualizer.set_cell(row_idx + i,
                                            col_idx,
                                            image=outputs[1][i])
                    col_idx += 1
            latent_codes.append(outputs[0][0:len(batch)])

    # Save results.
    code_shape = [num_targets, num_contexts] + list(latent_shape[1:])
    np.save(f'{output_dir}/{job_name}_encoded_codes.npy',
            np.concatenate(latent_codes_enc, axis=0).reshape(code_shape))
    np.save(f'{output_dir}/{job_name}_inverted_codes.npy',
            np.concatenate(latent_codes, axis=0).reshape(code_shape))
    visualizer.save(f'{output_dir}/{job_name}.html')
Пример #8
0
def training_loop(
                  submit_config,
                  Encoder_args            = {},
                  E_opt_args              = {},
                  D_opt_args              = {},
                  E_loss_args             = EasyDict(),
                  D_loss_args             = {},
                  lr_args                 = EasyDict(),
                  tf_config               = {},
                  dataset_args            = EasyDict(),
                  decoder_pkl             = EasyDict(),
                  inversion_pkl           = EasyDict(),
                  drange_data             = [0, 255],
                  drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
                  mirror_augment          = False,
                  resume_run_id           = config.ENCODER_PICKLE_DIR,     # Run ID or network pkl to resume training from, None = start from scratch.
                  resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
                  image_snapshot_ticks    = 1,        # How often to export image snapshots?
                  network_snapshot_ticks  = 4,       # How often to export network snapshots?
                  max_iters               = 150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        placeholder_real_portraits_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_train')
        placeholder_real_landmarks_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_train')
        placeholder_real_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_train')
        placeholder_landmarks_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_landmarks_shuffled_train')


        placeholder_real_portraits_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_test')
        placeholder_real_landmarks_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_test')
        placeholder_real_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_test')
        placeholder_real_landmarks_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_shuffled_test')

        real_split_landmarks = tf.split(placeholder_real_landmarks_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_portraits = tf.split(placeholder_real_portraits_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_shuffled = tf.split(placeholder_real_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_lm_shuffled = tf.split(placeholder_landmarks_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        
        placeholder_training_flag = tf.placeholder(tf.string, name='placeholder_training_flag')

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, _, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) # don't use pre-trained discriminator!
            num_layers = Gs.components.synthesis.input_shape[1]

            # here we add a new discriminator!
            D = tflib.Network('D',  # name of the network how we call it
                              num_channels=3, resolution=128, label_size=0,  #some needed for this build function
                              func_name="training.networks_stylegan.D_basic") # function of that network. more was not passed in d_args!
                              # input is not passed here (just construction - note that we do not call the actual function!). Instead, network will inspect build function and require it for the get_output_for function.
            print("Created new Discriminator!")

            E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args)
            start = 0
        Inv, _, _, _ = misc.load_pkl(inversion_pkl.inversion_pkl)

    E.print_layers(); Gs.print_layers(); D.print_layers()

    global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step,
                                               lr_args.decay_rate, staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            Inv_gpu = Inv if gpu == 0 else Inv.clone(Inv.name + '_shadow')
            perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False)
            real_portraits_gpu = process_reals(real_split_portraits[gpu], mirror_augment, drange_data, drange_net)
            shuffled_portraits_gpu = process_reals(real_split_shuffled[gpu], mirror_augment, drange_data, drange_net)
            real_landmarks_gpu = process_reals(real_split_landmarks[gpu], mirror_augment, drange_data, drange_net)
            shuffled_landmarks_gpu = process_reals(real_split_lm_shuffled[gpu], mirror_augment, drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, perceptual_model=perceptual_model, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, shuffled_landmarks=shuffled_landmarks_gpu, training_flag=placeholder_training_flag, **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, training_flag=placeholder_training_flag, **D_loss_args) # change signature in ...
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('building testing graph...')
    fake_X_val = test(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config)
    inv_X_val = test_inversion(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config)
    
    #sampled_portraits_val = sample_random_portraits(Gs, submit_config.batch_size)
    #sampled_portraits_val_test = sample_random_portraits(Gs, submit_config.batch_size_test)

    sess = tf.get_default_session()

    print('Getting training data...')
    # x_batch is a batch of (2, ..., ..., ...) records!
    stack_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train')
    stack_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test')
    
    stack_batch_train_secondary = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train_secondary')
    stack_batch_test_secondary = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test_secondary')

    summary_log = tf.summary.FileWriter(config.getGdrivePath())

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    init_fix = tf.initialize_variables(
        [global_step0],
        name='init_fix'
    )
    sess.run(init_fix)
    
    print('Optimization starts!!!')
    
    
    # here is the actual training loop: all iterations
    for it in range(start, max_iters):
        batch_stacks = sess.run(stack_batch_train)
        batch_portraits = batch_stacks[:,0,:,:,:]
        batch_landmarks = batch_stacks[:,1,:,:,:]
        
        batch_stacks_secondary = sess.run(stack_batch_train_secondary)
        batch_shuffled = batch_stacks_secondary[:,0,:,:,:]
        batch_lm_shuffled = batch_stacks_secondary[:,1,:,:,:]
        
        
        training_flag = "pose"
        
        feed_dict_1 = {placeholder_real_portraits_train: batch_portraits, placeholder_real_landmarks_train: batch_landmarks, placeholder_real_shuffled_train:batch_shuffled, placeholder_landmarks_shuffled_train:batch_lm_shuffled, placeholder_training_flag: training_flag}
        # here we query these encoder- and discriminator losses. as input we provide: batch_stacks = batch of images + landmarks.
        _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1)
        _, d_r_, d_f_, d_g_= sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1)

        cur_nimg += submit_config.batch_size

        if it % 50 == 0:
            print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % (
                it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)
            
            
            
            
        if it % 500 == 0:
            batch_stacks_test = sess.run(stack_batch_test)
            batch_portraits_test = batch_stacks_test[:,0,:,:,:]
            batch_landmarks_test = batch_stacks_test[:,1,:,:,:]
            
            batch_stacks_test_secondary = sess.run(stack_batch_test_secondary)
            batch_shuffled_test = batch_stacks_test_secondary[:,0,:,:,:]
            batch_shuffled_lm_test = batch_stacks_test_secondary[:,1,:,:,:]
            
            
            batch_portraits_test = misc.adjust_dynamic_range(batch_portraits_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_landmarks_test = misc.adjust_dynamic_range(batch_landmarks_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_shuffled_test = misc.adjust_dynamic_range(batch_shuffled_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_shuffled_lm_test = misc.adjust_dynamic_range(batch_shuffled_lm_test.astype(np.float32), [0, 255], [-1., 1.])

            # first: input + target landmarks = manipulated image
            samples_manipulated = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_shuffled_lm_test})

            # 2nd: manipulated + original landmarks
            samples_reconstructed = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: samples_manipulated, placeholder_real_landmarks_test: batch_landmarks_test})

            # also: show direct reconstruction
            samples_direct_rec = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test})

            # show results of the inverison
            portraits_inverted = sess.run(inv_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test})

            # show: original portrait, original landmark, diret reconstruction, fake landmark, manipulated, rec.
            debug_img = np.concatenate([
                    batch_landmarks_test, # original landmarks
                    batch_portraits_test, # original portraits,
                    samples_direct_rec, # direct
                    batch_shuffled_lm_test, # shuffled landmarks
                    samples_manipulated, # manipulated images
                    samples_reconstructed,
                    portraits_inverted# cycle reconstructed images
                ], axis=0)

            debug_img = adjust_pixel_range(debug_img)
            debug_img = fuse_images(debug_img, row=6, col=submit_config.batch_size_test)
            # save image results during training, first row is original images and the second row is reconstructed images
            save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), debug_img)

            # save image to gdrive
            img_path = os.path.join(config.getGdrivePath(), 'images', ('iter_%08d.png' % (cur_nimg)))
            save_image(img_path, debug_img)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg



            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)
                
                # save network snapshot to gdrive
                pkl_drive = os.path.join(config.getGdrivePath(), 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl_drive)

    misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
Пример #9
0
def encode(_target_image, _context_image, _output_dir):
    gpu_id = '0'
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    print(_target_image)
    assert os.path.exists('./static/' + _target_image)
    _output_dir = _output_dir[:-4]
    output_dir = './static/' + _output_dir

    tflib.init_tf({'rnd.np_random_seed': 1000})
    model_path = './styleganinv_face_256.pkl'
    with open(model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    crop_size = 110  # default crop size.
    center_x = 125
    center_y = 145
    crop_x = center_x - crop_size // 2  # default coordinate-X
    crop_y = center_y - crop_size // 2  # default coordinate-Y

    mask = np.zeros((1, image_size, image_size, 3), dtype=np.float32)
    mask[:, crop_y:crop_y + crop_size, crop_x:crop_x + crop_size, :] = 1.0

    # Build graph.
    sess = tf.get_default_session()

    batch_size = 4
    input_shape = E.input_shape
    input_shape[0] = batch_size  # default batch size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    x_mask = (tf.transpose(x, [0, 2, 3, 1]) + 1) * mask - 1
    x_mask_255 = (x_mask + 1) / 2 * 255

    latent_shape = Gs.components.synthesis.input_shape
    latent_shape[0] = batch_size  # default batch size
    wp = tf.get_variable(shape=latent_shape, name='latent_code')
    x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
    x_rec_mask = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) * mask - 1
    x_rec_mask_255 = (x_rec_mask + 1) / 2 * 255

    w_enc = E.get_output_for(x, phase=False)
    wp_enc = tf.reshape(w_enc, latent_shape)
    setter = tf.assign(wp, wp_enc)

    # Settings for optimization.
    print("Diffusion : Settings for Optimization.")
    perceptual_model = PerceptualModel([image_size, image_size], False)
    x_feat = perceptual_model(x_mask_255)
    x_rec_feat = perceptual_model(x_rec_mask_255)
    loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1])
    loss_pix = tf.reduce_mean(tf.square(x_mask - x_rec_mask), axis=[1, 2, 3])

    loss_weight_feat = 5e-5
    learning_rate = 0.01
    loss = loss_pix + loss_weight_feat * loss_feat  # default The perceptual loss scale for optimization. (default 5e-5)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss, var_list=[wp])
    tflib.init_uninitialized_vars()

    # Invert image
    num_iterations = 100
    num_results = 5
    save_interval = num_iterations // num_results

    images = np.zeros(input_shape, np.uint8)

    print("Load target image.")
    _target_image = './static/' + _target_image
    target_image = resize_image(load_image(_target_image),
                                (image_size, image_size))
    save_image('./' + output_dir + '_tar.png', target_image)

    print("Load context image.")
    context_image = getContextImage(_context_image)
    context_image = resize_image(load_image(context_image),
                                 (image_size, image_size))
    save_image('./' + output_dir + '_cont.png', context_image)

    # Inverting Context Image.
    # context_image = invert(model_path, getContextImage(_context_image), wp, latent_shape)
    save_image('./' + output_dir + '_cont_inv.png', context_image)

    # Create Stitched Image
    # context_image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size] = (
    #     target_image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
    # )
    # context_image[crop_y:crop_y + 170, crop_x - 70:crop_x + crop_size + 190] = (
    #     target_image[crop_y:crop_y + 170, crop_x - 70:crop_x + crop_size + 190]
    # )
    print("Cropping Image...")
    # context_image = cropImage(target_image, context_image)

    target_image, rect = cropWithWhite(target_image)
    target_image = fourChannels(target_image)
    target_image = cut(target_image)
    target_image = transBg(target_image)

    context_image = createStitchedImage(context_image, target_image, rect)
    save_image('./' + output_dir + '_sti.png', context_image)
    images[0] = np.transpose(context_image, [2, 0, 1])

    input = images.astype(np.float32) / 255 * 2.0 - 1.0

    # Run encoder
    print("Start Diffusion.")
    sess.run([setter], {x: input})
    output = sess.run([wp, x_rec])
    output[1] = adjust_pixel_range(output[1])

    col_idx = 4
    for step in tqdm(range(1, num_iterations + 1), leave=False):
        sess.run(train_op, {x: input})
        if step == num_iterations or step % save_interval == 0:
            output = sess.run([wp, x_rec])
            output[1] = adjust_pixel_range(output[1])
            if step == num_iterations:
                save_image(f'{output_dir}.png', output[1][0])
            col_idx += 1
    exit()
Пример #10
0
def invert(model_path, _image, _wp, _latent_shape):
    print("Inverting")
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    # Build graph.
    print("Inverting : Build Graph.")
    sess = tf.get_default_session()

    batch_size = 4
    input_shape = E.input_shape
    input_shape[0] = batch_size  # default batch size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    x_255 = (tf.transpose(x, [0, 2, 3, 1]) + 1) / 2 * 255

    wp = _wp
    x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)
    x_rec_255 = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) / 2 * 255

    w_enc = E.get_output_for(x, phase=False)
    wp_enc = tf.reshape(w_enc, _latent_shape)
    setter = tf.assign(wp, wp_enc)

    # Settings for optimization.
    print("Inverting : Settings for Optimization.")
    perceptual_model = PerceptualModel([image_size, image_size], False)
    x_feat = perceptual_model(x_255)
    x_rec_feat = perceptual_model(x_rec_255)
    loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1])
    loss_pix = tf.reduce_mean(tf.square(x - x_rec), axis=[1, 2, 3])
    w_enc_new = E.get_output_for(x_rec, phase=False)
    wp_enc_new = tf.reshape(w_enc_new, _latent_shape)
    loss_enc = tf.reduce_mean(tf.square(wp - wp_enc_new), axis=[1, 2])
    loss = (loss_pix + 5e-5 * loss_feat + 2.0 * loss_enc)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(loss, var_list=[wp])
    tflib.init_uninitialized_vars()

    # Invert image
    print("Start Inverting.")
    num_iterations = 40
    num_results = 2
    save_interval = num_iterations // num_results

    context_images = np.zeros(input_shape, np.uint8)

    context_image = resize_image(load_image(_image), (image_size, image_size))

    # Inverting Context Image.
    context_images[0] = np.transpose(context_image, [2, 0, 1])
    context_input = context_images.astype(np.float32) / 255 * 2.0 - 1.0

    sess.run([setter], {x: context_input})
    context_output = sess.run([wp, x_rec])
    context_output[1] = adjust_pixel_range(context_output[1])
    context_image = np.transpose(context_images[0], [1, 2, 0])

    for step in tqdm(range(1, num_iterations + 1), leave=False):
        sess.run(train_op, {x: context_input})
        if step == num_iterations or step % save_interval == 0:
            context_output = sess.run([wp, x_rec])
            context_output[1] = adjust_pixel_range(context_output[1])
            if step == num_iterations: context_image = context_output[1][0]

    return context_image
Пример #11
0
def training_loop(
        submit_config,
        Encoder_args={},
        D_args={},
        G_args={},
        E_opt_args={},
        D_opt_args={},
        E_loss_args=EasyDict(),
        D_loss_args={},
        lr_args=EasyDict(),
        tf_config={},
        dataset_args=EasyDict(),
        decoder_pkl=EasyDict(),
        drange_data=[0, 255],
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        mirror_augment=False,
        filter=64,  # Minimum number of feature maps in any layer.
        filter_max=512,  # Maximum number of feature maps in any layer.
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        d_scale=0.1,  # Decide whether to update discriminator.
        pretrained_D=True,  # Whether to use pre trained Discriminator.
        max_iters=150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('Input'):
        real_train = tf.placeholder(tf.float32, [
            submit_config.batch_size, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                    name='real_image_train')
        real_test = tf.placeholder(tf.float32, [
            submit_config.batch_size_test, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                   name='real_image_test')
        real_split = tf.split(real_train,
                              num_or_size_splits=submit_config.num_gpus,
                              axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            G_style_mod = tflib.Network('G_StyleMod',
                                        resolution=submit_config.image_size,
                                        label_size=0,
                                        **G_args)
            start = int(network_pkl.split('-')[-1].split('.')
                        [0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, PreD, Gs = misc.load_pkl(decoder_pkl.decoder_pkl)
            num_layers = Gs.components.synthesis.input_shape[1]
            E = tflib.Network('E_gpu0',
                              size=submit_config.image_size,
                              filter=filter,
                              filter_max=filter_max,
                              num_layers=num_layers,
                              is_training=True,
                              num_gpus=submit_config.num_gpus,
                              **Encoder_args)
            OriD = tflib.Network('D_ori',
                                 resolution=submit_config.image_size,
                                 label_size=0,
                                 **D_args)
            G_style_mod = tflib.Network('G_StyleMod',
                                        resolution=submit_config.image_size,
                                        label_size=0,
                                        **G_args)
            if pretrained_D:
                D = PreD
            else:
                D = OriD
            start = 0
        Gs_vars_pairs = {
            name: tflib.run(val)
            for name, val in Gs.components.synthesis.vars.items()
        }
        for g_name, g_val in G_style_mod.vars.items():
            tflib.set_vars({g_val: Gs_vars_pairs[g_name]})

    E.print_layers()
    Gs.print_layers()
    D.print_layers()

    global_step0 = tf.Variable(start,
                               trainable=False,
                               name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate,
                                               global_step0,
                                               lr_args.decay_step,
                                               lr_args.decay_rate,
                                               staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE',
                            learning_rate=learning_rate,
                            **E_opt_args)
    if d_scale > 0:
        D_opt = tflib.Optimizer(name='TrainD',
                                learning_rate=learning_rate,
                                **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('Building Graph on GPU %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name[:-1] + str(gpu))
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = G_style_mod if gpu == 0 else G_style_mod.clone(
                G_style_mod.name + '_shadow')
            feature_model = PerceptualModel(img_size=[
                E_loss_args.perceptual_img_size,
                E_loss_args.perceptual_img_size
            ],
                                            multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment,
                                     drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(
                    E=E_gpu,
                    G=G_gpu,
                    D=D_gpu,
                    feature_model=feature_model,
                    reals=real_gpu,
                    **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(
                    E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                if d_scale > 0:
                    D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    if d_scale > 0:
        D_train_op = D_opt.apply_updates()

    print('Building testing graph...')
    fake_X_val = test(E, G_style_mod, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess,
                                       data_dir=dataset_args.data_train,
                                       submit_config=submit_config,
                                       mode='train')
    image_batch_test = get_train_data(sess,
                                      data_dir=dataset_args.data_test,
                                      submit_config=submit_config,
                                      mode='test')

    summary_log = tf.summary.FileWriter(submit_config.run_dir)

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    print('Optimization starts!!!')
    for it in range(start, max_iters):

        batch_images = sess.run(image_batch_train)
        feed_dict = {real_train: batch_images}
        _, recon_, adv_, lr = sess.run(
            [E_train_op, E_loss_rec, E_loss_adv, learning_rate], feed_dict)
        if d_scale > 0:
            _, d_r_, d_f_, d_g_ = sess.run(
                [D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict)

        cur_nimg += submit_config.batch_size

        run_time = time.time() - start_time
        iter_time = run_time / (it - start + 1)
        eta_time = iter_time * (max_iters - it - 1)

        if it % 50 == 0:
            print(
                'Iter: %06d/%d, lr: %-.8f recon_loss: %-6.4f adv_loss: %-6.4f run_time:%-12s eta_time:%-12s'
                % (it, max_iters, lr, recon_, adv_,
                   dnnlib.util.format_time(time.time() - start_time),
                   dnnlib.util.format_time(eta_time)))
            if d_scale > 0:
                print('d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f ' %
                      (d_r_, d_f_, d_g_))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg

            if cur_tick % image_snapshot_ticks == 0:
                batch_images_test = sess.run(image_batch_test)
                batch_images_test = misc.adjust_dynamic_range(
                    batch_images_test.astype(np.float32), [0, 255], [-1., 1.])
                recon = sess.run(fake_X_val,
                                 feed_dict={real_test: batch_images_test})
                orin_recon = np.concatenate([batch_images_test, recon], axis=0)
                orin_recon = adjust_pixel_range(orin_recon)
                orin_recon = fuse_images(orin_recon,
                                         row=2,
                                         col=submit_config.batch_size_test)
                # save image results during training, first row is original images and the second row is reconstructed images
                save_image(
                    '%s/iter_%09d.png' % (submit_config.run_dir, cur_nimg),
                    orin_recon)

            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir,
                                   'network-snapshot-%09d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)

    misc.save_pkl((E, G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
Пример #12
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    image_dir = args.image_dir
    image_dir_name = os.path.basename(image_dir.rstrip('/'))
    assert os.path.exists(image_dir)
    assert os.path.exists(f'{image_dir}/image_list.txt')
    assert os.path.exists(f'{image_dir}/inverted_codes.npy')
    boundary_path = args.boundary_path
    assert os.path.exists(boundary_path)
    boundary_name = os.path.splitext(os.path.basename(boundary_path))[0]
    output_dir = args.output_dir or 'results/manipulation'
    job_name = f'{boundary_name.upper()}_{image_dir_name}'
    logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')

    # Load model.
    logger.info(f'Loading generator.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        _, _, _, Gs = pickle.load(f)

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3]
    wp = tf.placeholder(tf.float32, [args.batch_size, num_layers, latent_dim],
                        name='latent_code')
    x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False)

    # Load image, codes, and boundary.
    logger.info(f'Loading images and corresponding inverted latent codes.')
    image_list = []
    with open(f'{image_dir}/image_list.txt', 'r') as f:
        for line in f:
            name = os.path.splitext(os.path.basename(line.strip()))[0]
            assert os.path.exists(f'{image_dir}/{name}_ori.png')
            assert os.path.exists(f'{image_dir}/{name}_inv.png')
            image_list.append(name)
    latent_codes = np.load(f'{image_dir}/inverted_codes.npy')
    assert latent_codes.shape[0] == len(image_list)
    num_images = latent_codes.shape[0]
    logger.info(f'Loading boundary.')
    boundary_file = np.load(boundary_path, allow_pickle=True)[()]
    if isinstance(boundary_file, dict):
        boundary = boundary_file['boundary']
        manipulate_layers = boundary_file['meta_data']['manipulate_layers']
    else:
        boundary = boundary_file
        manipulate_layers = args.manipulate_layers
    if manipulate_layers:
        logger.info(f'  Manipulating on layers `{manipulate_layers}`.')
    else:
        logger.info(f'  Manipulating on ALL layers.')

    # Manipulate images.
    logger.info(f'Start manipulation.')
    step = args.step
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=num_images,
                                    num_cols=step + 3,
                                    viz_size=viz_size)
    visualizer.set_headers(['Name', 'Origin', 'Inverted'] +
                           [f'Step {i:02d}' for i in range(1, step + 1)])
    for img_idx, img_name in enumerate(image_list):
        ori_image = load_image(f'{image_dir}/{img_name}_ori.png')
        inv_image = load_image(f'{image_dir}/{img_name}_inv.png')
        visualizer.set_cell(img_idx, 0, text=img_name)
        visualizer.set_cell(img_idx, 1, image=ori_image)
        visualizer.set_cell(img_idx, 2, image=inv_image)

    codes = manipulate(latent_codes=latent_codes,
                       boundary=boundary,
                       start_distance=args.start_distance,
                       end_distance=args.end_distance,
                       step=step,
                       layerwise_manipulation=True,
                       num_layers=num_layers,
                       manipulate_layers=manipulate_layers,
                       is_code_layerwise=True,
                       is_boundary_layerwise=True)
    inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32)
    for img_idx in tqdm(range(num_images), leave=False):
        output_images = []
        for idx in range(0, step, args.batch_size):
            batch = codes[img_idx, idx:idx + args.batch_size]
            inputs[0:len(batch)] = batch
            images = sess.run(x, feed_dict={wp: inputs})
            output_images.append(images[0:len(batch)])
        output_images = adjust_pixel_range(
            np.concatenate(output_images, axis=0))
        for s, output_image in enumerate(output_images):
            visualizer.set_cell(img_idx, s + 3, image=output_image)

    # Save results.
    visualizer.save(f'{output_dir}/{job_name}.html')
Пример #13
0
def main():
    """Main function."""
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    assert os.path.exists(args.image_list)
    image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
    output_dir = args.output_dir or f'results/ghfeat/{image_list_name}'
    logger = setup_logger(output_dir, 'extract_feature.log',
                          'inversion_logger')

    logger.info(f'Loading model.')
    tflib.init_tf({'rnd.np_random_seed': 1000})
    with open(args.model_path, 'rb') as f:
        E, _, _, Gs = pickle.load(f)

    # Get input size.
    image_size = E.input_shape[2]
    assert image_size == E.input_shape[3]

    G_args = EasyDict(func_name='training.networks_stylegan.G_synthesis')
    G_style_mod = tflib.Network('G_StyleMod',
                                resolution=image_size,
                                label_size=0,
                                **G_args)
    Gs_vars_pairs = {
        name: tflib.run(val)
        for name, val in Gs.components.synthesis.vars.items()
    }
    for g_name, g_val in G_style_mod.vars.items():
        tflib.set_vars({g_val: Gs_vars_pairs[g_name]})

    # Build graph.
    logger.info(f'Building graph.')
    sess = tf.get_default_session()
    input_shape = E.input_shape
    input_shape[0] = args.batch_size
    x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
    ghfeat = E.get_output_for(x, is_training=False)
    x_rec = G_style_mod.get_output_for(ghfeat, randomize_noise=False)

    # Load image list.
    logger.info(f'Loading image list.')
    image_list = []
    with open(args.image_list, 'r') as f:
        for line in f:
            image_list.append(line.strip())

    # Extract GH-Feat from images.
    logger.info(f'Start feature extraction.')
    headers = ['Name', 'Original Image', 'Encoder Output']
    viz_size = None if args.viz_size == 0 else args.viz_size
    visualizer = HtmlPageVisualizer(num_rows=len(image_list),
                                    num_cols=len(headers),
                                    viz_size=viz_size)
    visualizer.set_headers(headers)

    images = np.zeros(input_shape, np.uint8)
    names = ['' for _ in range(args.batch_size)]
    features = []
    for img_idx in tqdm(range(0, len(image_list), args.batch_size),
                        leave=False):
        # Load inputs.
        batch = image_list[img_idx:img_idx + args.batch_size]
        for i, image_path in enumerate(batch):
            image = resize_image(load_image(image_path),
                                 (image_size, image_size))
            images[i] = np.transpose(image, [2, 0, 1])
            names[i] = os.path.splitext(os.path.basename(image_path))[0]
        inputs = images.astype(np.float32) / 255 * 2.0 - 1.0
        # Run encoder.
        outputs = sess.run([ghfeat, x_rec], {x: inputs})
        features.append(outputs[0][0:len(batch)])
        outputs[1] = adjust_pixel_range(outputs[1])
        for i, _ in enumerate(batch):
            image = np.transpose(images[i], [1, 2, 0])
            save_image(f'{output_dir}/{names[i]}_ori.png', image)
            save_image(f'{output_dir}/{names[i]}_enc.png', outputs[1][i])
            visualizer.set_cell(i + img_idx, 0, text=names[i])
            visualizer.set_cell(i + img_idx, 1, image=image)
            visualizer.set_cell(i + img_idx, 2, image=outputs[1][i])

    # Save results.
    os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
    np.save(f'{output_dir}/ghfeat.npy', np.concatenate(features, axis=0))
    visualizer.save(f'{output_dir}/reconstruction.html')