def get_t_matrix(t_latents, cond_latent):
     # t_latents[:, 0]: [-2., 2.] -> [-0.5, 0.5]
     # t_latents[:, 1]: [-2., 2.] -> [-0.5, 0.5]
     with tf.variable_scope('Condition0x'):
         cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                 act=act)
     with tf.variable_scope('Condition1x'):
         cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1),
                                 act='sigmoid')
     with tf.variable_scope('Condition0y'):
         cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                 act=act)
     with tf.variable_scope('Condition1y'):
         cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1),
                                 act='sigmoid')
     cond = tf.concat([cond_x, cond_y], axis=1)
     xy_shift = t_latents / 4. * cond
     tt_00 = tf.ones_like(xy_shift[:, 0:1])
     tt_01 = tf.zeros_like(xy_shift[:, 0:1])
     tt_02 = xy_shift[:, 0:1]
     tt_10 = tf.zeros_like(xy_shift[:, 1:])
     tt_11 = tf.ones_like(xy_shift[:, 1:])
     tt_12 = xy_shift[:, 1:]
     theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
     return theta
Esempio n. 2
0
def info_gan_head(
    hidden,  # First input: hidden features [minibatch, n_feat].
    dlatent_size=10,
    D_global_size=0,
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    dtype='float32',  # Data type to use for activations and outputs.
    **_kwargs):  # Ignore unrecognized keyword args.
    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    act = nonlinearity
    hidden.set_shape([None, nf(0)])
    hidden = tf.cast(hidden, dtype)
    with tf.variable_scope('InfoGanHead'):
        with tf.variable_scope('Dense_Hidden'):
            x = apply_bias_act(dense_layer(hidden, fmaps=512), act=act)
        with tf.variable_scope('Dense_InfoGan'):
            x = apply_bias_act(
                dense_layer(x,
                            fmaps=(D_global_size + 2 *
                                   (dlatent_size - D_global_size))))
    return x
Esempio n. 3
0
def build_C_global_layers(x,
                          name,
                          n_latents,
                          start_idx,
                          scope_idx,
                          dlatents_withl_in,
                          n_content,
                          act,
                          fused_modconv,
                          fmaps=128,
                          **kwargs):
    '''
    Build continuous latent layers, e.g. C_global layers.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        if n_content > 0:
            with tf.variable_scope('Condition0'):
                cond = apply_bias_act(dense_layer(
                    dlatents_withl_in[:, :n_content], fmaps=128),
                                      act=act)
            with tf.variable_scope('Condition1'):
                cond = apply_bias_act(dense_layer(cond, fmaps=n_latents),
                                      act='sigmoid')
        else:
            cond = 1.
        C_global_latents = dlatents_withl_in[:, start_idx:start_idx +
                                             n_latents] * cond
        x = apply_bias_act(modulated_conv2d_layer(x,
                                                  C_global_latents,
                                                  fmaps=fmaps,
                                                  kernel=3,
                                                  up=False,
                                                  fused_modconv=fused_modconv),
                           act=act)
    return x
def get_conditional_modifier(modifier, cond_latent, act='lrelu'):
    with tf.variable_scope('Condition0'):
        cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
    with tf.variable_scope('Condition1'):
        cond = apply_bias_act(dense_layer(cond,
                                          fmaps=modifier.shape.as_list()[1]),
                              act='sigmoid')
    modifier = modifier * cond
    return modifier
Esempio n. 5
0
def build_Cout_genatts_spgroup_layers(x,
                                      name,
                                      n_latents,
                                      scope_idx,
                                      act,
                                      fmaps=128,
                                      resolution=128,
                                      **kwargs):
    '''
    Build continuous latent out layers with generating group spatial attention.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial_gen'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh)
            atts_wh = tf.reshape(
                atts_wh, [-1, n_latents, 4, x_wh])  # [b, n_latents, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=2)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, 1, 1, x_wh]
            atts = att_h * att_w  # [b, n_latents, 1, x_wh, x_wh]

        with tf.variable_scope('Latent_pred'):
            x_out_ls = []
            for i in range(n_latents):
                x_tmp = x * atts[:, i]
                x_tmp_2 = tf.reduce_mean(x_tmp, axis=[2, 3])  # [b, in_dim]
                with tf.variable_scope('OutDense-' + str(i)):
                    with tf.variable_scope('Conv0'):
                        x_tmp_2 = apply_bias_act(dense_layer(x_tmp_2,
                                                             fmaps=fmaps),
                                                 act=act)  # [b, fmaps]
                    with tf.variable_scope('Conv1'):
                        x_out_tmp = dense_layer(x_tmp_2, fmaps=1)  # [b, 1]
                        x_out_ls.append(x_out_tmp)
            pred_out = tf.concat(x_out_ls, axis=1)  # [b, n_latents]

        with tf.variable_scope('Reshape_output'):
            atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
            atts = tf.image.resize(atts, size=(resolution, resolution))
            atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution])
        return x, pred_out, atts
def get_s_matrix(s_latents, cond_latent, act='lrelu'):
    # s_latents: [-2., 2.] -> [1, 3]
    with tf.variable_scope('Condition0'):
        cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
    with tf.variable_scope('Condition1'):
        cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid')
    scale = (s_latents + 2.) * cond + 1.
    tt_00 = scale
    tt_01 = tf.zeros_like(scale)
    tt_02 = tf.zeros_like(scale)
    tt_10 = tf.zeros_like(scale)
    tt_11 = scale
    tt_12 = tf.zeros_like(scale)
    theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
    return theta
def get_r_matrix(r_latents, cond_latent, act='lrelu'):
    # r_latents: [-2., 2.] -> [0, 2*pi]
    with tf.variable_scope('Condition0'):
        cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
    with tf.variable_scope('Condition1'):
        cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid')
    rad = (r_latents + 2) / 4. * 2. * np.pi
    rad = rad * cond
    tt_00 = tf.math.cos(rad)
    tt_01 = -tf.math.sin(rad)
    tt_02 = tf.zeros_like(rad)
    tt_10 = tf.math.sin(rad)
    tt_11 = tf.math.cos(rad)
    tt_12 = tf.zeros_like(rad)
    theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
    return theta
Esempio n. 8
0
def build_C_global_nocond_layers(x,
                                 name,
                                 n_latents,
                                 start_idx,
                                 scope_idx,
                                 dlatents_withl_in,
                                 act,
                                 fused_modconv,
                                 fmaps=128,
                                 **kwargs):
    '''
    Build continuous latent layers, e.g. C_global layers.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Conv0'):
            C_global_latents = apply_bias_act(dense_layer(
                dlatents_withl_in[:, start_idx:start_idx + n_latents], fmaps=128),
                                  act=act)
        # C_global_latents = dlatents_withl_in[:, start_idx:start_idx +
                                             # n_latents]
        with tf.variable_scope('Modulate'):
            x = apply_bias_act(modulated_conv2d_layer(x,
                                                      C_global_latents,
                                                      fmaps=fmaps,
                                                      kernel=3,
                                                      up=False,
                                                      fused_modconv=fused_modconv),
                               act=act)
    return x
Esempio n. 9
0
 def hier_out_branch(x, nd_out):
     with tf.variable_scope('Output'):
         if len(x.shape) == 4:
             x = tf.reduce_mean(tf.reduce_mean(x, axis=3), axis=2)
         elif len(x.shape) != 2:
             raise ValueError('Not recognized dimension.')
         x = apply_bias_act(dense_layer(x, fmaps=nd_out))
     return x
Esempio n. 10
0
def point_wise_feed_forward_network(x, d_model, dff):
    seq_len, x_dim = x.get_shape().as_list()[-2:]
    with tf.variable_scope('ffn_0_'):
        x = tf.reshape(x, [-1, x_dim])
        x = apply_bias_act(dense_layer(x, dff), act='relu')
        x = tf.reshape(x, [-1, seq_len, dff])  # (batch_size, seq_len, dff)
    with tf.variable_scope('ffn_1_'):
        x = apply_bias(dense_layer_last_dim(
            x, d_model))  # (batch_size, seq_len, d_model)
    return x
def get_sh_matrix(sh_latents, cond_latent, act='lrelu'):
    # sh_latents[:, 0]: [-2., 2.] -> [-1., 1.]
    # sh_latents[:, 1]: [-2., 2.] -> [-1., 1.]
    with tf.variable_scope('Condition0x'):
        cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
    with tf.variable_scope('Condition1x'):
        cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1), act='sigmoid')
    with tf.variable_scope('Condition0y'):
        cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
    with tf.variable_scope('Condition1y'):
        cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1), act='sigmoid')
    cond = tf.concat([cond_x, cond_y], axis=1)
    xy_shear = sh_latents / 2. * cond
    tt_00 = tf.ones_like(xy_shear[:, 0:1])
    tt_01 = xy_shear[:, 0:1]
    tt_02 = tf.zeros_like(xy_shear[:, 0:1])
    tt_10 = xy_shear[:, 1:]
    tt_11 = tf.ones_like(xy_shear[:, 1:])
    tt_12 = tf.zeros_like(xy_shear[:, 1:])
    theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
    return theta
Esempio n. 12
0
def build_Cout_spgroup_layers(x,
                              name,
                              n_latents,
                              start_idx,
                              scope_idx,
                              atts_in,
                              act,
                              fmaps=128,
                              resolution=128,
                              **kwargs):
    '''
    Build continuous latent out layers with learned group spatial attention.
    Support square images only.
    '''
    # atts_in: [b, all_n_latents, 1, resolution, resolution]
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_wh = x.shape[2]
            atts = atts_in[:, start_idx:start_idx +
                           n_latents]  # [b, n_latents, 1, resolution, resolution]
            atts = tf.reshape(atts, [-1, resolution, resolution, 1])
            atts = tf.image.resize(atts, size=(x_wh, x_wh))
            atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh])
            x_out_ls = []
            for i in range(n_latents):
                x_tmp = x * atts[:, i]
                x_tmp_2 = tf.reduce_mean(x_tmp, axis=[2, 3])  # [b, in_dim]
                with tf.variable_scope('OutDense-' + str(i)):
                    with tf.variable_scope('Conv0'):
                        x_tmp_2 = apply_bias_act(dense_layer(x_tmp_2,
                                                             fmaps=fmaps),
                                                 act=act)  # [b, fmaps]
                    with tf.variable_scope('Conv1'):
                        x_out_tmp = dense_layer(x_tmp_2, fmaps=1)  # [b, 1]
                        x_out_ls.append(x_out_tmp)
            pred_out = tf.concat(x_out_ls, axis=1)  # [b, n_latents]
            return x, pred_out
Esempio n. 13
0
def net_M(
    latents_in,
    C_global_size=10,
    D_global_size=0,
    latent_size=512,  # Latent vector (Z) dimensionality.
    mapping_layers=4,  # Number of mapping layers.
    mapping_lrmul=0.1,  # Learning rate multiplier for the mapping layers.
    mapping_fmaps=512,  # Number of activations in the mapping layers.
    mapping_nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    use_std_in_m=False,  # If output prior std.
    dtype='float32',  # Data type to use for activations and outputs.
    **_kwargs):  # Ignore unrecognized keyword args.
    act = mapping_nonlinearity

    latents_in.set_shape([None, C_global_size + D_global_size])
    x = latents_in
    # Mapping layers.
    for layer_idx in range(mapping_layers):
        with tf.variable_scope('Dense%d' % layer_idx):
            # if layer_idx == mapping_layers - 1:
            # fmaps = latent_size
            # act = 'tanh'
            # else:
            # fmaps = mapping_fmaps
            # act = mapping_nonlinearity
            # x = apply_bias_act(dense_layer(x, fmaps=fmaps, lrmul=mapping_lrmul),
            # act=act, lrmul=mapping_lrmul)
            if layer_idx == mapping_layers - 1:
                if use_std_in_m:
                    fmaps = 2 * latent_size
                else:
                    fmaps = latent_size
                act = 'linear'
            else:
                fmaps = mapping_fmaps
                act = mapping_nonlinearity
            x = apply_bias_act(dense_layer(x, fmaps=fmaps,
                                           lrmul=mapping_lrmul),
                               act=act,
                               lrmul=mapping_lrmul)
    # # x = x * 1.5
    # with tf.variable_scope('Dense1'):
    # # x = tf.zeros([tf.shape(x)[0], latent_size], dtype=x.dtype) + 0.5
    # x = tf.random.normal([tf.shape(x)[0], latent_size], mean=0.0, stddev=0.5)

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    return tf.identity(x, name='to_latent_out')
Esempio n. 14
0
def build_C_fgroup_layers(x,
                          name,
                          n_latents,
                          start_idx,
                          scope_idx,
                          dlatents_in,
                          act,
                          fused_modconv,
                          fmaps=128,
                          return_atts=False,
                          resolution=128,
                          **kwargs):
    '''
    Build continuous latent layers with learned group feature attention.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_start_end'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])
            att_dim = x_mean.shape[1]
            atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim)
            atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1
                                     ])  # [b, n_latents, 2, att_dim, 1, 1]
            att_sm = tf.nn.softmax(atts, axis=3)
            att_cs = tf.cumsum(att_sm, axis=3)
            att_cs_starts, att_cs_ends = tf.split(
                att_cs, 2, axis=2)  # [b, n_latents, 1, att_dim, 1, 1]
            att_cs_ends = 1 - att_cs_ends
            atts = att_cs_starts * att_cs_ends  # [b, n_latents, 1, att_dim, 1, 1]
            atts = tf.reshape(atts, [-1, n_latents, att_dim, 1, 1])

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i:i + 1]

        if return_atts:
            return x, atts
        else:
            return x
Esempio n. 15
0
def net_M_vc(
    latents_in,
    C_global_size=10,
    D_global_size=0,
    latent_size=512,  # Latent vector (Z) dimensionality.
    mapping_lrmul=0.1,  # Learning rate multiplier for the mapping layers.
    use_std_in_m=False,  # If output prior std.
    dtype='float32',  # Data type to use for activations and outputs.
    **_kwargs):  # Ignore unrecognized keyword args.

    latents_in.set_shape([None, C_global_size])
    x = latents_in

    x = apply_bias_act(dense_layer(x, fmaps=latent_size, lrmul=mapping_lrmul),
                       act='lrelu',
                       lrmul=mapping_lrmul)

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    return tf.identity(x, name='to_latent_out')
Esempio n. 16
0
def get_s_matrix(s_latents, cond_latent, act='lrelu'):
    # s_latents[:, 0]: [-2., 2.] -> [1., 3.]
    # s_latents[:, 1]: [-2., 2.] -> [1., 3.]
    if s_latents.shape.as_list()[1] == 1:
        with tf.variable_scope('Condition0'):
            cond = apply_bias_act(dense_layer(cond_latent, fmaps=128), act=act)
        with tf.variable_scope('Condition1'):
            cond = apply_bias_act(dense_layer(cond, fmaps=1), act='sigmoid')
        scale = (s_latents + 2.) * cond + 1.
        tt_00 = scale
        tt_01 = tf.zeros_like(scale)
        tt_02 = tf.zeros_like(scale)
        tt_10 = tf.zeros_like(scale)
        tt_11 = scale
        tt_12 = tf.zeros_like(scale)
    else:
        with tf.variable_scope('Condition0x'):
            cond_x = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                    act=act)
        with tf.variable_scope('Condition1x'):
            cond_x = apply_bias_act(dense_layer(cond_x, fmaps=1),
                                    act='sigmoid')
        with tf.variable_scope('Condition0y'):
            cond_y = apply_bias_act(dense_layer(cond_latent, fmaps=128),
                                    act=act)
        with tf.variable_scope('Condition1y'):
            cond_y = apply_bias_act(dense_layer(cond_y, fmaps=1),
                                    act='sigmoid')
        cond = tf.concat([cond_x, cond_y], axis=1)
        scale = (s_latents + 2.) * cond + 1.
        tt_00 = scale[:, 0:1]
        tt_01 = tf.zeros_like(scale[:, 0:1])
        tt_02 = tf.zeros_like(scale[:, 0:1])
        tt_10 = tf.zeros_like(scale[:, 1:])
        tt_11 = scale[:, 1:]
        tt_12 = tf.zeros_like(scale[:, 1:])
    theta = tf.concat([tt_00, tt_01, tt_02, tt_10, tt_11, tt_12], axis=1)
    return theta
Esempio n. 17
0
def vid_naive_cluster_head(
    fake_in,  # First input: generated image from z [minibatch, channel, n_frames, height, width].
    num_channels=3,  # Number of input color channels. Overridden based on dataset.
    resolution=1024,  # Input resolution. Overridden based on dataset.
    dlatent_size=10,
    D_global_size=0,
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    fake_in.set_shape([None, num_channels, None, resolution, resolution])
    fake_in = tf.cast(fake_in, dtype)

    vid_in = fake_in

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = conv3d_layer(y, fmaps=nf(res - 1), kernel=1)
            t = apply_bias_act_3d(t, act=act)
        return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        with tf.variable_scope('Conv3D_0'):
            x = conv3d_layer(x, fmaps=nf(res - 1), kernel=3)
            x = apply_bias_act_3d(x, act=act)
        with tf.variable_scope('Conv1_down'):
            x = conv3d_layer(x, fmaps=nf(res - 2), kernel=3, down=True)
            x = apply_bias_act_3d(x, act=act)
        return x

    # Main layers.
    x = None
    y = vid_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('I_%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample_3d(y)

    # Final layers.
    with tf.variable_scope('I_4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        with tf.variable_scope('Conv'):
            x = conv3d_layer(x, fmaps=nf(1), kernel=3)
            x = apply_bias_act_3d(x, act=act)
        with tf.variable_scope('Global_temporal_pool'):
            x = tf.reduce_mean(x, axis=2)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output.
    with tf.variable_scope('I_Output'):
        with tf.variable_scope('Dense_VC'):
            x = apply_bias_act(
                dense_layer(x, fmaps=dlatent_size - D_global_size))

    assert x.dtype == tf.as_dtype(dtype)
    return x
Esempio n. 18
0
def vc_head(
        fake1,  # First input: generated image from z [minibatch, channel, height, width].
        fake2,  # Second input: hidden features from z + delta(z) [minibatch, channel, height, width].
        num_channels=3,  # Number of input color channels. Overridden based on dataset.
        resolution=1024,  # Input resolution. Overridden based on dataset.
        dlatent_size=10,
        D_global_size=0,
        fmap_base=16 <<
        10,  # Overall multiplier for the number of feature maps.
        fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
        fmap_min=1,  # Minimum number of feature maps in any layer.
        fmap_max=512,  # Maximum number of feature maps in any layer.
        architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
        nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
        mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
        mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
        dtype='float32',  # Data type to use for activations and outputs.
        resample_kernel=[
            1, 3, 3, 1
        ],  # Low-pass filter to apply when resampling activations. None = no filtering.
        connect_mode='concat',  # How fake1 and fake2 connected.
        **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    fake1.set_shape([None, num_channels, resolution, resolution])
    fake2.set_shape([None, num_channels, resolution, resolution])
    fake1 = tf.cast(fake1, dtype)
    fake2 = tf.cast(fake2, dtype)
    if connect_mode == 'diff':
        images_in = fake1 - fake2
    elif connect_mode == 'concat':
        images_in = tf.concat([fake1, fake2], axis=1)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        if mbstd_group_size > 1:
            with tf.variable_scope('MinibatchStddev'):
                x = minibatch_stddev_layer(x, mbstd_group_size,
                                           mbstd_num_features)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
    with tf.variable_scope('Output'):
        with tf.variable_scope('Dense_VC'):
            x = apply_bias_act(
                dense_layer(x,
                            fmaps=(D_global_size +
                                   (dlatent_size - D_global_size))))

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    return x
Esempio n. 19
0
def build_std_gen(x,
                  name,
                  n_latents,
                  start_idx,
                  scope_idx,
                  dlatents_in,
                  act,
                  fused_modconv,
                  fmaps=128,
                  resolution=512,
                  fmap_base=2 << 8,
                  fmap_min=1,
                  fmap_max=512,
                  fmap_decay=1,
                  architecture='skip',
                  randomize_noise=True,
                  resample_kernel=[1, 3, 3, 1],
                  num_channels=3,
                  latent_split_ls_for_std_gen=[5, 5, 5, 5],
                  **kwargs):
    '''
    Build standard disentanglement generator with similar architecture to stylegan2.
    '''
    # with tf.variable_scope(name + '-' + str(scope_idx)):
    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    num_layers = resolution_log2 * 2 - 2
    images_out = None
    dtype = x.dtype
    assert n_latents == sum(latent_split_ls_for_std_gen)
    assert num_layers == len(latent_split_ls_for_std_gen)
    latents_ready_ls = []
    start_code = 0
    for i, seg in enumerate(latent_split_ls_for_std_gen):
        with tf.variable_scope('PreConvDense-' + str(i) + '-0'):
            x_tmp0 = dense_layer(dlatents_in[:, start_code:start_code + seg],
                                 fmaps=nf(1))
        with tf.variable_scope('PreConvDense-' + str(i) + '-1'):
            x_tmp1 = dense_layer(x_tmp0, fmaps=nf(1))
        start_code += seg
        latents_ready_ls.append(x_tmp1)

    # Noise inputs.
    noise_inputs = []
    for layer_idx in range(num_layers - 1):
        res = (layer_idx + 5) // 2
        shape = [1, 1, 2**res, 2**res]
        noise_inputs.append(
            tf.get_variable('noise%d' % layer_idx,
                            shape=shape,
                            initializer=tf.initializers.random_normal(),
                            trainable=False))

    # Single convolution layer with all the bells and whistles.
    def layer(x, layer_idx, fmaps, kernel, up=False):
        # start_idx_layer = sum(latent_split_ls_for_std_gen[:layer_idx])
        # for i in range(start_idx_layer, start_idx_layer + latent_split_ls_for_std_gen[layer_idx]):
        # x = modulated_conv2d_layer(x, latents_ready_spl_ls[i], fmaps=fmaps, kernel=kernel, up=up,
        # resample_kernel=resample_kernel, fused_modconv=fused_modconv)
        x = modulated_conv2d_layer(x,
                                   latents_ready_ls[layer_idx],
                                   fmaps=fmaps,
                                   kernel=kernel,
                                   up=up,
                                   resample_kernel=resample_kernel,
                                   fused_modconv=fused_modconv)
        if randomize_noise:
            noise = tf.random_normal(
                [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
        else:
            noise = tf.cast(noise_inputs[layer_idx], x.dtype)
        noise_strength = tf.get_variable('noise_strength',
                                         shape=[],
                                         initializer=tf.initializers.zeros())
        x += noise * tf.cast(noise_strength, x.dtype)
        return apply_bias_act(x, act=act)

    # Building blocks for main layers.
    def block(x, res):  # res = 3..resolution_log2
        t = x
        with tf.variable_scope('Conv0_up'):
            x = layer(x,
                      layer_idx=res * 2 - 5,
                      fmaps=nf(res - 1),
                      kernel=3,
                      up=True)
        with tf.variable_scope('Conv1'):
            x = layer(x, layer_idx=res * 2 - 4, fmaps=nf(res - 1), kernel=3)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 1),
                                 kernel=1,
                                 up=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def upsample(y):
        with tf.variable_scope('Upsample'):
            return upsample_2d(y, k=resample_kernel)

    def torgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('ToRGB'):
            t = apply_bias_act(
                modulated_conv2d_layer(x,
                                       latents_ready_ls[res * 2 - 3],
                                       fmaps=num_channels,
                                       kernel=1,
                                       demodulate=False,
                                       fused_modconv=fused_modconv))
            return t if y is None else y + t

    # Early layers.
    y = None
    with tf.variable_scope('4x4'):
        with tf.variable_scope('Const'):
            x = tf.get_variable('const',
                                shape=[1, nf(1), 4, 4],
                                initializer=tf.initializers.random_normal())
            x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
        with tf.variable_scope('Conv'):
            x = layer(x, layer_idx=0, fmaps=nf(1), kernel=3)

    # Main layers.
    for res in range(3, resolution_log2 + 1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            x = block(x, res)
            if res == resolution_log2:
                y = torgb(x, y, res)
    images_out = y

    assert images_out.dtype == tf.as_dtype(dtype)
    return tf.identity(images_out, name='images_out')
Esempio n. 20
0
def build_C_spgroup_layers_with_latents_ready(x,
                                              name,
                                              n_latents,
                                              scope_idx,
                                              latents_ready,
                                              return_atts=False,
                                              resolution=128,
                                              n_subs=1,
                                              **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention using latents_ready.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * n_subs * 4 * x_wh)
            atts_wh = tf.reshape(atts_wh,
                                 [-1, n_latents, n_subs, 4, x_wh
                                  ])  # [b, n_latents, n_subs, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=3)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, n_subs, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, n_subs, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, n_subs, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, n_subs, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, n_subs, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, n_subs, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, n_subs, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, n_subs, 1, 1, x_wh]
            atts = att_h * att_w  # [b, n_latents, n_subs, 1, x_wh, x_wh]
            atts = tf.reduce_mean(atts,
                                  axis=2)  # [b, n_latents, 1, x_wh, x_wh]
            # atts = tf.reduce_sum(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh]

        with tf.variable_scope('Att_apply'):
            C_global_latents = latents_ready  # [b, n_latents, 512]
            x_norm = instance_norm(x)
            # x_norm = tf.tile(x_norm, [1, n_latents, 1, 1])
            # x_norm = tf.reshape(x_norm, [-1, x.shape[1], x.shape[2], x.shape[3]]) # [b*n_latents, c, h, w]
            # C_global_latents = tf.reshape(C_global_latents, [-1, 1])
            # x_styled = style_mod(x_norm, C_global_latents)
            # x_styled = tf.reshape(x_styled, [-1, n_latents, x_styled.shape[1],
            # x_styled.shape[2], x_styled.shape[3]])
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
                    # x = x * (1 - atts[:, i]) + x_styled[:, i] * atts[:, i]

        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x
Esempio n. 21
0
def build_C_spgroup_stn_layers(x,
                               name,
                               n_latents,
                               start_idx,
                               scope_idx,
                               dlatents_in,
                               act,
                               fused_modconv,
                               fmaps=128,
                               return_atts=False,
                               resolution=128,
                               **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention with spatial transform.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh)
            atts_wh = tf.reshape(
                atts_wh, [-1, n_latents, 4, x_wh])  # [b, n_latents, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=2)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, 1, 1, x_wh]
            atts = att_h * att_w  # [b, n_latents, 1, x_wh, x_wh]

        with tf.variable_scope('trans_matrix'):
            theta = apply_bias_act(dense_layer(x_mean, fmaps=n_latents * 6))
            theta = tf.reshape(theta, [-1, 6])  # [b*n_latents, 6]
            atts = tf.reshape(
                atts, [-1, x_wh, x_wh, 1])  # [b*n_latents, x_wh, x_wh, 1]
            atts = transformer(atts, theta)  # [b*n_latents, x_wh, x_wh, 1]
            atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh])

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x
Esempio n. 22
0
def vpex_net(
        fake1,  # First input: generated image from z [minibatch, channel, height, width].
        fake2,  # Second input: hidden features from z + delta(z) [minibatch, channel, height, width].
        latents,  # Ground-truth latent code for fake1.
        num_channels=3,  # Number of input color channels. Overridden based on dataset.
        resolution=1024,  # Input resolution. Overridden based on dataset.
        dlatent_size=10,
        D_global_size=0,
        fmap_base=16 <<
    10,  # Overall multiplier for the number of feature maps.
        fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
        fmap_min=1,  # Minimum number of feature maps in any layer.
        fmap_max=512,  # Maximum number of feature maps in any layer.
        architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
        nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
        mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
        mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
        dtype='float32',  # Data type to use for activations and outputs.
        resample_kernel=[
            1, 3, 3, 1
        ],  # Low-pass filter to apply when resampling activations. None = no filtering.
        connect_mode='concat',  # How fake1 and fake2 connected.
        return_atts=False,  # If return I_atts.
        **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    fake1.set_shape([None, num_channels, resolution, resolution])
    fake2.set_shape([None, num_channels, resolution, resolution])
    latents.set_shape([None, dlatent_size])
    fake1 = tf.cast(fake1, dtype)
    fake2 = tf.cast(fake2, dtype)
    latents = tf.cast(latents, dtype)
    if connect_mode == 'diff':
        images_in = fake1 - fake2
    elif connect_mode == 'concat':
        images_in = tf.concat([fake1, fake2], axis=1)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # attention features for each latent dimension.
    def get_att_map(latents, x=None):
        with tf.variable_scope('create_att_feats'):
            x_ch, x_h, x_w = x.get_shape().as_list()[1:]
            att_feats = tf.get_variable(
                'att_feats',
                shape=[1, dlatent_size, x_ch, x_h, x_w],
                initializer=tf.initializers.random_normal())
            att_feats = tf.tile(tf.cast(att_feats, dtype),
                                [tf.shape(latents)[0], 1, 1, 1, 1])
            latents = latents[:, tf.newaxis, :]
            latents = tf.tile(latents, [1, dlatent_size, 1])
            latents = tf.reshape(latents, [-1, dlatent_size])
            # att_map = apply_bias_act(modulated_conv2d_layer(att_feats, latents, fmaps=64, kernel=3,
            # demodulate=False, fused_modconv=False),
            # act=act) # shape: [b*dlatent_size, 1, 8, 8]
            if x is None:
                att_map = att_feats
                att_map = tf.reshape(att_map, [-1, x_ch, x_h, x_w])
                map_ch = x_ch
            else:
                x = tf.reshape(x, [-1, 1, x_ch, x_h, x_w])
                x = tf.tile(x, [1, dlatent_size, 1, 1, 1])
                att_map = tf.concat([x, att_feats], axis=2)
                att_map = tf.reshape(att_map, [-1, 2 * x_ch, x_h, x_w])
                map_ch = 2 * x_ch
            with tf.variable_scope('att_conv_3x3'):
                att_map = apply_bias_act(conv2d_layer(att_map,
                                                      fmaps=map_ch,
                                                      kernel=3),
                                         act=act)
            with tf.variable_scope('att_conv_1x1'):
                att_map = apply_bias_act(
                    conv2d_layer(att_map, fmaps=1, kernel=1))
            att_map = tf.reshape(att_map, [-1, dlatent_size, 1, x_h * x_w])
            att_map = tf.nn.softmax(att_map, axis=-1)
            # att_map = tf.nn.sigmoid(att_map)
            # att_map = tf.reshape(att_map, [-1, dlatent_size, 1, 8, 8])
        return att_map

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 3, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Duplicate for each att.
    with tf.variable_scope('apply_att'):
        att_map = get_att_map(latents, x)
        x_ch, x_h, x_w = x.get_shape().as_list()[1:]
        assert x_h == 8
        x_ori = tf.reshape(x, [-1, 1, x_ch, x_h * x_w])  # [b, 1, ch, h*w]
        x = tf.reshape(x, [-1, 1, x_ch, x_h * x_w])
        x = att_map * x
        x = tf.reduce_sum(x, axis=-1)  # [b, dlatent, ch]
        x = tf.reshape(x, [-1, x_ch, 1, 1])  # [b * dlatent, ch, 1, 1]
        with tf.variable_scope('after_att_conv_1x1'):
            x = apply_bias_act(conv2d_layer(x, fmaps=x_ch, kernel=1))
        x = tf.reshape(x, [-1, dlatent_size, x_ch, 1])  # [b, dlatent, ch, 1]

        x = tf.tile(x, [1, 1, 1, x_h * x_w])
        # x = x + x_ori # [b, dlatent, ch, h * w]
        x = tf.reshape(x, [-1, x_ch, x_h, x_w])
        y_ch, y_h, y_w = y.get_shape().as_list()[1:]
        y = y[:, tf.newaxis, ...]
        y = tf.tile(y, [1, dlatent_size, 1, 1, 1])
        y = tf.reshape(y, [-1, y_ch, y_h, y_w])

    for res in range(3, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    with tf.variable_scope('Output'):
        with tf.variable_scope('Dense_VC'):
            x = apply_bias_act(dense_layer(x, fmaps=1))

    with tf.variable_scope('Final_reshape_x'):
        x = tf.reshape(x, [-1, dlatent_size])

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    if return_atts:
        with tf.variable_scope('Reshape_atts'):
            att_map = tf.reshape(att_map, [-1, 8, 8, 1])
            att_map = tf.image.resize(att_map, size=(resolution, resolution))
            att_map = tf.reshape(att_map,
                                 [-1, dlatent_size, 1, resolution, resolution])
        return x, att_map
    else:
        return x
Esempio n. 23
0
def build_C_spgroup_regW_layers(x,
                                name,
                                n_latents,
                                start_idx,
                                scope_idx,
                                dlatents_in,
                                act,
                                fused_modconv,
                                fmaps=128,
                                resolution=128,
                                n_subs=1,
                                **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * n_subs * 4 * x_wh)
            atts_wh = tf.reshape(atts_wh,
                                 [-1, n_latents, n_subs, 4, x_wh
                                  ])  # [b, n_latents, n_subs, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=3)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, n_subs, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, n_subs, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, n_subs, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, n_subs, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, n_subs, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, n_subs, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, n_subs, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, n_subs, 1, 1, x_wh]
            atts = att_h * att_w  # [b, n_latents, n_subs, 1, x_wh, x_wh]
            atts = tf.reduce_mean(atts,
                                  axis=2)  # [b, n_latents, 1, x_wh, x_wh]
            # atts = tf.reduce_sum(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh]

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            z_w = []
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    # print('C_global_latents.shape:', C_global_latents.shape)
                    x_styled, z_w_tmp = style_mod_with_regW(
                        x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
                    z_w.append(z_w_tmp)

        with tf.variable_scope('Reshape_output'):
            atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
            atts = tf.image.resize(atts, size=(resolution, resolution))
            atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution])
            z_w = tf.concat(z_w, axis=0)
        return x, atts, z_w
Esempio n. 24
0
def build_zpos_to_mat_layer(x,
                            name,
                            n_layers,
                            scope_idx,
                            is_training,
                            wh,
                            feat_cnn_dim,
                            resolution=128,
                            trans_dim=512,
                            dff=512,
                            trans_rate=0.1,
                            ncut_maxval=5,
                            post_trans_mat=16,
                            **kwargs):
    '''
    Build zpos_to_mat forwarding transformer to extract features per z.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('PosConstant'):
            n_lat = x.get_shape().as_list()[-1]
            pos = tf.get_variable('const',
                                  shape=[1, n_lat, trans_dim],
                                  initializer=tf.initializers.random_normal())
            pos = tf.tile(tf.cast(pos, x.dtype), [tf.shape(x)[0], 1, 1])
            zpos = pos + x[:, :, np.newaxis]
        with tf.variable_scope('MaskEncoding'):
            if is_training:
                ncut = tf.random.uniform(shape=[],
                                         maxval=ncut_maxval,
                                         dtype=tf.int32)
                split_masks_mul, split_idx = create_split_mask(n_lat, ncut)
            else:
                split_masks_mul = tf.ones(shape=[n_lat, n_lat],
                                          dtype=tf.float32)
                split_idx = tf.constant([n_lat])
            split_idx = tf.concat([split_idx, [n_lat]], axis=0)
            split_idx, _ = tf.unique(split_idx)
            mask_logits = get_return_v(
                trans_encoder_basic(zpos,
                                    is_training,
                                    split_masks_mul,
                                    n_layers,
                                    trans_dim,
                                    num_heads=8,
                                    dff=dff,
                                    rate=trans_rate), 1)  # (b, n_lat, d_model)
            mask_groups = dense_layer_last_dim(mask_logits,
                                               post_trans_mat * post_trans_mat)

        with tf.variable_scope('GatherSubgroups'):
            b = tf.shape(mask_groups)[0]
            len_group = tf.shape(split_idx)[0]
            gathered_groups = tf.reshape(
                tf.gather(mask_groups, split_idx - 1, axis=1),
                [b, len_group] + [post_trans_mat, post_trans_mat
                                  ])  # (b, len(split_idx), mat * mat)
            mat_agg = tf.eye(post_trans_mat, batch_shape=[b])

            def cond(i, mats):
                return tf.less(i, len_group)

            def bod(i, mats):
                mats = tf.matmul(gathered_groups[:, i, ...], mats)
                i += 1
                return (i, mats)

            i_mats = (0, mat_agg)
            _, mat_agg_final = tf.while_loop(cond, bod,
                                             i_mats)  # (b, mat, mat)
            mat_agg_final_out = tf.reshape(
                mat_agg_final, [b, post_trans_mat * post_trans_mat])

        with tf.variable_scope('MaskMapping'):
            mat_agg_final = tf.reshape(mat_agg_final,
                                       [b, post_trans_mat * post_trans_mat])
            feat = apply_bias_act(
                dense_layer(mat_agg_final, feat_cnn_dim * wh * wh))
            feat = tf.reshape(feat, [-1, feat_cnn_dim, wh, wh])

        with tf.variable_scope('ReshapeAttns'):
            split_masks_mul -= 1e-4
            atts = tf.reshape(split_masks_mul, [-1, n_lat, n_lat, 1])
            atts = tf.image.resize(atts, size=(resolution, resolution))
            atts = tf.tile(
                tf.reshape(atts, [-1, 1, 1, resolution, resolution]),
                [1, n_lat, 1, 1, 1])
        return feat, atts, mat_agg_final_out
Esempio n. 25
0
def build_C_spfgroup_layers(x,
                            name,
                            n_latents,
                            start_idx,
                            scope_idx,
                            dlatents_in,
                            act,
                            fused_modconv,
                            fmaps=128,
                            return_atts=False,
                            resolution=128,
                            **kwargs):
    '''
    Build continuous latent layers with learned group feature-spatial attention.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_channel_start_end'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            att_dim = x_mean.shape[1]
            atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim)
            atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1
                                     ])  # [b, n_latents, 2, att_dim, 1, 1]
            att_sm = tf.nn.softmax(atts, axis=3)
            att_cs = tf.cumsum(att_sm, axis=3)
            att_cs_starts, att_cs_ends = tf.split(att_cs, 2, axis=2)
            att_cs_ends = 1 - att_cs_ends
            att_channel = att_cs_starts * att_cs_ends  # [b, n_latents, 1, att_dim, 1, 1]
            att_channel = tf.reshape(att_channel,
                                     [-1, n_latents, att_dim, 1, 1])

        with tf.variable_scope('Att_spatial'):
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh)
            atts_wh = tf.reshape(
                atts_wh, [-1, n_latents, 4, x_wh])  # [b, n_latents, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=2)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, 1, 1, x_wh]
            att_sp = att_h * att_w  # [b, n_latents, 1, x_wh, x_wh]
            atts = att_channel * att_sp  # [b, n_latents, att_dim, h, w]
        # print('in spfgroup 1, x.shape:', x.get_shape().as_list())

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        # print('in spfgroup 2, x.shape:', x.get_shape().as_list())
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                att_sp = tf.reshape(att_sp, [-1, x_wh, x_wh, 1])
                att_sp = tf.image.resize(att_sp, size=(resolution, resolution))
                att_sp = tf.reshape(att_sp,
                                    [-1, n_latents, 1, resolution, resolution])
                # return x, att_channel, att_sp
            return x, att_sp
        else:
            return x
Esempio n. 26
0
def vid_head(
    fake_in,  # First input: generated image from z [minibatch, channel, n_frames, height, width].
    C_delta_idxes,  # Second input: the index of the varied latent.
    num_channels=3,  # Number of input color channels. Overridden based on dataset.
    resolution=1024,  # Input resolution. Overridden based on dataset.
    dlatent_size=10,
    D_global_size=0,
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    fake_in.set_shape([None, num_channels, None, resolution, resolution])
    fake_in = tf.cast(fake_in, dtype)
    C_delta_idxes.set_shape([None, dlatent_size])
    C_delta_idxes = tf.cast(C_delta_idxes, dtype)

    vid_in = fake_in

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = conv3d_layer(y, fmaps=nf(res - 1), kernel=1)
            t = apply_bias_act_3d(t, act=act)
        return t if x is None else x + t

    # def block(x, res):  # res = 2..resolution_log2
    # t = x
    # with tf.variable_scope('Conv0'):
    # x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
    # act=act)
    # with tf.variable_scope('Conv1_down'):
    # x = apply_bias_act(conv2d_layer(x,
    # fmaps=nf(res - 2),
    # kernel=3,
    # down=True,
    # resample_kernel=resample_kernel),
    # act=act)
    # if architecture == 'resnet':
    # with tf.variable_scope('Skip'):
    # t = conv2d_layer(t,
    # fmaps=nf(res - 2),
    # kernel=1,
    # down=True,
    # resample_kernel=resample_kernel)
    # x = (x + t) * (1 / np.sqrt(2))
    # return x

    def block(x, res):  # res = 2..resolution_log2
        with tf.variable_scope('Conv3D_0'):
            x = conv3d_layer(x, fmaps=nf(res - 1), kernel=3)
            x = apply_bias_act_3d(x, act=act)
        with tf.variable_scope('Conv1_down'):
            x = conv3d_layer(x, fmaps=nf(res - 2), kernel=3, down=True)
            x = apply_bias_act_3d(x, act=act)
        return x

    # def downsample(y):
    # with tf.variable_scope('Downsample'):
    # return downsample_2d(y, k=resample_kernel)

    # Main layers.
    x = None
    y = vid_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('I_%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample_3d(y)

    # Final layers.
    with tf.variable_scope('I_4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        with tf.variable_scope('Conv'):
            x = conv3d_layer(x, fmaps=nf(1), kernel=3)
            x = apply_bias_act_3d(x, act=act)
        with tf.variable_scope('Global_temporal_pool'):
            x = tf.reduce_mean(x, axis=2)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=64), act=act)

    print('before from C_delta_idxes, x.get_shape:', x.get_shape().as_list())
    print('before from C_delta_idxes, x.shape:', x.shape)
    print('before from C_delta_idxes, C_delta_idxes.shape:',
          C_delta_idxes.shape)
    # From C_delta_idxes
    with tf.variable_scope('I_From_C_Delta_Idx'):
        x_from_delta = apply_bias_act(dense_layer(C_delta_idxes, fmaps=32),
                                      act=act)
        x = tf.concat([x, x_from_delta], axis=1)

    # For MINE
    with tf.variable_scope('I_Output'):
        with tf.variable_scope('Dense_T_0'):
            x = apply_bias_act(dense_layer(x, fmaps=128), act=act)
        with tf.variable_scope('Dense_T_1'):
            x = apply_bias_act(dense_layer(x, fmaps=1))

    # Output.
    assert x.dtype == tf.as_dtype(dtype)
    return x
Esempio n. 27
0
def D_info_gan_stylegan2(
    images_in,  # First input: Images [minibatch, channel, height, width].
    labels_in,  # Second input: Labels [minibatch, label_size].
    num_channels=3,  # Number of input color channels. Overridden based on dataset.
    resolution=1024,  # Input resolution. Overridden based on dataset.
    label_size=0,  # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    images_in.set_shape([None, num_channels, resolution, resolution])
    labels_in.set_shape([None, label_size])
    images_in = tf.cast(images_in, dtype)
    labels_in = tf.cast(labels_in, dtype)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(conv2d_layer(y, fmaps=nf(res - 1), kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(res - 1), kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(conv2d_layer(x,
                                            fmaps=nf(res - 2),
                                            kernel=3,
                                            down=True,
                                            resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        if mbstd_group_size > 1:
            with tf.variable_scope('MinibatchStddev'):
                x = minibatch_stddev_layer(x, mbstd_group_size,
                                           mbstd_num_features)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(conv2d_layer(x, fmaps=nf(1), kernel=3), act=act)
        with tf.variable_scope('Dense0'):
            hidden = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
    with tf.variable_scope('Output'):
        x = apply_bias_act(
            dense_layer(hidden, fmaps=max(labels_in.shape[1], 1)))
        if labels_in.shape[1] > 0:
            x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True)
    scores_out = x

    # Output.
    assert scores_out.dtype == tf.as_dtype(dtype)
    scores_out = tf.identity(scores_out, name='scores_out')
    hidden = tf.identity(hidden, name='hidden')
    return scores_out, hidden
Esempio n. 28
0
def build_C_spgroup_lcond_layers(x,
                                 name,
                                 n_latents,
                                 start_idx,
                                 scope_idx,
                                 dlatents_in,
                                 act,
                                 fused_modconv,
                                 fmaps=128,
                                 return_atts=False,
                                 resolution=128,
                                 **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]

            atts_ls = []
            for i in range(n_latents):
                with tf.variable_scope('lcond-' + str(i)):
                    x_mean_styled = style_mod(x_mean,
                                              C_global_latents[:, i:i + 1])

                    att_wh = dense_layer(x_mean_styled, fmaps=4 * x_wh)
                    att_wh = tf.reshape(att_wh, [-1, 4, x_wh])  # [b, 4, x_wh]
                    att_wh_sm = tf.nn.softmax(att_wh, axis=-1)
                    att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
                    att_h_cs_start, att_h_cs_end, att_w_cs_start, att_w_cs_end = tf.split(
                        att_wh_cs, 4, axis=1)
                    att_h_cs_end = 1 - att_h_cs_end  # [b, 1, x_wh]
                    att_w_cs_end = 1 - att_w_cs_end  # [b, 1, x_wh]
                    att_h_cs_start = tf.reshape(att_h_cs_start,
                                                [-1, 1, 1, x_wh, 1])
                    att_h_cs_end = tf.reshape(att_h_cs_end,
                                              [-1, 1, 1, x_wh, 1])
                    att_h = att_h_cs_start * att_h_cs_end  # [b, 1, 1, x_wh, 1]
                    att_w_cs_start = tf.reshape(att_w_cs_start,
                                                [-1, 1, 1, 1, x_wh])
                    att_w_cs_end = tf.reshape(att_w_cs_end,
                                              [-1, 1, 1, 1, x_wh])
                    att_w = att_w_cs_start * att_w_cs_end  # [b, 1, 1, 1, x_wh]
                    att = att_h * att_w  # [b, 1, 1, x_wh, x_wh]
                    atts_ls.append(att)
            atts = tf.concat(atts_ls, axis=1)

        with tf.variable_scope('Att_apply'):
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x