Esempio n. 1
0
def draw_uncurated_result_figure(png, Gs, seed, psi):
    print(png)
    rows = 7
    cols = 7
    latents = np.random.RandomState(seed).randn(rows * cols, Gs.input_shape[1])
    #images = Gs.run(latents, None, **synthesis_kwargs)

    #latents = tf.random_normal([25] + Gs.input_shape[1:])
    # images = Gs.get_output_for(latents, None, is_validation=True, randomize_noise=False, truncation_psi=0.0, truncation_cutoff=14)
    images = Gs.get_output_for(latents,
                               None,
                               is_validation=True,
                               randomize_noise=False,
                               truncation_psi_val=psi,
                               truncation_cutoff_val=8)
    images = tflib.convert_images_to_uint8(images).eval()

    images = np.transpose(images, (0, 2, 3, 1))
    canvas = PIL.Image.new('RGB', (256 * rows, 256 * cols), 'white')
    image_iter = iter(list(images))
    for col in range(cols):
        for row in range(rows):
            image = PIL.Image.fromarray(next(image_iter), 'RGB')
            canvas.paste(image, (row * 256, col * 256))
    canvas.save(png)
    def __init__(self, model, batch_size, randomize_noise=False):
        self.batch_size = batch_size

        gauss_mean, _ = np.load("gaussian_fit.npy")
        # self.initial_dlatents = np.zeros((self.batch_size, 18, 512))
        self.initial_dlatents = np.tile(gauss_mean, (self.batch_size, 18, 1))
        model.components.synthesis.run(self.initial_dlatents,
                                       randomize_noise=randomize_noise,
                                       minibatch_size=self.batch_size,
                                       custom_inputs=[
                                           partial(
                                               create_variable_for_generator,
                                               batch_size=batch_size),
                                           partial(create_stub,
                                                   batch_size=batch_size)
                                       ],
                                       structure='fixed')

        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()

        self.dlatent_variable = next(v for v in tf.global_variables()
                                     if 'learnable_dlatents' in v.name)
        self.set_dlatents(self.initial_dlatents)

        self.generator_output = self.graph.get_tensor_by_name(
            'G_synthesis_1/_Run/concat:0')
        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_softmax.pkl')
        activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents_b = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                latents_c = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                images = Gs_clone.get_output_for(latents_b, latents_c, None, is_validation=True, randomize_noise=True)
                assert len(images) == 7
                images = tflib.convert_images_to_uint8(images[-1])
                assert images.shape[2] == images.shape[3] == 512
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate activations for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]

        # Calculate IS.
        scores = []
        for i in range(self.num_splits):
            part = activations[i * self.num_images // self.num_splits : (i + 1) * self.num_images // self.num_splits]
            kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))
        self._report_result(np.mean(scores), suffix='_mean')
        self._report_result(np.std(scores), suffix='_std')
Esempio n. 4
0
def draw_uncurated_result_figure(png, Gs, seed):
    print(png)
    rows = 7
    cols = 7
    canvas = PIL.Image.new('RGB', (512 * rows, 512 * cols), 'white')
    for col in range(cols):
        for row in range(rows):
            latents = np.random.normal(0, 1, [1, Gs.input_shape[1]])
            #images = Gs.run(latents, None, **synthesis_kwargs)

            #latents = tf.random_normal([25] + Gs.input_shape[1:])
            images = Gs.get_output_for(latents,
                                       None,
                                       is_validation=True,
                                       randomize_noise=False,
                                       truncation_psi_val=0.7)
            images = tflib.convert_images_to_uint8(images).eval()

            images = np.transpose(images, (0, 2, 3, 1))

            image_iter = iter(list(images))

            image = PIL.Image.fromarray(next(image_iter), 'RGB')
            canvas.paste(image, (row * 512, col * 512))
    canvas.save(png)
Esempio n. 5
0
    def __init__(self, model, batch_size, randomize_noise=False):
        self.batch_size = batch_size

        self.initial_dlatents = np.zeros((self.batch_size, 14, 512))
        model.components.synthesis.run(self.initial_dlatents,
                                       randomize_noise=randomize_noise,
                                       minibatch_size=self.batch_size,
                                       custom_inputs=[
                                           partial(
                                               create_variable_for_generator,
                                               batch_size=batch_size),
                                           partial(create_stub,
                                                   batch_size=batch_size)
                                       ],
                                       structure='fixed')

        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()

        in_expr, out_expr = next(
            expr for expr in model.components.synthesis._run_cache.values())

        self.dlatent_variable = next(v for v in in_expr
                                     if 'learnable_dlatents' in v.name)
        self.generator_output = next(v for v in out_expr
                                     if '_Run/concat' in v.name)
        self.set_dlatents(self.initial_dlatents)

        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)
Esempio n. 6
0
    def __init__(self, model, batch_size, randomize_noise=False):
        self.batch_size = batch_size

        self.initial_dlatents = np.zeros((self.batch_size, 18, 512))
        model.components.synthesis.run(self.initial_dlatents,
                                       randomize_noise=randomize_noise,
                                       minibatch_size=self.batch_size,
                                       custom_inputs=[
                                           partial(
                                               create_variable_for_generator,
                                               batch_size=batch_size),
                                           partial(create_stub,
                                                   batch_size=batch_size)
                                       ],
                                       structure='fixed')

        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()

        self.dlatent_variable = self.graph.get_tensor_by_name(
            'G_synthesis_1/_Run/dlatents_in:0')
        self.set_dlatents(self.initial_dlatents)

        self.generator_output = self.graph.get_tensor_by_name(
            'G_synthesis_1/_Run/concat/concat:0')
        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)
Esempio n. 7
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/inception_v3_softmax.pkl')
        activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate activations for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]

        # Calculate IS.
        scores = []
        for i in range(self.num_splits):
            part = activations[i * self.num_images // self.num_splits : (i + 1) * self.num_images // self.num_splits]
            kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))
        self._report_result(np.mean(scores), suffix='_mean')
        self._report_result(np.std(scores), suffix='_std')
Esempio n. 8
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_features.pkl'
        )
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end - begin],
                                                       num_gpus=num_gpus,
                                                       assume_frozen=True)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
                images = tflib.convert_images_to_uint8(images)
                print('shape before', images.shape)
                if images.shape[1] == 1:
                    #images = tf.repeat(images, 3, axis=1)
                    images = tf.concat([images, images, images], axis=1)
                    #images = tf.stack([images, images, images], axis=1)
                print('shape expanded ', images.shape)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))
Esempio n. 9
0
    def __init__(self, model, batch_size, randomize_noise=False):
        self.batch_size = batch_size

        self.initial_dlatents = np.zeros((self.batch_size, 18, 512))
        model.components.synthesis.run(self.initial_dlatents,
                                       randomize_noise=randomize_noise,
                                       minibatch_size=self.batch_size,
                                       custom_inputs=[
                                           partial(
                                               create_variable_for_generator,
                                               batch_size=batch_size),
                                           partial(create_stub,
                                                   batch_size=batch_size)
                                       ],
                                       structure='fixed')

        self.sess = tf.get_default_session()
        #         self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) # new
        self.graph = tf.get_default_graph()

        self.dlatent_variable = next(v for v in tf.global_variables()
                                     if 'learnable_dlatents' in v.name)
        self.set_dlatents(self.initial_dlatents)

        self.generator_output = self.graph.get_tensor_by_name(
            'G_synthesis_1/_Run/concat:0')
        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)
Esempio n. 10
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        feature_net = misc.load_pkl(
            'https://drive.google.com/uc?id=1MzY4MFpZzE-mNS26pzhYlWN-4vMm2ytu',
            'vgg16.pkl')

        # Calculate features for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            ref_features = misc.load_pkl(cache_file)
        else:
            ref_features = np.empty(
                [self.num_images, feature_net.output_shape[1]],
                dtype=np.float32)
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                ref_features[begin:end] = feature_net.run(images[:end - begin],
                                                          num_gpus=num_gpus,
                                                          assume_frozen=True)
                if end == self.num_images:
                    break
            misc.save_pkl(ref_features, cache_file)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tflex.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                feature_net_clone = feature_net.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(feature_net_clone.get_output_for(images))

        # Calculate features for fakes.
        eval_features = np.empty(
            [self.num_images, feature_net.output_shape[1]], dtype=np.float32)
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            eval_features[begin:end] = np.concatenate(tflib.run(result_expr),
                                                      axis=0)[:end - begin]

        # Calculate precision and recall.
        state = knn_precision_recall_features(
            ref_features=ref_features,
            eval_features=eval_features,
            feature_net=feature_net,
            nhood_sizes=[self.nhood_size],
            row_batch_size=self.row_batch_size,
            col_batch_size=self.row_batch_size,
            num_gpus=num_gpus)
        self._report_result(state.knn_precision[0], suffix='_precision')
        self._report_result(state.knn_recall[0], suffix='_recall')
Esempio n. 11
0
    def __init__(self, model, batch_size, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False):
        self.batch_size = batch_size
        self.tiled_dlatent=tiled_dlatent
        self.model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18

        if tiled_dlatent:
            self.initial_dlatents = np.zeros((self.batch_size, 512))
            model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
                randomize_noise=randomize_noise, minibatch_size=self.batch_size,
                custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
                                                partial(create_stub, batch_size=batch_size)],
                structure='fixed')
        else:
            self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
            model.components.synthesis.run(self.initial_dlatents,
                randomize_noise=randomize_noise, minibatch_size=self.batch_size,
                custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
                                                partial(create_stub, batch_size=batch_size)],
                structure='fixed')

        self.dlatent_avg_def = model.get_var('dlatent_avg')
        self.reset_dlatent_avg()
        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()

        self.dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name)
        self.set_dlatents(self.initial_dlatents)

        def get_tensor(name):
            try:
                return self.graph.get_tensor_by_name(name)
            except KeyError:
                return None

        self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
        # If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
        if self.generator_output is None:
            for op in self.graph.get_operations():
                print(op)
            raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
        self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)

        # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
        # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
        # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
        clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
        clipped_values = tf.where(clipping_mask, tf.random_normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
        self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn'
        )  # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end - begin],
                                                       num_gpus=num_gpus,
                                                       assume_frozen=True)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                images = Gs_clone.get_output_for(
                    latents,
                    None,
                    is_validation=True,
                    randomize_noise=True,
                    truncation_psi_val=self.truncation_psi,
                    truncation_cutoff_val=8)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))
Esempio n. 13
0
    def __init__(self,
                 model,
                 batch_size=1,
                 clipping_threshold=2,
                 model_res=1024,
                 randomize_noise=False):
        self.batch_size = batch_size
        self.model_scale = int(
            2 * (math.log(model_res, 2) - 1))  # For example, 1024 -> 18

        self.initial_dlatents = np.zeros(
            (self.batch_size, self.model_scale, 512))
        model.components.synthesis.run(self.initial_dlatents,
                                       randomize_noise=randomize_noise,
                                       minibatch_size=self.batch_size,
                                       custom_inputs=[
                                           partial(
                                               create_variable_for_generator,
                                               batch_size=batch_size,
                                               model_scale=self.model_scale),
                                           partial(create_stub,
                                                   batch_size=batch_size)
                                       ],
                                       structure='fixed')

        self.dlatent_avg_def = model.get_var('dlatent_avg')
        self.reset_dlatent_avg()
        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()

        self.dlatent_variable = next(v for v in tf.global_variables()
                                     if 'learnable_dlatents' in v.name)
        self.set_dlatents(self.initial_dlatents)

        try:
            self.generator_output = self.graph.get_tensor_by_name(
                'G_synthesis_1/_Run/concat:0')
        except KeyError:
            # If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
            self.generator_output = self.graph.get_tensor_by_name(
                'G_synthesis/_Run/concat:0')

        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)

        # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782

        clipping_mask = tf.math.logical_or(
            self.dlatent_variable > clipping_threshold,
            self.dlatent_variable < -clipping_threshold)
        clipped_values = tf.where(
            clipping_mask, tf.random_normal(shape=self.dlatent_variable.shape),
            self.dlatent_variable)
        self.stochastic_clip_op = tf.assign(self.dlatent_variable,
                                            clipped_values)
Esempio n. 14
0
 def generate_images(self, out_dir, img_name, i):
     generator_output = tf.get_default_graph().get_tensor_by_name('G_synthesis_1/_Run/concat:0')
     generated_image = tflib.convert_images_to_uint8(generator_output, nchw_to_nhwc=True, uint8_cast=False)
     generated_image_uint8 = tf.saturate_cast(generated_image, tf.uint8) # Converts the image to the dtype uint8
     generated_images = self.sess.run(generated_image_uint8)
      # Generate images from found dlatents and save them
     img_array = generated_images[0]
     img = PIL.Image.fromarray(img_array, 'RGB')
     img.save(os.path.join(out_dir, str(img_name) + '_' + str(i) + '.png'), 'PNG')
    def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs):  # pylint: disable=arguments-differ
        minibatch_size = num_gpus * self.minibatch_per_gpu
        with dnnlib.util.open_url(
                'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_features.pkl'
        ) as f:  # identical to http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
            feature_net = pickle.load(f)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if self.use_cached_real_stats and os.path.isfile(cache_file):
            with open(cache_file, 'rb') as f:
                feat_real = pickle.load(f)
        else:
            feat_real = []
            for images, _labels, num in self._iterate_reals(minibatch_size):
                if self.max_reals is not None:
                    num = min(num, self.max_reals - len(feat_real))
                if images.shape[1] == 1:
                    images = np.tile(images, [1, 3, 1, 1])
                feat_real += list(
                    feature_net.run(images,
                                    num_gpus=num_gpus,
                                    assume_frozen=True))[:num]
                if self.max_reals is not None and len(
                        feat_real) >= self.max_reals:
                    break
            feat_real = np.stack(feat_real)
            with open(cache_file, 'wb') as f:
                pickle.dump(feat_real, f)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                feature_net_clone = feature_net.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
                if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(feature_net_clone.get_output_for(images))

        # Calculate statistics for fakes.
        feat_fake = []
        for begin in range(0, self.num_fakes, minibatch_size):
            self._report_progress(begin, self.num_fakes)
            feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0))
        feat_fake = np.stack(feat_fake[:self.num_fakes])

        # Calculate KID.
        kid = compute_kid(feat_real, feat_fake)
        self._report_result(np.real(kid), fmt='%-12.8f')
Esempio n. 16
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        classifier = tf.keras.models.load_model('nets/lutz_new_classifier_tf1.14.h5', compile=False)
        classifier = add_preprocessing(classifier, "nets")
        # if num_gpus > 1: classifier = tf.keras.utils.multi_gpu_model(classifier, num_gpus, cpu_relocation=True) # Runs with undeterministic output
        activations = np.zeros([self.num_images, classifier.output_shape[1]], dtype=np.float32)

        # Calculate statistics for reals (adversarial examples).
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mean_real, std_real = misc.load_pkl(cache_file)
        else:
            for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                images = np.transpose(images, [0, 2, 3, 1]) # nchw to nhwc
                activations[begin:end] = classifier.predict_on_batch(images)[:end-begin]
                if end == self.num_images:
                    break
            mean_real = np.mean(activations)
            std_real = np.std(activations)
            misc.save_pkl((mean_real, std_real), cache_file)
        
        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:], seed=42)
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
                images = tflib.convert_images_to_uint8(images, nchw_to_nhwc=True)
                result_expr.append(images)
        result_expr = tf.concat(result_expr, axis=0)

        # Calculate statistics for fakes (generated examples).
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = classifier.predict_on_batch(result_expr)[:end-begin]
        mean_fake = np.mean(activations)
        std_fake = np.std(activations)

        # Save DCT Fake Score.
        self._report_result(mean_fake, suffix='_mean_gen')
        self._report_result(std_fake, suffix='_std_gen')
        self._report_result(mean_real, suffix='_mean_adv')
        self._report_result(std_real, suffix='_std_adv')

#----------------------------------------------------------------------------
def generate_images_with_labels(filename, Gs, w, h, num_labels, latents, truncation):
    canvas = PIL.Image.new('RGB', (w * num_labels, h * latents.shape[0]), 'white')
    for i in range(latents.shape[0]):
        for j in range(num_labels):
            onehot = np.zeros((1, num_labels), dtype=np.float)
            onehot[0, j] = 1.0
            image = Gs.get_output_for(latents[i],
                                      onehot,
                                      is_validation=True,
                                      randomize_noise=False,
                                      truncation_psi_val=truncation)
            image = tflib.convert_images_to_uint8(image, nchw_to_nhwc=True).eval()
            canvas.paste(PIL.Image.fromarray(image[0], 'RGB'), (j * w, i * h))
    canvas.save(filename)
Esempio n. 18
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        if self.num_images is None:
            self.num_images = self._get_dataset_obj().num_samples
        num_channels = Gs.output_shape[1]
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_softmax.pkl'
        )

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
                if num_channels == 1:
                    images = tf.repeat(images, 3, axis=1)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)
        results = []
        for _ in range(self.num_repeats):
            # Calculate statistics for fakes.
            for begin in range(0, self.num_images, minibatch_size):
                self._report_progress(begin, self.num_images)
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                        axis=0)[:end - begin]

            # Calculate IS.
            scores = []
            for i in range(self.num_splits):
                part = activations[i * self.num_images //
                                   self.num_splits:(i + 1) * self.num_images //
                                   self.num_splits]
                kl = part * (np.log(part) -
                             np.log(np.expand_dims(np.mean(part, 0), 0)))
                kl = np.mean(np.sum(kl, 1))
                scores.append(np.exp(kl))
            results.append(np.mean(scores))
        self._report_result(np.mean(results))
        if self.num_repeats > 1:
            self._report_result(np.std(results), suffix='-std')
Esempio n. 19
0
    def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs):  # pylint: disable=arguments-differ
        minibatch_size = num_gpus * self.minibatch_per_gpu
        with dnnlib.util.open_url(
                'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_softmax.pkl'
        ) as f:
            inception = pickle.load(f)
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device(f'/gpu:{gpu_idx}'):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random.normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
                if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate activations for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]

        # Calculate IS.
        scores = []
        for i in range(self.num_splits):
            part = activations[i * self.num_images // self.num_splits:(i + 1) *
                               self.num_images // self.num_splits]
            kl = part * (np.log(part) -
                         np.log(np.expand_dims(np.mean(part, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))
        self._report_result(np.mean(scores), suffix='_mean')
        self._report_result(np.std(scores), suffix='_std')
Esempio n. 20
0
	def __init__(self, model, batch_size, tiled_dlatent, randomize_noise):
		self.batch_size = batch_size
		self.tiled_dlatent=tiled_dlatent

		if tiled_dlatent:
			self.initial_dlatents = np.zeros((self.batch_size, 512))
			model.components.synthesis.run(np.zeros((self.batch_size, 18, 512)),
				randomize_noise=randomize_noise, minibatch_size=self.batch_size,
				custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
												partial(create_stub, batch_size=batch_size)],
				structure='fixed')
		else:
			self.initial_dlatents = np.zeros((self.batch_size, 18, 512))
			model.components.synthesis.run(self.initial_dlatents,
				randomize_noise=randomize_noise, minibatch_size=self.batch_size,
				custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False),
												partial(create_stub, batch_size=batch_size)],
				structure='fixed')

		self.sess = tf.get_default_session()
		self.graph = tf.get_default_graph()

		self.dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name)
		self._assign_dlatent_ph = tf.placeholder(tf.float32, name="assign_dlatent_ph")
		self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
		self.set_dlatents(self.initial_dlatents)

		try:
			self.generator_output = self.graph.get_tensor_by_name('G_synthesis_1/_Run/concat:0')
		except KeyError:
			# If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
			self.generator_output = self.graph.get_tensor_by_name('G_synthesis/_Run/concat:0')
		self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
		self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)

		# Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
		# (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
		# so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
		clipping_mask = tf.math.logical_or(self.dlatent_variable > 2.0, self.dlatent_variable < -2.0)
		clipped_values = tf.where(clipping_mask, tf.random_normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
		self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
def transitions(png, w, h, seed):
    col = 0
    canvas = PIL.Image.new('RGB', (w * len(psis), h), 'white')
    baseline_network_pkl = '../results/00002-sgan-car512-2gpu/network-snapshot-023949.pkl'
    G, _D, Gs = misc.load_pkl(baseline_network_pkl)
    for psi in psis:
        latents = np.random.RandomState([seed]).randn(1, Gs.input_shape[1])

        images = Gs.run(latents,
                        None,
                        is_validation=True,
                        randomize_noise=False,
                        truncation_psi_val=psi,
                        truncation_cutoff_val=16)
        images = tflib.convert_images_to_uint8(images).eval()

        images = np.transpose(images, (0, 2, 3, 1))
        canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (col, 0))

        col += w
    canvas.save(png)
Esempio n. 22
0
    def __init__(self, model, batch_size, randomize_noise=False):
        self.batch_size = batch_size
        #tf.reset_default_graph()
        self.initial_dlatents = np.zeros((self.batch_size, 18, 512))
        model.components.synthesis.run(self.initial_dlatents,
                                       randomize_noise=randomize_noise,
                                       minibatch_size=self.batch_size,
                                       custom_inputs=[
                                           partial(
                                               create_variable_for_generator,
                                               batch_size=batch_size),
                                           partial(create_stub,
                                                   batch_size=batch_size)
                                       ],
                                       structure='fixed')

        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()
        #for op in self.graph.get_operations():
        #print(op)
        self.dlatent_variable = next(
            (v
             for v in tf.global_variables() if 'learnable_dlatents' in v.name),
            1)
        #print(self.dlatent_variable)
        #print("--------------")
        #print(self.dlatent_variable.name)

        self.set_dlatents(self.initial_dlatents)
        #new
        #self.varname = (create_variable_for_generator, batch_size=batch_size)
        #self.varnameex = self.varname.name

        self.generator_output = self.graph.get_tensor_by_name(
            'G_synthesis_1/_Run/concat/concat:0')
        #new
        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)
def transitions(png, w, h, seed1, seed2, pkl):
    col = 0
    canvas = PIL.Image.new('RGB', (h, w * len(psis)), 'white')
    G, _D, Gs = misc.load_pkl(pkl)
    for psi in psis:
        latent1 = np.random.RandomState([seed1]).randn(1, Gs.input_shape[1])
        latent2 = np.random.RandomState([seed2]).randn(1, Gs.input_shape[1])
        dlatent1 = Gs.components.mapping.get_output_for(latent1,
                                                        None,
                                                        is_validation=True)
        dlatent2 = Gs.components.mapping.get_output_for(latent2,
                                                        None,
                                                        is_validation=True)
        dlatent_int = psi * dlatent1 + (1 - psi) * dlatent2
        images = Gs.components.synthesis.get_output_for(dlatent_int,
                                                        is_validation=True,
                                                        randomize_noise=True)
        images = tflib.convert_images_to_uint8(images).eval()

        images = np.transpose(images, (0, 2, 3, 1))
        canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, col))

        col += w
    canvas.save(png)
Esempio n. 24
0
    def set_network(self, Gs, dtype='float16'):
        if Gs is None:
            self._Gs = None
            return
        self._Gs = Gs.clone(randomize_noise=False,
                            dtype=dtype,
                            num_fp16_res=0,
                            fused_modconv=True)

        # Compute dlatent stats.
        self._info(
            f'Computing W midpoint and stddev using {self.dlatent_avg_samples} samples...'
        )
        latent_samples = np.random.RandomState(123).randn(
            self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
        dlatent_samples = self._Gs.components.mapping.run(
            latent_samples, None)  # [N, L, C]
        dlatent_samples = dlatent_samples[:, :1, :].astype(
            np.float32)  # [N, 1, C]
        self._dlatent_avg = np.mean(dlatent_samples, axis=0,
                                    keepdims=True)  # [1, 1, C]
        self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg)**2) /
                             self.dlatent_avg_samples)**0.5
        self._info(f'std = {self._dlatent_std:g}')

        # Setup noise inputs.
        self._info('Setting up noise inputs...')
        self._noise_vars = []
        noise_init_ops = []
        noise_normalize_ops = []
        while True:
            n = f'G_synthesis/noise{len(self._noise_vars)}'
            if not n in self._Gs.vars:
                break
            v = self._Gs.vars[n]
            self._noise_vars.append(v)
            noise_init_ops.append(
                tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32)))
            noise_mean = tf.reduce_mean(v)
            noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5
            noise_normalize_ops.append(
                tf.assign(v, (v - noise_mean) / noise_std))
        self._noise_init_op = tf.group(*noise_init_ops)
        self._noise_normalize_op = tf.group(*noise_normalize_ops)

        # Build image output graph.
        self._info('Building image output graph...')
        self._minibatch_size = 1
        self._dlatents_var = tf.Variable(
            tf.zeros([self._minibatch_size] +
                     list(self._dlatent_avg.shape[1:])),
            name='dlatents_var')
        self._dlatent_noise_in = tf.placeholder(tf.float32, [],
                                                name='noise_in')
        dlatents_noise = tf.random.normal(
            shape=self._dlatents_var.shape) * self._dlatent_noise_in
        self._dlatents_expr = tf.tile(
            self._dlatents_var + dlatents_noise,
            [1, self._Gs.components.synthesis.input_shape[1], 1])
        self._images_float_expr = tf.cast(
            self._Gs.components.synthesis.get_output_for(self._dlatents_expr),
            tf.float32)
        self._images_uint8_expr = tflib.convert_images_to_uint8(
            self._images_float_expr, nchw_to_nhwc=True)

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        proc_images_expr = (self._images_float_expr + 1) * (255 / 2)
        sh = proc_images_expr.shape.as_list()
        if sh[2] > 256:
            factor = sh[2] // 256
            proc_images_expr = tf.reduce_mean(tf.reshape(
                proc_images_expr,
                [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]),
                                              axis=[3, 5])

        # Build loss graph.
        self._info('Building loss graph...')
        self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape),
                                              name='target_images_var')
        if self._lpips is None:
            with dnnlib.util.open_url(
                    'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl'
            ) as f:
                self._lpips = pickle.load(f)
        self._dist = self._lpips.get_output_for(proc_images_expr,
                                                self._target_images_var)
        self._loss = tf.reduce_sum(self._dist)

        # Build noise regularization graph.
        self._info('Building noise regularization graph...')
        reg_loss = 0.0
        for v in self._noise_vars:
            sz = v.shape[2]
            while True:
                reg_loss += tf.reduce_mean(
                    v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(
                        v * tf.roll(v, shift=1, axis=2))**2
                if sz <= 8:
                    break  # Small enough already
                v = tf.reshape(v, [1, 1, sz // 2, 2, sz // 2, 2])  # Downscale
                v = tf.reduce_mean(v, axis=[3, 5])
                sz = sz // 2
        self._loss += reg_loss * self.regularize_noise_weight

        # Setup optimizer.
        self._info('Setting up optimizer...')
        self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in')
        self._opt = tflib.Optimizer(learning_rate=self._lrate_in)
        self._opt.register_gradients(self._loss,
                                     [self._dlatents_var] + self._noise_vars)
        self._opt_step = self._opt.apply_updates()
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn'
        )  # inception_v3_features.pkl
        real_activations = np.empty(
            [self.num_images, inception.output_shape[1]], dtype=np.float32)
        fake_activations = np.empty(
            [self.num_images, inception.output_shape[1]], dtype=np.float32)

        # Construct TensorFlow graph.
        self._configure(self.minibatch_per_gpu, hole_range=self.hole_range)
        real_img_expr = []
        fake_img_expr = []
        real_result_expr = []
        fake_result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                reals, labels = self._get_minibatch_tf()
                reals_tf = tflib.convert_images_from_uint8(reals)
                masks = self._get_random_masks_tf()
                fakes = Gs_clone.get_output_for(latents, labels, reals_tf,
                                                masks, **Gs_kwargs)
                fakes = tflib.convert_images_to_uint8(fakes[:, :3])
                reals = tflib.convert_images_to_uint8(reals_tf[:, :3])
                real_img_expr.append(reals)
                fake_img_expr.append(fakes)
                real_result_expr.append(inception_clone.get_output_for(reals))
                fake_result_expr.append(inception_clone.get_output_for(fakes))

        for begin in tqdm(range(0, self.num_images, minibatch_size)):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            real_results, fake_results = tflib.run(
                [real_result_expr, fake_result_expr])
            real_activations[begin:end] = np.concatenate(real_results,
                                                         axis=0)[:end - begin]
            fake_activations[begin:end] = np.concatenate(fake_results,
                                                         axis=0)[:end - begin]

        # Calculate FID conviniently.
        mu_real = np.mean(real_activations, axis=0)
        sigma_real = np.cov(real_activations, rowvar=False)
        mu_fake = np.mean(fake_activations, axis=0)
        sigma_fake = np.cov(fake_activations, rowvar=False)
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist), suffix='-FID')

        svm = sklearn.svm.LinearSVC(dual=False)
        svm_inputs = np.concatenate([real_activations, fake_activations])
        svm_targets = np.array([1] * real_activations.shape[0] +
                               [0] * fake_activations.shape[0])
        svm.fit(svm_inputs, svm_targets)
        self._report_result(1 - svm.score(svm_inputs, svm_targets),
                            suffix='-U')
        real_outputs = svm.decision_function(real_activations)
        fake_outputs = svm.decision_function(fake_activations)
        self._report_result(np.mean(fake_outputs > real_outputs), suffix='-P')
Esempio n. 26
0
data = load('out/dlatents0.npz')
lst = data.files
for item in lst:
    print(item)
    print(data[item].shape)

# mypath = "images/"
# onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
# print(onlyfiles)

print(data[item])

network_pkl = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl"

# If downloads fails, due to 'Google Drive download quota exceeded' you can try downloading manually from your own Google Drive account
# network_pkl = "/content/drive/My Drive/GAN/stylegan2-ffhq-config-f.pkl"

tflib.init_tf({'rnd.np_random_seed': 303})
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as fp:
    _G, _D, Gs = pickle.load(fp)
# self._images_float_expr

image_float_expr = tf.cast(Gs.components.synthesis.get_output_for(data[item]),
                           tf.float32)
images_uint8_expr = tflib.convert_images_to_uint8(image_float_expr,
                                                  nchw_to_nhwc=True)
PIL.Image.fromarray(tflib.run(images_uint8_expr)[0],
                    'RGB').save(f'out/proj.png')
Esempio n. 27
0
    def _evaluate(self, Gs, E, Inv, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(config.INCEPTION_PICKLE_DIR) # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)

        announce("Evaluating Reals")
        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
            print("loaded real mu, sigma from cache.")
        else:
            progress = 0
            for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):
                batch_stacks = data[0]
                progress += batch_stacks.shape[0]
                images = batch_stacks[:,0,:,:,:]
                landmarks = batch_stacks[:,1,:,:,:]

                # compute inception on full images!!!
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)


                # visualization
                images = images.astype(np.float32) / 255 * 2.0 - 1.0
                landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0

                if idx <= 10:
                    debug_img = np.concatenate([
                        images, # original landmarks
                        landmarks # original portraits,
                    ], axis=0)
                    debug_img = adjust_pixel_range(debug_img)
                    debug_img = fuse_images(debug_img, row=2, col=minibatch_size)
                    save_image("data_iter_{}08d.png".format(idx), debug_img)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
        
        announce("Evaluating Generator.")
        # Construct TensorFlow graph.
        result_expr = []
        print("Construct TensorFlow graph.")
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        print("Calculate statistics for fakes.")
        for begin in tqdm(range(0, self.num_images, minibatch_size), position=0, leave=True):
            end = min(begin + minibatch_size, self.num_images)
            #print("result_expr", len(result_expr)) # result_expr is a list!!!
            # results_expr[0].shape = (8, 2048) -> hat nur ein element.
            # weil: eigentlich würde man halt hier die GPUs zusammen konkattenieren.

            res_expr, fakes = tflib.run([result_expr, images])
            activations[begin:end] = np.concatenate(res_expr, axis=0)[:end-begin]

            if begin < 20:
                fakes = fakes.astype(np.float32) / 255 * 2.0 - 1.0
                debug_img = np.concatenate([
                    fakes
                ], axis=0)
                debug_img = adjust_pixel_range(debug_img)
                debug_img = fuse_images(debug_img, row=3, col=minibatch_size)
                save_image("fid_generator_iter_{}08d.png".format(end), debug_img)


        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        #print("mu_fake={}, sigma_fake={}".format(mu_fake, sigma_fake))
        
        # Calculate FID.
        print("Calculate FID (generator).")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="StyleGAN Generator Only")
        print("Distance StyleGAN", dist)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

        announce("Now evaluating encoder (appearnace)")
        print("building custom encoder graph!")
        with tf.variable_scope('fakeddddoptimizer'):

            # Build graph.
            BATCH_SIZE = self.minibatch_per_gpu
            input_shape = Inv.input_shape
            input_shape[0] = BATCH_SIZE
            latent_shape = Gs.components.synthesis.input_shape
            latent_shape[0] = BATCH_SIZE

            x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
            x_lm = tf.placeholder(tf.float32, shape=input_shape, name='some_landmark')
            x_kp = tf.placeholder(tf.float32, shape=[self.minibatch_per_gpu, 136], name='some_keypoints')

            if self.model_type == "rignet":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_lm, phase=False)
            elif self.model_type == "keypoints":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_kp, phase=False)
            else:
                w_enc = E.get_output_for(x, x_lm, phase=False)

            wp_enc = tf.reshape(w_enc, latent_shape)

            manipulated_images = Gs.components.synthesis.get_output_for(wp_enc, randomize_noise=False)
            manipulated_images = tflib.convert_images_to_uint8(manipulated_images)
            inception_codes = inception_clone.get_output_for(manipulated_images) # shape (8, 2048)

        for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):
            batch_stacks = data[0]
            images = batch_stacks[:,0,:,:,:]    # shape (8, 3, 128, 128)
            landmarks = batch_stacks[:,1,:,:,:] # shape (8, 3, 128, 128)
            images = images.astype(np.float32) / 255 * 2.0 - 1.0
            landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0
            keypoints = np.roll(data[1], shift=1, axis=0)

            begin = idx * minibatch_size
            end = min(begin + minibatch_size, self.num_images) # begin: 0; end: 8

            activations[begin:end], manip  = tflib.run([inception_codes, manipulated_images], feed_dict={x:images, x_lm:landmarks, x_kp:keypoints})
            # acivations: (5000, 2048)



            if idx < 10:
                print("saving img")
                manip = manip.astype(np.float32) / 255 * 2.0 - 1.0
                debug_img = np.concatenate([
                    images, # original landmarks
                    landmarks, # original portraits,
                    manip
                ], axis=0)
                debug_img = adjust_pixel_range(debug_img)
                debug_img = fuse_images(debug_img, row=3, col=minibatch_size)
                save_image("fid_iter_{}08d.png".format(idx), debug_img)


            if end == self.num_images:
                break

        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)
        #print("enc_mu_fake={}, enc_sigma_fake={}".format(mu_fake, sigma_fake))


        # Calculate FID.
        print("Calculate FID for encoded samples")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="Our Face-Landmark-Encoder (Apperance)")
        print("distance OUR FACE-LANDMARK-ENCODER", dist)


#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

        announce("Now evaluating encoder. (POSE)")
        print("building custom encoder graph!")
        with tf.variable_scope('fakeddddoptimizer'):

            # Build graph.
            BATCH_SIZE = self.minibatch_per_gpu
            input_shape = Inv.input_shape
            input_shape[0] = BATCH_SIZE
            latent_shape = Gs.components.synthesis.input_shape
            latent_shape[0] = BATCH_SIZE

            x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')
            x_lm = tf.placeholder(tf.float32, shape=input_shape, name='some_landmark')
            x_kp = tf.placeholder(tf.float32, shape=[self.minibatch_per_gpu, 136], name='some_keypoints')

            if self.model_type == "rignet":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_lm, phase=False)
            elif self.model_type == "keypoints":
                w_enc_1 = Inv.get_output_for(x, phase=False)
                wp_enc_1 = tf.reshape(w_enc_1, latent_shape)
                w_enc = E.get_output_for(wp_enc_1, x_kp, phase=False)
            else:
                w_enc = E.get_output_for(x, x_lm, phase=False)

            wp_enc = tf.reshape(w_enc, latent_shape)

            manipulated_images = Gs.components.synthesis.get_output_for(wp_enc, randomize_noise=False)
            manipulated_images = tflib.convert_images_to_uint8(manipulated_images)
            inception_codes = inception_clone.get_output_for(manipulated_images) # shape (8, 2048)

        for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):

            image_data = data[0]
            images = image_data[:,0,:,:,:]
            landmarks = np.roll(image_data[:,1,:,:,:], shift=1, axis=0)
            
            keypoints = np.roll(data[1], shift=1, axis=0)

            images = images.astype(np.float32) / 255 * 2.0 - 1.0
            landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0

            begin = idx * minibatch_size
            end = min(begin + minibatch_size, self.num_images) # begin: 0; end: 8

            activations[begin:end], manip  = tflib.run([inception_codes, manipulated_images], feed_dict={x:images, x_lm:landmarks, x_kp:keypoints})
            # acivations: (5000, 2048)



            if idx < 10:
                print("saving img")
                manip = manip.astype(np.float32) / 255 * 2.0 - 1.0
                debug_img = np.concatenate([
                    images, # original landmarks
                    landmarks, # original portraits,
                    manip
                ], axis=0)
                debug_img = adjust_pixel_range(debug_img)
                debug_img = fuse_images(debug_img, row=3, col=minibatch_size)
                save_image("fid_iter_POSE_{}08d.png".format(idx), debug_img)


            if end == self.num_images:
                break

        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)
        #print("enc_mu_fake={}, enc_sigma_fake={}".format(mu_fake, sigma_fake))


        # Calculate FID.
        print("Calculate FID for encoded samples (POSE)")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="Our_Face_Landmark_Encoder (Pose)")
        print("distance OUR FACE-LANDMARK-ENCODER (POSE)", dist)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

        announce("Now in domain inversion only encoder.")
        print("building custom in domain inversion graph!")
        with tf.variable_scope('fakedddwdoptimizer'):

            # Build graph.
            BATCH_SIZE = self.minibatch_per_gpu
            input_shape = Inv.input_shape
            input_shape[0] = BATCH_SIZE
            latent_shape = Gs.components.synthesis.input_shape
            latent_shape[0] = BATCH_SIZE

            x = tf.placeholder(tf.float32, shape=input_shape, name='real_image')

            w_enc_1 = Inv.get_output_for(x, phase=False)
            wp_enc_1 = tf.reshape(w_enc_1, latent_shape)

            manipulated_images = Gs.components.synthesis.get_output_for(wp_enc_1, randomize_noise=False)
            manipulated_images = tflib.convert_images_to_uint8(manipulated_images)
            inception_codes = inception_clone.get_output_for(manipulated_images)

        for idx, data in tqdm(enumerate(self._iterate_reals(minibatch_size=minibatch_size)), position=0, leave=True):
            batch_stacks = data[0]
            images = batch_stacks[:,0,:,:,:]
            landmarks = batch_stacks[:,1,:,:,:]
            images = images.astype(np.float32) / 255 * 2.0 - 1.0
            landmarks = landmarks.astype(np.float32) / 255 * 2.0 - 1.0

            #print("landmarks", landmarks.shape)# (8, 3, 128, 128)
            #print("images", images.shape) # (8, 3, 128, 128)
            #print("inception_codes", inception_codes.shape) # (8, 2048)
            #print("activations", activations.shape) # (5000, 2048)
            begin = idx * minibatch_size
            end = min(begin + minibatch_size, self.num_images)
            #print("b,e", begin, end) # 0, 8; ...

            activations[begin:end]  = tflib.run(inception_codes, feed_dict={x:images})

            if end == self.num_images:
                break

        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)
        #print("enc_mu_fake={}, enc_sigma_fake={}".format(mu_fake, sigma_fake))


        # Calculate FID.
        print("Calculate FID for IN-DOMAIN-GAN-INVERSION")
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist), suffix="_In-Domain-Inversion_Only")
        print("distance IN-DOMAIN-GAN-INVERSION:", dist)
Esempio n. 28
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/inception_v3_features.pkl')
        activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            for idx, reals in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):
                images, labels = reals
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)

                # Replace rotation label with random rotation
                random_index = tf.cast(tf.floor(tf.random_uniform([self.minibatch_per_gpu], minval=0, maxval=8)), dtype=tf.int32)
                random_one_hot = tf.one_hot(random_index, depth=8)
                labels = tf.concat([
                    labels[:, :self.rotation_offset],
                    random_one_hot,
                    labels[:, self.rotation_offset + 8:]
                ], axis=-1)

                # Interpolate with neighboring rotation label
                rotations = labels[:, self.rotation_offset:self.rotation_offset + 8]
                rotation_index = tf.cast(tf.argmax(rotations, axis=1), dtype=tf.int32)
                rotation_shift = tf.cast(
                    tf.where(tf.random_uniform(shape=[self.minibatch_per_gpu], minval=-1, maxval=1) > 0,
                             tf.ones([self.minibatch_per_gpu]),
                             tf.ones([self.minibatch_per_gpu]) * -1), dtype=tf.int32)
                new_rotation_index = tf.mod(rotation_index + rotation_shift, 8)
                new_rotation = tf.cast(tf.one_hot(new_rotation_index, 8), dtype=tf.int32)
                new_rotation = new_rotation * tf.cast(
                    tf.reduce_max(labels[:, self.rotation_offset:self.rotation_offset + 8], axis=-1, keepdims=True),
                    dtype=tf.int32)
                new_rotation = tf.cast(new_rotation, dtype=tf.float32)
                labels_copy = tf.identity(labels)
                labels_neighbor = tf.concat(
                    [labels_copy[:, :self.rotation_offset], new_rotation, labels_copy[:, self.rotation_offset + 8:]], axis=-1)
                interpolation_mag = tf.random_uniform(shape=[self.minibatch_per_gpu, 1, 1], minval=0, maxval=1)
                if self.latent_space == 'w':
                    dlatent_neighbor = Gs_clone.components.mapping.get_output_for(latents, labels_neighbor)
                    dlatent = Gs_clone.components.mapping.get_output_for(latents, labels)
                    interpolation_mag = tf.tile(interpolation_mag, [1, tf.shape(dlatent)[1], tf.shape(dlatent)[2]])
                    dlatent_interpolate = dlatent * interpolation_mag + dlatent_neighbor * (1 - interpolation_mag)
                    images = Gs_clone.components.synthesis.get_output_for(dlatent_interpolate)
                elif self.latent_space == 'z':
                    interpolation_mag = tf.tile(interpolation_mag[:, 0], [1, tf.shape(labels)[1]])
                    labels_interpolate = labels * interpolation_mag + labels_neighbor * (1 - interpolation_mag)
                    images = Gs_clone.get_output_for(latents, labels_interpolate, **Gs_kwargs)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2*s)
        self._report_result(np.real(dist))
    def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs):  # pylint: disable=arguments-differ
        minibatch_size = num_gpus * self.minibatch_per_gpu
        with dnnlib.util.open_url(
                'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_features.pkl'
        ) as f:  # identical to http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
            feature_net = pickle.load(f)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if self.use_cached_real_stats and os.path.isfile(cache_file):
            with open(cache_file, 'rb') as f:
                mu_real, sigma_real = pickle.load(f)
        else:
            nfeat = feature_net.output_shape[1]
            mu_real = np.zeros(nfeat)
            sigma_real = np.zeros([nfeat, nfeat])
            num_real = 0
            for images, _labels, num in self._iterate_reals(minibatch_size):
                if self.max_reals is not None:
                    num = min(num, self.max_reals - num_real)
                if images.shape[1] == 1:
                    images = np.tile(images, [1, 3, 1, 1])
                for feat in list(
                        feature_net.run(images,
                                        num_gpus=num_gpus,
                                        assume_frozen=True))[:num]:
                    mu_real += feat
                    sigma_real += np.outer(feat, feat)
                    num_real += 1
                if self.max_reals is not None and num_real >= self.max_reals:
                    break
            mu_real /= num_real
            sigma_real /= num_real
            sigma_real -= np.outer(mu_real, mu_real)
            with open(cache_file, 'wb') as f:
                pickle.dump((mu_real, sigma_real), f)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                feature_net_clone = feature_net.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                labels = self._get_random_labels_tf(self.minibatch_per_gpu)
                images = Gs_clone.get_output_for(latents, labels, **G_kwargs)
                if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1])
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(feature_net_clone.get_output_for(images))

        # Calculate statistics for fakes.
        feat_fake = []
        for begin in range(0, self.num_fakes, minibatch_size):
            self._report_progress(begin, self.num_fakes)
            feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0))
        feat_fake = np.stack(feat_fake[:self.num_fakes])
        mu_fake = np.mean(feat_fake, axis=0)
        sigma_fake = np.cov(feat_fake, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))
Esempio n. 30
0
    initial_dlatents,
    randomize_noise=
    True,  # Turns out this should not be off ever for trying to lean dlatents, who knew
    minibatch_size=1,
    custom_inputs=[
        partial(create_variable_for_generator, batch_size=1),
        partial(create_stub, batch_size=1)
    ],
    structure='fixed')

dlatent_variable = next(v for v in tf.global_variables()
                        if 'learnable_dlatents' in v.name)
generator_output = tf.get_default_graph().get_tensor_by_name(
    'G_synthesis_1/_Run/G_synthesis/images_out:0')
generated_image = tflib.convert_images_to_uint8(generator_output,
                                                nchw_to_nhwc=True,
                                                uint8_cast=False)
generated_image_uint8 = tf.saturate_cast(generated_image, tf.uint8)

# Loss part
vgg16 = VGG16(include_top=False, input_shape=(512, 512, 3))
perceptual_model = keras.Model(vgg16.input, vgg16.layers[9].output)
generated_img_features = perceptual_model(
    preprocess_input(generated_image, mode="tf"))
ref_img = tf.get_variable('ref_img',
                          shape=generated_image.shape,
                          dtype='float32',
                          initializer=tf.zeros_initializer())
ref_img_features = tf.get_variable('ref_img_features',
                                   shape=generated_img_features.shape,
                                   dtype='float32',