Example #1
0
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu

        # Construct TensorFlow graph for each GPU.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()

                # Generate images.
                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
                dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True)
                images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True)

                # Downsample to 256x256. The attribute classifiers were built for 256x256.
                if images.shape[2] > 256:
                    factor = images.shape[2] // 256
                    images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
                    images = tf.reduce_mean(images, axis=[3, 5])

                # Run classifier for each attribute.
                result_dict = dict(latents=latents, dlatents=dlatents[:,-1])
                for attrib_idx in self.attrib_indices:
                    classifier = misc.load_pkl(classifier_urls[attrib_idx])
                    logits = classifier.get_output_for(images, None)
                    predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1))
                    result_dict[attrib_idx] = predictions
                result_expr.append(result_dict)

        # Sampling loop.
        results = []
        for _ in range(0, self.num_samples, minibatch_size):
            results += tflib.run(result_expr)
        results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()}

        # Calculate conditional entropy for each attribute.
        conditional_entropies = defaultdict(list)
        for attrib_idx in self.attrib_indices:
            # Prune the least confident samples.
            pruned_indices = list(range(self.num_samples))
            pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i]))
            pruned_indices = pruned_indices[:self.num_keep]

            # Fit SVM to the remaining samples.
            svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1)
            for space in ['latents', 'dlatents']:
                svm_inputs = results[space][pruned_indices]
                try:
                    svm = sklearn.svm.LinearSVC()
                    svm.fit(svm_inputs, svm_targets)
                    svm.score(svm_inputs, svm_targets)
                    svm_outputs = svm.predict(svm_inputs)
                except:
                    svm_outputs = svm_targets # assume perfect prediction

                # Calculate conditional entropy.
                p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)]
                conditional_entropies[space].append(conditional_entropy(p))
Example #2
0
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            'https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn'
        )  # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end - begin],
                                                       num_gpus=num_gpus,
                                                       assume_frozen=True)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                images = Gs_clone.get_output_for(latents,
                                                 None,
                                                 is_validation=True,
                                                 randomize_noise=True)
                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))
Example #3
0
def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips):
    print(png)
    Gsc = Gs.clone()
    noise_vars = [
        var for name, var in Gsc.components.synthesis.vars.items()
        if name.startswith('noise')
    ]
    noise_pairs = list(zip(noise_vars,
                           tflib.run(noise_vars)))  # [(var, val), ...]
    latents = np.stack(
        np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
    all_images = []
    for noise_range in noise_ranges:
        tflib.set_vars({
            var: val * (1 if i in noise_range else 0)
            for i, (var, val) in enumerate(noise_pairs)
        })
        range_images = Gsc.run(latents,
                               None,
                               truncation_psi=1,
                               randomize_noise=False,
                               **synthesis_kwargs)
        range_images[flips, :, :] = range_images[flips, :, ::-1]
        all_images.append(list(range_images))

    canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white')
    for col, col_images in enumerate(zip(*all_images)):
        canvas.paste(
            PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w // 2, h)),
            (col * w, 0))
        canvas.paste(
            PIL.Image.fromarray(col_images[1], 'RGB').crop((w // 2, 0, w, h)),
            (col * w + w // 2, 0))
        canvas.paste(
            PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w // 2, h)),
            (col * w, h))
        canvas.paste(
            PIL.Image.fromarray(col_images[3], 'RGB').crop((w // 2, 0, w, h)),
            (col * w + w // 2, h))
    canvas.save(png)
Example #4
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id='latest',  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        # Load pre-trained
        if resume_run_id is not None:
            if resume_run_id == 'latest':
                network_pkl, resume_kimg = misc.locate_latest_pkl()
                print('Loading networks from "%s"...' % network_pkl)
                G, D, Gs = misc.load_pkl(network_pkl)

            elif resume_run_id == 'restore_partial':
                print('Restore partially...')
                # Initialize networks
                G = tflib.Network('G',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **G_args)
                D = tflib.Network('D',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **D_args)
                Gs = G.clone('Gs')

                # Load pre-trained networks
                assert restore_partial_fn != None
                G_partial, D_partial, Gs_partial = pickle.load(
                    open(restore_partial_fn, 'rb'))

                # Restore (subset of) pre-trained weights
                # (only parameters that match both name and shape)
                G.copy_compatible_trainables_from(G_partial)
                D.copy_compatible_trainables_from(D_partial)
                Gs.copy_compatible_trainables_from(Gs_partial)

            else:
                network_pkl = misc.locate_network_pkl(resume_run_id,
                                                      resume_snapshot)
                print('Loading networks from "%s"...' % network_pkl)
                G, D, Gs = misc.load_pkl(network_pkl)

        # Start from scratch
        else:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Example #5
0
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu

        # Construct TensorFlow graph.
        distance_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                noise_vars = [
                    var for name, var in
                    Gs_clone.components.synthesis.vars.items()
                    if name.startswith('noise')
                ]

                # Generate random latents and interpolation t-values.
                lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] +
                                           Gs_clone.input_shape[1:])
                lerp_t = tf.random_uniform(
                    [self.minibatch_per_gpu], 0.0,
                    1.0 if self.sampling == 'full' else 0.0)

                # Interpolate in W or Z.
                if self.space == 'w':
                    dlat_t01 = Gs_clone.components.mapping.get_output_for(
                        lat_t01, None, is_validation=True)
                    dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]
                    dlat_e0 = tflib.lerp(dlat_t0, dlat_t1,
                                         lerp_t[:, np.newaxis, np.newaxis])
                    dlat_e1 = tflib.lerp(
                        dlat_t0, dlat_t1,
                        lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)
                    dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1),
                                          dlat_t01.shape)
                else:  # space == 'z'
                    lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]
                    lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])
                    lat_e1 = slerp(lat_t0, lat_t1,
                                   lerp_t[:, np.newaxis] + self.epsilon)
                    lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1),
                                         lat_t01.shape)
                    dlat_e01 = Gs_clone.components.mapping.get_output_for(
                        lat_e01, None, is_validation=True)

                # Synthesize images.
                with tf.control_dependencies([
                        var.initializer for var in noise_vars
                ]):  # use same noise inputs for the entire minibatch
                    images = Gs_clone.components.synthesis.get_output_for(
                        dlat_e01, is_validation=True, randomize_noise=False)

                # Crop only the face region.
                c = int(images.shape[2] // 8)
                images = images[:, :, c * 3:c * 7, c * 2:c * 6]

                # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
                if images.shape[2] > 256:
                    factor = images.shape[2] // 256
                    images = tf.reshape(images, [
                        -1, images.shape[1], images.shape[2] // factor, factor,
                        images.shape[3] // factor, factor
                    ])
                    images = tf.reduce_mean(images, axis=[3, 5])

                # Scale dynamic range from [-1,1] to [0,255] for VGG.
                images = (images + 1) * (255 / 2)

                # Evaluate perceptual distance.
                img_e0, img_e1 = images[0::2], images[1::2]
                distance_measure = misc.load_pkl(
                    'https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2'
                )  # vgg16_zhang_perceptual.pkl
                distance_expr.append(
                    distance_measure.get_output_for(img_e0, img_e1) *
                    (1 / self.epsilon**2))

        # Sampling loop.
        all_distances = []
        for _ in range(0, self.num_samples, minibatch_size):
            all_distances += tflib.run(distance_expr)
        all_distances = np.concatenate(all_distances, axis=0)

        # Reject outliers.
        lo = np.percentile(all_distances, 1, interpolation='lower')
        hi = np.percentile(all_distances, 99, interpolation='higher')
        filtered_distances = np.extract(
            np.logical_and(lo <= all_distances, all_distances <= hi),
            all_distances)
        self._report_result(np.mean(filtered_distances))
Example #6
0
 def get_random_labels_np(self, minibatch_size):  # => labels
     self.configure(minibatch_size)
     if self._tf_labels_np is None:
         self._tf_labels_np = self.get_random_labels_tf(minibatch_size)
     return tflib.run(self._tf_labels_np)
Example #7
0
 def get_minibatch_np(self, minibatch_size, lod=0):  # => images, labels
     self.configure(minibatch_size, lod)
     if self._tf_minibatch_np is None:
         self._tf_minibatch_np = self.get_minibatch_tf()
     return tflib.run(self._tf_minibatch_np)