Exemplo n.º 1
0
def G_style(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi=0.7,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=8,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    **kwargs):  # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None
                       and not tflib.is_tf_expression(truncation_cutoff)
                       and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis',
                                             func_name=G_synthesis,
                                             **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=G_mapping,
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg',
                                  shape=[dlatent_size],
                                  initializer=tf.initializers.zeros(),
                                  trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in,
                                                 **kwargs)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(
                latents2, labels_in, **kwargs)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(
                tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)),
                dlatents, dlatents2)

    # Apply truncation trick.
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            coefs = tf.where(layer_idx < truncation_cutoff,
                             truncation_psi * ones, ones)
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    with tf.control_dependencies(
        [tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(
            dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')
def G_style(
    latents_in,                                     # 第一个输入:Z码向量 [minibatch, latent_size].
    labels_in,                                      # 第二个输入:条件标签 [minibatch, label_size].
    truncation_psi          = 0.7,                  # 截断技巧的样式强度乘数。 None = disable.
    truncation_cutoff       = 8,                    # 要应用截断技巧的层数。 None = disable.
    truncation_psi_val      = None,                 # 验证期间要使用的truncation_psi的值。
    truncation_cutoff_val   = None,                 # 验证期间要使用的truncation_cutoff的值。
    dlatent_avg_beta        = 0.995,                # 在训练期间跟踪W的移动平均值的衰减率。 None = disable.
    style_mixing_prob       = 0.9,                  # 训练期间混合样式的概率。 None = disable.
    is_training             = False,                # 网络正在接受训练? 这个选择可以启用和禁用特定特征。
    is_validation           = False,                # 网络正在验证中? 这个选择用于确定truncation_psi的值。
    is_template_graph       = False,                # True表示由Network类构造的模板图,False表示实际评估。
    components              = dnnlib.EasyDict(),    # 子网络的容器。调用时候保留。
    **kwargs):                                      # 子网络的参数们 (G_mapping 和 G_synthesis)。

    # 参数验证。
    assert not is_training or not is_validation        # 不能同时出现训练/验证状态
    assert isinstance(components, dnnlib.EasyDict)     # components作为EasyDict类,后续被用来装下synthesis和mapping两个网络
    if is_validation:    # 把验证期间要使用的truncation_psi_val和truncation_cutoff_val值赋过来(默认是None),也就是验证期不使用截断
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
        truncation_psi = None  # 训练期间或截断率为1时,不使用截断
    if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0):
        truncation_cutoff = None  # 训练期间或截断层为0时,不使用截断
    if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
        dlatent_avg_beta = None  # 非训练期间或计算平均W时的衰减率为1时,不使用衰减
    if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
        style_mixing_prob = None  # 非训练期间或样式混合的概率小于等于0时,不使用样式混合

    # 设置子网络。
    if 'synthesis' not in components:  # 载入合成网络
        components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)
    num_layers = components.synthesis.input_shape[1]  # num_layers = 18
    dlatent_size = components.synthesis.input_shape[2]  # dlatent_size = (18,512)
    if 'mapping' not in components:  # 载入映射网络
        components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)

    # 设置变量。
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)  
    # 初始化为0。lod的定义式为:lod = resolution_log2 - res,其中resolution_log2(=10)表示最终分辨率级别,res表示当前层对应的分辨率级别(2-10)。
    dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)  # 人脸平均值

    # 计算映射网络输出。
    dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs)

    # 更新W的移动平均值。
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)   # 找到新batch的dlatent平均值
            # ???
            update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) 
            # 把batch的dlatent平均值朝着总dlatent平均值以dlatent_avg_beta步幅靠近,作为新的人脸dlatent平均值
            # update_op 是一个用于sess的变量
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)  # 确保update_op操作完成
            # with tf.control_dependencies: 在with包含的操作operation执行前先执行op列表即[update_up]:
            # tf.identity(x)是一个operation。所以会确保update_op操作完成。
            # tf.identity是返回一个一模一样新的tensor的op,这会增加一个新节点到gragh中

    # 执行样式混合正则化。
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs)  # 用来做样式混合的随机中间向量
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]   # 层的索引:[[[0],[1],[2],[3],[4]...[17]]]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2  # 当前层等于总层数减去lod_in的两倍,因为每个分辨率对应两层
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,   # 如果随机值小于样式混合的概率,则从1到当前层随机选一个层,否则保留原层
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)  
            # 对于1到mixing_cutoff层保留dlatents的值,其余层采用dlatents2的值替换

    # 应用截断技巧。
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]  # 层的索引:[[[0],[1],[2],[3],[4]...[17]]]
            ones = np.ones(layer_idx.shape, dtype=np.float32)  # ones:[[[1],[1],[1],[1],[1]...[1]]]
            coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)  
            # 截断的步幅,需要截断的层步幅为truncation_psi,否则为1
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)  
            # 截断,用平均脸dlatent_avg朝着当前脸dlatents以coefs步幅靠近,最后得到的结果取代当前脸dlatents

    # 计算合成网络输出。
    with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')  # 返回生成的图片
Exemplo n.º 3
0
def G_main(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    latmask,  # mask for split-frame latents blending
    dconst,  # initial (const) layer displacement
    truncation_psi=0.5,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=None,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    return_dlatents=False,  # Return dlatents in addition to the images?
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    mapping_func='G_mapping',  # Build func name for the mapping network.
    synthesis_func='G_synthesis_stylegan2',  # Build func name for the synthesis network.
    **kwargs):  # Arguments for sub-networks (mapping and synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training:
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network(
            'G_synthesis', func_name=globals()[synthesis_func], **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=globals()[mapping_func],
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg',
                                  shape=[dlatent_size],
                                  initializer=tf.initializers.zeros(),
                                  trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in,
                                                 labels_in,
                                                 is_training=is_training,
                                                 **kwargs)
    dlatents = tf.cast(dlatents, tf.float32)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.variable_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(
                latents2, labels_in, is_training=is_training, **kwargs)
            dlatents2 = tf.cast(dlatents2, tf.float32)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            # original version
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            """ # Diff Augment version
            mixing_cutoff = tf.where_v2(
                tf.random_uniform([tf.shape(dlatents)[0]], 0.0, 1.0) < style_mixing_prob,
                tf.random_uniform([tf.shape(dlatents)[0]], 1, cur_layers, dtype=tf.int32),
                cur_layers[np.newaxis])[:, np.newaxis, np.newaxis]
            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
            """

    # Apply truncation trick.
    if truncation_psi is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            layer_psi = np.ones(layer_idx.shape, dtype=np.float32)
            if truncation_cutoff is None:
                layer_psi *= truncation_psi
            else:
                layer_psi = tf.where(layer_idx < truncation_cutoff,
                                     layer_psi * truncation_psi, layer_psi)
            dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi)

    # Evaluate synthesis network.
    deps = []
    if 'lod' in components.synthesis.vars:
        deps.append(tf.assign(components.synthesis.vars['lod'], lod_in))
    with tf.control_dependencies(deps):
        images_out = components.synthesis.get_output_for(
            dlatents,
            latmask,
            dconst,
            is_training=is_training,
            force_clean_graph=is_template_graph,
            **kwargs)

    # Return requested outputs.
    images_out = tf.identity(images_out, name='images_out')
    if return_dlatents:
        return images_out, dlatents
    return images_out
Exemplo n.º 4
0
def G_style(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi=0.7,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=8,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    **kwargs):  # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None
                       and not tflib.is_tf_expression(truncation_cutoff)
                       and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None

    # Setup components.
    if "synthesis" not in components:
        components.synthesis = tflib.Network(num_inputs=1,
                                             name="G_synthesis",
                                             func_name=G_synthesis,
                                             **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if "mapping" not in components:
        components.mapping = tflib.Network(num_inputs=2,
                                           name="G_mapping",
                                           func_name=G_mapping,
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    dlatent_avg = tf.get_variable(
        "dlatent_avg",
        shape=[dlatent_size],
        initializer=tf.initializers.zeros(),
        trainable=False,
    )

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in,
                                                 **kwargs)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope("DlatentAvg"):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Apply truncation trick.
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope("Truncation"):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            coefs = tf.where(layer_idx < truncation_cutoff,
                             truncation_psi * ones, ones)
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    images_out = components.synthesis.get_output_for(dlatents, **kwargs)
    return images_out