def get_t_matrix(t_latents, cond_latent): # t_latents[:, 0]: [-2., 2.] -> [-0.5, 0.5] # t_latents[:, 1]: [-2., 2.] -> [-0.5, 0.5] with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) xy_shift = t_latents / 4. * cond tt_00 = tf.ones_like(xy_shift[:, 0:1]) tt_01 = tf.zeros_like(xy_shift[:, 0:1]) tt_02 = xy_shift[:, 0:1] tt_10 = tf.zeros_like(xy_shift[:, 1:]) tt_11 = tf.ones_like(xy_shift[:, 1:]) tt_12 = xy_shift[:, 1:] theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def info_gan_head( hidden, # First input: hidden features [minibatch, n_feat]. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. dtype='float32', # Data type to use for activations and outputs. **_kwargs): # Ignore unrecognized keyword args. def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) act = nonlinearity hidden.set_shape([None, nf(0)]) hidden = tf.cast(hidden, dtype) with tf.variable_scope('InfoGanHead'): with tf.variable_scope('Dense_Hidden'): x = apply_bias_act(dense_layer(hidden, fmaps=512), act=act) with tf.variable_scope('Dense_InfoGan'): x = apply_bias_act( dense_layer(x, fmaps=(D_global_size + 2 * (dlatent_size - D_global_size)))) return x
def build_C_global_layers(x, name, n_latents, start_idx, scope_idx, dlatents_withl_in, n_content, act, fused_modconv, fmaps=128, **kwargs): ''' Build continuous latent layers, e.g. C_global layers. ''' with tf.variable_scope(name + '-' + str(scope_idx)): if n_content > 0: with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer( dlatents_withl_in[:, :n_content], fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=n_latents), act='sigmoid') else: cond = 1. C_global_latents = dlatents_withl_in[:, start_idx:start_idx + n_latents] * cond x = apply_bias_act(modulated_conv2d_layer(x, C_global_latents, fmaps=fmaps, kernel=3, up=False, fused_modconv=fused_modconv), act=act) return x
def get_conditional_modifier(modifier, cond_latent, act='lrelu'): with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=modifier.shape.as_list()[1]), act='sigmoid') modifier = modifier * cond return modifier
def build_Cout_genatts_spgroup_layers(x, name, n_latents, scope_idx, act, fmaps=128, resolution=128, **kwargs): ''' Build continuous latent out layers with generating group spatial attention. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial_gen'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('Latent_pred'): x_out_ls = [] for i in range(n_latents): x_tmp = x * atts[:, i] x_tmp_2 = tf.reduce_mean(x_tmp, axis=[2, 3]) # [b, in_dim] with tf.variable_scope('OutDense-' + str(i)): with tf.variable_scope('Conv0'): x_tmp_2 = apply_bias_act(dense_layer(x_tmp_2, fmaps=fmaps), act=act) # [b, fmaps] with tf.variable_scope('Conv1'): x_out_tmp = dense_layer(x_tmp_2, fmaps=1) # [b, 1] x_out_ls.append(x_out_tmp) pred_out = tf.concat(x_out_ls, axis=1) # [b, n_latents] with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, pred_out, atts
def get_s_matrix(s_latents, cond_latent, act='lrelu'): # s_latents: [-2., 2.] -> [1, 3] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') scale = (s_latents + 2.) * cond + 1. tt_00 = scale tt_01 = tf.zeros_like(scale) tt_02 = tf.zeros_like(scale) tt_10 = tf.zeros_like(scale) tt_11 = scale tt_12 = tf.zeros_like(scale) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def get_r_matrix(r_latents, cond_latent, act='lrelu'): # r_latents: [-2., 2.] -> [0, 2*pi] with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') rad = (r_latents + 2) / 4. * 2. * np.pi rad = rad * cond tt_00 = tf.math.cos(rad) tt_01 = -tf.math.sin(rad) tt_02 = tf.zeros_like(rad) tt_10 = tf.math.sin(rad) tt_11 = tf.math.cos(rad) tt_12 = tf.zeros_like(rad) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def build_C_global_nocond_layers(x, name, n_latents, start_idx, scope_idx, dlatents_withl_in, act, fused_modconv, fmaps=128, **kwargs): ''' Build continuous latent layers, e.g. C_global layers. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Conv0'): C_global_latents = apply_bias_act(dense_layer( dlatents_withl_in[:, start_idx:start_idx + n_latents], fmaps=128), act=act) # C_global_latents = dlatents_withl_in[:, start_idx:start_idx + # n_latents] with tf.variable_scope('Modulate'): x = apply_bias_act(modulated_conv2d_layer(x, C_global_latents, fmaps=fmaps, kernel=3, up=False, fused_modconv=fused_modconv), act=act) return x
def hier_out_branch(x, nd_out): with tf.variable_scope('Output'): if len(x.shape) == 4: x = tf.reduce_mean(tf.reduce_mean(x, axis=3), axis=2) elif len(x.shape) != 2: raise ValueError('Not recognized dimension.') x = apply_bias_act(dense_layer(x, fmaps=nd_out)) return x
def point_wise_feed_forward_network(x, d_model, dff): seq_len, x_dim = x.get_shape().as_list()[-2:] with tf.variable_scope('ffn_0_'): x = tf.reshape(x, [-1, x_dim]) x = apply_bias_act(dense_layer(x, dff), act='relu') x = tf.reshape(x, [-1, seq_len, dff]) # (batch_size, seq_len, dff) with tf.variable_scope('ffn_1_'): x = apply_bias(dense_layer_last_dim( x, d_model)) # (batch_size, seq_len, d_model) return x
def get_sh_matrix(sh_latents, cond_latent, act='lrelu'): # sh_latents[:, 0]: [-2., 2.] -> [-1., 1.] # sh_latents[:, 1]: [-2., 2.] -> [-1., 1.] with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) xy_shear = sh_latents / 2. * cond tt_00 = tf.ones_like(xy_shear[:, 0:1]) tt_01 = xy_shear[:, 0:1] tt_02 = tf.zeros_like(xy_shear[:, 0:1]) tt_10 = xy_shear[:, 1:] tt_11 = tf.ones_like(xy_shear[:, 1:]) tt_12 = tf.zeros_like(xy_shear[:, 1:]) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def build_Cout_spgroup_layers(x, name, n_latents, start_idx, scope_idx, atts_in, act, fmaps=128, resolution=128, **kwargs): ''' Build continuous latent out layers with learned group spatial attention. Support square images only. ''' # atts_in: [b, all_n_latents, 1, resolution, resolution] with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_wh = x.shape[2] atts = atts_in[:, start_idx:start_idx + n_latents] # [b, n_latents, 1, resolution, resolution] atts = tf.reshape(atts, [-1, resolution, resolution, 1]) atts = tf.image.resize(atts, size=(x_wh, x_wh)) atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) x_out_ls = [] for i in range(n_latents): x_tmp = x * atts[:, i] x_tmp_2 = tf.reduce_mean(x_tmp, axis=[2, 3]) # [b, in_dim] with tf.variable_scope('OutDense-' + str(i)): with tf.variable_scope('Conv0'): x_tmp_2 = apply_bias_act(dense_layer(x_tmp_2, fmaps=fmaps), act=act) # [b, fmaps] with tf.variable_scope('Conv1'): x_out_tmp = dense_layer(x_tmp_2, fmaps=1) # [b, 1] x_out_ls.append(x_out_tmp) pred_out = tf.concat(x_out_ls, axis=1) # [b, n_latents] return x, pred_out
def net_M( latents_in, C_global_size=10, D_global_size=0, latent_size=512, # Latent vector (Z) dimensionality. mapping_layers=4, # Number of mapping layers. mapping_lrmul=0.1, # Learning rate multiplier for the mapping layers. mapping_fmaps=512, # Number of activations in the mapping layers. mapping_nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. use_std_in_m=False, # If output prior std. dtype='float32', # Data type to use for activations and outputs. **_kwargs): # Ignore unrecognized keyword args. act = mapping_nonlinearity latents_in.set_shape([None, C_global_size + D_global_size]) x = latents_in # Mapping layers. for layer_idx in range(mapping_layers): with tf.variable_scope('Dense%d' % layer_idx): # if layer_idx == mapping_layers - 1: # fmaps = latent_size # act = 'tanh' # else: # fmaps = mapping_fmaps # act = mapping_nonlinearity # x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), # act=act, lrmul=mapping_lrmul) if layer_idx == mapping_layers - 1: if use_std_in_m: fmaps = 2 * latent_size else: fmaps = latent_size act = 'linear' else: fmaps = mapping_fmaps act = mapping_nonlinearity x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul), act=act, lrmul=mapping_lrmul) # # x = x * 1.5 # with tf.variable_scope('Dense1'): # # x = tf.zeros([tf.shape(x)[0], latent_size], dtype=x.dtype) + 0.5 # x = tf.random.normal([tf.shape(x)[0], latent_size], mean=0.0, stddev=0.5) # Output. assert x.dtype == tf.as_dtype(dtype) return tf.identity(x, name='to_latent_out')
def build_C_fgroup_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group feature attention. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_start_end'): x_mean = tf.reduce_mean(x, axis=[2, 3]) att_dim = x_mean.shape[1] atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim) atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1 ]) # [b, n_latents, 2, att_dim, 1, 1] att_sm = tf.nn.softmax(atts, axis=3) att_cs = tf.cumsum(att_sm, axis=3) att_cs_starts, att_cs_ends = tf.split( att_cs, 2, axis=2) # [b, n_latents, 1, att_dim, 1, 1] att_cs_ends = 1 - att_cs_ends atts = att_cs_starts * att_cs_ends # [b, n_latents, 1, att_dim, 1, 1] atts = tf.reshape(atts, [-1, n_latents, att_dim, 1, 1]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i:i + 1] if return_atts: return x, atts else: return x
def net_M_vc( latents_in, C_global_size=10, D_global_size=0, latent_size=512, # Latent vector (Z) dimensionality. mapping_lrmul=0.1, # Learning rate multiplier for the mapping layers. use_std_in_m=False, # If output prior std. dtype='float32', # Data type to use for activations and outputs. **_kwargs): # Ignore unrecognized keyword args. latents_in.set_shape([None, C_global_size]) x = latents_in x = apply_bias_act(dense_layer(x, fmaps=latent_size, lrmul=mapping_lrmul), act='lrelu', lrmul=mapping_lrmul) # Output. assert x.dtype == tf.as_dtype(dtype) return tf.identity(x, name='to_latent_out')
def get_s_matrix(s_latents, cond_latent, act='lrelu'): # s_latents[:, 0]: [-2., 2.] -> [1., 3.] # s_latents[:, 1]: [-2., 2.] -> [1., 3.] if s_latents.shape.as_list()[1] == 1: with tf.variable_scope('Condition0'): cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1'): cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid') scale = (s_latents + 2.) * cond + 1. tt_00 = scale tt_01 = tf.zeros_like(scale) tt_02 = tf.zeros_like(scale) tt_10 = tf.zeros_like(scale) tt_11 = scale tt_12 = tf.zeros_like(scale) else: with tf.variable_scope('Condition0x'): cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1x'): cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid') with tf.variable_scope('Condition0y'): cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act) with tf.variable_scope('Condition1y'): cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid') cond = tf.concat([cond_x, cond_y], axis=1) scale = (s_latents + 2.) * cond + 1. tt_00 = scale[:, 0:1] tt_01 = tf.zeros_like(scale[:, 0:1]) tt_02 = tf.zeros_like(scale[:, 0:1]) tt_10 = tf.zeros_like(scale[:, 1:]) tt_11 = scale[:, 1:] tt_12 = tf.zeros_like(scale[:, 1:]) theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1) return theta
def vid_naive_cluster_head( fake_in, # First input: generated image from z [minibatch, channel, n_frames, height, width]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake_in.set_shape([None, num_channels, None, resolution, resolution]) fake_in = tf.cast(fake_in, dtype) vid_in = fake_in # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = conv3d_layer(y, fmaps=nf(res - 1), kernel=1) t = apply_bias_act_3d(t, act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 with tf.variable_scope('Conv3D_0'): x = conv3d_layer(x, fmaps=nf(res - 1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Conv1_down'): x = conv3d_layer(x, fmaps=nf(res - 2), kernel=3, down=True) x = apply_bias_act_3d(x, act=act) return x # Main layers. x = None y = vid_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('I_%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample_3d(y) # Final layers. with tf.variable_scope('I_4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) with tf.variable_scope('Conv'): x = conv3d_layer(x, fmaps=nf(1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Global_temporal_pool'): x = tf.reduce_mean(x, axis=2) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output. with tf.variable_scope('I_Output'): with tf.variable_scope('Dense_VC'): x = apply_bias_act( dense_layer(x, fmaps=dlatent_size - D_global_size)) assert x.dtype == tf.as_dtype(dtype) return x
def vc_head( fake1, # First input: generated image from z [minibatch, channel, height, width]. fake2, # Second input: hidden features from z + delta(z) [minibatch, channel, height, width]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. connect_mode='concat', # How fake1 and fake2 connected. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake1.set_shape([None, num_channels, resolution, resolution]) fake2.set_shape([None, num_channels, resolution, resolution]) fake1 = tf.cast(fake1, dtype) fake2 = tf.cast(fake2, dtype) if connect_mode == 'diff': images_in = fake1 - fake2 elif connect_mode == 'concat': images_in = tf.concat([fake1, fake2], axis=1) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): with tf.variable_scope('Dense_VC'): x = apply_bias_act( dense_layer(x, fmaps=(D_global_size + (dlatent_size - D_global_size)))) # Output. assert x.dtype == tf.as_dtype(dtype) return x
def build_std_gen(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, resolution=512, fmap_base=2 << 8, fmap_min=1, fmap_max=512, fmap_decay=1, architecture='skip', randomize_noise=True, resample_kernel=[1, 3, 3, 1], num_channels=3, latent_split_ls_for_std_gen=[5, 5, 5, 5], **kwargs): ''' Build standard disentanglement generator with similar architecture to stylegan2. ''' # with tf.variable_scope(name + '-' + str(scope_idx)): resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] num_layers = resolution_log2 * 2 - 2 images_out = None dtype = x.dtype assert n_latents == sum(latent_split_ls_for_std_gen) assert num_layers == len(latent_split_ls_for_std_gen) latents_ready_ls = [] start_code = 0 for i, seg in enumerate(latent_split_ls_for_std_gen): with tf.variable_scope('PreConvDense-' + str(i) + '-0'): x_tmp0 = dense_layer(dlatents_in[:, start_code:start_code + seg], fmaps=nf(1)) with tf.variable_scope('PreConvDense-' + str(i) + '-1'): x_tmp1 = dense_layer(x_tmp0, fmaps=nf(1)) start_code += seg latents_ready_ls.append(x_tmp1) # Noise inputs. noise_inputs = [] for layer_idx in range(num_layers - 1): res = (layer_idx + 5) // 2 shape = [1, 1, 2**res, 2**res] noise_inputs.append( tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False)) # Single convolution layer with all the bells and whistles. def layer(x, layer_idx, fmaps, kernel, up=False): # start_idx_layer = sum(latent_split_ls_for_std_gen[:layer_idx]) # for i in range(start_idx_layer, start_idx_layer + latent_split_ls_for_std_gen[layer_idx]): # x = modulated_conv2d_layer(x, latents_ready_spl_ls[i], fmaps=fmaps, kernel=kernel, up=up, # resample_kernel=resample_kernel, fused_modconv=fused_modconv) x = modulated_conv2d_layer(x, latents_ready_ls[layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv) if randomize_noise: noise = tf.random_normal( [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) else: noise = tf.cast(noise_inputs[layer_idx], x.dtype) noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros()) x += noise * tf.cast(noise_strength, x.dtype) return apply_bias_act(x, act=act) # Building blocks for main layers. def block(x, res): # res = 3..resolution_log2 t = x with tf.variable_scope('Conv0_up'): x = layer(x, layer_idx=res * 2 - 5, fmaps=nf(res - 1), kernel=3, up=True) with tf.variable_scope('Conv1'): x = layer(x, layer_idx=res * 2 - 4, fmaps=nf(res - 1), kernel=3) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 1), kernel=1, up=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def upsample(y): with tf.variable_scope('Upsample'): return upsample_2d(y, k=resample_kernel) def torgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('ToRGB'): t = apply_bias_act( modulated_conv2d_layer(x, latents_ready_ls[res * 2 - 3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv)) return t if y is None else y + t # Early layers. y = None with tf.variable_scope('4x4'): with tf.variable_scope('Const'): x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.random_normal()) x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]) with tf.variable_scope('Conv'): x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3) # Main layers. for res in range(3, resolution_log2 + 1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): x = block(x, res) if res == resolution_log2: y = torgb(x, y, res) images_out = y assert images_out.dtype == tf.as_dtype(dtype) return tf.identity(images_out, name='images_out')
def build_C_spgroup_layers_with_latents_ready(x, name, n_latents, scope_idx, latents_ready, return_atts=False, resolution=128, n_subs=1, **kwargs): ''' Build continuous latent layers with learned group spatial attention using latents_ready. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * n_subs * 4 * x_wh) atts_wh = tf.reshape(atts_wh, [-1, n_latents, n_subs, 4, x_wh ]) # [b, n_latents, n_subs, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=3) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, n_subs, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, n_subs, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, n_subs, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, n_subs, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, n_subs, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, n_subs, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, n_subs, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, n_subs, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, n_subs, 1, x_wh, x_wh] atts = tf.reduce_mean(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh] # atts = tf.reduce_sum(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('Att_apply'): C_global_latents = latents_ready # [b, n_latents, 512] x_norm = instance_norm(x) # x_norm = tf.tile(x_norm, [1, n_latents, 1, 1]) # x_norm = tf.reshape(x_norm, [-1, x.shape[1], x.shape[2], x.shape[3]]) # [b*n_latents, c, h, w] # C_global_latents = tf.reshape(C_global_latents, [-1, 1]) # x_styled = style_mod(x_norm, C_global_latents) # x_styled = tf.reshape(x_styled, [-1, n_latents, x_styled.shape[1], # x_styled.shape[2], x_styled.shape[3]]) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] # x = x * (1 - atts[:, i]) + x_styled[:, i] * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def build_C_spgroup_stn_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention with spatial transform. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('trans_matrix'): theta = apply_bias_act(dense_layer(x_mean, fmaps=n_latents * 6)) theta = tf.reshape(theta, [-1, 6]) # [b*n_latents, 6] atts = tf.reshape( atts, [-1, x_wh, x_wh, 1]) # [b*n_latents, x_wh, x_wh, 1] atts = transformer(atts, theta) # [b*n_latents, x_wh, x_wh, 1] atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def vpex_net( fake1, # First input: generated image from z [minibatch, channel, height, width]. fake2, # Second input: hidden features from z + delta(z) [minibatch, channel, height, width]. latents, # Ground-truth latent code for fake1. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. connect_mode='concat', # How fake1 and fake2 connected. return_atts=False, # If return I_atts. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake1.set_shape([None, num_channels, resolution, resolution]) fake2.set_shape([None, num_channels, resolution, resolution]) latents.set_shape([None, dlatent_size]) fake1 = tf.cast(fake1, dtype) fake2 = tf.cast(fake2, dtype) latents = tf.cast(latents, dtype) if connect_mode == 'diff': images_in = fake1 - fake2 elif connect_mode == 'concat': images_in = tf.concat([fake1, fake2], axis=1) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # attention features for each latent dimension. def get_att_map(latents, x=None): with tf.variable_scope('create_att_feats'): x_ch, x_h, x_w = x.get_shape().as_list()[1:] att_feats = tf.get_variable( 'att_feats', shape=[1, dlatent_size, x_ch, x_h, x_w], initializer=tf.initializers.random_normal()) att_feats = tf.tile(tf.cast(att_feats, dtype), [tf.shape(latents)[0], 1, 1, 1, 1]) latents = latents[:, tf.newaxis, :] latents = tf.tile(latents, [1, dlatent_size, 1]) latents = tf.reshape(latents, [-1, dlatent_size]) # att_map = apply_bias_act(modulated_conv2d_layer(att_feats, latents, fmaps=64, kernel=3, # demodulate=False, fused_modconv=False), # act=act) # shape: [b*dlatent_size, 1, 8, 8] if x is None: att_map = att_feats att_map = tf.reshape(att_map, [-1, x_ch, x_h, x_w]) map_ch = x_ch else: x = tf.reshape(x, [-1, 1, x_ch, x_h, x_w]) x = tf.tile(x, [1, dlatent_size, 1, 1, 1]) att_map = tf.concat([x, att_feats], axis=2) att_map = tf.reshape(att_map, [-1, 2 * x_ch, x_h, x_w]) map_ch = 2 * x_ch with tf.variable_scope('att_conv_3x3'): att_map = apply_bias_act(conv2d_layer(att_map, fmaps=map_ch, kernel=3), act=act) with tf.variable_scope('att_conv_1x1'): att_map = apply_bias_act( conv2d_layer(att_map, fmaps=1, kernel=1)) att_map = tf.reshape(att_map, [-1, dlatent_size, 1, x_h * x_w]) att_map = tf.nn.softmax(att_map, axis=-1) # att_map = tf.nn.sigmoid(att_map) # att_map = tf.reshape(att_map, [-1, dlatent_size, 1, 8, 8]) return att_map # Main layers. x = None y = images_in for res in range(resolution_log2, 3, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Duplicate for each att. with tf.variable_scope('apply_att'): att_map = get_att_map(latents, x) x_ch, x_h, x_w = x.get_shape().as_list()[1:] assert x_h == 8 x_ori = tf.reshape(x, [-1, 1, x_ch, x_h * x_w]) # [b, 1, ch, h*w] x = tf.reshape(x, [-1, 1, x_ch, x_h * x_w]) x = att_map * x x = tf.reduce_sum(x, axis=-1) # [b, dlatent, ch] x = tf.reshape(x, [-1, x_ch, 1, 1]) # [b * dlatent, ch, 1, 1] with tf.variable_scope('after_att_conv_1x1'): x = apply_bias_act(conv2d_layer(x, fmaps=x_ch, kernel=1)) x = tf.reshape(x, [-1, dlatent_size, x_ch, 1]) # [b, dlatent, ch, 1] x = tf.tile(x, [1, 1, 1, x_h * x_w]) # x = x + x_ori # [b, dlatent, ch, h * w] x = tf.reshape(x, [-1, x_ch, x_h, x_w]) y_ch, y_h, y_w = y.get_shape().as_list()[1:] y = y[:, tf.newaxis, ...] y = tf.tile(y, [1, dlatent_size, 1, 1, 1]) y = tf.reshape(y, [-1, y_ch, y_h, y_w]) for res in range(3, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) with tf.variable_scope('Output'): with tf.variable_scope('Dense_VC'): x = apply_bias_act(dense_layer(x, fmaps=1)) with tf.variable_scope('Final_reshape_x'): x = tf.reshape(x, [-1, dlatent_size]) # Output. assert x.dtype == tf.as_dtype(dtype) if return_atts: with tf.variable_scope('Reshape_atts'): att_map = tf.reshape(att_map, [-1, 8, 8, 1]) att_map = tf.image.resize(att_map, size=(resolution, resolution)) att_map = tf.reshape(att_map, [-1, dlatent_size, 1, resolution, resolution]) return x, att_map else: return x
def build_C_spgroup_regW_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, resolution=128, n_subs=1, **kwargs): ''' Build continuous latent layers with learned group spatial attention. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * n_subs * 4 * x_wh) atts_wh = tf.reshape(atts_wh, [-1, n_latents, n_subs, 4, x_wh ]) # [b, n_latents, n_subs, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=3) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, n_subs, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, n_subs, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, n_subs, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, n_subs, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, n_subs, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, n_subs, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, n_subs, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, n_subs, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, n_subs, 1, x_wh, x_wh] atts = tf.reduce_mean(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh] # atts = tf.reduce_sum(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) z_w = [] for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): # print('C_global_latents.shape:', C_global_latents.shape) x_styled, z_w_tmp = style_mod_with_regW( x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] z_w.append(z_w_tmp) with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) z_w = tf.concat(z_w, axis=0) return x, atts, z_w
def build_zpos_to_mat_layer(x, name, n_layers, scope_idx, is_training, wh, feat_cnn_dim, resolution=128, trans_dim=512, dff=512, trans_rate=0.1, ncut_maxval=5, post_trans_mat=16, **kwargs): ''' Build zpos_to_mat forwarding transformer to extract features per z. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('PosConstant'): n_lat = x.get_shape().as_list()[-1] pos = tf.get_variable('const', shape=[1, n_lat, trans_dim], initializer=tf.initializers.random_normal()) pos = tf.tile(tf.cast(pos, x.dtype), [tf.shape(x)[0], 1, 1]) zpos = pos + x[:, :, np.newaxis] with tf.variable_scope('MaskEncoding'): if is_training: ncut = tf.random.uniform(shape=[], maxval=ncut_maxval, dtype=tf.int32) split_masks_mul, split_idx = create_split_mask(n_lat, ncut) else: split_masks_mul = tf.ones(shape=[n_lat, n_lat], dtype=tf.float32) split_idx = tf.constant([n_lat]) split_idx = tf.concat([split_idx, [n_lat]], axis=0) split_idx, _ = tf.unique(split_idx) mask_logits = get_return_v( trans_encoder_basic(zpos, is_training, split_masks_mul, n_layers, trans_dim, num_heads=8, dff=dff, rate=trans_rate), 1) # (b, n_lat, d_model) mask_groups = dense_layer_last_dim(mask_logits, post_trans_mat * post_trans_mat) with tf.variable_scope('GatherSubgroups'): b = tf.shape(mask_groups)[0] len_group = tf.shape(split_idx)[0] gathered_groups = tf.reshape( tf.gather(mask_groups, split_idx - 1, axis=1), [b, len_group] + [post_trans_mat, post_trans_mat ]) # (b, len(split_idx), mat * mat) mat_agg = tf.eye(post_trans_mat, batch_shape=[b]) def cond(i, mats): return tf.less(i, len_group) def bod(i, mats): mats = tf.matmul(gathered_groups[:, i, ...], mats) i += 1 return (i, mats) i_mats = (0, mat_agg) _, mat_agg_final = tf.while_loop(cond, bod, i_mats) # (b, mat, mat) mat_agg_final_out = tf.reshape( mat_agg_final, [b, post_trans_mat * post_trans_mat]) with tf.variable_scope('MaskMapping'): mat_agg_final = tf.reshape(mat_agg_final, [b, post_trans_mat * post_trans_mat]) feat = apply_bias_act( dense_layer(mat_agg_final, feat_cnn_dim * wh * wh)) feat = tf.reshape(feat, [-1, feat_cnn_dim, wh, wh]) with tf.variable_scope('ReshapeAttns'): split_masks_mul -= 1e-4 atts = tf.reshape(split_masks_mul, [-1, n_lat, n_lat, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.tile( tf.reshape(atts, [-1, 1, 1, resolution, resolution]), [1, n_lat, 1, 1, 1]) return feat, atts, mat_agg_final_out
def build_C_spfgroup_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group feature-spatial attention. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_channel_start_end'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] att_dim = x_mean.shape[1] atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim) atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1 ]) # [b, n_latents, 2, att_dim, 1, 1] att_sm = tf.nn.softmax(atts, axis=3) att_cs = tf.cumsum(att_sm, axis=3) att_cs_starts, att_cs_ends = tf.split(att_cs, 2, axis=2) att_cs_ends = 1 - att_cs_ends att_channel = att_cs_starts * att_cs_ends # [b, n_latents, 1, att_dim, 1, 1] att_channel = tf.reshape(att_channel, [-1, n_latents, att_dim, 1, 1]) with tf.variable_scope('Att_spatial'): x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] att_sp = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] atts = att_channel * att_sp # [b, n_latents, att_dim, h, w] # print('in spfgroup 1, x.shape:', x.get_shape().as_list()) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] # print('in spfgroup 2, x.shape:', x.get_shape().as_list()) if return_atts: with tf.variable_scope('Reshape_output'): att_sp = tf.reshape(att_sp, [-1, x_wh, x_wh, 1]) att_sp = tf.image.resize(att_sp, size=(resolution, resolution)) att_sp = tf.reshape(att_sp, [-1, n_latents, 1, resolution, resolution]) # return x, att_channel, att_sp return x, att_sp else: return x
def vid_head( fake_in, # First input: generated image from z [minibatch, channel, n_frames, height, width]. C_delta_idxes, # Second input: the index of the varied latent. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. dlatent_size=10, D_global_size=0, fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity fake_in.set_shape([None, num_channels, None, resolution, resolution]) fake_in = tf.cast(fake_in, dtype) C_delta_idxes.set_shape([None, dlatent_size]) C_delta_idxes = tf.cast(C_delta_idxes, dtype) vid_in = fake_in # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = conv3d_layer(y, fmaps=nf(res - 1), kernel=1) t = apply_bias_act_3d(t, act=act) return t if x is None else x + t # def block(x, res): # res = 2..resolution_log2 # t = x # with tf.variable_scope('Conv0'): # x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), # act=act) # with tf.variable_scope('Conv1_down'): # x = apply_bias_act(conv2d_layer(x, # fmaps=nf(res - 2), # kernel=3, # down=True, # resample_kernel=resample_kernel), # act=act) # if architecture == 'resnet': # with tf.variable_scope('Skip'): # t = conv2d_layer(t, # fmaps=nf(res - 2), # kernel=1, # down=True, # resample_kernel=resample_kernel) # x = (x + t) * (1 / np.sqrt(2)) # return x def block(x, res): # res = 2..resolution_log2 with tf.variable_scope('Conv3D_0'): x = conv3d_layer(x, fmaps=nf(res - 1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Conv1_down'): x = conv3d_layer(x, fmaps=nf(res - 2), kernel=3, down=True) x = apply_bias_act_3d(x, act=act) return x # def downsample(y): # with tf.variable_scope('Downsample'): # return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = vid_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('I_%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample_3d(y) # Final layers. with tf.variable_scope('I_4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) with tf.variable_scope('Conv'): x = conv3d_layer(x, fmaps=nf(1), kernel=3) x = apply_bias_act_3d(x, act=act) with tf.variable_scope('Global_temporal_pool'): x = tf.reduce_mean(x, axis=2) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=64), act=act) print('before from C_delta_idxes, x.get_shape:', x.get_shape().as_list()) print('before from C_delta_idxes, x.shape:', x.shape) print('before from C_delta_idxes, C_delta_idxes.shape:', C_delta_idxes.shape) # From C_delta_idxes with tf.variable_scope('I_From_C_Delta_Idx'): x_from_delta = apply_bias_act(dense_layer(C_delta_idxes, fmaps=32), act=act) x = tf.concat([x, x_from_delta], axis=1) # For MINE with tf.variable_scope('I_Output'): with tf.variable_scope('Dense_T_0'): x = apply_bias_act(dense_layer(x, fmaps=128), act=act) with tf.variable_scope('Dense_T_1'): x = apply_bias_act(dense_layer(x, fmaps=1)) # Output. assert x.dtype == tf.as_dtype(dtype) return x
def D_info_gan_stylegan2( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): hidden = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act( dense_layer(hidden, fmaps=max(labels_in.shape[1], 1))) if labels_in.shape[1] > 0: x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') hidden = tf.identity(hidden, name='hidden') return scores_out, hidden
def build_C_spgroup_lcond_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] atts_ls = [] for i in range(n_latents): with tf.variable_scope('lcond-' + str(i)): x_mean_styled = style_mod(x_mean, C_global_latents[:, i:i + 1]) att_wh = dense_layer(x_mean_styled, fmaps=4 * x_wh) att_wh = tf.reshape(att_wh, [-1, 4, x_wh]) # [b, 4, x_wh] att_wh_sm = tf.nn.softmax(att_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_start, att_h_cs_end, att_w_cs_start, att_w_cs_end = tf.split( att_wh_cs, 4, axis=1) att_h_cs_end = 1 - att_h_cs_end # [b, 1, x_wh] att_w_cs_end = 1 - att_w_cs_end # [b, 1, x_wh] att_h_cs_start = tf.reshape(att_h_cs_start, [-1, 1, 1, x_wh, 1]) att_h_cs_end = tf.reshape(att_h_cs_end, [-1, 1, 1, x_wh, 1]) att_h = att_h_cs_start * att_h_cs_end # [b, 1, 1, x_wh, 1] att_w_cs_start = tf.reshape(att_w_cs_start, [-1, 1, 1, 1, x_wh]) att_w_cs_end = tf.reshape(att_w_cs_end, [-1, 1, 1, 1, x_wh]) att_w = att_w_cs_start * att_w_cs_end # [b, 1, 1, 1, x_wh] att = att_h * att_w # [b, 1, 1, x_wh, x_wh] atts_ls.append(att) atts = tf.concat(atts_ls, axis=1) with tf.variable_scope('Att_apply'): x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x