def D_stylegan( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels = 3, # Number of input color channels. Overridden based on dataset. resolution = 1024, # Input resolution. Overridden based on dataset. label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base = 16 << 10, # Overall multiplier for the number of feature maps. fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. fmap_min = 1, # Minimum number of feature maps in any layer. fmap_max = 512, # Maximum number of feature maps in any layer. nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer. dtype = 'float32', # Data type to use for activations and outputs. resample_kernel = [1,3,3,1], # Low-pass filter to apply when resampling activations. None = no filtering. structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically. is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max) if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive' act = nonlinearity images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) # Building blocks for spatial layers. def fromrgb(x, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): return apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=1), act=act) def block(x, res): # res = 2..resolution_log2 with tf.variable_scope('%dx%d' % (2**res, 2**res)): with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) return x # Fixed structure: simple and efficient, but does not support progressive growing. if structure == 'fixed': x = fromrgb(images_in, resolution_log2) for res in range(resolution_log2, 2, -1): x = block(x, res) # Linear structure: simple but inefficient. if structure == 'linear': img = images_in x = fromrgb(img, resolution_log2) for res in range(resolution_log2, 2, -1): lod = resolution_log2 - res x = block(x, res) with tf.variable_scope('Downsample_lod%d' % lod): img = downsample_2d(img) y = fromrgb(img, res - 1) with tf.variable_scope('Grow_lod%d' % lod): x = tflib.lerp_clip(x, y, lod_in - lod) # Recursive structure: complex but efficient. if structure == 'recursive': def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda) def grow(res, lod): x = lambda: fromrgb(naive_downsample_2d(images_in, factor=2**lod), res) if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) x = block(x(), res); y = lambda: x y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(naive_downsample_2d(images_in, factor=2**(lod+1)), res - 1), lod_in - lod)) return y() x = grow(3, resolution_log2 - 3) # Final layers at 4x4 resolution. with tf.variable_scope('4x4'): if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1))) if labels_in.shape[1] > 0: x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') return scores_out
def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel)
def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if 3 < res < 8: with tf.variable_scope('Downsample'): x_downsample = downsample_2d(x) height = x_downsample.shape[2] width = x_downsample.shape[3] c_reduced = 1 label_mapping_fmaps = 16 with tf.variable_scope('F_Attention'): f_x = conv2d_layer(x_downsample, fmaps=1, kernel=1) f_x = tf.reshape(f_x, [-1, height * width]) with tf.variable_scope('Label_F_Attention'): label_f = dense_layer(dlabel, fmaps=label_mapping_fmaps) label_f = apply_bias_act(label_f) + 1 with tf.variable_scope('F_concat_Attention'): f_x_s = dense_layer(tf.concat([f_x, label_f], axis=-1), fmaps=c_reduced * height * width) f_x_s = tf.reshape(f_x_s, [-1, c_reduced, height * width]) f_x_s = tf.transpose(f_x_s, perm=[0, 2, 1]) with tf.variable_scope('G_Attention'): g_x = conv2d_layer(x_downsample, fmaps=1, kernel=1) g_x = tf.reshape(g_x, [-1, height * width]) with tf.variable_scope('Label_G_Attention'): label_g = dense_layer(dlabel, fmaps=label_mapping_fmaps) label_g = apply_bias_act(label_g) + 1 with tf.variable_scope('G_concat_Attention'): g_x_s = dense_layer(tf.concat([g_x, label_g], axis=-1), fmaps=c_reduced * height * width) g_x_s = tf.reshape(g_x_s, [-1, c_reduced, height * width]) with tf.variable_scope('H_Attention'): h_x = conv2d_layer(x_downsample, fmaps=c_reduced, kernel=1) h_x = tf.reshape(h_x, [-1, c_reduced, height * width]) f_g_multiply = tf.matmul(f_x_s, g_x_s) attention_map = tf.nn.softmax(f_g_multiply, axis=-1) attention_map_h_multiply = tf.matmul( h_x, tf.transpose(attention_map, [0, 2, 1])) attention_map_h_multiply_reshape = tf.reshape( attention_map_h_multiply, [-1, c_reduced, height, width]) with tf.variable_scope('V_Attention'): v_x = conv2d_layer(attention_map_h_multiply_reshape, fmaps=x_downsample.shape[1], kernel=1) with tf.variable_scope('Upsample'): v_x_upsample = upsample_2d(v_x) with tf.variable_scope('Gamma_Attention'): gamma = tf.get_variable(shape=[], initializer=tf.initializers.zeros(), name='attention_gamma') x = x + v_x_upsample * gamma if res == cutoff_layer: return v_x_upsample, gamma, x, attention_map_h_multiply_reshape if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x