def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu

        graph_def = tf.GraphDef()
        with misc.open_file_or_url(
                'http://rail.eecs.berkeley.edu/models/lpips/net-lin_alex_v0.1.pb'
        ) as f:
            graph_def.ParseFromString(f.read())

        # Construct TensorFlow graph.
        self._configure(self.minibatch_per_gpu)
        result_expr = []
        for gpu_idx in range(num_gpus):

            def auto_gpu(opr):
                if opr.type in ['SparseToDense', 'Tile', 'GatherV2', 'Pack']:
                    return '/cpu:0'
                else:
                    return '/gpu:%d' % gpu_idx

            with tf.device(auto_gpu):
                Gs_clone = Gs.clone()
                reals, labels = self._get_minibatch_tf()
                reals = tflib.convert_images_from_uint8(reals)
                masks = self._get_random_masks_tf()

                latents0 = tf.random_normal([self.minibatch_per_gpu] +
                                            Gs_clone.input_shape[1:])
                fakes0 = Gs_clone.get_output_for(latents0, labels, reals,
                                                 masks,
                                                 **Gs_kwargs)[:, :3, :, :]
                fakes0 = tf.clip_by_value(fakes0, -1.0, 1.0)

                latents1 = tf.random_normal([self.minibatch_per_gpu] +
                                            Gs_clone.input_shape[1:])
                fakes1 = Gs_clone.get_output_for(latents1, labels, reals,
                                                 masks,
                                                 **Gs_kwargs)[:, :3, :, :]
                fakes1 = tf.clip_by_value(fakes1, -1.0, 1.0)

                distance, = tf.import_graph_def(graph_def,
                                                input_map={
                                                    '0:0': fakes0,
                                                    '1:0': fakes1
                                                },
                                                return_elements=['Reshape_10'])
                result_expr.append(distance.outputs)

        # Run metric
        results = []
        for begin in range(0, self.num_pairs, minibatch_size):
            self._report_progress(begin, self.num_pairs)
            res = tflib.run(result_expr)
            results.append(np.reshape(res, [-1]))
        results = np.concatenate(results)
        self._report_result(np.mean(results))
        self._report_result(np.std(results), suffix='-var')
コード例 #2
0
    def build_perceptual_model(self, generator, discriminator=None):
        # Learning rate
        global_step = tf.Variable(0,
                                  dtype=tf.int32,
                                  trainable=False,
                                  name="global_step")
        incremented_global_step = tf.assign_add(global_step, 1)
        self._reset_global_step = tf.assign(global_step, 0)
        self.learning_rate = tf.train.exponential_decay(
            self.lr,
            incremented_global_step,
            self.decay_steps,
            self.decay_rate,
            staircase=True)
        self.sess.run([self._reset_global_step])

        if self.discriminator_loss is not None:
            self.discriminator = discriminator

        generated_image_tensor = generator.generated_image
        generated_image = tf.image.resize_nearest_neighbor(
            generated_image_tensor, (self.img_size, self.img_size),
            align_corners=True)

        self.ref_img = tf.get_variable('ref_img',
                                       shape=generated_image.shape,
                                       dtype='float32',
                                       initializer=tf.initializers.zeros())
        self.ref_weight = tf.get_variable('ref_weight',
                                          shape=generated_image.shape,
                                          dtype='float32',
                                          initializer=tf.initializers.zeros())
        self.add_placeholder("ref_img")
        self.add_placeholder("ref_weight")

        if (self.vgg_loss is not None):
            # Use locally stored weigths if available
            if os.path.isfile(VGG16_WEIGHTS_NOTOP_LOCAL):
                weights = VGG16_WEIGHTS_NOTOP_LOCAL
            else:
                weights = 'imagenet'
            vgg16 = VGG16(include_top=False,
                          weights=weights,
                          input_shape=(self.img_size, self.img_size, 3))
            self.perceptual_model = Model(vgg16.input,
                                          vgg16.layers[self.layer].output)
            generated_img_features = self.perceptual_model(
                preprocess_input(self.ref_weight * generated_image))
            self.ref_img_features = tf.get_variable(
                'ref_img_features',
                shape=generated_img_features.shape,
                dtype='float32',
                initializer=tf.initializers.zeros())
            self.features_weight = tf.get_variable(
                'features_weight',
                shape=generated_img_features.shape,
                dtype='float32',
                initializer=tf.initializers.zeros())
            self.sess.run([
                self.features_weight.initializer,
                self.features_weight.initializer
            ])
            self.add_placeholder("ref_img_features")
            self.add_placeholder("features_weight")

        if self.perc_model is not None and self.lpips_loss is not None:
            img1 = tflib.convert_images_from_uint8(self.ref_weight *
                                                   self.ref_img,
                                                   nhwc_to_nchw=True)
            img2 = tflib.convert_images_from_uint8(self.ref_weight *
                                                   generated_image,
                                                   nhwc_to_nchw=True)

        self.loss = 0
        # L1 loss on VGG16 features
        if (self.vgg_loss is not None):
            if self.adaptive_loss:
                self.loss += self.vgg_loss * tf_custom_adaptive_loss(
                    self.features_weight * self.ref_img_features,
                    self.features_weight * generated_img_features)
            else:
                self.loss += self.vgg_loss * tf_custom_logcosh_loss(
                    self.features_weight * self.ref_img_features,
                    self.features_weight * generated_img_features)
        # + logcosh loss on image pixels
        if (self.pixel_loss is not None):
            if self.adaptive_loss:
                self.loss += self.pixel_loss * tf_custom_adaptive_rgb_loss(
                    self.ref_weight * self.ref_img,
                    self.ref_weight * generated_image)
            else:
                self.loss += self.pixel_loss * tf_custom_logcosh_loss(
                    self.ref_weight * self.ref_img,
                    self.ref_weight * generated_image)
        # + MS-SIM loss on image pixels
        if (self.mssim_loss is not None):
            self.loss += self.mssim_loss * tf.math.reduce_mean(
                1 - tf.image.ssim_multiscale(self.ref_weight *
                                             self.ref_img, self.ref_weight *
                                             generated_image, 1))
        # + extra perceptual loss on image pixels
        if self.perc_model is not None and self.lpips_loss is not None:
            self.loss += self.lpips_loss * tf.math.reduce_mean(
                self.perc_model.get_output_for(img1, img2))
        # + L1 penalty on dlatent weights
        if self.l1_penalty is not None:
            self.loss += self.l1_penalty * 512 * tf.math.reduce_mean(
                tf.math.abs(generator.dlatent_variable -
                            generator.get_dlatent_avg()))
        # discriminator loss (realism)
        if self.discriminator_loss is not None:
            self.loss += self.discriminator_loss * tf.math.reduce_mean(
                self.discriminator.get_output_for(
                    tflib.convert_images_from_uint8(
                        generated_image_tensor, nhwc_to_nchw=True), self.stub))
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn'
        )  # inception_v3_features.pkl
        real_activations = np.empty(
            [self.num_images, inception.output_shape[1]], dtype=np.float32)
        fake_activations = np.empty(
            [self.num_images, inception.output_shape[1]], dtype=np.float32)

        # Construct TensorFlow graph.
        self._configure(self.minibatch_per_gpu, hole_range=self.hole_range)
        real_img_expr = []
        fake_img_expr = []
        real_result_expr = []
        fake_result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                reals, labels = self._get_minibatch_tf()
                reals_tf = tflib.convert_images_from_uint8(reals)
                masks = self._get_random_masks_tf()
                fakes = Gs_clone.get_output_for(latents, labels, reals_tf,
                                                masks, **Gs_kwargs)
                fakes = tflib.convert_images_to_uint8(fakes[:, :3])
                reals = tflib.convert_images_to_uint8(reals_tf[:, :3])
                real_img_expr.append(reals)
                fake_img_expr.append(fakes)
                real_result_expr.append(inception_clone.get_output_for(reals))
                fake_result_expr.append(inception_clone.get_output_for(fakes))

        for begin in tqdm(range(0, self.num_images, minibatch_size)):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            real_results, fake_results = tflib.run(
                [real_result_expr, fake_result_expr])
            real_activations[begin:end] = np.concatenate(real_results,
                                                         axis=0)[:end - begin]
            fake_activations[begin:end] = np.concatenate(fake_results,
                                                         axis=0)[:end - begin]

        # Calculate FID conviniently.
        mu_real = np.mean(real_activations, axis=0)
        sigma_real = np.cov(real_activations, rowvar=False)
        mu_fake = np.mean(fake_activations, axis=0)
        sigma_fake = np.cov(fake_activations, rowvar=False)
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist), suffix='-FID')

        svm = sklearn.svm.LinearSVC(dual=False)
        svm_inputs = np.concatenate([real_activations, fake_activations])
        svm_targets = np.array([1] * real_activations.shape[0] +
                               [0] * fake_activations.shape[0])
        svm.fit(svm_inputs, svm_targets)
        self._report_result(1 - svm.score(svm_inputs, svm_targets),
                            suffix='-U')
        real_outputs = svm.decision_function(real_activations)
        fake_outputs = svm.decision_function(fake_activations)
        self._report_result(np.mean(fake_outputs > real_outputs), suffix='-P')
コード例 #4
0
def training_loop(
        run_dir='.',  # Output directory.
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        loss_args={},  # Options for loss function.
        train_dataset_args={},  # Options for dataset to train with.
        # Options for dataset to evaluate metrics against.
    metric_dataset_args={},
        augment_args={},  # Options for adaptive augmentations.
        metric_arg_list=[],  # Metrics to evaluate during training.
        num_gpus=1,  # Number of GPUs to use.
        minibatch_size=32,  # Global minibatch size.
        minibatch_gpu=4,  # Number of samples processed at a time by one GPU.
        # Half-life of the exponential moving average (EMA) of generator weights.
    G_smoothing_kimg=10,
        G_smoothing_rampup=None,  # EMA ramp-up coefficient.
        # Number of minibatches to run in the inner loop.
    minibatch_repeats=4,
        lazy_regularization=True,  # Perform regularization as a separate training step?
        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    G_reg_interval=4,
        # How often the perform regularization for D? Ignored if lazy_regularization=False.
        D_reg_interval=16,
        # Total length of the training, measured in thousands of real images.
        total_kimg=25000,
        kimg_per_tick=4,  # Progress snapshot interval.
        # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    image_snapshot_ticks=50,
        # How often to save network snapshots? None = only save 'networks-final.pkl'.
        network_snapshot_ticks=50,
        resume_pkl=None,  # Network pickle to resume training from.
        # Callback function for determining whether to abort training.
    abort_fn=None,
        progress_fn=None,  # Callback function for updating training progress.
):
    assert minibatch_size % (num_gpus * minibatch_gpu) == 0
    start_time = time.time()

    print('Loading training set...')
    training_set = dataset.load_dataset(**train_dataset_args)
    print('Image shape:', np.int32(training_set.shape).tolist())
    print('Label shape:', [training_set.label_size])
    print()

    print('Constructing networks...')
    with tf.device('/gpu:0'):
        G = tflib.Network('G',
                          num_channels=training_set.shape[0],
                          resolution=training_set.shape[1],
                          label_size=training_set.label_size,
                          **G_args)
        D = tflib.Network('D',
                          num_channels=training_set.shape[0],
                          resolution=training_set.shape[1],
                          label_size=training_set.label_size,
                          **D_args)
        Gs = G.clone('Gs')
        if resume_pkl is not None:
            print(f'Resuming from "{resume_pkl}"')
            with dnnlib.util.open_url(resume_pkl) as f:
                rG, rD, rGs = pickle.load(f)
            G.copy_vars_from(rG)
            D.copy_vars_from(rD)
            Gs.copy_vars_from(rGs)
    G.print_layers()
    D.print_layers()

    print('Exporting sample images...')
    grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(
        training_set)
    save_image_grid(grid_reals,
                    os.path.join(run_dir, 'reals.png'),
                    drange=[0, 255],
                    grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=minibatch_gpu)
    # save_image_grid(grid_fakes, os.path.join(
    #     run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size)

    print(f'Replicating networks across {num_gpus} GPUs...')
    G_gpus = [G]
    D_gpus = [D]
    for gpu in range(1, num_gpus):
        with tf.device(f'/gpu:{gpu}'):
            G_gpus.append(G.clone(f'{G.name}_gpu{gpu}'))
            D_gpus.append(D.clone(f'{D.name}_gpu{gpu}'))

    print('Initializing augmentations...')
    aug = None
    if augment_args.get('class_name', None) is not None:
        aug = dnnlib.util.construct_class_by_name(**augment_args)
        aug.init_validation_set(D_gpus=D_gpus, training_set=training_set)

    print('Setting up optimizers...')
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args[
            'minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args:
                args['beta1'] **= mb_ratio
            if 'beta2' in args:
                args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    print('Constructing training graph...')
    data_fetch_ops = []
    training_set.configure(minibatch_gpu)
    for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)):
        with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'):

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                real_images_var = tf.Variable(
                    name='images',
                    trainable=False,
                    initial_value=tf.zeros([minibatch_gpu] +
                                           training_set.shape))
                real_labels_var = tf.Variable(name='labels',
                                              trainable=False,
                                              initial_value=tf.zeros([
                                                  minibatch_gpu,
                                                  training_set.label_size
                                              ]))
                real_images_write, real_labels_write = training_set.get_minibatch_tf(
                )
                real_images_write = tflib.convert_images_from_uint8(
                    real_images_write)
                data_fetch_ops += [
                    tf.assign(real_images_var, real_images_write)
                ]
                data_fetch_ops += [
                    tf.assign(real_labels_var, real_labels_write)
                ]

            # Evaluate loss function and register gradients.
            fake_labels = training_set.get_random_labels_tf(minibatch_gpu)
            terms = dnnlib.util.call_func_by_name(G=G_gpu,
                                                  D=D_gpu,
                                                  aug=aug,
                                                  fake_labels=fake_labels,
                                                  real_images=real_images_var,
                                                  real_labels=real_labels_var,
                                                  **loss_args)
            if lazy_regularization:
                if terms.G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(terms.G_reg * G_reg_interval),
                        G_gpu.trainables)
                if terms.D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(terms.D_reg * D_reg_interval),
                        D_gpu.trainables)
            else:
                if terms.G_reg is not None:
                    terms.G_loss += terms.G_reg
                if terms.D_reg is not None:
                    terms.D_loss += terms.D_reg
            G_opt.register_gradients(tf.reduce_mean(terms.G_loss),
                                     G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(terms.D_loss),
                                     D_gpu.trainables)

    print('Finalizing training ops...')
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[])
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in)
    tflib.init_uninitialized_vars()
    with tf.device('/gpu:0'):
        peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()

    print('Initializing metrics...')
    summary_log = tf.summary.FileWriter(run_dir)
    metrics = []
    for args in metric_arg_list:
        metric = dnnlib.util.construct_class_by_name(**args)
        metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir)
        metrics.append(metric)

    print(f'Training for {total_kimg} kimg...')
    print()
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    running_mb_counter = 0

    done = False
    while not done:

        # Compute EMA decay parameter.
        Gs_nimg = G_smoothing_kimg * 1000.0
        if G_smoothing_rampup is not None:
            Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup)
        Gs_beta = 0.5**(minibatch_size / max(Gs_nimg, 1e-8))

        # Run training ops.
        for _repeat_idx in range(minibatch_repeats):
            rounds = range(0, minibatch_size, minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op])
                if run_G_reg:
                    tflib.run(G_reg_op)
                tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta})
                if run_D_reg:
                    tflib.run(D_reg_op)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op)
                    if run_G_reg:
                        tflib.run(G_reg_op)
                tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta})
                for _round in rounds:
                    tflib.run(data_fetch_op)
                    tflib.run(D_train_op)
                    if run_D_reg:
                        tflib.run(D_reg_op)

            # Run validation.
            if aug is not None:
                aug.run_validation(minibatch_size=minibatch_size)

        # Tune augmentation parameters.
        if aug is not None:
            aug.tune(minibatch_size * minibatch_repeats)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None
                                                   and abort_fn())
        if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_end_time = time.time()
            total_time = tick_end_time - start_time
            tick_time = tick_end_time - tick_start_time

            # Report progress.
            print(' '.join([
                f"tick {autosummary('Progress/tick', cur_tick):<5d}",
                f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}",
                f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}",
                f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}",
                f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}",
                f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}",
                f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}",
                f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}",
            ]))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            if progress_fn is not None:
                progress_fn(cur_nimg // 1000, total_kimg)

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    done or cur_tick % image_snapshot_ticks == 0):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=minibatch_gpu)
                save_image_grid(grid_fakes,
                                os.path.join(
                                    run_dir,
                                    f'fakes{cur_nimg // 1000:06d}.png'),
                                drange=[-1, 1],
                                grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    done or cur_tick % network_snapshot_ticks == 0):
                pkl = os.path.join(
                    run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl')
                with open(pkl, 'wb') as f:
                    pickle.dump((G, D, Gs), f)
                if len(metrics):
                    print('Evaluating metrics...')
                    for metric in metrics:
                        metric.run(pkl, num_gpus=num_gpus)

            # Update summaries.
            for metric in metrics:
                metric.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            tick_start_time = time.time()
            maintenance_time = tick_start_time - tick_end_time

    print()
    print('Exiting...')
    summary_log.close()
    training_set.close()
コード例 #5
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn'
        )  # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(
            num_images=self.ref_samples)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            real_activations = np.empty(
                [self.ref_samples, inception.output_shape[1]],
                dtype=np.float32)
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size,
                                        is_training=self.ref_train)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.ref_samples)
                real_activations[begin:end] = inception.run(images[:end -
                                                                   begin, :3],
                                                            num_gpus=num_gpus,
                                                            assume_frozen=True)
                if end == self.ref_samples:
                    break
            mu_real = np.mean(real_activations, axis=0)
            sigma_real = np.cov(real_activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        self._configure(self.minibatch_per_gpu)
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                reals, labels = self._get_minibatch_tf()
                reals = tflib.convert_images_from_uint8(reals)
                masks = self._get_random_masks_tf()
                images = Gs_clone.get_output_for(latents, labels, reals, masks,
                                                 **Gs_kwargs)
                images = images[:, :3]
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            self._report_progress(begin, self.num_images)
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))