Esempio n. 1
0
 def grow(x, res, lod):
     y = block(res, x)
     img = lambda: naive_upsample_2d(torgb(res, y), factor=2**lod)
     img = cset(
         img, (lod_in > lod), lambda: naive_upsample_2d(tflib.lerp(
             torgb(res, y), upsample_2d(torgb(res - 1, x)), lod_in - lod
         ),
                                                        factor=2**lod))
     if lod > 0:
         img = cset(img, (lod_in < lod),
                    lambda: grow(y, res + 1, lod - 1))
     return img()
Esempio n. 2
0
 def upsample(y):
     with tf.variable_scope('Upsample'):
         return upsample_2d(y, k=resample_kernel)
def G_synthesis_stylegan_revised(
    dlatents_in,                        # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
    dlatent_size        = 512,          # Disentangled latent (W) dimensionality.
    num_channels        = 3,            # Number of output color channels.
    resolution          = 1024,         # Output resolution.
    fmap_base           = 16 << 10,     # Overall multiplier for the number of feature maps.
    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.
    fmap_min            = 1,            # Minimum number of feature maps in any layer.
    fmap_max            = 512,          # Maximum number of feature maps in any layer.
    randomize_noise     = True,         # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
    nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu', etc.
    dtype               = 'float32',    # Data type to use for activations and outputs.
    resample_kernel     = [1,3,3,1],    # Low-pass filter to apply when resampling activations. None = no filtering.
    fused_modconv       = True,         # Implement modulated_conv2d_layer() as a single fused op?
    structure           = 'auto',       # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.
    force_clean_graph   = False,        # True = construct a clean graph that looks nice in TensorBoard, False = default behavior.
    **_kwargs):                         # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4
    def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
    if is_template_graph: force_clean_graph = True
    if force_clean_graph: randomize_noise = False
    if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive'
    act = nonlinearity
    num_layers = resolution_log2 * 2 - 2
    images_out = None

    # Primary inputs.
    dlatents_in.set_shape([None, num_layers, dlatent_size])
    dlatents_in = tf.cast(dlatents_in, dtype)
    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)

    # Noise inputs.
    noise_inputs = []
    for layer_idx in range(num_layers - 1):
        res = (layer_idx + 5) // 2
        shape = [1, 1, 2**res, 2**res]
        noise_inputs.append(tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False))

    # Single convolution layer with all the bells and whistles.
    def layer(x, layer_idx, fmaps, kernel, up=False):
        x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv)
        if randomize_noise:
            noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
        else:
            noise = tf.cast(noise_inputs[layer_idx], x.dtype)
        noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros())
        x += noise * tf.cast(noise_strength, x.dtype)
        return apply_bias_act(x, act=act)

    # Early layers.
    with tf.variable_scope('4x4'):
        with tf.variable_scope('Const'):
            x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.random_normal())
            x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
        with tf.variable_scope('Conv'):
            x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3)

    # Building blocks for remaining layers.
    def block(res, x): # res = 3..resolution_log2
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            with tf.variable_scope('Conv0_up'):
                x = layer(x, layer_idx=res*2-5, fmaps=nf(res-1), kernel=3, up=True)
            with tf.variable_scope('Conv1'):
                x = layer(x, layer_idx=res*2-4, fmaps=nf(res-1), kernel=3)
            return x
    def torgb(res, x): # res = 2..resolution_log2
        with tf.variable_scope('ToRGB_lod%d' % (resolution_log2 - res)):
            return apply_bias_act(modulated_conv2d_layer(x, dlatents_in[:, res*2-3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv))

    # Fixed structure: simple and efficient, but does not support progressive growing.
    if structure == 'fixed':
        for res in range(3, resolution_log2 + 1):
            x = block(res, x)
        images_out = torgb(resolution_log2, x)

    # Linear structure: simple but inefficient.
    if structure == 'linear':
        images_out = torgb(2, x)
        for res in range(3, resolution_log2 + 1):
            lod = resolution_log2 - res
            x = block(res, x)
            img = torgb(res, x)
            with tf.variable_scope('Upsample_lod%d' % lod):
                images_out = upsample_2d(images_out)
            with tf.variable_scope('Grow_lod%d' % lod):
                images_out = tflib.lerp_clip(img, images_out, lod_in - lod)

    # Recursive structure: complex but efficient.
    if structure == 'recursive':
        def cset(cur_lambda, new_cond, new_lambda):
            return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
        def grow(x, res, lod):
            y = block(res, x)
            img = lambda: naive_upsample_2d(torgb(res, y), factor=2**lod)
            img = cset(img, (lod_in > lod), lambda: naive_upsample_2d(tflib.lerp(torgb(res, y), upsample_2d(torgb(res - 1, x)), lod_in - lod), factor=2**lod))
            if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))
            return img()
        images_out = grow(x, 3, resolution_log2 - 3)

    assert images_out.dtype == tf.as_dtype(dtype)
    return tf.identity(images_out, name='images_out')
Esempio n. 4
0
    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if 3 < res < 8:
            with tf.variable_scope('Downsample'):
                x_downsample = downsample_2d(x)
            height = x_downsample.shape[2]
            width = x_downsample.shape[3]
            c_reduced = 1
            label_mapping_fmaps = 16

            with tf.variable_scope('F_Attention'):
                f_x = conv2d_layer(x_downsample, fmaps=1, kernel=1)
                f_x = tf.reshape(f_x, [-1, height * width])
            with tf.variable_scope('Label_F_Attention'):
                label_f = dense_layer(dlabel, fmaps=label_mapping_fmaps)
                label_f = apply_bias_act(label_f) + 1
            with tf.variable_scope('F_concat_Attention'):
                f_x_s = dense_layer(tf.concat([f_x, label_f], axis=-1),
                                    fmaps=c_reduced * height * width)
                f_x_s = tf.reshape(f_x_s, [-1, c_reduced, height * width])
                f_x_s = tf.transpose(f_x_s, perm=[0, 2, 1])

            with tf.variable_scope('G_Attention'):
                g_x = conv2d_layer(x_downsample, fmaps=1, kernel=1)
                g_x = tf.reshape(g_x, [-1, height * width])
            with tf.variable_scope('Label_G_Attention'):
                label_g = dense_layer(dlabel, fmaps=label_mapping_fmaps)
                label_g = apply_bias_act(label_g) + 1
            with tf.variable_scope('G_concat_Attention'):
                g_x_s = dense_layer(tf.concat([g_x, label_g], axis=-1),
                                    fmaps=c_reduced * height * width)
                g_x_s = tf.reshape(g_x_s, [-1, c_reduced, height * width])

            with tf.variable_scope('H_Attention'):
                h_x = conv2d_layer(x_downsample, fmaps=c_reduced, kernel=1)
                h_x = tf.reshape(h_x, [-1, c_reduced, height * width])

            f_g_multiply = tf.matmul(f_x_s, g_x_s)
            attention_map = tf.nn.softmax(f_g_multiply, axis=-1)
            attention_map_h_multiply = tf.matmul(
                h_x, tf.transpose(attention_map, [0, 2, 1]))
            attention_map_h_multiply_reshape = tf.reshape(
                attention_map_h_multiply, [-1, c_reduced, height, width])
            with tf.variable_scope('V_Attention'):
                v_x = conv2d_layer(attention_map_h_multiply_reshape,
                                   fmaps=x_downsample.shape[1],
                                   kernel=1)
            with tf.variable_scope('Upsample'):
                v_x_upsample = upsample_2d(v_x)
            with tf.variable_scope('Gamma_Attention'):
                gamma = tf.get_variable(shape=[],
                                        initializer=tf.initializers.zeros(),
                                        name='attention_gamma')
            x = x + v_x_upsample * gamma

            if res == cutoff_layer:
                return v_x_upsample, gamma, x, attention_map_h_multiply_reshape

        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x