def block(x, res): # res = 2..resolution_log2 attention_map = tf.constant(0) t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if res == 4: with tf.variable_scope('fmap_attention'): fmap_attention = conv2d_layer(x, fmaps=1, kernel=1) fmap_attention = tf.reshape(fmap_attention, [-1, x.shape[2] * x.shape[3]]) with tf.variable_scope('label_attention'): label_attention = dense_layer(dlabel, fmaps=label_mapping_fmaps) with tf.variable_scope('combine_attention'): attention_map = dense_layer(tf.concat( [fmap_attention, label_attention], axis=-1), fmaps=x.shape[2] * x.shape[3]) with tf.variable_scope('x_reduced_channels'): x_reduced_channels = conv2d_layer(x, fmaps=1, kernel=1) attention_map = tf.nn.softmax(attention_map, axis=-1) attention_map = tf.reshape(attention_map, [-1, 1, x.shape[2], x.shape[3]]) combine = x_reduced_channels * attention_map with tf.variable_scope('x_increase_channels'): x_increase_channels = conv2d_layer(combine, fmaps=x.shape[1], kernel=1) with tf.variable_scope('Gamma_Attention'): gamma = tf.get_variable(shape=[], initializer=tf.initializers.zeros(), name='attention_gamma') x = x + x_increase_channels * gamma if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x, attention_map
def D_mapping_label( labels_in, # Second input: Conditioning labels [minibatch, label_size]. label_size = 0, # Label dimensionality, 0 if no labels. dlatent_size = 512, # Disentangled latent (W) dimensionality. dlatent_broadcast = None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size]. mapping_layers = 8, # Number of mapping layers. mapping_fmaps = 128, # Number of activations in the mapping layers. mapping_lrmul = 0.01, # Learning rate multiplier for the mapping layers. mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc. dtype = 'float32', # Data type to use for activations and outputs. **_kwargs): # Ignore unrecognized keyword args. act = mapping_nonlinearity # Inputs. labels_in.set_shape([None, label_size]) labels_in = tf.cast(labels_in, dtype) x = labels_in # Mapping layers. for layer_idx in range(mapping_layers): with tf.variable_scope('Dense%d' % layer_idx): fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), act=act, lrmul=mapping_lrmul) # Broadcast. if dlatent_broadcast is not None: with tf.variable_scope('Broadcast'): x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1]) # Output. assert x.dtype == tf.as_dtype(dtype) return tf.identity(x, name='dlabel_out')
def G_mapping( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. latent_size=512, # Latent vector (Z) dimensionality. label_size=0, # Label dimensionality, 0 if no labels. dlatent_size=512, # Disentangled latent (W) dimensionality. dlatent_broadcast=None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size]. mapping_layers=8, # Number of mapping layers. mapping_fmaps=512, # Number of activations in the mapping layers. mapping_lrmul=0.01, # Learning rate multiplier for the mapping layers. mapping_nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers? dtype='float32', # Data type to use for activations and outputs. **_kwargs): # Ignore unrecognized keyword args. act = mapping_nonlinearity # Inputs. latents_in.set_shape([None, latent_size]) labels_in.set_shape([None, label_size]) latents_in = tf.cast(latents_in, dtype) labels_in = tf.cast(labels_in, dtype) x = latents_in # Embed labels and concatenate them with latents. if label_size: with tf.variable_scope('LabelConcat'): w = tf.get_variable('weight', shape=[label_size, latent_size], initializer=tf.initializers.random_normal()) y = tf.matmul(labels_in, tf.cast(w, dtype)) x = tf.concat([x, y], axis=1) # Normalize latents. if normalize_latents: with tf.variable_scope('Normalize'): x *= tf.rsqrt( tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + 1e-8) # Mapping layers. for layer_idx in range(mapping_layers): with tf.variable_scope('Dense%d' % layer_idx): fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), act=act, lrmul=mapping_lrmul) # Broadcast. if dlatent_broadcast is not None: with tf.variable_scope('Broadcast'): x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1]) # Output. assert x.dtype == tf.as_dtype(dtype) return tf.identity(x, name='dlatents_out')
def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, fused_modconv=True, weight_var='weight', mod_weight_var='mod_weight', mod_bias_var='mod_bias'): assert not (up and down) assert kernel >= 1 and kernel % 2 == 1 # Modulate. num_fmaps = kernel * kernel * x.shape[1].value * fmaps s = dense_layer(y, fmaps=num_fmaps, weight_var=mod_weight_var) # [BI] Transform incoming W to style. s = apply_bias_act(s, bias_var=mod_bias_var) + 1 / (kernel ** 2) # [BI] Add bias (initially 1). ww = tf.reshape(s, [-1, kernel, kernel, x.shape[1].value, fmaps]) # Demodulate. if demodulate: d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) # [BO] Scaling factor. ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps. # Reshape/scale input. if fused_modconv: x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups. w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1]) else: x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations. # Convolution with optional up/downsampling. if up: x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) elif down: x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) else: x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME') # Reshape/scale output. if fused_modconv: x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch. elif demodulate: x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations. return x
def D_stylegan2( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1))) if labels_in.shape[1] > 0: # Ignore interpolated labels [1, 0, 0, 0.3, 0.7] -> [1, 0, 0, 0, 0] x = tf.reduce_sum(x * tf.floor(labels_in), axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') return scores_out #----------------------------------------------------------------------------
def D_stylegan2( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=256, # Input resolution. Overridden based on dataset. label_size=127, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. dlabel_size=32, cutoff_layer=7, **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity dlabel = D_mapping_label(labels_in=labels_in, label_size=label_size, dlabel_size=dlabel_size) images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) dlabel.set_shape([None, dlabel_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) dlabel = tf.cast(dlabel, dtype) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if 3 < res < 8: with tf.variable_scope('Downsample'): x_downsample = downsample_2d(x) height = x_downsample.shape[2] width = x_downsample.shape[3] c_reduced = 1 label_mapping_fmaps = 16 with tf.variable_scope('F_Attention'): f_x = conv2d_layer(x_downsample, fmaps=1, kernel=1) f_x = tf.reshape(f_x, [-1, height * width]) with tf.variable_scope('Label_F_Attention'): label_f = dense_layer(dlabel, fmaps=label_mapping_fmaps) label_f = apply_bias_act(label_f) + 1 with tf.variable_scope('F_concat_Attention'): f_x_s = dense_layer(tf.concat([f_x, label_f], axis=-1), fmaps=c_reduced * height * width) f_x_s = tf.reshape(f_x_s, [-1, c_reduced, height * width]) f_x_s = tf.transpose(f_x_s, perm=[0, 2, 1]) with tf.variable_scope('G_Attention'): g_x = conv2d_layer(x_downsample, fmaps=1, kernel=1) g_x = tf.reshape(g_x, [-1, height * width]) with tf.variable_scope('Label_G_Attention'): label_g = dense_layer(dlabel, fmaps=label_mapping_fmaps) label_g = apply_bias_act(label_g) + 1 with tf.variable_scope('G_concat_Attention'): g_x_s = dense_layer(tf.concat([g_x, label_g], axis=-1), fmaps=c_reduced * height * width) g_x_s = tf.reshape(g_x_s, [-1, c_reduced, height * width]) with tf.variable_scope('H_Attention'): h_x = conv2d_layer(x_downsample, fmaps=c_reduced, kernel=1) h_x = tf.reshape(h_x, [-1, c_reduced, height * width]) f_g_multiply = tf.matmul(f_x_s, g_x_s) attention_map = tf.nn.softmax(f_g_multiply, axis=-1) attention_map_h_multiply = tf.matmul( h_x, tf.transpose(attention_map, [0, 2, 1])) attention_map_h_multiply_reshape = tf.reshape( attention_map_h_multiply, [-1, c_reduced, height, width]) with tf.variable_scope('V_Attention'): v_x = conv2d_layer(attention_map_h_multiply_reshape, fmaps=x_downsample.shape[1], kernel=1) with tf.variable_scope('Upsample'): v_x_upsample = upsample_2d(v_x) with tf.variable_scope('Gamma_Attention'): gamma = tf.get_variable(shape=[], initializer=tf.initializers.zeros(), name='attention_gamma') x = x + v_x_upsample * gamma if res == cutoff_layer: return v_x_upsample, gamma, x, attention_map_h_multiply_reshape if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) if res == cutoff_layer: return block(x, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1))) if labels_in.shape[1] > 0: x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') return scores_out #----------------------------------------------------------------------------
def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if 3 < res < 8: with tf.variable_scope('Downsample'): x_downsample = downsample_2d(x) height = x_downsample.shape[2] width = x_downsample.shape[3] c_reduced = 1 label_mapping_fmaps = 16 with tf.variable_scope('F_Attention'): f_x = conv2d_layer(x_downsample, fmaps=1, kernel=1) f_x = tf.reshape(f_x, [-1, height * width]) with tf.variable_scope('Label_F_Attention'): label_f = dense_layer(dlabel, fmaps=label_mapping_fmaps) label_f = apply_bias_act(label_f) + 1 with tf.variable_scope('F_concat_Attention'): f_x_s = dense_layer(tf.concat([f_x, label_f], axis=-1), fmaps=c_reduced * height * width) f_x_s = tf.reshape(f_x_s, [-1, c_reduced, height * width]) f_x_s = tf.transpose(f_x_s, perm=[0, 2, 1]) with tf.variable_scope('G_Attention'): g_x = conv2d_layer(x_downsample, fmaps=1, kernel=1) g_x = tf.reshape(g_x, [-1, height * width]) with tf.variable_scope('Label_G_Attention'): label_g = dense_layer(dlabel, fmaps=label_mapping_fmaps) label_g = apply_bias_act(label_g) + 1 with tf.variable_scope('G_concat_Attention'): g_x_s = dense_layer(tf.concat([g_x, label_g], axis=-1), fmaps=c_reduced * height * width) g_x_s = tf.reshape(g_x_s, [-1, c_reduced, height * width]) with tf.variable_scope('H_Attention'): h_x = conv2d_layer(x_downsample, fmaps=c_reduced, kernel=1) h_x = tf.reshape(h_x, [-1, c_reduced, height * width]) f_g_multiply = tf.matmul(f_x_s, g_x_s) attention_map = tf.nn.softmax(f_g_multiply, axis=-1) attention_map_h_multiply = tf.matmul( h_x, tf.transpose(attention_map, [0, 2, 1])) attention_map_h_multiply_reshape = tf.reshape( attention_map_h_multiply, [-1, c_reduced, height, width]) with tf.variable_scope('V_Attention'): v_x = conv2d_layer(attention_map_h_multiply_reshape, fmaps=x_downsample.shape[1], kernel=1) with tf.variable_scope('Upsample'): v_x_upsample = upsample_2d(v_x) with tf.variable_scope('Gamma_Attention'): gamma = tf.get_variable(shape=[], initializer=tf.initializers.zeros(), name='attention_gamma') x = x + v_x_upsample * gamma if res == cutoff_layer: return v_x_upsample, gamma, x, attention_map_h_multiply_reshape if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x
def D_stylegan2( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=256, # Input resolution. Overridden based on dataset. label_size=127, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. mapping_label_func='D_mapping_label', dlabel_size=64, use_attention_downsampling=False, label_mapping_fmaps=32, output_fmap_res=3, **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity # dlabel = D_mapping_label(labels_in=labels_in, label_size=label_size, dlabel_size=dlabel_size) if 'mapping_label' not in components: components.mapping_label = tflib.Network( 'D_mapping_label', func_name=globals()[mapping_label_func], label_size=label_size, dlabel_size=dlabel_size) dlabel = components.mapping_label.get_output_for(labels_in) dlabel = tf.cast(dlabel, dtype) images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) def block(x, res): # res = 2..resolution_log2 attention_map = tf.constant(0) t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if res == 4: with tf.variable_scope('fmap_attention'): fmap_attention = conv2d_layer(x, fmaps=1, kernel=1) fmap_attention = tf.reshape(fmap_attention, [-1, x.shape[2] * x.shape[3]]) with tf.variable_scope('label_attention'): label_attention = dense_layer(dlabel, fmaps=label_mapping_fmaps) with tf.variable_scope('combine_attention'): attention_map = dense_layer(tf.concat( [fmap_attention, label_attention], axis=-1), fmaps=x.shape[2] * x.shape[3]) with tf.variable_scope('x_reduced_channels'): x_reduced_channels = conv2d_layer(x, fmaps=1, kernel=1) attention_map = tf.nn.softmax(attention_map, axis=-1) attention_map = tf.reshape(attention_map, [-1, 1, x.shape[2], x.shape[3]]) combine = x_reduced_channels * attention_map with tf.variable_scope('x_increase_channels'): x_increase_channels = conv2d_layer(combine, fmaps=x.shape[1], kernel=1) with tf.variable_scope('Gamma_Attention'): gamma = tf.get_variable(shape=[], initializer=tf.initializers.zeros(), name='attention_gamma') x = x + x_increase_channels * gamma if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x, attention_map # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x, attention_map = block(x, res) if res == output_fmap_res: fmap_output = x attention_map_out = attention_map if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1))) return x, fmap_output, attention_map_out if labels_in.shape[1] > 0: x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') return scores_out #----------------------------------------------------------------------------