def D_stylegan(
    images_in,                          # First input: Images [minibatch, channel, height, width].
    labels_in,                          # Second input: Labels [minibatch, label_size].
    num_channels        = 3,            # Number of input color channels. Overridden based on dataset.
    resolution          = 1024,         # Input resolution. Overridden based on dataset.
    label_size          = 0,            # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
    fmap_base           = 16 << 10,     # Overall multiplier for the number of feature maps.
    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.
    fmap_min            = 1,            # Minimum number of feature maps in any layer.
    fmap_max            = 512,          # Maximum number of feature maps in any layer.
    nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu', etc.
    mbstd_group_size    = 4,            # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features  = 1,            # Number of features for the minibatch standard deviation layer.
    dtype               = 'float32',    # Data type to use for activations and outputs.
    resample_kernel     = [1,3,3,1],    # Low-pass filter to apply when resampling activations. None = no filtering.
    structure           = 'auto',       # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.
    **_kwargs):                         # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4
    def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max)
    if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive'
    act = nonlinearity

    images_in.set_shape([None, num_channels, resolution, resolution])
    labels_in.set_shape([None, label_size])
    images_in = tf.cast(images_in, dtype)
    labels_in = tf.cast(labels_in, dtype)
    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)

    # Building blocks for spatial layers.
    def fromrgb(x, res): # res = 2..resolution_log2
        with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):
            return apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=1), act=act)
    def block(x, res): # res = 2..resolution_log2
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            with tf.variable_scope('Conv0'):
                x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=3), act=act)
            with tf.variable_scope('Conv1_down'):
                x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel), act=act)
            return x

    # Fixed structure: simple and efficient, but does not support progressive growing.
    if structure == 'fixed':
        x = fromrgb(images_in, resolution_log2)
        for res in range(resolution_log2, 2, -1):
            x = block(x, res)

    # Linear structure: simple but inefficient.
    if structure == 'linear':
        img = images_in
        x = fromrgb(img, resolution_log2)
        for res in range(resolution_log2, 2, -1):
            lod = resolution_log2 - res
            x = block(x, res)
            with tf.variable_scope('Downsample_lod%d' % lod):
                img = downsample_2d(img)
            y = fromrgb(img, res - 1)
            with tf.variable_scope('Grow_lod%d' % lod):
                x = tflib.lerp_clip(x, y, lod_in - lod)

    # Recursive structure: complex but efficient.
    if structure == 'recursive':
        def cset(cur_lambda, new_cond, new_lambda):
            return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
        def grow(res, lod):
            x = lambda: fromrgb(naive_downsample_2d(images_in, factor=2**lod), res)
            if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
            x = block(x(), res); y = lambda: x
            y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(naive_downsample_2d(images_in, factor=2**(lod+1)), res - 1), lod_in - lod))
            return y()
        x = grow(3, resolution_log2 - 3)

    # Final layers at 4x4 resolution.
    with tf.variable_scope('4x4'):
        if mbstd_group_size > 1:
            with tf.variable_scope('MinibatchStddev'):
                x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
    with tf.variable_scope('Output'):
        x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1)))
        if labels_in.shape[1] > 0:
            x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True)
    scores_out = x

    # Output.
    assert scores_out.dtype == tf.as_dtype(dtype)
    scores_out = tf.identity(scores_out, name='scores_out')
    return scores_out
示例#2
0
 def downsample(y):
     with tf.variable_scope('Downsample'):
         return downsample_2d(y, k=resample_kernel)
示例#3
0
    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if 3 < res < 8:
            with tf.variable_scope('Downsample'):
                x_downsample = downsample_2d(x)
            height = x_downsample.shape[2]
            width = x_downsample.shape[3]
            c_reduced = 1
            label_mapping_fmaps = 16

            with tf.variable_scope('F_Attention'):
                f_x = conv2d_layer(x_downsample, fmaps=1, kernel=1)
                f_x = tf.reshape(f_x, [-1, height * width])
            with tf.variable_scope('Label_F_Attention'):
                label_f = dense_layer(dlabel, fmaps=label_mapping_fmaps)
                label_f = apply_bias_act(label_f) + 1
            with tf.variable_scope('F_concat_Attention'):
                f_x_s = dense_layer(tf.concat([f_x, label_f], axis=-1),
                                    fmaps=c_reduced * height * width)
                f_x_s = tf.reshape(f_x_s, [-1, c_reduced, height * width])
                f_x_s = tf.transpose(f_x_s, perm=[0, 2, 1])

            with tf.variable_scope('G_Attention'):
                g_x = conv2d_layer(x_downsample, fmaps=1, kernel=1)
                g_x = tf.reshape(g_x, [-1, height * width])
            with tf.variable_scope('Label_G_Attention'):
                label_g = dense_layer(dlabel, fmaps=label_mapping_fmaps)
                label_g = apply_bias_act(label_g) + 1
            with tf.variable_scope('G_concat_Attention'):
                g_x_s = dense_layer(tf.concat([g_x, label_g], axis=-1),
                                    fmaps=c_reduced * height * width)
                g_x_s = tf.reshape(g_x_s, [-1, c_reduced, height * width])

            with tf.variable_scope('H_Attention'):
                h_x = conv2d_layer(x_downsample, fmaps=c_reduced, kernel=1)
                h_x = tf.reshape(h_x, [-1, c_reduced, height * width])

            f_g_multiply = tf.matmul(f_x_s, g_x_s)
            attention_map = tf.nn.softmax(f_g_multiply, axis=-1)
            attention_map_h_multiply = tf.matmul(
                h_x, tf.transpose(attention_map, [0, 2, 1]))
            attention_map_h_multiply_reshape = tf.reshape(
                attention_map_h_multiply, [-1, c_reduced, height, width])
            with tf.variable_scope('V_Attention'):
                v_x = conv2d_layer(attention_map_h_multiply_reshape,
                                   fmaps=x_downsample.shape[1],
                                   kernel=1)
            with tf.variable_scope('Upsample'):
                v_x_upsample = upsample_2d(v_x)
            with tf.variable_scope('Gamma_Attention'):
                gamma = tf.get_variable(shape=[],
                                        initializer=tf.initializers.zeros(),
                                        name='attention_gamma')
            x = x + v_x_upsample * gamma

            if res == cutoff_layer:
                return v_x_upsample, gamma, x, attention_map_h_multiply_reshape

        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x