Example #1
0
def process_reals(x, labels, lod, mirror_augment, drange_data, drange_net):
    with tf.name_scope('DynamicRange'):
        x = tf.cast(x, tf.float32)
        x = misc.adjust_dynamic_range(x, drange_data, drange_net)
    if mirror_augment:
        with tf.name_scope('MirrorAugment'):
            x = tf.where(
                tf.random_uniform([tf.shape(x)[0]]) < 0.5, x,
                tf.reverse(x, [3]))
    with tf.name_scope(
            'FadeLOD'
    ):  # Smooth crossfade between consecutive levels-of-detail.
        s = tf.shape(x)
        y = tf.reshape(x, [-1, s[1], s[2] // 2, 2, s[3] // 2, 2])
        y = tf.reduce_mean(y, axis=[3, 5], keepdims=True)
        y = tf.tile(y, [1, 1, 1, 2, 1, 2])
        y = tf.reshape(y, [-1, s[1], s[2], s[3]])
        x = tflib.lerp(x, y, lod - tf.floor(lod))
    with tf.name_scope(
            'UpscaleLOD'
    ):  # Upscale to match the expected input/output size of the networks.
        s = tf.shape(x)
        factor = tf.cast(2**tf.floor(lod), tf.int32)
        x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
        x = tf.tile(x, [1, 1, 1, factor, 1, factor])
        x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
    return x, labels
Example #2
0
 def grow(x, res, lod):
     y = block(res, x)
     img = lambda: upscale2d(torgb(res, y), 2**lod)
     img = cset(
         img, (lod_in > lod), lambda: upscale2d(
             tflib.lerp(torgb(res, y), upscale2d(torgb(res - 1, x)),
                        lod_in - lod), 2**lod))
     if lod > 0:
         img = cset(img, (lod_in < lod),
                    lambda: grow(y, res + 1, lod - 1))
     return img()
Example #3
0
 def grow(res, lod):
     x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)
     if lod > 0:
         x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
     x = block(x(), res)
     y = lambda: x
     if res > 2:
         y = cset(
             y, (lod_in > lod), lambda: tflib.lerp(
                 x,
                 fromrgb(downscale2d(images_in, 2**(lod + 1)), res - 1),
                 lod_in - lod))
     return y()
Example #4
0
File: loss.py Project: johndpope/BA
def D_wgan_gp(G,
              D,
              opt,
              training_set,
              minibatch_size,
              reals,
              labels,
              wgan_lambda=10.0,
              wgan_epsilon=0.001,
              wgan_target=1.0):
    _ = opt, training_set
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    real_scores_out = D.get_output_for(reals, labels, is_training=True)
    fake_scores_out = D.get_output_for(fake_images_out,
                                       labels,
                                       is_training=True)
    real_scores_out = autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
    loss = fake_scores_out - real_scores_out
    with tf.name_scope('EpsilonPenalty'):
        epsilon_penalty = autosummary('Loss/epsilon_penalty',
                                      tf.square(real_scores_out))
    loss += epsilon_penalty * wgan_epsilon

    with tf.name_scope('GradientPenalty'):
        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1],
                                           0.0,
                                           1.0,
                                           dtype=fake_images_out.dtype)
        mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype),
                                      fake_images_out, mixing_factors)
        mixed_scores_out = D.get_output_for(mixed_images_out,
                                            labels,
                                            is_training=True)
        mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)
        mixed_grads = tf.gradients(tf.reduce_sum(mixed_scores_out),
                                   [mixed_images_out])[0]
        mixed_norms = tf.sqrt(
            tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3]))
        mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)
        gradient_penalty = tf.square(mixed_norms - wgan_target)
        reg = gradient_penalty * (wgan_lambda / (wgan_target**2))
    return loss, reg
Example #5
0
 def grow(res, lod):
     x = lambda: fromrgb(naive_downsample_2d(images_in, factor=2**lod), res)
     if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
     x = block(x(), res); y = lambda: x
     y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(naive_downsample_2d(images_in, factor=2**(lod+1)), res - 1), lod_in - lod))
     return y()
Example #6
0
 def grow(x, res, lod):
     y = block(res, x)
     img = lambda: naive_upsample_2d(torgb(res, y), factor=2**lod)
     img = cset(img, (lod_in > lod), lambda: naive_upsample_2d(tflib.lerp(torgb(res, y), upsample_2d(torgb(res - 1, x)), lod_in - lod), factor=2**lod))
     if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))
     return img()
Example #7
0
def G_main(
    latents_in,                                         # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,                                          # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi          = 0.5,                      # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff       = None,                     # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val      = None,                     # Value for truncation_psi to use during validation.
    truncation_cutoff_val   = None,                     # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta        = 0.995,                    # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob       = 0.9,                      # Probability of mixing styles during training. None = disable.
    is_training             = False,                    # Network is under training? Enables and disables specific features.
    is_validation           = False,                    # Network is under validation? Chooses which value to use for truncation_psi.
    return_dlatents         = False,                    # Return dlatents in addition to the images?
    is_template_graph       = False,                    # True = template graph constructed by the Network class, False = actual evaluation.
    components              = EasyDict(),        # Container for sub-networks. Retained between calls.
    mapping_func            = 'G_mapping',              # Build func name for the mapping network.
    synthesis_func          = 'G_synthesis_stylegan2',  # Build func name for the synthesis network.
    **kwargs):                                          # Arguments for sub-networks (mapping and synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
        truncation_psi = None
    if is_training:
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs)
    dlatents = tf.cast(dlatents, tf.float32)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.variable_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(latents2, labels_in, is_training=is_training, **kwargs)
            dlatents2 = tf.cast(dlatents2, tf.float32)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)

    # Apply truncation trick.
    if truncation_psi is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            layer_psi = np.ones(layer_idx.shape, dtype=np.float32)
            if truncation_cutoff is None:
                layer_psi *= truncation_psi
            else:
                layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi)
            dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi)

    # Evaluate synthesis network.
    deps = []
    if 'lod' in components.synthesis.vars:
        deps.append(tf.assign(components.synthesis.vars['lod'], lod_in))
    with tf.control_dependencies(deps):
        images_out = components.synthesis.get_output_for(dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs)

    # Return requested outputs.
    images_out = tf.identity(images_out, name='images_out')
    if return_dlatents:
        return images_out, dlatents
    return images_out
Example #8
0
    def _evaluate(self, Gs, Gs_kwargs, num_gpus):
        Gs_kwargs = dict(Gs_kwargs)
        Gs_kwargs.update(self.Gs_overrides)
        minibatch_size = num_gpus * self.minibatch_per_gpu

        # Construct TensorFlow graph.
        distance_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                noise_vars = [
                    var for name, var in
                    Gs_clone.components.synthesis.vars.items()
                    if name.startswith('noise')
                ]

                # Generate random latents and interpolation t-values.
                lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] +
                                           Gs_clone.input_shape[1:])
                lerp_t = tf.random_uniform(
                    [self.minibatch_per_gpu], 0.0,
                    1.0 if self.sampling == 'full' else 0.0)
                labels = tf.reshape(
                    tf.tile(self._get_random_labels_tf(self.minibatch_per_gpu),
                            [1, 2]), [self.minibatch_per_gpu * 2, -1])

                # Interpolate in W or Z.
                if self.space == 'w':
                    dlat_t01 = Gs_clone.components.mapping.get_output_for(
                        lat_t01, labels, **Gs_kwargs)
                    dlat_t01 = tf.cast(dlat_t01, tf.float32)
                    dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]
                    dlat_e0 = tflib.lerp(dlat_t0, dlat_t1,
                                         lerp_t[:, np.newaxis, np.newaxis])
                    dlat_e1 = tflib.lerp(
                        dlat_t0, dlat_t1,
                        lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)
                    dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1),
                                          dlat_t01.shape)
                else:  # space == 'z'
                    lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]
                    lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])
                    lat_e1 = slerp(lat_t0, lat_t1,
                                   lerp_t[:, np.newaxis] + self.epsilon)
                    lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1),
                                         lat_t01.shape)
                    dlat_e01 = Gs_clone.components.mapping.get_output_for(
                        lat_e01, labels, **Gs_kwargs)

                # Synthesize images.
                with tf.control_dependencies([
                        var.initializer for var in noise_vars
                ]):  # use same noise inputs for the entire minibatch
                    images = Gs_clone.components.synthesis.get_output_for(
                        dlat_e01, randomize_noise=False, **Gs_kwargs)
                    images = tf.cast(images, tf.float32)

                # Crop only the face region.
                if self.crop:
                    c = int(images.shape[2] // 8)
                    images = images[:, :, c * 3:c * 7, c * 2:c * 6]

                # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
                factor = images.shape[2] // 256
                if factor > 1:
                    images = tf.reshape(images, [
                        -1, images.shape[1], images.shape[2] // factor, factor,
                        images.shape[3] // factor, factor
                    ])
                    images = tf.reduce_mean(images, axis=[3, 5])

                # Scale dynamic range from [-1,1] to [0,255] for VGG.
                images = (images + 1) * (255 / 2)

                # Evaluate perceptual distance.
                img_e0, img_e1 = images[0::2], images[1::2]
                distance_measure = misc.load_pkl(
                    'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl'
                )
                distance_expr.append(
                    distance_measure.get_output_for(img_e0, img_e1) *
                    (1 / self.epsilon**2))

        # Sampling loop.
        all_distances = []
        for begin in range(0, self.num_samples, minibatch_size):
            self._report_progress(begin, self.num_samples)
            all_distances += tflib.run(distance_expr)
        all_distances = np.concatenate(all_distances, axis=0)

        # Reject outliers.
        lo = np.percentile(all_distances, 1, interpolation='lower')
        hi = np.percentile(all_distances, 99, interpolation='higher')
        filtered_distances = np.extract(
            np.logical_and(lo <= all_distances, all_distances <= hi),
            all_distances)
        self._report_result(np.mean(filtered_distances))