def update_moving_average_of_w(w_broadcasted, w_avg, w_ema_decay): with tf.variable_scope('wAvg'): batch_avg = tf.reduce_mean(w_broadcasted[:, 0], axis=0) update_op = tf.assign(w_avg, lerp(batch_avg, w_avg, w_ema_decay)) with tf.control_dependencies([update_op]): w_broadcasted = tf.identity(w_broadcasted) return w_broadcasted
def truncation_trick(n_broadcast, w_broadcasted, w_avg, truncation_psi, truncation_cutoff): with tf.variable_scope('Truncation'): layer_indices = np.arange(n_broadcast)[np.newaxis, :, np.newaxis] ones = np.ones(layer_indices.shape, dtype=np.float32) coefs = tf.where(layer_indices < truncation_cutoff, truncation_psi * ones, ones) w_broadcasted = lerp(w_avg, w_broadcasted, coefs) return w_broadcasted
def smooth_crossfade(images, alpha): s = tf.shape(images) y = tf.reshape(images, [-1, s[1], s[2] // 2, 2, s[3] // 2, 2]) y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) y = tf.tile(y, [1, 1, 1, 2, 1, 2]) y = tf.reshape(y, [-1, s[1], s[2], s[3]]) images = lerp(images, y, alpha) return images
def test8(): from network.common_ops import lerp w_dim = 5 n_broadcast = 18 truncation_psi = 0.7 truncation_cutoff = 8 w_broadcasted = tf.constant(1.0, dtype=tf.float32, shape=[1, n_broadcast, w_dim]) w_avg = tf.constant(0.5, dtype=tf.float32, shape=[w_dim]) layer_idx = np.arange(n_broadcast)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) w_broadcasted = lerp(w_avg, w_broadcasted, coefs) print('layer_idx: {}'.format(layer_idx)) print('ones: {}'.format(ones)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) coefs_out, out = sess.run([coefs, w_broadcasted]) print(coefs_out) print(out) return