def G_style( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign( dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.name_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for( latents2, labels_in, **kwargs) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) dlatents = tf.where( tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # Apply truncation trick. if truncation_psi is not None and truncation_cutoff is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) # Evaluate synthesis network. with tf.control_dependencies( [tf.assign(components.synthesis.find_var('lod'), lod_in)]): images_out = components.synthesis.get_output_for( dlatents, force_clean_graph=is_template_graph, **kwargs) return tf.identity(images_out, name='images_out')
def G_style( latents_in, # 第一个输入:Z码向量 [minibatch, latent_size]. labels_in, # 第二个输入:条件标签 [minibatch, label_size]. truncation_psi = 0.7, # 截断技巧的样式强度乘数。 None = disable. truncation_cutoff = 8, # 要应用截断技巧的层数。 None = disable. truncation_psi_val = None, # 验证期间要使用的truncation_psi的值。 truncation_cutoff_val = None, # 验证期间要使用的truncation_cutoff的值。 dlatent_avg_beta = 0.995, # 在训练期间跟踪W的移动平均值的衰减率。 None = disable. style_mixing_prob = 0.9, # 训练期间混合样式的概率。 None = disable. is_training = False, # 网络正在接受训练? 这个选择可以启用和禁用特定特征。 is_validation = False, # 网络正在验证中? 这个选择用于确定truncation_psi的值。 is_template_graph = False, # True表示由Network类构造的模板图,False表示实际评估。 components = dnnlib.EasyDict(), # 子网络的容器。调用时候保留。 **kwargs): # 子网络的参数们 (G_mapping 和 G_synthesis)。 # 参数验证。 assert not is_training or not is_validation # 不能同时出现训练/验证状态 assert isinstance(components, dnnlib.EasyDict) # components作为EasyDict类,后续被用来装下synthesis和mapping两个网络 if is_validation: # 把验证期间要使用的truncation_psi_val和truncation_cutoff_val值赋过来(默认是None),也就是验证期不使用截断 truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None # 训练期间或截断率为1时,不使用截断 if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): truncation_cutoff = None # 训练期间或截断层为0时,不使用截断 if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None # 非训练期间或计算平均W时的衰减率为1时,不使用衰减 if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # 非训练期间或样式混合的概率小于等于0时,不使用样式混合 # 设置子网络。 if 'synthesis' not in components: # 载入合成网络 components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) num_layers = components.synthesis.input_shape[1] # num_layers = 18 dlatent_size = components.synthesis.input_shape[2] # dlatent_size = (18,512) if 'mapping' not in components: # 载入映射网络 components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) # 设置变量。 lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) # 初始化为0。lod的定义式为:lod = resolution_log2 - res,其中resolution_log2(=10)表示最终分辨率级别,res表示当前层对应的分辨率级别(2-10)。 dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # 人脸平均值 # 计算映射网络输出。 dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) # 更新W的移动平均值。 if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) # 找到新batch的dlatent平均值 # ??? update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) # 把batch的dlatent平均值朝着总dlatent平均值以dlatent_avg_beta步幅靠近,作为新的人脸dlatent平均值 # update_op 是一个用于sess的变量 with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # 确保update_op操作完成 # with tf.control_dependencies: 在with包含的操作operation执行前先执行op列表即[update_up]: # tf.identity(x)是一个operation。所以会确保update_op操作完成。 # tf.identity是返回一个一模一样新的tensor的op,这会增加一个新节点到gragh中 # 执行样式混合正则化。 if style_mixing_prob is not None: with tf.name_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs) # 用来做样式混合的随机中间向量 layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] # 层的索引:[[[0],[1],[2],[3],[4]...[17]]] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 # 当前层等于总层数减去lod_in的两倍,因为每个分辨率对应两层 mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, # 如果随机值小于样式混合的概率,则从1到当前层随机选一个层,否则保留原层 lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # 对于1到mixing_cutoff层保留dlatents的值,其余层采用dlatents2的值替换 # 应用截断技巧。 if truncation_psi is not None and truncation_cutoff is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] # 层的索引:[[[0],[1],[2],[3],[4]...[17]]] ones = np.ones(layer_idx.shape, dtype=np.float32) # ones:[[[1],[1],[1],[1],[1]...[1]]] coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) # 截断的步幅,需要截断的层步幅为truncation_psi,否则为1 dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) # 截断,用平均脸dlatent_avg朝着当前脸dlatents以coefs步幅靠近,最后得到的结果取代当前脸dlatents # 计算合成网络输出。 with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]): images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs) return tf.identity(images_out, name='images_out') # 返回生成的图片
def G_main( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. latmask, # mask for split-frame latents blending dconst, # initial (const) layer displacement truncation_psi=0.5, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=None, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. return_dlatents=False, # Return dlatents in addition to the images? is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. mapping_func='G_mapping', # Build func name for the mapping network. synthesis_func='G_synthesis_stylegan2', # Build func name for the synthesis network. **kwargs): # Arguments for sub-networks (mapping and synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training: truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network( 'G_synthesis', func_name=globals()[synthesis_func], **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs) dlatents = tf.cast(dlatents, tf.float32) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign( dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.variable_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for( latents2, labels_in, is_training=is_training, **kwargs) dlatents2 = tf.cast(dlatents2, tf.float32) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 # original version mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) """ # Diff Augment version mixing_cutoff = tf.where_v2( tf.random_uniform([tf.shape(dlatents)[0]], 0.0, 1.0) < style_mixing_prob, tf.random_uniform([tf.shape(dlatents)[0]], 1, cur_layers, dtype=tf.int32), cur_layers[np.newaxis])[:, np.newaxis, np.newaxis] dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) """ # Apply truncation trick. if truncation_psi is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] layer_psi = np.ones(layer_idx.shape, dtype=np.float32) if truncation_cutoff is None: layer_psi *= truncation_psi else: layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi) dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi) # Evaluate synthesis network. deps = [] if 'lod' in components.synthesis.vars: deps.append(tf.assign(components.synthesis.vars['lod'], lod_in)) with tf.control_dependencies(deps): images_out = components.synthesis.get_output_for( dlatents, latmask, dconst, is_training=is_training, force_clean_graph=is_template_graph, **kwargs) # Return requested outputs. images_out = tf.identity(images_out, name='images_out') if return_dlatents: return images_out, dlatents return images_out
def G_style( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None # Setup components. if "synthesis" not in components: components.synthesis = tflib.Network(num_inputs=1, name="G_synthesis", func_name=G_synthesis, **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if "mapping" not in components: components.mapping = tflib.Network(num_inputs=2, name="G_mapping", func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) # Setup variables. dlatent_avg = tf.get_variable( "dlatent_avg", shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False, ) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope("DlatentAvg"): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign( dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Apply truncation trick. if truncation_psi is not None and truncation_cutoff is not None: with tf.variable_scope("Truncation"): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) # Evaluate synthesis network. images_out = components.synthesis.get_output_for(dlatents, **kwargs) return images_out