Exemplo n.º 1
0
def build_res_conv_layer(x,
                         name,
                         n_layers,
                         scope_idx,
                         act,
                         resample_kernel,
                         fmaps=128,
                         **kwargs):
    # e.g. {'Conv-up': 2}, {'Conv-id': 1}
    sample_type = name.split('-')[-1]
    assert sample_type in ['up', 'down', 'id']
    x_ori = x
    for i in range(n_layers):
        with tf.variable_scope(name + '-' + str(scope_idx) + '-' + str(i)):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=fmaps,
                                            kernel=3,
                                            up=(sample_type == 'up'),
                                            down=(sample_type == 'down'),
                                            resample_kernel=resample_kernel),
                               act=act)
        if sample_type == 'up':
            with tf.variable_scope('Upsampling' + '-' + str(scope_idx) + '-' +
                                   str(i)):
                x_ori = naive_upsample_2d(x_ori)
        elif sample_type == 'down':
            with tf.variable_scope('Downsampling' + '-' + str(scope_idx) +
                                   '-' + str(i)):
                x_ori = naive_downsample_2d(x_ori)

    with tf.variable_scope(name + 'Resampled-' + str(scope_idx)):
        x_ori = apply_bias_act(conv2d_layer(x_ori, fmaps=fmaps, kernel=1),
                               act=act)
        x = x + x_ori
    return x
Exemplo n.º 2
0
 def block(x, res): # res = 2..resolution_log2
     t = x
     # with tf.variable_scope('Conv0'):
         # x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-1), kernel=3), act=act)
     with tf.variable_scope('Conv1_down'):
         x = apply_bias_act(conv2d_layer(x, fmaps=nf(res-2), kernel=3, down=True, resample_kernel=resample_kernel), act=act)
     if architecture == 'resnet':
         with tf.variable_scope('Skip'):
             t = conv2d_layer(t, fmaps=nf(res-2), kernel=1, down=True, resample_kernel=resample_kernel)
             x = (x + t) * (1 / np.sqrt(2))
     return x
 def apply_st(x, st_matrix, idx, up=True):  # idx: 2, 3, 4
     with tf.variable_scope('Transform'):
         x = tf.transpose(x, [0, 2, 3, 1])  # NCHW -> NHWC
         x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3])
         x = tf.transpose(x, [0, 3, 1, 2])  # NHWC -> NCHW
     with tf.variable_scope('Upconv'):
         x = apply_bias_act(conv2d_layer(x,
                                         fmaps=nf(idx),
                                         kernel=3,
                                         up=up,
                                         resample_kernel=resample_kernel),
                            act=act)
     with tf.variable_scope('Conv'):
         x = apply_bias_act(conv2d_layer(x, fmaps=nf(idx), kernel=3),
                            act=act)
     return x
Exemplo n.º 4
0
def build_noise_layer(x,
                      name,
                      n_layers,
                      scope_idx,
                      act,
                      use_noise,
                      randomize_noise,
                      fmaps=128,
                      **kwargs):
    for i in range(n_layers):
        with tf.variable_scope(name + '-' + str(scope_idx) + '-' + str(i)):
            x = conv2d_layer(x, fmaps=fmaps, kernel=3, up=False)
            if use_noise:
                if randomize_noise:
                    noise = tf.random_normal(
                        [tf.shape(x)[0], 1, x.shape[2], x.shape[3]],
                        dtype=x.dtype)
                else:
                    # noise = tf.get_variable(
                    # 'noise_variable-' + str(scope_idx) + '-' + str(i),
                    # shape=[1, 1, x.shape[2], x.shape[3]],
                    # initializer=tf.initializers.random_normal(),
                    # trainable=False)
                    noise_np = np.random.normal(size=(1, 1, x.shape[2],
                                                      x.shape[3]))
                    noise = tf.constant(noise_np)
                    noise = tf.cast(noise, x.dtype)
                noise_strength = tf.get_variable(
                    'noise_strength-' + str(scope_idx) + '-' + str(i),
                    shape=[],
                    initializer=tf.initializers.zeros())
                x += noise * tf.cast(noise_strength, x.dtype)
            x = apply_bias_act(x, act=act)
    return x
Exemplo n.º 5
0
 def block(x, res):  # res = 3..resolution_log2
     t = x
     with tf.variable_scope('Conv0_up'):
         x, atts_0 = layer(x,
                           layer_idx=res * 2 - 5,
                           fmaps=nf(res - 1),
                           kernel=3,
                           up=True)
     with tf.variable_scope('Conv1'):
         x, atts_1 = layer(x,
                           layer_idx=res * 2 - 4,
                           fmaps=nf(res - 1),
                           kernel=3)
     if architecture == 'resnet':
         with tf.variable_scope('Skip'):
             t = conv2d_layer(t,
                              fmaps=nf(res - 1),
                              kernel=1,
                              up=True,
                              resample_kernel=resample_kernel)
             x = (x + t) * (1 / np.sqrt(2))
     if return_atts:
         atts = tf.concat([atts_0, atts_1], axis=1)
     else:
         atts = None
     return x, atts
Exemplo n.º 6
0
def build_noise_layer(x,
                      name,
                      n_layers,
                      scope_idx,
                      act,
                      use_noise,
                      randomize_noise,
                      noise_inputs=None,
                      fmaps=128,
                      **kwargs):
    # print('in noise_inputs:', noise_inputs)
    for i in range(n_layers):
        if noise_inputs is not None:
            noise_inputs.append(
                tf.get_variable('noise%d' % len(noise_inputs),
                                shape=[1, 1] + x.get_shape().as_list()[2:],
                                initializer=tf.initializers.random_normal(),
                                trainable=False))
        with tf.variable_scope(name + '-' + str(scope_idx) + '-' + str(i)):
            x = conv2d_layer(x, fmaps=fmaps, kernel=3, up=False)
            if use_noise:
                if randomize_noise:
                    noise = tf.random_normal(
                        [tf.shape(x)[0], 1, x.shape[2], x.shape[3]],
                        dtype=x.dtype)
                else:
                    noise = tf.cast(noise_inputs[-1], x.dtype)
                noise_strength = tf.get_variable(
                    'noise_strength-' + str(scope_idx) + '-' + str(i),
                    shape=[],
                    initializer=tf.initializers.zeros())
                x += noise * tf.cast(noise_strength, x.dtype)
            x = apply_bias_act(x, act=act)
    return x
def apply_st(x,
             st_matrix,
             up=True,
             fmaps=128,
             resample_kernel=[1, 3, 3, 1],
             act='lrelu'):
    with tf.variable_scope('Transform'):
        x = tf.transpose(x, [0, 2, 3, 1])  # NCHW -> NHWC
        x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3])
        x = tf.transpose(x, [0, 3, 1, 2])  # NHWC -> NCHW
    with tf.variable_scope('ConvMayUp'):
        x = apply_bias_act(conv2d_layer(x,
                                        fmaps=fmaps,
                                        kernel=3,
                                        up=up,
                                        resample_kernel=resample_kernel),
                           act=act)
    with tf.variable_scope('Conv'):
        x = apply_bias_act(conv2d_layer(x, fmaps=fmaps, kernel=3), act=act)
    return x
Exemplo n.º 8
0
def build_trans_mask_to_feat_encoder_layer(x_mask,
                                           dlatents_in,
                                           name,
                                           n_layers,
                                           scope_idx,
                                           is_training,
                                           wh,
                                           feat_cnn_dim,
                                           construct_feat_by_concat=False,
                                           trans_dim=512,
                                           dff=512,
                                           trans_rate=0.1,
                                           **kwargs):
    '''
    Build mask_to_feat forwarding transformer to predict semantic variation masks.
    x_mask: [b, n_masks, wh * wh]
    dlatents_in: [b, n_masks]
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        b = tf.shape(x_mask)[0]
        n_masks = x_mask.get_shape().as_list()[-2]
        with tf.variable_scope('FeatEncoding'):
            x = apply_bias(dense_layer_last_dim(x_mask, trans_dim))
            feat_logits = get_return_v(
                trans_encoder_basic(x,
                                    is_training,
                                    None,
                                    n_layers,
                                    trans_dim,
                                    num_heads=8,
                                    dff=dff,
                                    rate=trans_rate), 1)  # (b, z_dim, d_model)
            # [b, n_masks, d_model]
        with tf.variable_scope('ConstructFeatMap'):
            assert trans_dim % (wh * wh) == 0
            feat_precnn_dim = trans_dim // (wh * wh)
            feat_logits = tf.reshape(feat_logits,
                                     [-1, feat_precnn_dim, wh, wh])
            feat_on_masks = conv2d_layer(
                feat_logits, fmaps=feat_cnn_dim,
                kernel=3)  # [b*n_masks, feat_cnn_dim, wh, wh]
            feat_on_masks = tf.reshape(feat_on_masks,
                                       [-1, n_masks, feat_cnn_dim, wh, wh])
            if construct_feat_by_concat:
                construct_feat = construct_feat_by_concat_masks_latent(
                    feat_on_masks, tf.reshape(x_mask, [b, n_masks, wh, wh]),
                    dlatents_in)
            else:
                construct_feat = construct_feat_by_masks_latent(
                    feat_on_masks, tf.reshape(x_mask, [b, n_masks, wh, wh]),
                    dlatents_in)
            # [b, feat_cnn_dim, h, w]
        return construct_feat
Exemplo n.º 9
0
 def get_att_map(latents, x=None):
     with tf.variable_scope('create_att_feats'):
         x_ch, x_h, x_w = x.get_shape().as_list()[1:]
         att_feats = tf.get_variable(
             'att_feats',
             shape=[1, dlatent_size, x_ch, x_h, x_w],
             initializer=tf.initializers.random_normal())
         att_feats = tf.tile(tf.cast(att_feats, dtype),
                             [tf.shape(latents)[0], 1, 1, 1, 1])
         latents = latents[:, tf.newaxis, :]
         latents = tf.tile(latents, [1, dlatent_size, 1])
         latents = tf.reshape(latents, [-1, dlatent_size])
         # att_map = apply_bias_act(modulated_conv2d_layer(att_feats, latents, fmaps=64, kernel=3,
         # demodulate=False, fused_modconv=False),
         # act=act) # shape: [b*dlatent_size, 1, 8, 8]
         if x is None:
             att_map = att_feats
             att_map = tf.reshape(att_map, [-1, x_ch, x_h, x_w])
             map_ch = x_ch
         else:
             x = tf.reshape(x, [-1, 1, x_ch, x_h, x_w])
             x = tf.tile(x, [1, dlatent_size, 1, 1, 1])
             att_map = tf.concat([x, att_feats], axis=2)
             att_map = tf.reshape(att_map, [-1, 2 * x_ch, x_h, x_w])
             map_ch = 2 * x_ch
         with tf.variable_scope('att_conv_3x3'):
             att_map = apply_bias_act(conv2d_layer(att_map,
                                                   fmaps=map_ch,
                                                   kernel=3),
                                      act=act)
         with tf.variable_scope('att_conv_1x1'):
             att_map = apply_bias_act(
                 conv2d_layer(att_map, fmaps=1, kernel=1))
         att_map = tf.reshape(att_map, [-1, dlatent_size, 1, x_h * x_w])
         att_map = tf.nn.softmax(att_map, axis=-1)
         # att_map = tf.nn.sigmoid(att_map)
         # att_map = tf.reshape(att_map, [-1, dlatent_size, 1, 8, 8])
     return att_map
Exemplo n.º 10
0
 def torgb(x, y, res):  # res = 2..resolution_log2
     with tf.variable_scope('ToRGB'):
         # t = apply_bias_act(modulated_conv2d_layer(x, latents_ready_ls[res*2-3], fmaps=num_channels, kernel=1,
         # demodulate=False, fused_modconv=fused_modconv))
         t, atts = get_return_v(
             build_C_spgroup_layers_with_latents_ready(
                 x,
                 'SP_latents',
                 latent_split_ls_for_std_gen[res * 2 - 3],
                 res * 2 - 3,
                 latents_ready_ls[res * 2 - 3],
                 return_atts=return_atts,
                 resolution=resolution,
                 n_subs=n_subs,
                 **kwargs), 2)
         t = apply_bias_act(conv2d_layer(t, fmaps=num_channels, kernel=1))
         return t if y is None else y + t, atts
 def noised_conv_layer(x, layer_idx, fmaps, kernel, up=False):
     x = conv2d_layer(x,
                      fmaps=fmaps,
                      up=up,
                      kernel=kernel,
                      resample_kernel=resample_kernel)
     if use_noise:
         if randomize_noise:
             noise = tf.random_normal(
                 [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
         else:
             noise = tf.cast(noise_inputs[layer_idx], x.dtype)
         noise_strength = tf.get_variable(
             'noise_strength',
             shape=[],
             initializer=tf.initializers.zeros())
         x += noise * tf.cast(noise_strength, x.dtype)
     return apply_bias_act(x, act=act)
Exemplo n.º 12
0
 def block(x, res):  # res = 3..resolution_log2
     t = x
     with tf.variable_scope('Conv0_up'):
         x = layer(x,
                   layer_idx=res * 2 - 5,
                   fmaps=nf(res - 1),
                   kernel=3,
                   up=True)
     with tf.variable_scope('Conv1'):
         x = layer(x, layer_idx=res * 2 - 4, fmaps=nf(res - 1), kernel=3)
     if architecture == 'resnet':
         with tf.variable_scope('Skip'):
             t = conv2d_layer(t,
                              fmaps=nf(res - 1),
                              kernel=1,
                              up=True,
                              resample_kernel=resample_kernel)
             x = (x + t) * (1 / np.sqrt(2))
     return x
Exemplo n.º 13
0
def build_C_spgroup_softmax_layers(x,
                                   name,
                                   n_latents,
                                   start_idx,
                                   scope_idx,
                                   dlatents_in,
                                   act,
                                   fused_modconv,
                                   fmaps=128,
                                   return_atts=False,
                                   resolution=128,
                                   **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention with pure softmax.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_wh = x.shape[2]
            atts = conv2d_layer(x, fmaps=n_latents, kernel=3)
            atts = tf.reshape(
                atts, [-1, n_latents, x_wh * x_wh])  # [b, n_latents, m]
            atts = tf.nn.softmax(atts, axis=-1)
            atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh])

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x
Exemplo n.º 14
0
def build_conv_layer(x,
                     name,
                     n_layers,
                     scope_idx,
                     act,
                     resample_kernel,
                     fmaps=128,
                     **kwargs):
    # e.g. {'Conv-up': 2}, {'Conv-id': 1}
    sample_type = name.split('-')[-1]
    assert sample_type in ['up', 'down', 'id']
    for i in range(n_layers):
        with tf.variable_scope(name + '-' + str(scope_idx) + '-' + str(i)):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=fmaps,
                                            kernel=3,
                                            up=(sample_type == 'up'),
                                            down=(sample_type == 'down'),
                                            resample_kernel=resample_kernel),
                               act=act)
    return x
Exemplo n.º 15
0
 def layer(x, layer_idx, fmaps, kernel, up=False):
     x, atts = get_return_v(
         build_C_spgroup_layers_with_latents_ready(
             x,
             'SP_latents',
             latent_split_ls_for_std_gen[layer_idx],
             layer_idx,
             latents_ready_ls[layer_idx],
             return_atts=return_atts,
             resolution=resolution,
             n_subs=n_subs,
             **kwargs), 2)
     x = conv2d_layer(x, fmaps=fmaps, kernel=kernel, up=up)
     if randomize_noise:
         noise = tf.random_normal(
             [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
     else:
         noise = tf.cast(noise_inputs[layer_idx], x.dtype)
     noise_strength = tf.get_variable('noise_strength',
                                      shape=[],
                                      initializer=tf.initializers.zeros())
     x += noise * tf.cast(noise_strength, x.dtype)
     return apply_bias_act(x, act=act), atts
Exemplo n.º 16
0
 def fromrgb(x, y, res):  # res = 2..resolution_log2
     with tf.variable_scope('FromRGB'):
         t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                            act=act)
         return t if x is None else x + t
Exemplo n.º 17
0
def D_info_gan_stylegan2(
    images_in,  # First input: Images [minibatch, channel, height, width].
    labels_in,  # Second input: Labels [minibatch, label_size].
    num_channels=3,  # Number of input color channels. Overridden based on dataset.
    resolution=1024,  # Input resolution. Overridden based on dataset.
    label_size=0,  # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    images_in.set_shape([None, num_channels, resolution, resolution])
    labels_in.set_shape([None, label_size])
    images_in = tf.cast(images_in, dtype)
    labels_in = tf.cast(labels_in, dtype)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        if mbstd_group_size > 1:
            with tf.variable_scope('MinibatchStddev'):
                x = minibatch_stddev_layer(x, mbstd_group_size,
                                           mbstd_num_features)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            hidden = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
    with tf.variable_scope('Output'):
        x = apply_bias_act(
            dense_layer(hidden, fmaps=max(labels_in.shape[1], 1)))
        if labels_in.shape[1] > 0:
            x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True)
    scores_out = x

    # Output.
    assert scores_out.dtype == tf.as_dtype(dtype)
    scores_out = tf.identity(scores_out, name='scores_out')
    hidden = tf.identity(hidden, name='hidden')
    return scores_out, hidden
Exemplo n.º 18
0
def vc_head(
        fake1,  # First input: generated image from z [minibatch, channel, height, width].
        fake2,  # Second input: hidden features from z + delta(z) [minibatch, channel, height, width].
        num_channels=3,  # Number of input color channels. Overridden based on dataset.
        resolution=1024,  # Input resolution. Overridden based on dataset.
        dlatent_size=10,
        D_global_size=0,
        fmap_base=16 <<
        10,  # Overall multiplier for the number of feature maps.
        fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
        fmap_min=1,  # Minimum number of feature maps in any layer.
        fmap_max=512,  # Maximum number of feature maps in any layer.
        architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
        nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
        mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
        mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
        dtype='float32',  # Data type to use for activations and outputs.
        resample_kernel=[
            1, 3, 3, 1
        ],  # Low-pass filter to apply when resampling activations. None = no filtering.
        connect_mode='concat',  # How fake1 and fake2 connected.
        **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    fake1.set_shape([None, num_channels, resolution, resolution])
    fake2.set_shape([None, num_channels, resolution, resolution])
    fake1 = tf.cast(fake1, dtype)
    fake2 = tf.cast(fake2, dtype)
    if connect_mode == 'diff':
        images_in = fake1 - fake2
    elif connect_mode == 'concat':
        images_in = tf.concat([fake1, fake2], axis=1)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        if mbstd_group_size > 1:
            with tf.variable_scope('MinibatchStddev'):
                x = minibatch_stddev_layer(x, mbstd_group_size,
                                           mbstd_num_features)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
    with tf.variable_scope('Output'):
        with tf.variable_scope('Dense_VC'):
            x = apply_bias_act(
                dense_layer(x,
                            fmaps=(D_global_size +
                                   (dlatent_size - D_global_size))))

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    return x
def G_synthesis_sb_general_dsp(
    dlatents_withl_in,  # Input: Disentangled latents (W) [minibatch, label_size+dlatent_size].
    dlatent_size=7,  # Disentangled latent (W) dimensionality. Including discrete info, rotation, scaling, xy shearing, and xy translation.
    label_size=0,  # Label dimensionality, 0 if no labels.
    D_global_size=3,  # Global D_latents.
    C_global_size=0,  # Global C_latents.
    sb_C_global_size=4,  # Global spatial-biased C_latents.
    C_local_hfeat_size=0,  # Local heatmap*features learned C_latents.
    C_local_heat_size=0,  # Local heatmap learned C_latents.
    num_channels=1,  # Number of output color channels.
    resolution=64,  # Output resolution.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    fused_modconv=True,  # Implement modulated_conv2d_layer() as a single fused op?
    use_noise=False,
    randomize_noise=True,  # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
    **_kwargs):  # Ignore unrecognized keyword args.
    '''
    dlatents_withl_in: dims contain: [label, D_global, C_global, sb_C_global,
                                C_local_hfeat, C_local_feat]
    '''
    resolution_log2 = int(np.log2(resolution))  # == 6 for resolution 64
    assert resolution == 2**resolution_log2 and resolution >= 4
    num_layers = resolution_log2 * 2 - 2  # == 10 for resolution 64

    act = nonlinearity
    images_out = None

    # Primary inputs.
    assert dlatent_size == D_global_size + C_global_size + sb_C_global_size + \
        C_local_hfeat_size + C_local_heat_size
    n_cat = label_size + D_global_size
    dlatents_withl_in.set_shape([None, label_size + dlatent_size])
    dlatents_withl_in = tf.cast(dlatents_withl_in, dtype)
    n_content = label_size + D_global_size + C_global_size

    # Noise inputs.
    noise_inputs = []
    for layer_idx in range(num_layers - 3):
        res = (layer_idx + 7) // 2  # [3, 4, 4, 5, 5, 6, 6]
        shape = [1, 1, 2**res, 2**res]
        noise_inputs.append(
            tf.get_variable('noise%d' % layer_idx,
                            shape=shape,
                            initializer=tf.initializers.random_normal(),
                            trainable=False))

    # Single convolution layer with all the bells and whistles.
    def noised_conv_layer(x, layer_idx, fmaps, kernel, up=False):
        x = conv2d_layer(x,
                         fmaps=fmaps,
                         up=up,
                         kernel=kernel,
                         resample_kernel=resample_kernel)
        if use_noise:
            if randomize_noise:
                noise = tf.random_normal(
                    [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
            else:
                noise = tf.cast(noise_inputs[layer_idx], x.dtype)
            noise_strength = tf.get_variable(
                'noise_strength',
                shape=[],
                initializer=tf.initializers.zeros())
            x += noise * tf.cast(noise_strength, x.dtype)
        return apply_bias_act(x, act=act)

    # Early layers consists of 4x4 constant layer,
    # label+global discrete latents,
    # and global continuous latents.
    y = None
    with tf.variable_scope('4x4'):
        with tf.variable_scope('Const'):
            x = tf.get_variable('const',
                                shape=[1, 128, 4, 4],
                                initializer=tf.initializers.random_normal())
            x = tf.tile(tf.cast(x, dtype),
                        [tf.shape(dlatents_withl_in)[0], 1, 1, 1])
        with tf.variable_scope('Upconv'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=128,
                                            kernel=3,
                                            up=True,
                                            resample_kernel=resample_kernel),
                               act=act)

    with tf.variable_scope('8x8'):
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=128, kernel=3), act=act)
        with tf.variable_scope('Label_Dglobal_control'):
            x = apply_bias_act(modulated_conv2d_layer(
                x,
                dlatents_withl_in[:, :n_cat],
                fmaps=128,
                kernel=3,
                up=False,
                resample_kernel=resample_kernel,
                fused_modconv=fused_modconv),
                               act=act)
        with tf.variable_scope('After_DiscreteGlobal_noised'):
            x = noised_conv_layer(x, layer_idx=0, fmaps=128, kernel=3)
        with tf.variable_scope('Cglobal_control'):
            start_idx = n_cat
            x = apply_bias_act(modulated_conv2d_layer(
                x,
                dlatents_withl_in[:, start_idx:start_idx + C_global_size],
                fmaps=128,
                kernel=3,
                up=False,
                resample_kernel=resample_kernel,
                fused_modconv=fused_modconv),
                               act=act)
        with tf.variable_scope('After_ContinuousGlobal_noised'):
            x = noised_conv_layer(x, layer_idx=1, up=True, fmaps=128, kernel=3)

    # Spatial biased layers.
    with tf.variable_scope('16x16'):
        if C_local_hfeat_size > 0:
            with tf.variable_scope('LocalHFeat_C_latents'):
                with tf.variable_scope('ConstFeats'):
                    const_feats = tf.get_variable(
                        'constfeats',
                        shape=[1, C_local_hfeat_size, 32, 1, 1],
                        initializer=tf.initializers.random_normal())
                    const_feats = tf.tile(
                        tf.cast(const_feats,
                                dtype), [tf.shape(const_feats)[0], 1, 1, 1, 1])
                with tf.variable_scope('ControlAttHeat'):
                    hfeat_start_idx = label_size + D_global_size + C_global_size + \
                        sb_C_global_size
                    att_heat = get_att_heat(x,
                                            nheat=C_local_hfeat_size,
                                            act=act)
                    att_heat = tf.reshape(
                        att_heat,
                        [tf.shape(att_heat)[0], C_local_hfeat_size, 1] +
                        att_heat.shape.as_list()[2:4])
                    # C_local_heat latent [-2, 2] --> [0, 1]
                    hfeat_modifier = (2 + dlatents_withl_in[:, hfeat_start_idx:hfeat_start_idx + \
                                                     C_local_hfeat_size]) / 4.
                    hfeat_modifier = get_conditional_modifier(
                        hfeat_modifier,
                        dlatents_withl_in[:, :n_content],
                        act=act)
                    hfeat_modifier = tf.reshape(
                        hfeat_modifier,
                        [tf.shape(x)[0], C_local_hfeat_size, 1, 1, 1])
                    att_heat = att_heat * hfeat_modifier
                    added_feats = const_feats * att_heat
                    added_feats = tf.reshape(added_feats, [
                        tf.shape(att_heat)[0],
                        C_local_hfeat_size * att_heat.shape.as_list()[2]
                    ] + att_heat.shape.as_list()[3:5])
                    x = tf.concat([x, added_feats], axis=1)

        with tf.variable_scope('SpatialBiased_C_global'):
            # Rotation layers.
            start_idx = start_idx + C_global_size
            with tf.variable_scope('Rotation'):
                r_matrix = get_r_matrix(
                    dlatents_withl_in[:, start_idx:start_idx + 1],
                    dlatents_withl_in[:, :n_content],
                    act=act)
                x = apply_st(x, r_matrix, up=False, fmaps=128, act=act)
            with tf.variable_scope('After_Rotation_noised'):
                x = noised_conv_layer(x, layer_idx=2, fmaps=128, kernel=3)
            # Scaling layers.
            start_idx = start_idx + 1
            with tf.variable_scope('Scaling'):
                s_matrix = get_s_matrix(
                    dlatents_withl_in[:, start_idx:start_idx + 1],
                    dlatents_withl_in[:, :n_content],
                    act=act)
                x = apply_st(x, s_matrix, up=False, fmaps=128, act=act)
            with tf.variable_scope('After_Scaling_noised'):
                x = noised_conv_layer(x,
                                      layer_idx=3,
                                      up=True,
                                      fmaps=128,
                                      kernel=3)

    with tf.variable_scope('32x32'):
        with tf.variable_scope('SpatialBiased_C_global'):
            # Shearing layers.
            with tf.variable_scope('Shearing'):
                start_idx = start_idx + 1
                sh_matrix = get_sh_matrix(
                    dlatents_withl_in[:, start_idx:start_idx + 2],
                    dlatents_withl_in[:, :n_content],
                    act=act)
                x = apply_st(x, sh_matrix, up=False, fmaps=128, act=act)
                with tf.variable_scope('After_Shearing_noised'):
                    x = noised_conv_layer(x, layer_idx=4, fmaps=128, kernel=3)
            # Translation layers.
            with tf.variable_scope('Translation'):
                start_idx = start_idx + 2
                t_matrix = get_t_matrix(
                    dlatents_withl_in[:, start_idx:start_idx + 2],
                    dlatents_withl_in[:, :n_content],
                    act=act)
                x = apply_st(x, t_matrix, up=False, fmaps=128, act=act)
                with tf.variable_scope('After_Translation_noised'):
                    if resolution_log2 >= 6:
                        x = noised_conv_layer(x,
                                              layer_idx=5,
                                              up=True,
                                              fmaps=128,
                                              kernel=3)
                    else:
                        x = noised_conv_layer(x,
                                              layer_idx=5,
                                              fmaps=128,
                                              kernel=3)

    with tf.variable_scope('64x64' if resolution_log2 >= 6 else '32x32'):
        with tf.variable_scope('LocalHeat_C_latents'):
            with tf.variable_scope('ControlAttHeat'):
                heat_start_idx = label_size + D_global_size + C_global_size + \
                    sb_C_global_size + C_local_hfeat_size
                att_heat = get_att_heat(x, nheat=C_local_heat_size, act=act)
                # C_local_heat latent [-2, 2] --> [0, 1]
                heat_modifier = (2 + dlatents_withl_in[:, heat_start_idx:heat_start_idx + \
                                                 C_local_heat_size]) / 4.
                heat_modifier = get_conditional_modifier(
                    heat_modifier, dlatents_withl_in[:, :n_content], act=act)
                heat_modifier = tf.reshape(
                    heat_modifier,
                    [tf.shape(heat_modifier)[0], C_local_heat_size, 1, 1])
                att_heat = att_heat * heat_modifier
                x = tf.concat([x, att_heat], axis=1)
            with tf.variable_scope('After_LocalHeat_noised'):
                x = noised_conv_layer(x, layer_idx=6, fmaps=128, kernel=3)
    y = torgb(x, y, num_channels=num_channels)

    # # Tail layers.
    # for res in range(6, resolution_log2 + 1):
    # with tf.variable_scope('%dx%d' % (res * 2, res * 2)):
    # x = apply_bias_act(conv2d_layer(x,
    # fmaps=128,
    # kernel=1,
    # up=True,
    # resample_kernel=resample_kernel),
    # act=act)
    # y = torgb(x, y, num_channels=num_channels)
    images_out = y
    assert images_out.dtype == tf.as_dtype(dtype)
    return tf.identity(images_out, name='images_out')
def get_att_heat(x, nheat, act):
    with tf.variable_scope('Conv'):
        x = apply_bias_act(conv2d_layer(x, fmaps=128, kernel=3), act=act)
    with tf.variable_scope('ConvAtt'):
        x = apply_bias_act(conv2d_layer(x, fmaps=1, kernel=3), act='sigmoid')
    return x
def torgb(x, y, num_channels):
    with tf.variable_scope('ToRGB'):
        t = apply_bias_act(conv2d_layer(x, fmaps=num_channels, kernel=1))
        return t if y is None else y + t
Exemplo n.º 22
0
def vpex_net(
        fake1,  # First input: generated image from z [minibatch, channel, height, width].
        fake2,  # Second input: hidden features from z + delta(z) [minibatch, channel, height, width].
        latents,  # Ground-truth latent code for fake1.
        num_channels=3,  # Number of input color channels. Overridden based on dataset.
        resolution=1024,  # Input resolution. Overridden based on dataset.
        dlatent_size=10,
        D_global_size=0,
        fmap_base=16 <<
    10,  # Overall multiplier for the number of feature maps.
        fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
        fmap_min=1,  # Minimum number of feature maps in any layer.
        fmap_max=512,  # Maximum number of feature maps in any layer.
        architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
        nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
        mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
        mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
        dtype='float32',  # Data type to use for activations and outputs.
        resample_kernel=[
            1, 3, 3, 1
        ],  # Low-pass filter to apply when resampling activations. None = no filtering.
        connect_mode='concat',  # How fake1 and fake2 connected.
        return_atts=False,  # If return I_atts.
        **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    fake1.set_shape([None, num_channels, resolution, resolution])
    fake2.set_shape([None, num_channels, resolution, resolution])
    latents.set_shape([None, dlatent_size])
    fake1 = tf.cast(fake1, dtype)
    fake2 = tf.cast(fake2, dtype)
    latents = tf.cast(latents, dtype)
    if connect_mode == 'diff':
        images_in = fake1 - fake2
    elif connect_mode == 'concat':
        images_in = tf.concat([fake1, fake2], axis=1)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # attention features for each latent dimension.
    def get_att_map(latents, x=None):
        with tf.variable_scope('create_att_feats'):
            x_ch, x_h, x_w = x.get_shape().as_list()[1:]
            att_feats = tf.get_variable(
                'att_feats',
                shape=[1, dlatent_size, x_ch, x_h, x_w],
                initializer=tf.initializers.random_normal())
            att_feats = tf.tile(tf.cast(att_feats, dtype),
                                [tf.shape(latents)[0], 1, 1, 1, 1])
            latents = latents[:, tf.newaxis, :]
            latents = tf.tile(latents, [1, dlatent_size, 1])
            latents = tf.reshape(latents, [-1, dlatent_size])
            # att_map = apply_bias_act(modulated_conv2d_layer(att_feats, latents, fmaps=64, kernel=3,
            # demodulate=False, fused_modconv=False),
            # act=act) # shape: [b*dlatent_size, 1, 8, 8]
            if x is None:
                att_map = att_feats
                att_map = tf.reshape(att_map, [-1, x_ch, x_h, x_w])
                map_ch = x_ch
            else:
                x = tf.reshape(x, [-1, 1, x_ch, x_h, x_w])
                x = tf.tile(x, [1, dlatent_size, 1, 1, 1])
                att_map = tf.concat([x, att_feats], axis=2)
                att_map = tf.reshape(att_map, [-1, 2 * x_ch, x_h, x_w])
                map_ch = 2 * x_ch
            with tf.variable_scope('att_conv_3x3'):
                att_map = apply_bias_act(conv2d_layer(att_map,
                                                      fmaps=map_ch,
                                                      kernel=3),
                                         act=act)
            with tf.variable_scope('att_conv_1x1'):
                att_map = apply_bias_act(
                    conv2d_layer(att_map, fmaps=1, kernel=1))
            att_map = tf.reshape(att_map, [-1, dlatent_size, 1, x_h * x_w])
            att_map = tf.nn.softmax(att_map, axis=-1)
            # att_map = tf.nn.sigmoid(att_map)
            # att_map = tf.reshape(att_map, [-1, dlatent_size, 1, 8, 8])
        return att_map

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 3, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Duplicate for each att.
    with tf.variable_scope('apply_att'):
        att_map = get_att_map(latents, x)
        x_ch, x_h, x_w = x.get_shape().as_list()[1:]
        assert x_h == 8
        x_ori = tf.reshape(x, [-1, 1, x_ch, x_h * x_w])  # [b, 1, ch, h*w]
        x = tf.reshape(x, [-1, 1, x_ch, x_h * x_w])
        x = att_map * x
        x = tf.reduce_sum(x, axis=-1)  # [b, dlatent, ch]
        x = tf.reshape(x, [-1, x_ch, 1, 1])  # [b * dlatent, ch, 1, 1]
        with tf.variable_scope('after_att_conv_1x1'):
            x = apply_bias_act(conv2d_layer(x, fmaps=x_ch, kernel=1))
        x = tf.reshape(x, [-1, dlatent_size, x_ch, 1])  # [b, dlatent, ch, 1]

        x = tf.tile(x, [1, 1, 1, x_h * x_w])
        # x = x + x_ori # [b, dlatent, ch, h * w]
        x = tf.reshape(x, [-1, x_ch, x_h, x_w])
        y_ch, y_h, y_w = y.get_shape().as_list()[1:]
        y = y[:, tf.newaxis, ...]
        y = tf.tile(y, [1, dlatent_size, 1, 1, 1])
        y = tf.reshape(y, [-1, y_ch, y_h, y_w])

    for res in range(3, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    with tf.variable_scope('Output'):
        with tf.variable_scope('Dense_VC'):
            x = apply_bias_act(dense_layer(x, fmaps=1))

    with tf.variable_scope('Final_reshape_x'):
        x = tf.reshape(x, [-1, dlatent_size])

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    if return_atts:
        with tf.variable_scope('Reshape_atts'):
            att_map = tf.reshape(att_map, [-1, 8, 8, 1])
            att_map = tf.image.resize(att_map, size=(resolution, resolution))
            att_map = tf.reshape(att_map,
                                 [-1, dlatent_size, 1, resolution, resolution])
        return x, att_map
    else:
        return x
def G_synthesis_spatial_biased_dsp(
        dlatents_in,  # Input: Disentangled latents (W) [minibatch, dlatent_size].
        dlatent_size=7,  # Disentangled latent (W) dimensionality. Including discrete info, rotation, scaling, and xy translation.
        D_global_size=3,  # Discrete latents.
        sb_C_global_size=4,  # Continuous latents.
        label_size=0,  # Label dimensionality, 0 if no labels.
        num_channels=1,  # Number of output color channels.
        resolution=64,  # Output resolution.
        fmap_base=16 <<
        10,  # Overall multiplier for the number of feature maps.
        fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
        fmap_min=1,  # Minimum number of feature maps in any layer.
        fmap_max=512,  # Maximum number of feature maps in any layer.
        architecture='skip',  # Architecture: 'orig', 'skip', 'resnet'.
        nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
        dtype='float32',  # Data type to use for activations and outputs.
        resample_kernel=[
            1, 3, 3, 1
        ],  # Low-pass filter to apply when resampling activations. None = no filtering.
        fused_modconv=True,  # Implement modulated_conv2d_layer() as a single fused op?
        **_kwargs):  # Ignore unrecognized keyword args.
    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity
    images_out = None

    # Primary inputs.
    assert dlatent_size == D_global_size + sb_C_global_size
    n_cat = label_size + D_global_size
    dlatents_in.set_shape([None, label_size + dlatent_size])
    dlatents_in = tf.cast(dlatents_in, dtype)

    # Return rotation matrix
    def get_r_matrix(r_latents, cond_latent):
        # r_latents: [-2., 2.] -> [0, 2*pi]
        with tf.variable_scope('Condition0'):
            cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
        with tf.variable_scope('Condition1'):
            cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid')
        rad = (r_latents + 2) / 4. * 2. * np.pi
        rad = rad * cond
        tt_00 = tf.math.cos(rad)
        tt_01 = -tf.math.sin(rad)
        tt_02 = tf.zeros_like(rad)
        tt_10 = tf.math.sin(rad)
        tt_11 = tf.math.cos(rad)
        tt_12 = tf.zeros_like(rad)
        theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
        return theta

    # Return scaling matrix
    def get_s_matrix(s_latents, cond_latent):
        # s_latents: [-2., 2.] -> [1, 3]
        with tf.variable_scope('Condition0'):
            cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
        with tf.variable_scope('Condition1'):
            cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid')
        scale = (s_latents / 2. + 2.) * cond
        tt_00 = scale
        tt_01 = tf.zeros_like(scale)
        tt_02 = tf.zeros_like(scale)
        tt_10 = tf.zeros_like(scale)
        tt_11 = scale
        tt_12 = tf.zeros_like(scale)
        theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
        return theta

    # Return shear matrix
    def get_sh_matrix(sh_latents, cond_latent):
        # sh_latents[:, 0]: [-2., 2.] -> [-1., 1.]
        # sh_latents[:, 1]: [-2., 2.] -> [-1., 1.]
        with tf.variable_scope('Condition0x'):
            cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                    act=act)
        with tf.variable_scope('Condition1x'):
            cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1),
                                    act='sigmoid')
        with tf.variable_scope('Condition0y'):
            cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                    act=act)
        with tf.variable_scope('Condition1y'):
            cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1),
                                    act='sigmoid')
        cond = tf.concat([cond_x, cond_y], axis=1)
        xy_shear = sh_latents / 2. * cond
        tt_00 = tf.ones_like(xy_shear[:, 0:1])
        tt_01 = xy_shear[:, 0:1]
        tt_02 = tf.zeros_like(xy_shear[:, 0:1])
        tt_10 = xy_shear[:, 1:]
        tt_11 = tf.ones_like(xy_shear[:, 1:])
        tt_12 = tf.zeros_like(xy_shear[:, 1:])
        theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
        return theta

    # Return translation matrix
    def get_t_matrix(t_latents, cond_latent):
        # t_latents[:, 0]: [-2., 2.] -> [-0.5, 0.5]
        # t_latents[:, 1]: [-2., 2.] -> [-0.5, 0.5]
        with tf.variable_scope('Condition0x'):
            cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                    act=act)
        with tf.variable_scope('Condition1x'):
            cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1),
                                    act='sigmoid')
        with tf.variable_scope('Condition0y'):
            cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                    act=act)
        with tf.variable_scope('Condition1y'):
            cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1),
                                    act='sigmoid')
        cond = tf.concat([cond_x, cond_y], axis=1)
        xy_shift = t_latents / 4. * cond
        tt_00 = tf.ones_like(xy_shift[:, 0:1])
        tt_01 = tf.zeros_like(xy_shift[:, 0:1])
        tt_02 = xy_shift[:, 0:1]
        tt_10 = tf.zeros_like(xy_shift[:, 1:])
        tt_11 = tf.ones_like(xy_shift[:, 1:])
        tt_12 = xy_shift[:, 1:]
        theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
        return theta

    # Apply spatial transform
    def apply_st(x, st_matrix, idx, up=True):  # idx: 2, 3, 4
        with tf.variable_scope('Transform'):
            x = tf.transpose(x, [0, 2, 3, 1])  # NCHW -> NHWC
            x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3])
            x = tf.transpose(x, [0, 3, 1, 2])  # NHWC -> NCHW
        with tf.variable_scope('Upconv'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(idx),
                                            kernel=3,
                                            up=up,
                                            resample_kernel=resample_kernel),
                               act=act)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(idx), kernel=3),
                               act=act)
        return x

    def upsample(y):
        with tf.variable_scope('Upsample'):
            return upsample_2d(y, k=resample_kernel)

    def torgb(x, y):
        with tf.variable_scope('ToRGB'):
            t = apply_bias_act(conv2d_layer(x, fmaps=num_channels, kernel=1))
            return t if y is None else y + t

    # Early layers.
    y = None
    with tf.variable_scope('4x4'):
        with tf.variable_scope('Const'):
            x = tf.get_variable('const',
                                shape=[1, nf(1), 4, 4],
                                initializer=tf.initializers.random_normal())
            x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
        with tf.variable_scope('Upconv8x8'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(1),
                                            kernel=3,
                                            up=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('ModulatedConv'):
            x = apply_bias_act(modulated_conv2d_layer(
                x,
                dlatents_in[:, :n_cat],
                fmaps=nf(2),
                kernel=3,
                up=False,
                resample_kernel=resample_kernel,
                fused_modconv=fused_modconv),
                               act=act)

        with tf.variable_scope('Conv1'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(2), kernel=3), act=act)

    # Rotation layers.
    with tf.variable_scope('16x16'):
        r_matrix = get_r_matrix(dlatents_in[:, n_cat:n_cat + 1],
                                dlatents_in[:, :n_cat])
        x = apply_st(x, r_matrix, 2)

    # Scaling layers.
    with tf.variable_scope('32x32'):
        s_matrix = get_s_matrix(dlatents_in[:, n_cat + 1:n_cat + 2],
                                dlatents_in[:, :n_cat])
        x = apply_st(x, s_matrix, 3)

    # Shearing layers.
    with tf.variable_scope('32x32_Shear'):
        sh_matrix = get_sh_matrix(dlatents_in[:, n_cat + 2:n_cat + 4],
                                  dlatents_in[:, :n_cat])
        x = apply_st(x, sh_matrix, 3, up=False)

    # Translation layers.
    with tf.variable_scope('64x64'):
        t_matrix = get_t_matrix(dlatents_in[:, n_cat + 4:],
                                dlatents_in[:, :n_cat])
        x = apply_st(x, t_matrix, 4)
    y = torgb(x, y)

    # # Tail layers.
    # for res in range(6, resolution_log2 + 1):
    # with tf.variable_scope('%dx%d' % (res * 2, res * 2)):
    # x = apply_bias_act(conv2d_layer(x,
    # fmaps=nf(res),
    # kernel=1,
    # up=True,
    # resample_kernel=resample_kernel),
    # act=act)
    # if architecture == 'skip':
    # y = upsample(y)
    # if architecture == 'skip' or res == resolution_log2:
    # y = torgb(x, y)
    images_out = y
    assert images_out.dtype == tf.as_dtype(dtype)
    return tf.identity(images_out, name='images_out')