def apply_st(x, st_matrix, idx, up=True): # idx: 2, 3, 4 with tf.variable_scope('Transform'): x = tf.transpose(x, [0, 2, 3, 1]) # NCHW -> NHWC x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3]) x = tf.transpose(x, [0, 3, 1, 2]) # NHWC -> NCHW with tf.variable_scope('Upconv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(idx), kernel=3, up=up, resample_kernel=resample_kernel), act=act) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=nf(idx), kernel=3), act=act) return x
def apply_st(x, st_matrix, up=True, fmaps=128, resample_kernel=[1, 3, 3, 1], act='lrelu'): with tf.variable_scope('Transform'): x = tf.transpose(x, [0, 2, 3, 1]) # NCHW -> NHWC x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3]) x = tf.transpose(x, [0, 3, 1, 2]) # NHWC -> NCHW with tf.variable_scope('ConvMayUp'): x = apply_bias_act(conv2d_layer(x, fmaps=fmaps, kernel=3, up=up, resample_kernel=resample_kernel), act=act) with tf.variable_scope('Conv'): x = apply_bias_act(conv2d_layer(x, fmaps=fmaps, kernel=3), act=act) return x
def build_C_spgroup_stn_layers(x, name, n_latents, start_idx, scope_idx, dlatents_in, act, fused_modconv, fmaps=128, return_atts=False, resolution=128, **kwargs): ''' Build continuous latent layers with learned group spatial attention with spatial transform. Support square images only. ''' with tf.variable_scope(name + '-' + str(scope_idx)): with tf.variable_scope('Att_spatial'): x_mean = tf.reduce_mean(x, axis=[2, 3]) # [b, in_dim] x_wh = x.shape[2] atts_wh = dense_layer(x_mean, fmaps=n_latents * 4 * x_wh) atts_wh = tf.reshape( atts_wh, [-1, n_latents, 4, x_wh]) # [b, n_latents, 4, x_wh] att_wh_sm = tf.nn.softmax(atts_wh, axis=-1) att_wh_cs = tf.cumsum(att_wh_sm, axis=-1) att_h_cs_starts, att_h_cs_ends, att_w_cs_starts, att_w_cs_ends = tf.split( att_wh_cs, 4, axis=2) att_h_cs_ends = 1 - att_h_cs_ends # [b, n_latents, 1, x_wh] att_w_cs_ends = 1 - att_w_cs_ends # [b, n_latents, 1, x_wh] att_h_cs_starts = tf.reshape(att_h_cs_starts, [-1, n_latents, 1, x_wh, 1]) att_h_cs_ends = tf.reshape(att_h_cs_ends, [-1, n_latents, 1, x_wh, 1]) att_h = att_h_cs_starts * att_h_cs_ends # [b, n_latents, 1, x_wh, 1] att_w_cs_starts = tf.reshape(att_w_cs_starts, [-1, n_latents, 1, 1, x_wh]) att_w_cs_ends = tf.reshape(att_w_cs_ends, [-1, n_latents, 1, 1, x_wh]) att_w = att_w_cs_starts * att_w_cs_ends # [b, n_latents, 1, 1, x_wh] atts = att_h * att_w # [b, n_latents, 1, x_wh, x_wh] with tf.variable_scope('trans_matrix'): theta = apply_bias_act(dense_layer(x_mean, fmaps=n_latents * 6)) theta = tf.reshape(theta, [-1, 6]) # [b*n_latents, 6] atts = tf.reshape( atts, [-1, x_wh, x_wh, 1]) # [b*n_latents, x_wh, x_wh, 1] atts = transformer(atts, theta) # [b*n_latents, x_wh, x_wh, 1] atts = tf.reshape(atts, [-1, n_latents, 1, x_wh, x_wh]) with tf.variable_scope('Att_apply'): C_global_latents = dlatents_in[:, start_idx:start_idx + n_latents] x_norm = instance_norm(x) for i in range(n_latents): with tf.variable_scope('style_mod-' + str(i)): x_styled = style_mod(x_norm, C_global_latents[:, i:i + 1]) x = x * (1 - atts[:, i]) + x_styled * atts[:, i] if return_atts: with tf.variable_scope('Reshape_output'): atts = tf.reshape(atts, [-1, x_wh, x_wh, 1]) atts = tf.image.resize(atts, size=(resolution, resolution)) atts = tf.reshape(atts, [-1, n_latents, 1, resolution, resolution]) return x, atts else: return x
def apply_st(x, st_matrix): with tf.variable_scope('Transform'): x = tf.transpose(x, [0, 2, 3, 1]) # NCHW -> NHWC x = transformer(x, st_matrix, out_dims=x.shape.as_list()[1:3]) x = tf.transpose(x, [0, 3, 1, 2]) # NHWC -> NCHW return x