Example #1
0
def build_C_fgroup_layers(x,
                          name,
                          n_latents,
                          start_idx,
                          scope_idx,
                          dlatents_in,
                          act,
                          fused_modconv,
                          fmaps=128,
                          return_atts=False,
                          resolution=128,
                          **kwargs):
    '''
    Build continuous latent layers with learned group feature attention.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_start_end'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])
            att_dim = x_mean.shape[1]
            atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim)
            atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1
                                     ])  # [b, n_latents, 2, att_dim, 1, 1]
            att_sm = tf.nn.softmax(atts, axis=3)
            att_cs = tf.cumsum(att_sm, axis=3)
            att_cs_starts, att_cs_ends = tf.split(
                att_cs, 2, axis=2)  # [b, n_latents, 1, att_dim, 1, 1]
            att_cs_ends = 1 - att_cs_ends
            atts = att_cs_starts * att_cs_ends  # [b, n_latents, 1, att_dim, 1, 1]
            atts = tf.reshape(atts, [-1, n_latents, att_dim, 1, 1])

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i:i + 1]

        if return_atts:
            return x, atts
        else:
            return x
Example #2
0
def build_C_spgroup_softmax_layers(x,
                                   name,
                                   n_latents,
                                   start_idx,
                                   scope_idx,
                                   dlatents_in,
                                   act,
                                   fused_modconv,
                                   fmaps=128,
                                   return_atts=False,
                                   resolution=128,
                                   **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention with pure softmax.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_wh = x.shape[2]
            atts = conv2d_layer(x, fmaps=n_latents, kernel=3)
            atts = tf.reshape(
                atts, [-1, n_latents, x_wh * x_wh])  # [b, n_latents, m]
            atts = tf.nn.softmax(atts, axis=-1)
            atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh])

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x
def construct_feat_by_concat_masks_latent(feat_on_masks, masks, dlatents_in):
    '''
    feat_on_masks: [b, n_masks, dim, h, w]
    masks: [b, n_masks, h, w]
    dlatents_in: [b, n_masks]
    '''
    n_masks, dim, h, w = feat_on_masks.get_shape().as_list()[1:]
    masks = masks[:, :, np.newaxis, ...]

    feat_on_masks = tf.reshape(feat_on_masks, [-1, dim, h, w])
    feat_on_masks = instance_norm(feat_on_masks)
    feat_on_masks = tf.reshape(feat_on_masks, [-1, n_masks, dim, h, w])
    canvas = []
    for i in range(n_masks):
        with tf.variable_scope('style_mod-' + str(i)):
            feat_styled = style_mod(feat_on_masks[:, i],
                                    dlatents_in[:, i:i + 1])  # [b, dim, h, w]
            canvas.append(feat_styled)
            # canvas = canvas * (1 - masks[:, i]) + feat_styled * masks[:, i]
    canvas = tf.concat(canvas, axis=1)
    return canvas
def construct_feat_by_masks_latent(feat_on_masks, masks, dlatents_in):
    '''
    feat_on_masks: [b, n_masks, dim, h, w]
    masks: [b, n_masks, h, w]
    dlatents_in: [b, n_masks]
    '''
    n_masks, dim, h, w = feat_on_masks.get_shape().as_list()[1:]
    with tf.variable_scope('CanvasConst'):
        canvas = tf.get_variable('canvas_const',
                                 shape=[1, dim, h, w],
                                 initializer=tf.initializers.random_normal())
        canvas = tf.tile(tf.cast(canvas, feat_on_masks.dtype),
                         [tf.shape(feat_on_masks)[0], 1, 1, 1])
    masks = masks[:, :, np.newaxis, ...]

    feat_on_masks = tf.reshape(feat_on_masks, [-1, dim, h, w])
    feat_on_masks = instance_norm(feat_on_masks)
    feat_on_masks = tf.reshape(feat_on_masks, [-1, n_masks, dim, h, w])
    for i in range(n_masks):
        with tf.variable_scope('style_mod-' + str(i)):
            feat_styled = style_mod(feat_on_masks[:, i], dlatents_in[:,
                                                                     i:i + 1])
            canvas = canvas * (1 - masks[:, i]) + feat_styled * masks[:, i]
    return canvas
Example #5
0
def build_C_spfgroup_layers(x,
                            name,
                            n_latents,
                            start_idx,
                            scope_idx,
                            dlatents_in,
                            act,
                            fused_modconv,
                            fmaps=128,
                            return_atts=False,
                            resolution=128,
                            **kwargs):
    '''
    Build continuous latent layers with learned group feature-spatial attention.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_channel_start_end'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            att_dim = x_mean.shape[1]
            atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim)
            atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1
                                     ])  # [b, n_latents, 2, att_dim, 1, 1]
            att_sm = tf.nn.softmax(atts, axis=3)
            att_cs = tf.cumsum(att_sm, axis=3)
            att_cs_starts, att_cs_ends = tf.split(att_cs, 2, axis=2)
            att_cs_ends = 1 - att_cs_ends
            att_channel = att_cs_starts * att_cs_ends  # [b, n_latents, 1, att_dim, 1, 1]
            att_channel = tf.reshape(att_channel,
                                     [-1, n_latents, att_dim, 1, 1])

        with tf.variable_scope('Att_spatial'):
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh)
            atts_wh = tf.reshape(
                atts_wh, [-1, n_latents, 4, x_wh])  # [b, n_latents, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=2)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, 1, 1, x_wh]
            att_sp = att_h * att_w  # [b, n_latents, 1, x_wh, x_wh]
            atts = att_channel * att_sp  # [b, n_latents, att_dim, h, w]
        # print('in spfgroup 1, x.shape:', x.get_shape().as_list())

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        # print('in spfgroup 2, x.shape:', x.get_shape().as_list())
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                att_sp = tf.reshape(att_sp, [-1, x_wh, x_wh, 1])
                att_sp = tf.image.resize(att_sp, size=(resolution, resolution))
                att_sp = tf.reshape(att_sp,
                                    [-1, n_latents, 1, resolution, resolution])
                # return x, att_channel, att_sp
            return x, att_sp
        else:
            return x
Example #6
0
def build_C_spgroup_stn_layers(x,
                               name,
                               n_latents,
                               start_idx,
                               scope_idx,
                               dlatents_in,
                               act,
                               fused_modconv,
                               fmaps=128,
                               return_atts=False,
                               resolution=128,
                               **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention with spatial transform.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh)
            atts_wh = tf.reshape(
                atts_wh, [-1, n_latents, 4, x_wh])  # [b, n_latents, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=2)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, 1, 1, x_wh]
            atts = att_h * att_w  # [b, n_latents, 1, x_wh, x_wh]

        with tf.variable_scope('trans_matrix'):
            theta = apply_bias_act(dense_layer(x_mean, fmaps=n_latents * 6))
            theta = tf.reshape(theta, [-1, 6])  # [b*n_latents, 6]
            atts = tf.reshape(
                atts, [-1, x_wh, x_wh, 1])  # [b*n_latents, x_wh, x_wh, 1]
            atts = transformer(atts, theta)  # [b*n_latents, x_wh, x_wh, 1]
            atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh])

        with tf.variable_scope('Att_apply'):
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x
Example #7
0
def build_C_spgroup_lcond_layers(x,
                                 name,
                                 n_latents,
                                 start_idx,
                                 scope_idx,
                                 dlatents_in,
                                 act,
                                 fused_modconv,
                                 fmaps=128,
                                 return_atts=False,
                                 resolution=128,
                                 **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents]

            atts_ls = []
            for i in range(n_latents):
                with tf.variable_scope('lcond-' + str(i)):
                    x_mean_styled = style_mod(x_mean,
                                              C_global_latents[:, i:i + 1])

                    att_wh = dense_layer(x_mean_styled, fmaps=4 * x_wh)
                    att_wh = tf.reshape(att_wh, [-1, 4, x_wh])  # [b, 4, x_wh]
                    att_wh_sm = tf.nn.softmax(att_wh, axis=-1)
                    att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
                    att_h_cs_start, att_h_cs_end, att_w_cs_start, att_w_cs_end = tf.split(
                        att_wh_cs, 4, axis=1)
                    att_h_cs_end = 1 - att_h_cs_end  # [b, 1, x_wh]
                    att_w_cs_end = 1 - att_w_cs_end  # [b, 1, x_wh]
                    att_h_cs_start = tf.reshape(att_h_cs_start,
                                                [-1, 1, 1, x_wh, 1])
                    att_h_cs_end = tf.reshape(att_h_cs_end,
                                              [-1, 1, 1, x_wh, 1])
                    att_h = att_h_cs_start * att_h_cs_end  # [b, 1, 1, x_wh, 1]
                    att_w_cs_start = tf.reshape(att_w_cs_start,
                                                [-1, 1, 1, 1, x_wh])
                    att_w_cs_end = tf.reshape(att_w_cs_end,
                                              [-1, 1, 1, 1, x_wh])
                    att_w = att_w_cs_start * att_w_cs_end  # [b, 1, 1, 1, x_wh]
                    att = att_h * att_w  # [b, 1, 1, x_wh, x_wh]
                    atts_ls.append(att)
            atts = tf.concat(atts_ls, axis=1)

        with tf.variable_scope('Att_apply'):
            x_norm = instance_norm(x)
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x
Example #8
0
def build_C_spgroup_layers_with_latents_ready(x,
                                              name,
                                              n_latents,
                                              scope_idx,
                                              latents_ready,
                                              return_atts=False,
                                              resolution=128,
                                              n_subs=1,
                                              **kwargs):
    '''
    Build continuous latent layers with learned group spatial attention using latents_ready.
    Support square images only.
    '''
    with tf.variable_scope(name + '-' + str(scope_idx)):
        with tf.variable_scope('Att_spatial'):
            x_mean = tf.reduce_mean(x, axis=[2, 3])  # [b, in_dim]
            x_wh = x.shape[2]
            atts_wh = dense_layer(x_mean, fmaps=n_latents * n_subs * 4 * x_wh)
            atts_wh = tf.reshape(atts_wh,
                                 [-1, n_latents, n_subs, 4, x_wh
                                  ])  # [b, n_latents, n_subs, 4, x_wh]
            att_wh_sm = tf.nn.softmax(atts_wh, axis=-1)
            att_wh_cs = tf.cumsum(att_wh_sm, axis=-1)
            att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split(
                att_wh_cs, 4, axis=3)
            att_h_cs_ends = 1 - att_h_cs_ends  # [b, n_latents, n_subs, 1, x_wh]
            att_w_cs_ends = 1 - att_w_cs_ends  # [b, n_latents, n_subs, 1, x_wh]
            att_h_cs_starts = tf.reshape(att_h_cs_starts,
                                         [-1, n_latents, n_subs, 1, x_wh, 1])
            att_h_cs_ends = tf.reshape(att_h_cs_ends,
                                       [-1, n_latents, n_subs, 1, x_wh, 1])
            att_h = att_h_cs_starts * att_h_cs_ends  # [b, n_latents, n_subs, 1, x_wh, 1]
            att_w_cs_starts = tf.reshape(att_w_cs_starts,
                                         [-1, n_latents, n_subs, 1, 1, x_wh])
            att_w_cs_ends = tf.reshape(att_w_cs_ends,
                                       [-1, n_latents, n_subs, 1, 1, x_wh])
            att_w = att_w_cs_starts * att_w_cs_ends  # [b, n_latents, n_subs, 1, 1, x_wh]
            atts = att_h * att_w  # [b, n_latents, n_subs, 1, x_wh, x_wh]
            atts = tf.reduce_mean(atts,
                                  axis=2)  # [b, n_latents, 1, x_wh, x_wh]
            # atts = tf.reduce_sum(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh]

        with tf.variable_scope('Att_apply'):
            C_global_latents = latents_ready  # [b, n_latents, 512]
            x_norm = instance_norm(x)
            # x_norm = tf.tile(x_norm, [1, n_latents, 1, 1])
            # x_norm = tf.reshape(x_norm, [-1, x.shape[1], x.shape[2], x.shape[3]]) # [b*n_latents, c, h, w]
            # C_global_latents = tf.reshape(C_global_latents, [-1, 1])
            # x_styled = style_mod(x_norm, C_global_latents)
            # x_styled = tf.reshape(x_styled, [-1, n_latents, x_styled.shape[1],
            # x_styled.shape[2], x_styled.shape[3]])
            for i in range(n_latents):
                with tf.variable_scope('style_mod-' + str(i)):
                    x_styled = style_mod(x_norm, C_global_latents[:, i])
                    x = x * (1 - atts[:, i]) + x_styled * atts[:, i]
                    # x = x * (1 - atts[:, i]) + x_styled[:, i] * atts[:, i]

        if return_atts:
            with tf.variable_scope('Reshape_output'):
                atts = tf.reshape(atts, [-1, x_wh, x_wh, 1])
                atts = tf.image.resize(atts, size=(resolution, resolution))
                atts = tf.reshape(atts,
                                  [-1, n_latents, 1, resolution, resolution])
            return x, atts
        else:
            return x