def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0): assert not (up and down) kernel = w.shape[0].value assert w.shape[1].value == kernel assert kernel >= 1 and kernel % 2 == 1 w = tf.cast(w, x.dtype) if up: x = upsample_conv_2d(x, w, data_format='NCHW', k=resample_kernel, padding=padding) elif down: x = conv_downsample_2d(x, w, data_format='NCHW', k=resample_kernel, padding=padding) else: padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding] x = tf.nn.conv2d(x, w, data_format='NCHW', strides=[1, 1, 1, 1], padding=padding_mode) return x
def conv2d_layer(x, fmaps, kernel, up=False, down=False, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, weight_var='weight'): assert not (up and down) assert kernel >= 1 and kernel % 2 == 1 w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var) if up: x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) elif down: x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) else: x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME') return x
def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, fused_modconv=True, weight_var='weight', mod_weight_var='mod_weight', mod_bias_var='mod_bias'): assert not (up and down) assert kernel >= 1 and kernel % 2 == 1 # Get weight. w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var) ww = w[np.newaxis] # [BkkIO] Introduce minibatch dimension. ## att mask if not up and not down: mask = 1 + get_weight([x.shape[2].value, x.shape[3].value], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var='mask') x = x * mask * tf.rsqrt(tf.reduce_mean(tf.square(mask)) + 1e-8) # Modulate. s = dense_layer(y, fmaps=x.shape[1].value, weight_var=mod_weight_var) # [BI] Transform incoming W to style. s = apply_bias_act(s, bias_var=mod_bias_var) + 1 # [BI] Add bias (initially 1). ww *= tf.cast(s[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) # [BkkIO] Scale input feature maps. # Demodulate. if demodulate: d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) # [BO] Scaling factor. ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps. # Reshape/scale input. if fused_modconv: x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups. w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1]) else: x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations. # Convolution with optional up/downsampling. if up: x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) elif down: x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) else: x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME') # Reshape/scale output. if fused_modconv: x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch. elif demodulate: x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations. return x
def decomposition_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, fused_modconv=True, weight_var='weight', mod_weight_var='U_weight', mod_bias_var='mod_bias'): assert not (up and down) assert kernel >= 1 and kernel % 2 == 1 # Get weight. out_channel = fmaps in_channel = kernel * kernel * x.shape[1].value s_dimension = min(out_channel, in_channel) U = get_weight([out_channel, s_dimension], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var='U_' + weight_var) V = get_weight([in_channel, s_dimension], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var='V_' + weight_var) # linear and normalization to obtain the style vector s and its diagnonal matrix S s = dense_layer( y, fmaps=s_dimension, weight_var=mod_weight_var) # [BI] Transform incoming W to style. s = apply_bias_act( s, bias_var=mod_bias_var) + 1 # [BI] Add bias (initially 1). s *= tf.rsqrt(tf.reduce_mean(tf.square(s), axis=1, keepdims=True) + 1e-8) # construct diagnol matrix using style s S = tf.matrix_diag(s) # using S to construct a controllable matrix w w = tf.matmul(tf.reshape(S, [-1, s_dimension]), tf.transpose(V, [1, 0])) w = tf.reshape(w, [-1, s_dimension, in_channel]) w = tf.transpose(w, [0, 2, 1]) w = tf.matmul(tf.reshape(w, [-1, s_dimension]), tf.transpose(U, [1, 0])) w = tf.reshape(w, [-1, kernel, kernel, x.shape[1].value, out_channel]) # normalize w similar to weight demodulation if demodulate: d = tf.rsqrt(tf.reduce_sum(tf.square(w), axis=[1, 2, 3]) + 1e-8) # [BO] Scaling factor. w *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps. if fused_modconv: x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3] ]) # Fused => reshape minibatch to convolution groups. w = tf.reshape(tf.transpose(w, [1, 2, 3, 0, 4]), [w.shape[1], w.shape[2], w.shape[3], -1]) else: x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations. # Convolution with optional up/downsampling. if up: x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) elif down: x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) else: x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME') if fused_modconv: x = tf.reshape( x, [-1, fmaps, x.shape[2], x.shape[3] ]) # Fused => reshape convolution groups back to minibatch. elif demodulate: x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations. return x