def build_C_fgroup_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group feature attention. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_start_end'): x_mean = tf.reduce_mean(x, axis=[2, 3]) att_dim = x_mean.shape[1] atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim) atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1 ]) # [b, n_latents, 2, att_dim, 1, 1] att_sm = tf.nn.softmax(atts, axis=3) att_cs = tf.cumsum(att_sm, axis=3) att_cs_starts, att_cs_ends = tf.split( att_cs, 2, axis=2) # [b, n_latents, 1, att_dim, 1, 1] att_cs_ends = 1 - att_cs_ends atts = att_cs_starts * att_cs_ends # [b, n_latents, 1, att_dim, 1, 1] atts = tf.reshape(atts, [-1, n_latents, att_dim, 1, 1]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i:i + 1] if return_atts: return x, atts else: return x
def build_C_spgroup_softmax_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention with pure softmax. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_wh = x.shape[2] atts = conv2d_layer(x, fmaps=n_latents, kernel=3) atts = tf.reshape( atts, [-1, n_latents, x_wh * x_wh]) # [b, n_latents, m] atts = tf.nn.softmax(atts, axis=-1) atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def construct_feat_by_concat_masks_latent(feat_on_masks, masks, dlatents_in): ''' feat_on_masks: [b, n_masks, dim, h, w] masks: [b, n_masks, h, w] dlatents_in: [b, n_masks] ''' n_masks, dim, h, w = feat_on_masks.get_shape().as_list()[1:] masks = masks[:, :, np.newaxis, ...] feat_on_masks = tf.reshape(feat_on_masks, [-1, dim, h, w]) feat_on_masks = instance_norm(feat_on_masks) feat_on_masks = tf.reshape(feat_on_masks, [-1, n_masks, dim, h, w]) canvas = [] for i in range(n_masks): with tf.variable_scope('style_mod-' + str(i)): feat_styled = style_mod(feat_on_masks[:, i], dlatents_in[:, i:i + 1]) # [b, dim, h, w] canvas.append(feat_styled) # canvas = canvas * (1 - masks[:, i]) + feat_styled * masks[:, i] canvas = tf.concat(canvas, axis=1) return canvas
def construct_feat_by_masks_latent(feat_on_masks, masks, dlatents_in): ''' feat_on_masks: [b, n_masks, dim, h, w] masks: [b, n_masks, h, w] dlatents_in: [b, n_masks] ''' n_masks, dim, h, w = feat_on_masks.get_shape().as_list()[1:] with tf.variable_scope('CanvasConst'): canvas = tf.get_variable('canvas_const', shape=[1, dim, h, w], initializer=tf.initializers.random_normal()) canvas = tf.tile(tf.cast(canvas, feat_on_masks.dtype), [tf.shape(feat_on_masks)[0], 1, 1, 1]) masks = masks[:, :, np.newaxis, ...] feat_on_masks = tf.reshape(feat_on_masks, [-1, dim, h, w]) feat_on_masks = instance_norm(feat_on_masks) feat_on_masks = tf.reshape(feat_on_masks, [-1, n_masks, dim, h, w]) for i in range(n_masks): with tf.variable_scope('style_mod-' + str(i)): feat_styled = style_mod(feat_on_masks[:, i], dlatents_in[:, i:i + 1]) canvas = canvas * (1 - masks[:, i]) + feat_styled * masks[:, i] return canvas
def build_C_spfgroup_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group feature-spatial attention. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_channel_start_end'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] att_dim = x_mean.shape[1] atts = dense_layer(x_mean, fmaps=n_latents * 2 * att_dim) atts = tf.reshape(atts, [-1, n_latents, 2, att_dim, 1, 1 ]) # [b, n_latents, 2, att_dim, 1, 1] att_sm = tf.nn.softmax(atts, axis=3) att_cs = tf.cumsum(att_sm, axis=3) att_cs_starts, att_cs_ends = tf.split(att_cs, 2, axis=2) att_cs_ends = 1 - att_cs_ends att_channel = att_cs_starts * att_cs_ends # [b, n_latents, 1, att_dim, 1, 1] att_channel = tf.reshape(att_channel, [-1, n_latents, att_dim, 1, 1]) with tf.variable_scope('Att_spatial'): x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] att_sp = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] atts = att_channel * att_sp # [b, n_latents, att_dim, h, w] # print('in spfgroup 1, x.shape:', x.get_shape().as_list()) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] # print('in spfgroup 2, x.shape:', x.get_shape().as_list()) if return_atts: with tf.variable_scope('Reshape_output'): att_sp = tf.reshape(att_sp, [-1, x_wh, x_wh, 1]) att_sp = tf.image.resize(att_sp, size=(resolution, resolution)) att_sp = tf.reshape(att_sp, [-1, n_latents, 1, resolution, resolution]) # return x, att_channel, att_sp return x, att_sp else: return x
def build_C_spgroup_stn_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention with spatial transform. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('trans_matrix'): theta = apply_bias_act(dense_layer(x_mean, fmaps=n_latents * 6)) theta = tf.reshape(theta, [-1, 6]) # [b*n_latents, 6] atts = tf.reshape( atts, [-1, x_wh, x_wh, 1]) # [b*n_latents, x_wh, x_wh, 1] atts = transformer(atts, theta) # [b*n_latents, x_wh, x_wh, 1] atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def build_C_spgroup_lcond_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] atts_ls = [] for i in range(n_latents): with tf.variable_scope('lcond-' + str(i)): x_mean_styled = style_mod(x_mean, C_global_latents[:, i:i + 1]) att_wh = dense_layer(x_mean_styled, fmaps=4 * x_wh) att_wh = tf.reshape(att_wh, [-1, 4, x_wh]) # [b, 4, x_wh] att_wh_sm = tf.nn.softmax(att_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_start, att_h_cs_end, att_w_cs_start, att_w_cs_end = tf.split( att_wh_cs, 4, axis=1) att_h_cs_end = 1 - att_h_cs_end # [b, 1, x_wh] att_w_cs_end = 1 - att_w_cs_end # [b, 1, x_wh] att_h_cs_start = tf.reshape(att_h_cs_start, [-1, 1, 1, x_wh, 1]) att_h_cs_end = tf.reshape(att_h_cs_end, [-1, 1, 1, x_wh, 1]) att_h = att_h_cs_start * att_h_cs_end # [b, 1, 1, x_wh, 1] att_w_cs_start = tf.reshape(att_w_cs_start, [-1, 1, 1, 1, x_wh]) att_w_cs_end = tf.reshape(att_w_cs_end, [-1, 1, 1, 1, x_wh]) att_w = att_w_cs_start * att_w_cs_end # [b, 1, 1, 1, x_wh] att = att_h * att_w # [b, 1, 1, x_wh, x_wh] atts_ls.append(att) atts = tf.concat(atts_ls, axis=1) with tf.variable_scope('Att_apply'): x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def build_C_spgroup_layers_with_latents_ready(x, name, n_latents, scope_idx, latents_ready, return_atts=False, resolution=128, n_subs=1, **kwargs): ''' Build continuous latent layers with learned group spatial attention using latents_ready. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * n_subs * 4 * x_wh) atts_wh = tf.reshape(atts_wh, [-1, n_latents, n_subs, 4, x_wh ]) # [b, n_latents, n_subs, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=3) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, n_subs, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, n_subs, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, n_subs, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, n_subs, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, n_subs, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, n_subs, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, n_subs, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, n_subs, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, n_subs, 1, x_wh, x_wh] atts = tf.reduce_mean(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh] # atts = tf.reduce_sum(atts, axis=2) # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('Att_apply'): C_global_latents = latents_ready # [b, n_latents, 512] x_norm = instance_norm(x) # x_norm = tf.tile(x_norm, [1, n_latents, 1, 1]) # x_norm = tf.reshape(x_norm, [-1, x.shape[1], x.shape[2], x.shape[3]]) # [b*n_latents, c, h, w] # C_global_latents = tf.reshape(C_global_latents, [-1, 1]) # x_styled = style_mod(x_norm, C_global_latents) # x_styled = tf.reshape(x_styled, [-1, n_latents, x_styled.shape[1], # x_styled.shape[2], x_styled.shape[3]]) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] # x = x * (1 - atts[:, i]) + x_styled[:, i] * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x