def layer(x, layer_idx, fmaps, kernel, up=False): x, atts = get_return_v( build_C_spgroup_layers_with_latents_ready( x, 'SP_latents', latent_split_ls_for_std_gen[layer_idx], layer_idx, latents_ready_ls[layer_idx], return_atts=return_atts, resolution=resolution, n_subs=n_subs, **kwargs), 2) x = conv2d_layer(x, fmaps=fmaps, kernel=kernel, up=up) if randomize_noise: noise = tf.random_normal( [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) else: noise = tf.cast(noise_inputs[layer_idx], x.dtype) noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros()) x += noise * tf.cast(noise_strength, x.dtype) return apply_bias_act(x, act=act), atts
def build_Cout_spgroup_layers(x, name, n_latents, start_idx, scope_idx, atts_in, act, fmaps=128, resolution=128, **kwargs): ''' Build continuous latent out layers with learned group spatial attention. Support square images only. ''' # atts_in: [b, all_n_latents, 1, resolution, resolution] with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_wh = x.shape[2] atts = atts_in[:, start_idx:start_idx + n_latents] # [b, n_latents, 1, resolution, resolution] atts = tf.reshape(atts, [-1, resolution, resolution, 1]) atts = tf.image.resize(atts, size=(x_wh, x_wh)) atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) x_out_ls = [] for i in range(n_latents): x_tmp = x * atts[:, i] x_tmp_2 = tf.reduce_mean(x_tmp, axis=[2, 3]) # [b, in_dim] with tf.variable_scope('OutDense-' + str(i)): with tf.variable_scope('Conv0'): x_tmp_2 = apply_bias_act(dense_layer(x_tmp_2, fmaps=fmaps), act=act) # [b, fmaps] with tf.variable_scope('Conv1'): x_out_tmp = dense_layer(x_tmp_2, fmaps=1) # [b, 1] x_out_ls.append(x_out_tmp) pred_out = tf.concat(x_out_ls, axis=1) # [b, n_latents] return x, pred_out
def get_s_matrix(s_latents, cond_latent, act='lrelu'): # s_latents[:, 0]: [-2., 2.] -> [1., 3.] # s_latents[:, 1]: [-2., 2.] -> [1., 3.] if s_latents.shape.as_list()[1] == 1: with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') scale = (s_latents + 2.) * cond + 1. tt_00 = scale tt_01 = tf.zeros_like(scale) tt_02 = tf.zeros_like(scale) tt_10 = tf.zeros_like(scale) tt_11 = scale tt_12 = tf.zeros_like(scale) else: with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) scale = (s_latents + 2.) * cond + 1. tt_00 = scale[:, 0:1] tt_01 = tf.zeros_like(scale[:, 0:1]) tt_02 = tf.zeros_like(scale[:, 0:1]) tt_10 = tf.zeros_like(scale[:, 1:]) tt_11 = scale[:, 1:] tt_12 = tf.zeros_like(scale[:, 1:]) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def get_t_matrix(t_latents, cond_latent, act='lrelu'): # t_latents[:, 0]: [-2., 2.] -> [-0.5, 0.5] # t_latents[:, 1]: [-2., 2.] -> [-0.5, 0.5] if t_latents.shape.as_list()[1] == 1: with tf.variable_scope('Condition0x'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') xy_shift = t_latents / 4. * cond tt_00 = tf.ones_like(xy_shift) tt_01 = tf.zeros_like(xy_shift) tt_02 = xy_shift tt_10 = tf.zeros_like(xy_shift) tt_11 = tf.ones_like(xy_shift) tt_12 = xy_shift else: with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) xy_shift = t_latents / 4. * cond tt_00 = tf.ones_like(xy_shift[:, 0:1]) tt_01 = tf.zeros_like(xy_shift[:, 0:1]) tt_02 = xy_shift[:, 0:1] tt_10 = tf.zeros_like(xy_shift[:, 1:]) tt_11 = tf.ones_like(xy_shift[:, 1:]) tt_12 = xy_shift[:, 1:] theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def vc_head( fake1, # First input: generated image from z [minibatch, channel, height, width]. fake2, # Second input: hidden features from z + delta(z) [minibatch, channel, height, width]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. connect_mode='concat', # How fake1 and fake2 connected. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake1.set_shape([None, num_channels, resolution, resolution]) fake2.set_shape([None, num_channels, resolution, resolution]) fake1 = tf.cast(fake1, dtype) fake2 = tf.cast(fake2, dtype) if connect_mode == 'diff': images_in = fake1 - fake2 elif connect_mode == 'concat': images_in = tf.concat([fake1, fake2], axis=1) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): with tf.variable_scope('Dense_VC'): x = apply_bias_act( dense_layer(x, fmaps=(D_global_size + (dlatent_size - D_global_size)))) # Output. assert x.dtype == tf.as_dtype(dtype) return x
def G_synthesis_sb_general_dsp( dlatents_withl_in, # Input: Disentangled latents (W) [minibatch, label_size+dlatent_size]. dlatent_size=7, # Disentangled latent (W) dimensionality. Including discrete info, rotation, scaling, xy shearing, and xy translation. label_size=0, # Label dimensionality, 0 if no labels. D_global_size=3, # Global D_latents. C_global_size=0, # Global C_latents. sb_C_global_size=4, # Global spatial-biased C_latents. C_local_hfeat_size=0, # Local heatmap*features learned C_latents. C_local_heat_size=0, # Local heatmap learned C_latents. num_channels=1, # Number of output color channels. resolution=64, # Output resolution. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. fused_modconv=True, # Implement modulated_conv2d_layer() as a single fused op? use_noise=False, randomize_noise=True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. **_kwargs): # Ignore unrecognized keyword args. ''' dlatents_withl_in: dims contain: [label, D_global, C_global, sb_C_global, C_local_hfeat, C_local_feat] ''' resolution_log2 = int(np.log2(resolution)) # == 6 for resolution 64 assert resolution == 2**resolution_log2 and resolution >= 4 num_layers = resolution_log2 * 2 - 2 # == 10 for resolution 64 act = nonlinearity images_out = None # Primary inputs. assert dlatent_size == D_global_size + C_global_size + sb_C_global_size + \ C_local_hfeat_size + C_local_heat_size n_cat = label_size + D_global_size dlatents_withl_in.set_shape([None, label_size + dlatent_size]) dlatents_withl_in = tf.cast(dlatents_withl_in, dtype) n_content = label_size + D_global_size + C_global_size # Noise inputs. noise_inputs = [] for layer_idx in range(num_layers - 3): res = (layer_idx + 7) // 2 # [3, 4, 4, 5, 5, 6, 6] shape = [1, 1, 2**res, 2**res] noise_inputs.append( tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False)) # Single convolution layer with all the bells and whistles. def noised_conv_layer(x, layer_idx, fmaps, kernel, up=False): x = conv2d_layer(x, fmaps=fmaps, up=up, kernel=kernel, resample_kernel=resample_kernel) if use_noise: if randomize_noise: noise = tf.random_normal( [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) else: noise = tf.cast(noise_inputs[layer_idx], x.dtype) noise_strength = tf.get_variable( 'noise_strength', shape=[], initializer=tf.initializers.zeros()) x += noise * tf.cast(noise_strength, x.dtype) return apply_bias_act(x, act=act) # Early layers consists of 4x4 constant layer, # label+global discrete latents, # and global continuous latents. y = None with tf.variable_scope('4x4'): with tf.variable_scope('Const'): x = tf.get_variable('const', shape=[1, 128, 4, 4], initializer=tf.initializers.random_normal()) x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_withl_in)[0], 1, 1, 1]) with tf.variable_scope('Upconv'): x = apply_bias_act(conv2d_layer(x, fmaps=128, kernel=3, up=True, resample_kernel=resample_kernel), act=act) with tf.variable_scope('8x8'): with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=128, kernel=3), act=act) with tf.variable_scope('Label_Dglobal_control'): x = apply_bias_act(modulated_conv2d_layer( x, dlatents_withl_in[:, :n_cat], fmaps=128, kernel=3, up=False, resample_kernel=resample_kernel, fused_modconv=fused_modconv), act=act) with tf.variable_scope('After_DiscreteGlobal_noised'): x = noised_conv_layer(x, layer_idx=0, fmaps=128, kernel=3) with tf.variable_scope('Cglobal_control'): start_idx = n_cat x = apply_bias_act(modulated_conv2d_layer( x, dlatents_withl_in[:, start_idx:start_idx + C_global_size], fmaps=128, kernel=3, up=False, resample_kernel=resample_kernel, fused_modconv=fused_modconv), act=act) with tf.variable_scope('After_ContinuousGlobal_noised'): x = noised_conv_layer(x, layer_idx=1, up=True, fmaps=128, kernel=3) # Spatial biased layers. with tf.variable_scope('16x16'): if C_local_hfeat_size > 0: with tf.variable_scope('LocalHFeat_C_latents'): with tf.variable_scope('ConstFeats'): const_feats = tf.get_variable( 'constfeats', shape=[1, C_local_hfeat_size, 32, 1, 1], initializer=tf.initializers.random_normal()) const_feats = tf.tile( tf.cast(const_feats, dtype), [tf.shape(const_feats)[0], 1, 1, 1, 1]) with tf.variable_scope('ControlAttHeat'): hfeat_start_idx = label_size + D_global_size + C_global_size + \ sb_C_global_size att_heat = get_att_heat(x, nheat=C_local_hfeat_size, act=act) att_heat = tf.reshape( att_heat, [tf.shape(att_heat)[0], C_local_hfeat_size, 1] + att_heat.shape.as_list()[2:4]) # C_local_heat latent [-2, 2] --> [0, 1] hfeat_modifier = (2 + dlatents_withl_in[:, hfeat_start_idx:hfeat_start_idx + \ C_local_hfeat_size]) / 4. hfeat_modifier = get_conditional_modifier( hfeat_modifier, dlatents_withl_in[:, :n_content], act=act) hfeat_modifier = tf.reshape( hfeat_modifier, [tf.shape(x)[0], C_local_hfeat_size, 1, 1, 1]) att_heat = att_heat * hfeat_modifier added_feats = const_feats * att_heat added_feats = tf.reshape(added_feats, [ tf.shape(att_heat)[0], C_local_hfeat_size * att_heat.shape.as_list()[2] ] + att_heat.shape.as_list()[3:5]) x = tf.concat([x, added_feats], axis=1) with tf.variable_scope('SpatialBiased_C_global'): # Rotation layers. start_idx = start_idx + C_global_size with tf.variable_scope('Rotation'): r_matrix = get_r_matrix( dlatents_withl_in[:, start_idx:start_idx + 1], dlatents_withl_in[:, :n_content], act=act) x = apply_st(x, r_matrix, up=False, fmaps=128, act=act) with tf.variable_scope('After_Rotation_noised'): x = noised_conv_layer(x, layer_idx=2, fmaps=128, kernel=3) # Scaling layers. start_idx = start_idx + 1 with tf.variable_scope('Scaling'): s_matrix = get_s_matrix( dlatents_withl_in[:, start_idx:start_idx + 1], dlatents_withl_in[:, :n_content], act=act) x = apply_st(x, s_matrix, up=False, fmaps=128, act=act) with tf.variable_scope('After_Scaling_noised'): x = noised_conv_layer(x, layer_idx=3, up=True, fmaps=128, kernel=3) with tf.variable_scope('32x32'): with tf.variable_scope('SpatialBiased_C_global'): # Shearing layers. with tf.variable_scope('Shearing'): start_idx = start_idx + 1 sh_matrix = get_sh_matrix( dlatents_withl_in[:, start_idx:start_idx + 2], dlatents_withl_in[:, :n_content], act=act) x = apply_st(x, sh_matrix, up=False, fmaps=128, act=act) with tf.variable_scope('After_Shearing_noised'): x = noised_conv_layer(x, layer_idx=4, fmaps=128, kernel=3) # Translation layers. with tf.variable_scope('Translation'): start_idx = start_idx + 2 t_matrix = get_t_matrix( dlatents_withl_in[:, start_idx:start_idx + 2], dlatents_withl_in[:, :n_content], act=act) x = apply_st(x, t_matrix, up=False, fmaps=128, act=act) with tf.variable_scope('After_Translation_noised'): if resolution_log2 >= 6: x = noised_conv_layer(x, layer_idx=5, up=True, fmaps=128, kernel=3) else: x = noised_conv_layer(x, layer_idx=5, fmaps=128, kernel=3) with tf.variable_scope('64x64' if resolution_log2 >= 6 else '32x32'): with tf.variable_scope('LocalHeat_C_latents'): with tf.variable_scope('ControlAttHeat'): heat_start_idx = label_size + D_global_size + C_global_size + \ sb_C_global_size + C_local_hfeat_size att_heat = get_att_heat(x, nheat=C_local_heat_size, act=act) # C_local_heat latent [-2, 2] --> [0, 1] heat_modifier = (2 + dlatents_withl_in[:, heat_start_idx:heat_start_idx + \ C_local_heat_size]) / 4. heat_modifier = get_conditional_modifier( heat_modifier, dlatents_withl_in[:, :n_content], act=act) heat_modifier = tf.reshape( heat_modifier, [tf.shape(heat_modifier)[0], C_local_heat_size, 1, 1]) att_heat = att_heat * heat_modifier x = tf.concat([x, att_heat], axis=1) with tf.variable_scope('After_LocalHeat_noised'): x = noised_conv_layer(x, layer_idx=6, fmaps=128, kernel=3) y = torgb(x, y, num_channels=num_channels) # # Tail layers. # for res in range(6, resolution_log2 + 1): # with tf.variable_scope('%dx%d' % (res * 2, res * 2)): # x = apply_bias_act(conv2d_layer(x, # fmaps=128, # kernel=1, # up=True, # resample_kernel=resample_kernel), # act=act) # y = torgb(x, y, num_channels=num_channels) images_out = y assert images_out.dtype == tf.as_dtype(dtype) return tf.identity(images_out, name='images_out')
def get_att_heat(x, nheat, act): with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=128, kernel=3), act=act) with tf.variable_scope('ConvAtt'): x = apply_bias_act(conv2d_layer(x, fmaps=1, kernel=3), act='sigmoid') return x
def torgb(x, y, num_channels): with tf.variable_scope('ToRGB'): t = apply_bias_act(conv2d_layer(x, fmaps=num_channels, kernel=1)) return t if y is None else y + t
def D_simple_dsp( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=1, # Number of input color channels. Overridden based on dataset. label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. nonlinearity='relu', # Activation function: 'relu', 'lrelu', etc. dtype='float32', # Data type to use for activations and outputs. **_kwargs): # Ignore unrecognized keyword args. act = nonlinearity images_in.set_shape([None, num_channels, 64, 64]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) x = images_in with tf.variable_scope('32x32'): w = get_weight([4, 4, x.shape[1].value, 32]) x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 2, 2], padding='SAME') x = apply_bias_act(x, act=act) with tf.variable_scope('16x16'): w = get_weight([4, 4, x.shape[1].value, 32]) x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 2, 2], padding='SAME') x = apply_bias_act(x, act=act) with tf.variable_scope('8x8'): w = get_weight([4, 4, x.shape[1].value, 64]) x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 2, 2], padding='SAME') x = apply_bias_act(x, act=act) with tf.variable_scope('4x4'): w = get_weight([4, 4, x.shape[1].value, 64]) x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 2, 2], padding='SAME') x = apply_bias_act(x, act=act) with tf.variable_scope('output'): x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) w = get_weight([x.shape[1].value, 128]) w = tf.cast(w, x.dtype) x = tf.matmul(x, w) x = apply_bias_act(x, act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Score'): w = get_weight([x.shape[1].value, 1]) w = tf.cast(w, x.dtype) x = tf.matmul(x, w) x = apply_bias_act(x) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') return scores_out
def G_synthesis_spatial_biased_dsp( dlatents_in, # Input: Disentangled latents (W) [minibatch, dlatent_size]. dlatent_size=7, # Disentangled latent (W) dimensionality. Including discrete info, rotation, scaling, and xy translation. D_global_size=3, # Discrete latents. sb_C_global_size=4, # Continuous latents. label_size=0, # Label dimensionality, 0 if no labels. num_channels=1, # Number of output color channels. resolution=64, # Output resolution. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. fused_modconv=True, # Implement modulated_conv2d_layer() as a single fused op? **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity images_out = None # Primary inputs. assert dlatent_size == D_global_size + sb_C_global_size n_cat = label_size + D_global_size dlatents_in.set_shape([None, label_size + dlatent_size]) dlatents_in = tf.cast(dlatents_in, dtype) # Return rotation matrix def get_r_matrix(r_latents, cond_latent): # r_latents: [-2., 2.] -> [0, 2*pi] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') rad = (r_latents + 2) / 4. * 2. * np.pi rad = rad * cond tt_00 = tf.math.cos(rad) tt_01 = -tf.math.sin(rad) tt_02 = tf.zeros_like(rad) tt_10 = tf.math.sin(rad) tt_11 = tf.math.cos(rad) tt_12 = tf.zeros_like(rad) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Return scaling matrix def get_s_matrix(s_latents, cond_latent): # s_latents: [-2., 2.] -> [1, 3] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') scale = (s_latents / 2. + 2.) * cond tt_00 = scale tt_01 = tf.zeros_like(scale) tt_02 = tf.zeros_like(scale) tt_10 = tf.zeros_like(scale) tt_11 = scale tt_12 = tf.zeros_like(scale) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Return shear matrix def get_sh_matrix(sh_latents, cond_latent): # sh_latents[:, 0]: [-2., 2.] -> [-1., 1.] # sh_latents[:, 1]: [-2., 2.] -> [-1., 1.] with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) xy_shear = sh_latents / 2. * cond tt_00 = tf.ones_like(xy_shear[:, 0:1]) tt_01 = xy_shear[:, 0:1] tt_02 = tf.zeros_like(xy_shear[:, 0:1]) tt_10 = xy_shear[:, 1:] tt_11 = tf.ones_like(xy_shear[:, 1:]) tt_12 = tf.zeros_like(xy_shear[:, 1:]) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Return translation matrix def get_t_matrix(t_latents, cond_latent): # t_latents[:, 0]: [-2., 2.] -> [-0.5, 0.5] # t_latents[:, 1]: [-2., 2.] -> [-0.5, 0.5] with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) xy_shift = t_latents / 4. * cond tt_00 = tf.ones_like(xy_shift[:, 0:1]) tt_01 = tf.zeros_like(xy_shift[:, 0:1]) tt_02 = xy_shift[:, 0:1] tt_10 = tf.zeros_like(xy_shift[:, 1:]) tt_11 = tf.ones_like(xy_shift[:, 1:]) tt_12 = xy_shift[:, 1:] theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Apply spatial transform def apply_st(x, st_matrix, idx, up=True): # idx: 2, 3, 4 with tf.variable_scope('Transform'): x = tf.transpose(x, [0, 2, 3, 1]) # NCHW -> NHWC x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3]) x = tf.transpose(x, [0, 3, 1, 2]) # NHWC -> NCHW with tf.variable_scope('Upconv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(idx), kernel=3, up=up, resample_kernel=resample_kernel), act=act) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(idx), kernel=3), act=act) return x def upsample(y): with tf.variable_scope('Upsample'): return upsample_2d(y, k=resample_kernel) def torgb(x, y): with tf.variable_scope('ToRGB'): t = apply_bias_act(conv2d_layer(x, fmaps=num_channels, kernel=1)) return t if y is None else y + t # Early layers. y = None with tf.variable_scope('4x4'): with tf.variable_scope('Const'): x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.random_normal()) x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]) with tf.variable_scope('Upconv8x8'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3, up=True, resample_kernel=resample_kernel), act=act) with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('ModulatedConv'): x = apply_bias_act(modulated_conv2d_layer( x, dlatents_in[:, :n_cat], fmaps=nf(2), kernel=3, up=False, resample_kernel=resample_kernel, fused_modconv=fused_modconv), act=act) with tf.variable_scope('Conv1'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(2), kernel=3), act=act) # Rotation layers. with tf.variable_scope('16x16'): r_matrix = get_r_matrix(dlatents_in[:, n_cat:n_cat + 1], dlatents_in[:, :n_cat]) x = apply_st(x, r_matrix, 2) # Scaling layers. with tf.variable_scope('32x32'): s_matrix = get_s_matrix(dlatents_in[:, n_cat + 1:n_cat + 2], dlatents_in[:, :n_cat]) x = apply_st(x, s_matrix, 3) # Shearing layers. with tf.variable_scope('32x32_Shear'): sh_matrix = get_sh_matrix(dlatents_in[:, n_cat + 2:n_cat + 4], dlatents_in[:, :n_cat]) x = apply_st(x, sh_matrix, 3, up=False) # Translation layers. with tf.variable_scope('64x64'): t_matrix = get_t_matrix(dlatents_in[:, n_cat + 4:], dlatents_in[:, :n_cat]) x = apply_st(x, t_matrix, 4) y = torgb(x, y) # # Tail layers. # for res in range(6, resolution_log2 + 1): # with tf.variable_scope('%dx%d' % (res * 2, res * 2)): # x = apply_bias_act(conv2d_layer(x, # fmaps=nf(res), # kernel=1, # up=True, # resample_kernel=resample_kernel), # act=act) # if architecture == 'skip': # y = upsample(y) # if architecture == 'skip' or res == resolution_log2: # y = torgb(x, y) images_out = y assert images_out.dtype == tf.as_dtype(dtype) return tf.identity(images_out, name='images_out')
def vpex_net( fake1, # First input: generated image from z [minibatch, channel, height, width]. fake2, # Second input: hidden features from z + delta(z) [minibatch, channel, height, width]. latents, # Ground-truth latent code for fake1. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. connect_mode='concat', # How fake1 and fake2 connected. return_atts=False, # If return I_atts. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake1.set_shape([None, num_channels, resolution, resolution]) fake2.set_shape([None, num_channels, resolution, resolution]) latents.set_shape([None, dlatent_size]) fake1 = tf.cast(fake1, dtype) fake2 = tf.cast(fake2, dtype) latents = tf.cast(latents, dtype) if connect_mode == 'diff': images_in = fake1 - fake2 elif connect_mode == 'concat': images_in = tf.concat([fake1, fake2], axis=1) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # attention features for each latent dimension. def get_att_map(latents, x=None): with tf.variable_scope('create_att_feats'): x_ch, x_h, x_w = x.get_shape().as_list()[1:] att_feats = tf.get_variable( 'att_feats', shape=[1, dlatent_size, x_ch, x_h, x_w], initializer=tf.initializers.random_normal()) att_feats = tf.tile(tf.cast(att_feats, dtype), [tf.shape(latents)[0], 1, 1, 1, 1]) latents = latents[:, tf.newaxis, :] latents = tf.tile(latents, [1, dlatent_size, 1]) latents = tf.reshape(latents, [-1, dlatent_size]) # att_map = apply_bias_act(modulated_conv2d_layer(att_feats, latents, fmaps=64, kernel=3, # demodulate=False, fused_modconv=False), # act=act) # shape: [b*dlatent_size, 1, 8, 8] if x is None: att_map = att_feats att_map = tf.reshape(att_map, [-1, x_ch, x_h, x_w]) map_ch = x_ch else: x = tf.reshape(x, [-1, 1, x_ch, x_h, x_w]) x = tf.tile(x, [1, dlatent_size, 1, 1, 1]) att_map = tf.concat([x, att_feats], axis=2) att_map = tf.reshape(att_map, [-1, 2 * x_ch, x_h, x_w]) map_ch = 2 * x_ch with tf.variable_scope('att_conv_3x3'): att_map = apply_bias_act(conv2d_layer(att_map, fmaps=map_ch, kernel=3), act=act) with tf.variable_scope('att_conv_1x1'): att_map = apply_bias_act( conv2d_layer(att_map, fmaps=1, kernel=1)) att_map = tf.reshape(att_map, [-1, dlatent_size, 1, x_h * x_w]) att_map = tf.nn.softmax(att_map, axis=-1) # att_map = tf.nn.sigmoid(att_map) # att_map = tf.reshape(att_map, [-1, dlatent_size, 1, 8, 8]) return att_map # Main layers. x = None y = images_in for res in range(resolution_log2, 3, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Duplicate for each att. with tf.variable_scope('apply_att'): att_map = get_att_map(latents, x) x_ch, x_h, x_w = x.get_shape().as_list()[1:] assert x_h == 8 x_ori = tf.reshape(x, [-1, 1, x_ch, x_h * x_w]) # [b, 1, ch, h*w] x = tf.reshape(x, [-1, 1, x_ch, x_h * x_w]) x = att_map * x x = tf.reduce_sum(x, axis=-1) # [b, dlatent, ch] x = tf.reshape(x, [-1, x_ch, 1, 1]) # [b * dlatent, ch, 1, 1] with tf.variable_scope('after_att_conv_1x1'): x = apply_bias_act(conv2d_layer(x, fmaps=x_ch, kernel=1)) x = tf.reshape(x, [-1, dlatent_size, x_ch, 1]) # [b, dlatent, ch, 1] x = tf.tile(x, [1, 1, 1, x_h * x_w]) # x = x + x_ori # [b, dlatent, ch, h * w] x = tf.reshape(x, [-1, x_ch, x_h, x_w]) y_ch, y_h, y_w = y.get_shape().as_list()[1:] y = y[:, tf.newaxis, ...] y = tf.tile(y, [1, dlatent_size, 1, 1, 1]) y = tf.reshape(y, [-1, y_ch, y_h, y_w]) for res in range(3, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) with tf.variable_scope('Output'): with tf.variable_scope('Dense_VC'): x = apply_bias_act(dense_layer(x, fmaps=1)) with tf.variable_scope('Final_reshape_x'): x = tf.reshape(x, [-1, dlatent_size]) # Output. assert x.dtype == tf.as_dtype(dtype) if return_atts: with tf.variable_scope('Reshape_atts'): att_map = tf.reshape(att_map, [-1, 8, 8, 1]) att_map = tf.image.resize(att_map, size=(resolution, resolution)) att_map = tf.reshape(att_map, [-1, dlatent_size, 1, resolution, resolution]) return x, att_map else: return x
def vid_naive_cluster_head( fake_in, # First input: generated image from z [minibatch, channel, n_frames, height, width]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake_in.set_shape([None, num_channels, None, resolution, resolution]) fake_in = tf.cast(fake_in, dtype) vid_in = fake_in # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = conv3d_layer(y, fmaps=nf(res - 1), kernel=1) t = apply_bias_act_3d(t, act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 with tf.variable_scope('Conv3D_0'): x = conv3d_layer(x, fmaps=nf(res - 1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Conv1_down'): x = conv3d_layer(x, fmaps=nf(res - 2), kernel=3, down=True) x = apply_bias_act_3d(x, act=act) return x # Main layers. x = None y = vid_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('I_%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample_3d(y) # Final layers. with tf.variable_scope('I_4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) with tf.variable_scope('Conv'): x = conv3d_layer(x, fmaps=nf(1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Global_temporal_pool'): x = tf.reduce_mean(x, axis=2) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output. with tf.variable_scope('I_Output'): with tf.variable_scope('Dense_VC'): x = apply_bias_act( dense_layer(x, fmaps=dlatent_size - D_global_size)) assert x.dtype == tf.as_dtype(dtype) return x
def vid_head( fake_in, # First input: generated image from z [minibatch, channel, n_frames, height, width]. C_delta_idxes, # Second input: the index of the varied latent. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake_in.set_shape([None, num_channels, None, resolution, resolution]) fake_in = tf.cast(fake_in, dtype) C_delta_idxes.set_shape([None, dlatent_size]) C_delta_idxes = tf.cast(C_delta_idxes, dtype) vid_in = fake_in # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = conv3d_layer(y, fmaps=nf(res - 1), kernel=1) t = apply_bias_act_3d(t, act=act) return t if x is None else x + t # def block(x, res): # res = 2..resolution_log2 # t = x # with tf.variable_scope('Conv0'): # x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), # act=act) # with tf.variable_scope('Conv1_down'): # x = apply_bias_act(conv2d_layer(x, # fmaps=nf(res - 2), # kernel=3, # down=True, # resample_kernel=resample_kernel), # act=act) # if architecture == 'resnet': # with tf.variable_scope('Skip'): # t = conv2d_layer(t, # fmaps=nf(res - 2), # kernel=1, # down=True, # resample_kernel=resample_kernel) # x = (x + t) * (1 / np.sqrt(2)) # return x def block(x, res): # res = 2..resolution_log2 with tf.variable_scope('Conv3D_0'): x = conv3d_layer(x, fmaps=nf(res - 1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Conv1_down'): x = conv3d_layer(x, fmaps=nf(res - 2), kernel=3, down=True) x = apply_bias_act_3d(x, act=act) return x # def downsample(y): # with tf.variable_scope('Downsample'): # return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = vid_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('I_%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample_3d(y) # Final layers. with tf.variable_scope('I_4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) with tf.variable_scope('Conv'): x = conv3d_layer(x, fmaps=nf(1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Global_temporal_pool'): x = tf.reduce_mean(x, axis=2) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=64), act=act) print('before from C_delta_idxes, x.get_shape:', x.get_shape().as_list()) print('before from C_delta_idxes, x.shape:', x.shape) print('before from C_delta_idxes, C_delta_idxes.shape:', C_delta_idxes.shape) # From C_delta_idxes with tf.variable_scope('I_From_C_Delta_Idx'): x_from_delta = apply_bias_act(dense_layer(C_delta_idxes, fmaps=32), act=act) x = tf.concat([x, x_from_delta], axis=1) # For MINE with tf.variable_scope('I_Output'): with tf.variable_scope('Dense_T_0'): x = apply_bias_act(dense_layer(x, fmaps=128), act=act) with tf.variable_scope('Dense_T_1'): x = apply_bias_act(dense_layer(x, fmaps=1)) # Output. assert x.dtype == tf.as_dtype(dtype) return x
def G_synthesis_simple_dsp( dlatents_in, # Input: Disentangled latents (W) [minibatch, dlatent_size]. dlatent_size=7, # Disentangled latent (W) dimensionality. Including discrete info, rotation, scaling, and xy translation. D_global_size=3, # Discrete latents. sb_C_global_size=4, # Continuous latents. label_size=0, # Label dimensionality, 0 if no labels. num_channels=1, # Number of output color channels. nonlinearity='relu', # Activation function: 'relu', 'lrelu', etc. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. fused_modconv=True, # Implement modulated_conv2d_layer() as a single fused op? **_kwargs): # Ignore unrecognized keyword args. act = nonlinearity images_out = None # Primary inputs. assert dlatent_size == D_global_size + sb_C_global_size n_cat = label_size + D_global_size dlatents_in.set_shape([None, label_size + dlatent_size]) dlatents_in = tf.cast(dlatents_in, dtype) # Return rotation matrix def get_r_matrix(r_latents, cond_latent): # r_latents: [-2., 2.] -> [0, 2*pi] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') rad = (r_latents + 2) / 4. * 2. * np.pi rad = rad * cond tt_00 = tf.math.cos(rad) tt_01 = -tf.math.sin(rad) tt_02 = tf.zeros_like(rad) tt_10 = tf.math.sin(rad) tt_11 = tf.math.cos(rad) tt_12 = tf.zeros_like(rad) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Return scaling matrix def get_s_matrix(s_latents, cond_latent): # s_latents: [-2., 2.] -> [1, 3] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') scale = (s_latents / 2. + 2.) * cond tt_00 = scale tt_01 = tf.zeros_like(scale) tt_02 = tf.zeros_like(scale) tt_10 = tf.zeros_like(scale) tt_11 = scale tt_12 = tf.zeros_like(scale) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Return translation matrix def get_t_matrix(t_latents, cond_latent): # t_latents[:, 0]: [-2., 2.] -> [-0.5, 0.5] # t_latents[:, 1]: [-2., 2.] -> [-0.5, 0.5] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') xy_shift = t_latents / 4. * cond tt_00 = tf.ones_like(xy_shift[:, 0:1]) tt_01 = tf.zeros_like(xy_shift[:, 0:1]) tt_02 = xy_shift[:, 0:1] tt_10 = tf.zeros_like(xy_shift[:, 1:]) tt_11 = tf.ones_like(xy_shift[:, 1:]) tt_12 = xy_shift[:, 1:] theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta # Apply spatial transform def apply_st(x, st_matrix): with tf.variable_scope('Transform'): x = tf.transpose(x, [0, 2, 3, 1]) # NCHW -> NHWC x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3]) x = tf.transpose(x, [0, 3, 1, 2]) # NHWC -> NCHW return x def torgb(x, y): with tf.variable_scope('ToRGB'): t = apply_bias_act(conv2d_layer(x, fmaps=num_channels, kernel=1)) return t if y is None else y + t # Early layers. y = None with tf.variable_scope('4x4'): with tf.variable_scope('Const'): x = tf.get_variable('const', shape=[1, 64, 4, 4], initializer=tf.initializers.random_normal()) x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]) with tf.variable_scope('4x4Conv'): w = get_weight([3, 3, x.shape[1].value, 64]) x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME') x = apply_bias_act(x, act=act) with tf.variable_scope('8x8ModulatedConv'): x = apply_bias_act(modulated_conv2d_layer( x, dlatents_in[:, :n_cat], fmaps=64, kernel=3, up=True, resample_kernel=resample_kernel, fused_modconv=fused_modconv), act=act) with tf.variable_scope('16x16'): w = get_weight([4, 4, x.shape[1].value, 32]) # Transpose weights. w = tf.transpose(w, [0, 1, 3, 2]) x = tf.nn.conv2d_transpose( x, w, output_shape=[tf.shape(dlatents_in)[0], 32, 16, 16], strides=[1, 1, 2, 2], padding='SAME', data_format='NCHW') x = apply_bias_act(x, act=act) with tf.variable_scope('rotation'): r_matrix = get_r_matrix(dlatents_in[:, n_cat:n_cat + 1], dlatents_in[:, :n_cat]) x = apply_st(x, r_matrix) with tf.variable_scope('scale'): s_matrix = get_s_matrix(dlatents_in[:, n_cat + 1:n_cat + 2], dlatents_in[:, :n_cat]) x = apply_st(x, s_matrix) with tf.variable_scope('translation'): t_matrix = get_t_matrix(dlatents_in[:, n_cat + 2:], dlatents_in[:, :n_cat]) x = apply_st(x, t_matrix) with tf.variable_scope('32x32'): w = get_weight([4, 4, x.shape[1].value, 32]) # Transpose weights. w = tf.transpose(w, [0, 1, 3, 2]) x = tf.nn.conv2d_transpose( x, w, output_shape=[tf.shape(dlatents_in)[0], 32, 32, 32], strides=[1, 1, 2, 2], padding='SAME', data_format='NCHW') x = apply_bias_act(x, act=act) with tf.variable_scope('64x64'): w = get_weight([4, 4, x.shape[1].value, 1]) # Transpose weights. w = tf.transpose(w, [0, 1, 3, 2]) x = tf.nn.conv2d_transpose( x, w, output_shape=[tf.shape(dlatents_in)[0], 1, 64, 64], strides=[1, 1, 2, 2], padding='SAME', data_format='NCHW') x = apply_bias_act(x) images_out = x assert images_out.dtype == tf.as_dtype(dtype) return tf.identity(images_out, name='images_out')
def D_info_gan_stylegan2( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): hidden = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act( dense_layer(hidden, fmaps=max(labels_in.shape[1], 1))) if labels_in.shape[1] > 0: x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') hidden = tf.identity(hidden, name='hidden') return scores_out, hidden
def build_C_spgroup_stn_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention with spatial transform. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('trans_matrix'): theta = apply_bias_act(dense_layer(x_mean, fmaps=n_latents * 6)) theta = tf.reshape(theta, [-1, 6]) # [b*n_latents, 6] atts = tf.reshape( atts, [-1, x_wh, x_wh, 1]) # [b*n_latents, x_wh, x_wh, 1] atts = transformer(atts, theta) # [b*n_latents, x_wh, x_wh, 1] atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t
def build_zpos_to_mat_layer(x, name, n_layers, scope_idx, is_training, wh, feat_cnn_dim, resolution=128, trans_dim=512, dff=512, trans_rate=0.1, ncut_maxval=5, post_trans_mat=16, **kwargs): ''' Build zpos_to_mat forwarding transformer to extract features per z. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('PosConstant'): n_lat = x.get_shape().as_list()[-1] pos = tf.get_variable('const', shape=[1, n_lat, trans_dim], initializer=tf.initializers.random_normal()) pos = tf.tile(tf.cast(pos, x.dtype), [tf.shape(x)[0], 1, 1]) zpos = pos + x[:, :, np.newaxis] with tf.variable_scope('MaskEncoding'): if is_training: ncut = tf.random.uniform(shape=[], maxval=ncut_maxval, dtype=tf.int32) split_masks_mul, split_idx = create_split_mask(n_lat, ncut) else: split_masks_mul = tf.ones(shape=[n_lat, n_lat], dtype=tf.float32) split_idx = tf.constant([n_lat]) split_idx = tf.concat([split_idx, [n_lat]], axis=0) split_idx, _ = tf.unique(split_idx) mask_logits = get_return_v( trans_encoder_basic(zpos, is_training, split_masks_mul, n_layers, trans_dim, num_heads=8, dff=dff, rate=trans_rate), 1) # (b, n_lat, d_model) mask_groups = dense_layer_last_dim(mask_logits, post_trans_mat * post_trans_mat) with tf.variable_scope('GatherSubgroups'): b = tf.shape(mask_groups)[0] len_group = tf.shape(split_idx)[0] gathered_groups = tf.reshape( tf.gather(mask_groups, split_idx - 1, axis=1), [b, len_group] + [post_trans_mat, post_trans_mat ]) # (b, len(split_idx), mat * mat) mat_agg = tf.eye(post_trans_mat, batch_shape=[b]) def cond(i, mats): return tf.less(i, len_group) def bod(i, mats): mats = tf.matmul(gathered_groups[:, i, ...], mats) i += 1 return (i, mats) i_mats = (0, mat_agg) _, mat_agg_final = tf.while_loop(cond, bod, i_mats) # (b, mat, mat) mat_agg_final_out = tf.reshape( mat_agg_final, [b, post_trans_mat * post_trans_mat]) with tf.variable_scope('MaskMapping'): mat_agg_final = tf.reshape(mat_agg_final, [b, post_trans_mat * post_trans_mat]) feat = apply_bias_act( dense_layer(mat_agg_final, feat_cnn_dim * wh * wh)) feat = tf.reshape(feat, [-1, feat_cnn_dim, wh, wh]) with tf.variable_scope('ReshapeAttns'): split_masks_mul -= 1e-4 atts = tf.reshape(split_masks_mul, [-1, n_lat, n_lat, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.tile( tf.reshape(atts, [-1, 1, 1, resolution, resolution]), [1, n_lat, 1, 1, 1]) return feat, atts, mat_agg_final_out