Beispiel #1
0
def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0):
    assert not (up and down)
    kernel = w.shape[0].value
    assert w.shape[1].value == kernel
    assert kernel >= 1 and kernel % 2 == 1

    w = tf.cast(w, x.dtype)
    if up:
        x = upsample_conv_2d(x,
                             w,
                             data_format='NCHW',
                             k=resample_kernel,
                             padding=padding)
    elif down:
        x = conv_downsample_2d(x,
                               w,
                               data_format='NCHW',
                               k=resample_kernel,
                               padding=padding)
    else:
        padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding]
        x = tf.nn.conv2d(x,
                         w,
                         data_format='NCHW',
                         strides=[1, 1, 1, 1],
                         padding=padding_mode)
    return x
def conv2d_layer(x,
                 fmaps,
                 kernel,
                 up=False,
                 down=False,
                 resample_kernel=None,
                 gain=1,
                 use_wscale=True,
                 lrmul=1,
                 weight_var='weight'):
    assert not (up and down)
    assert kernel >= 1 and kernel % 2 == 1
    w = get_weight([kernel, kernel, x.shape[1].value, fmaps],
                   gain=gain,
                   use_wscale=use_wscale,
                   lrmul=lrmul,
                   weight_var=weight_var)
    if up:
        x = upsample_conv_2d(x,
                             tf.cast(w, x.dtype),
                             data_format='NCHW',
                             k=resample_kernel)
    elif down:
        x = conv_downsample_2d(x,
                               tf.cast(w, x.dtype),
                               data_format='NCHW',
                               k=resample_kernel)
    else:
        x = tf.nn.conv2d(x,
                         tf.cast(w, x.dtype),
                         data_format='NCHW',
                         strides=[1, 1, 1, 1],
                         padding='SAME')
    return x
def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, gain=1, use_wscale=True, lrmul=1, fused_modconv=True, weight_var='weight', mod_weight_var='mod_weight', mod_bias_var='mod_bias'):
    assert not (up and down)
    assert kernel >= 1 and kernel % 2 == 1

    # Get weight.
    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var)
    ww = w[np.newaxis] # [BkkIO] Introduce minibatch dimension.

    ## att mask
    if not up and not down:

        mask = 1 + get_weight([x.shape[2].value, x.shape[3].value], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var='mask')
        x = x * mask * tf.rsqrt(tf.reduce_mean(tf.square(mask)) + 1e-8)

    # Modulate.
    s = dense_layer(y, fmaps=x.shape[1].value, weight_var=mod_weight_var) # [BI] Transform incoming W to style.
    s = apply_bias_act(s, bias_var=mod_bias_var) + 1 # [BI] Add bias (initially 1).
    ww *= tf.cast(s[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) # [BkkIO] Scale input feature maps.

    # Demodulate.
    if demodulate:
        d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) # [BO] Scaling factor.
        ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] Scale output feature maps.

    # Reshape/scale input.
    if fused_modconv:
        x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups.
        w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1])
    else:
        x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations.

    # Convolution with optional up/downsampling.
    if up:
        x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
    elif down:
        x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
    else:
        x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME')

    # Reshape/scale output.
    if fused_modconv:
        x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch.
    elif demodulate:
        x *= tf.cast(d[:, :, np.newaxis, np.newaxis], x.dtype) # [BOhw] Not fused => scale output activations.
    return x
Beispiel #4
0
def decomposition_conv2d_layer(x,
                               y,
                               fmaps,
                               kernel,
                               up=False,
                               down=False,
                               demodulate=True,
                               resample_kernel=None,
                               gain=1,
                               use_wscale=True,
                               lrmul=1,
                               fused_modconv=True,
                               weight_var='weight',
                               mod_weight_var='U_weight',
                               mod_bias_var='mod_bias'):
    assert not (up and down)
    assert kernel >= 1 and kernel % 2 == 1
    # Get weight.
    out_channel = fmaps
    in_channel = kernel * kernel * x.shape[1].value
    s_dimension = min(out_channel, in_channel)
    U = get_weight([out_channel, s_dimension],
                   gain=gain,
                   use_wscale=use_wscale,
                   lrmul=lrmul,
                   weight_var='U_' + weight_var)
    V = get_weight([in_channel, s_dimension],
                   gain=gain,
                   use_wscale=use_wscale,
                   lrmul=lrmul,
                   weight_var='V_' + weight_var)
    # linear and normalization to obtain the style vector s and its diagnonal matrix S
    s = dense_layer(
        y, fmaps=s_dimension,
        weight_var=mod_weight_var)  # [BI] Transform incoming W to style.
    s = apply_bias_act(
        s, bias_var=mod_bias_var) + 1  # [BI] Add bias (initially 1).
    s *= tf.rsqrt(tf.reduce_mean(tf.square(s), axis=1, keepdims=True) + 1e-8)
    # construct diagnol matrix using style s
    S = tf.matrix_diag(s)
    # using S to construct a controllable matrix w
    w = tf.matmul(tf.reshape(S, [-1, s_dimension]), tf.transpose(V, [1, 0]))
    w = tf.reshape(w, [-1, s_dimension, in_channel])
    w = tf.transpose(w, [0, 2, 1])
    w = tf.matmul(tf.reshape(w, [-1, s_dimension]), tf.transpose(U, [1, 0]))
    w = tf.reshape(w, [-1, kernel, kernel, x.shape[1].value, out_channel])
    # normalize w similar to weight demodulation
    if demodulate:
        d = tf.rsqrt(tf.reduce_sum(tf.square(w), axis=[1, 2, 3]) +
                     1e-8)  # [BO] Scaling factor.
        w *= d[:, np.newaxis, np.newaxis,
               np.newaxis, :]  # [BkkIO] Scale output feature maps.

    if fused_modconv:
        x = tf.reshape(x,
                       [1, -1, x.shape[2], x.shape[3]
                        ])  # Fused => reshape minibatch to convolution groups.
        w = tf.reshape(tf.transpose(w, [1, 2, 3, 0, 4]),
                       [w.shape[1], w.shape[2], w.shape[3], -1])
    else:
        x *= tf.cast(s[:, :, np.newaxis, np.newaxis],
                     x.dtype)  # [BIhw] Not fused => scale input activations.

    # Convolution with optional up/downsampling.
    if up:
        x = upsample_conv_2d(x,
                             tf.cast(w, x.dtype),
                             data_format='NCHW',
                             k=resample_kernel)
    elif down:
        x = conv_downsample_2d(x,
                               tf.cast(w, x.dtype),
                               data_format='NCHW',
                               k=resample_kernel)
    else:
        x = tf.nn.conv2d(x,
                         tf.cast(w, x.dtype),
                         data_format='NCHW',
                         strides=[1, 1, 1, 1],
                         padding='SAME')

    if fused_modconv:
        x = tf.reshape(
            x, [-1, fmaps, x.shape[2], x.shape[3]
                ])  # Fused => reshape convolution groups back to minibatch.
    elif demodulate:
        x *= tf.cast(d[:, :, np.newaxis, np.newaxis],
                     x.dtype)  # [BOhw] Not fused => scale output activations.
    return x