예제 #1
0
 def first_cnn(self, h, resolution, channel, test=False):
     with nn.parameter_scope("phase_{}".format(resolution)):
         # affine is 1x1 conv with 4x4 kernel and 3x3 pad.
         with nn.parameter_scope("conv1"):
             h = affine(h,
                        channel * 4 * 4,
                        with_bias=not self.use_bn,
                        use_wscale=self.use_wscale,
                        use_he_backward=self.use_he_backward)
             h = BN(h, use_bn=self.use_bn, test=test)
             h = F.reshape(h, (h.shape[0], channel, 4, 4))
             h = pixel_wise_feature_vector_normalization(
                 BN(h, use_bn=self.use_bn, test=test))
             h = self.activation(h)
         with nn.parameter_scope("conv2"):
             h = conv(h,
                      channel,
                      kernel=(3, 3),
                      pad=(1, 1),
                      stride=(1, 1),
                      with_bias=not self.use_bn,
                      use_wscale=self.use_wscale,
                      use_he_backward=self.use_he_backward)
             h = pixel_wise_feature_vector_normalization(
                 BN(h, use_bn=self.use_bn, test=test))
             h = self.activation(h)
     return h
예제 #2
0
def compute_metric(gen, batch_size, img_num, latent, hyper_sphere):
    num_batches = img_num // batch_size

    img1 = []
    for k in range(num_batches):
        logger.info("generating at iter={} / {}".format(k, num_batches))
        z_data = np.random.randn(batch_size, latent, 1, 1)
        z = nn.Variable.from_numpy_array(z_data)
        z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
        y = gen(z, test=True)
        img = y.d.transpose(0, 2, 3, 1)
        img1.append(img)
    img1 = np.concatenate(img1, axis=0)
    img2 = []

    for k in range(num_batches):
        logger.info("generating at iter={} / {}".format(k, num_batches))
        z_data = np.random.randn(batch_size, latent, 1, 1)
        z = nn.Variable.from_numpy_array(z_data)
        z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
        y = gen(z, test=True)
        img = y.d.transpose(0, 2, 3, 1)
        img2.append(img)
    img2 = np.concatenate(img2, axis=0)

    img1 = np.uint8((img1 + 1.) / 2. * 255)
    img2 = np.uint8((img2 + 1.) / 2. * 255)

    return msssim(img1, img2,
                  max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, weights=None)
예제 #3
0
    def cnn(self, h, resolution, channel, test):
        """CNN block

        The following operations are performed two times.

        1. Upsampling
        2. Conv
        3. Pixel-wise normalization
        4. Relu
        """
        h = F.unpooling(h, kernel=(2, 2))
        with nn.parameter_scope("phase_{}".format(resolution)):
            with nn.parameter_scope("conv1"):
                h = conv(h, channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
                         with_bias=not self.use_bn,
                         use_wscale=self.use_wscale,
                         use_he_backward=self.use_he_backward)
                h = pixel_wise_feature_vector_normalization(
                    BN(h, use_bn=self.use_bn, test=test))
                h = self.activation(h)
            with nn.parameter_scope("conv2"):
                h = conv(h, channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
                         with_bias=not self.use_bn,
                         use_wscale=self.use_wscale,
                         use_he_backward=self.use_he_backward)
                h = pixel_wise_feature_vector_normalization(
                    BN(h, use_bn=self.use_bn, test=test))
                h = self.activation(h)
        return h
예제 #4
0
def generate_interpolated_images(model_load_path,
                                 batch_size=16,
                                 n_latent=512,
                                 use_bn=False,
                                 hyper_sphere=True,
                                 last_act='tanh',
                                 use_wscale=True,
                                 use_he_backward=False,
                                 resolution_list=[4, 8, 16, 32, 64, 128],
                                 channel_list=[512, 512, 256, 128, 64, 32]):
    # Generate
    gen = load_gen(model_load_path,
                   use_bn=use_bn,
                   last_act=last_act,
                   use_wscale=use_wscale,
                   use_he_backward=use_he_backward)
    z_data0 = np.random.randn(1, n_latent, 1, 1)
    z_data1 = np.random.randn(1, n_latent, 1, 1)
    imgs = []
    for i in range(batch_size):
        alpha = 1. * i / (batch_size - 1)
        z_data = (1 - alpha) * z_data0 + alpha * z_data1
        z = nn.Variable.from_numpy_array(z_data)
        z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
        y = gen(z, test=True)
        imgs.append(y.d)
    imgs = np.concatenate(imgs, axis=0)
    return imgs
예제 #5
0
def generate_flipped_images(gen,
                            latent_vector,
                            hyper_sphere=True,
                            save_dir=None):
    """
    generate flipped images
    Args:
        gen : generator
        latent_vector(numpy.ndarray) : latent_vector
        hyper_sphere (bool) : default True
        save_dir (str) : directory to save the images
    """
    if not path.isdir(save_dir):
        mkdir(save_dir)
    z_data = np.reshape(latent_vector,
                        (latent_vector.shape[0], latent_vector.shape[1], 1, 1))
    z = nn.Variable.from_numpy_array(z_data)
    z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
    batch_size = 64  # we have taken batch size of 64
    num_images = latent_vector.shape[0]
    iterations = int(num_images / batch_size)
    if num_images % batch_size != 0:
        iterations += 1
    count = 0
    for ell in range(iterations):
        y = gen(z[ell * batch_size:(ell + 1) * batch_size], test=True)
        images = convert_images_to_uint8(y, drange=[-1, 1])
        for i in range(images.shape[0]):
            imsave(save_dir + '/gen_' + str(count) + '.jpg',
                   images[i],
                   channel_first=True)
            count += 1

    print("all paired images generated")
def compute_metric(di,
                   gen,
                   latent,
                   num_minibatch,
                   nhoods_per_image,
                   nhood_size,
                   level_list,
                   dir_repeats,
                   dirs_per_repeat,
                   hyper_sphere=True):
    logger.info("Generate images")
    st = time.time()
    real_descriptor = [[] for _ in level_list]
    fake_descriptor = [[] for _ in level_list]
    for k in range(num_minibatch):
        logger.info("iter={} / {}".format(k, num_minibatch))
        real, _ = di.next()
        real = np.uint8((real + 1.) / 2. * 255)

        B = len(real)
        z_data = np.random.randn(B, latent, 1, 1)
        z = nn.Variable.from_numpy_array(z_data)
        z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
        y = gen(z)
        fake = y.d
        fake = np.uint8((y.d + 1.) / 2. * 255)

        for i, desc in enumerate(
                generate_laplacian_pyramid(real, len(level_list))):
            real_descriptor[i].append(
                get_descriptors_for_minibatch(desc, nhood_size,
                                              nhoods_per_image))
        for i, desc in enumerate(
                generate_laplacian_pyramid(fake, len(level_list))):
            fake_descriptor[i].append(
                get_descriptors_for_minibatch(desc, nhood_size,
                                              nhoods_per_image))
    logger.info(
        "Elapsed time for generating images: {} [s]".format(time.time() - st))

    logger.info("Compute Sliced Wasserstein Distance")
    scores = []
    for i, level in enumerate(level_list):
        st = time.time()
        real = finalize_descriptors(real_descriptor[i])
        fake = finalize_descriptors(fake_descriptor[i])
        scores.append(
            sliced_wasserstein(real, fake, dir_repeats, dirs_per_repeat))
        logger.info("Level: {}, dist: {}".format(level, scores[-1]))
        logger.info("Elapsed time: {} [s] at {}-th level".format(
            time.time() - st, i))
    return scores
예제 #7
0
def generate_images(model_load_path,
                    batch_size=16, n_latent=512, use_bn=False,
                    hyper_sphere=True, last_act='tanh',
                    use_wscale=True, use_he_backward=False,
                    resolution_list=[4, 8, 16, 32, 64, 128],
                    channel_list=[512, 512, 256, 128, 64, 32]):
    # Generate
    gen = load_gen(model_load_path, use_bn=use_bn, last_act=last_act,
                   use_wscale=use_wscale, use_he_backward=use_he_backward)
    z_data = np.random.randn(batch_size, n_latent, 1, 1)
    z = nn.Variable.from_numpy_array(z_data)
    z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
    y = gen(z, test=True)
    return y.d
예제 #8
0
def generate_images(gen,
                    num_images,
                    n_latent=512,
                    hyper_sphere=True,
                    save_dir=None,
                    latent_vector=None):
    """
    generate the images
    Args:
        gen : load generator
        num_images (int) : number of images to generate
        n_latent (int) : 512-D latent space trained on the CelebA
        hyper_sphere (bool) : default True
        save_dir (str) : directory to save the images
        latent_vector (str) : path to save the latent vectors(.pkl file)
    """

    if not path.isdir(save_dir):
        mkdir(save_dir)
    z_data = np.random.randn(num_images, n_latent, 1, 1)
    # Saving latent vectors
    with open(latent_vector, 'wb+') as f:
        pickle.dump(z_data.reshape((num_images, n_latent)), f)
    z = nn.Variable.from_numpy_array(z_data)
    z = pixel_wise_feature_vector_normalization(z) if hyper_sphere else z
    batch_size = 64
    iterations = int(num_images / batch_size)
    if num_images % batch_size != 0:
        iterations += 1
    count = 0
    for ell in range(iterations):
        y = gen(z[ell * batch_size:(ell + 1) * batch_size], test=True)
        images = convert_images_to_uint8(y, drange=[-1, 1])
        for i in range(images.shape[0]):
            imsave(save_dir + '/gen_' + str(count) + '.jpg',
                   images[i],
                   channel_first=True)
            count += 1
    print("images are generated")
예제 #9
0
    def _transition(self, ecpoch_per_resolution):
        batch_size = self.di.batch_size
        resolution = self.gen.resolution_list[-1]
        phase = "{}to{}".format(
            self.gen.resolution_list[-2], self.gen.resolution_list[-1])
        logger.info("phase : {}".format(phase))

        kernel_size = self.resolution_list[-1] // resolution
        kernel = (kernel_size, kernel_size)

        total_itr = (self.di.size // batch_size + 1) * ecpoch_per_resolution
        global_itr = 1.
        alpha = global_itr / total_itr

        for epoch in range(ecpoch_per_resolution):
            logger.info("epoch : {}".format(epoch + 1))
            itr = 0
            current_epoch = self.di.epoch
            while self.di.epoch == current_epoch:
                img, _ = self.di.next()
                x = nn.Variable.from_numpy_array(img)

                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha, test=True)
                y.unlinked()
                y.need_grad = False
                x_r = F.average_pooling(x, kernel=kernel)

                p_real = self.dis.transition(x_r, alpha)
                p_fake = self.dis.transition(y, alpha)

                loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.)
                                  + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight)

                if itr % self.n_critic + 1 == self.n_critic:
                    with nn.parameter_scope("discriminator"):
                        self.solver_dis.set_parameters(nn.get_parameters(),
                                                       reset=False, retain_state=True)
                        self.solver_dis.zero_grad()
                        loss_dis.backward(clear_buffer=True)
                        self.solver_dis.update()

                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha, test=False)
                p_fake = self.dis.transition(y, alpha)

                loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2))
                with nn.parameter_scope("generator"):
                    self.solver_gen.set_parameters(
                        nn.get_parameters(), reset=False, retain_state=True)
                    self.solver_gen.zero_grad()
                    loss_gen.backward(clear_buffer=True)
                    self.solver_gen.update()

                itr += 1
                global_itr += 1.
                alpha = global_itr / total_itr

            if epoch % self.save_image_interval + 1 == self.save_image_interval:
                z = nn.Variable.from_numpy_array(self.z_test)
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha)
                img_name = "phase_{}_epoch_{}".format(phase, epoch + 1)
                self.monitor_image_tile.add(
                    img_name, F.unpooling(y, kernel=kernel))
예제 #10
0
    def _train(self, ecpoch_per_resolution, each_save=False):
        batch_size = self.di.batch_size
        resolution = self.gen.resolution_list[-1]
        logger.info("phase : {}".format(resolution))

        kernel_size = self.resolution_list[-1] // resolution
        kernel = (kernel_size, kernel_size)

        img_name = "original_phase_{}".format(resolution)
        img, _ = self.di.next()
        self.monitor_image_tile.add(img_name, img)

        for epoch in range(ecpoch_per_resolution):
            logger.info("epoch : {}".format(epoch + 1))
            itr = 0
            current_epoch = self.di.epoch
            while self.di.epoch == current_epoch:
                img, _ = self.di.next()
                x = nn.Variable.from_numpy_array(img)
                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=True)

                y.unlinked()
                y.need_grad = False
                x_r = F.average_pooling(x, kernel=kernel)

                p_real = self.dis(x_r)
                p_fake = self.dis(y)
                p_real.persistent, p_fake.persistent = True, True

                loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.)
                                  + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight)
                loss_dis.persistent = True

                if itr % self.n_critic + 1 == self.n_critic:
                    with nn.parameter_scope("discriminator"):
                        self.solver_dis.set_parameters(nn.get_parameters(),
                                                       reset=False, retain_state=True)
                        self.solver_dis.zero_grad()
                        loss_dis.backward(clear_buffer=True)
                        self.solver_dis.update()
                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=False)
                p_fake = self.dis(y)
                p_fake.persistent = True

                loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2.))
                loss_gen.persistent = True

                with nn.parameter_scope("generator"):
                    self.solver_gen.set_parameters(nn.get_parameters(),
                                                   reset=False, retain_state=True)
                    self.solver_gen.zero_grad()
                    loss_gen.backward(clear_buffer=True)
                    self.solver_gen.update()

                # Monitor
                self.monitor_p_real.add(
                    self.global_itr, p_real.d.copy().mean())
                self.monitor_p_fake.add(
                    self.global_itr, p_fake.d.copy().mean())
                self.monitor_loss_dis.add(self.global_itr, loss_dis.d.copy())
                self.monitor_loss_gen.add(self.global_itr, loss_gen.d.copy())

                itr += 1
                self.global_itr += 1

            if epoch % self.save_image_interval + 1 == self.save_image_interval:
                z = nn.Variable.from_numpy_array(self.z_test)
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=True)
                img_name = "phase_{}_epoch_{}".format(resolution, epoch + 1)
                self.monitor_image_tile.add(
                    img_name, F.unpooling(y, kernel=kernel))

            if each_save:
                self.gen.save_parameters(self.monitor_path, "Gen_phase_{}_epoch_{}".format(
                    self.resolution_list[-1], epoch+1))
                self.dis.save_parameters(self.monitor_path, "Dis_phase_{}_epoch_{}".format(
                    self.resolution_list[-1], epoch+1))