def update_em_routing(wx, input_activation, activation_biases, sigma_biases, logit_shape, num_out_atoms, num_routing, output_dim, leaky, final_beta, min_var): """Fully connected routing with EM for Mixture of Gaussians.""" # Wx: [batch, indim, outdim, outatom, height, width] # logit_shape: [indim, outdim, 1, height, width] # input_activations: [batch, indim, 1, 1, 1, 1] # activation_biases: [1, 1, outdim, 1, height, width] # prior = utils.bias_variable([1] + logit_shape, name='prior') update = tf.fill( tf.stack([ tf.shape(input_activation)[0], logit_shape[0], logit_shape[1], logit_shape[2], logit_shape[3], logit_shape[4] ]), 0.0) out_activation = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3], logit_shape[4] ]), 0.0) out_center = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, num_out_atoms, logit_shape[3], logit_shape[4] ]), 0.0) def _body(i, update, activation, center): """Body of the EM while loop.""" del activation beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32))) # beta = final_beta # route: [outdim, height?, width?, batch, indim] if leaky: posterior = layers.leaky_routing(update, output_dim) else: posterior = tf.nn.softmax(update, dim=2) vote_conf = posterior * input_activation # masses: [batch, 1, outdim, 1, height, width] masses = tf.reduce_sum(vote_conf, axis=1, keep_dims=True) + 0.00001 preactivate_unrolled = vote_conf * wx # center: [batch, 1, outdim, outatom, height, width] center = .9 * tf.reduce_sum(preactivate_unrolled, axis=1, keep_dims=True) / masses + .1 * center noise = (wx - center) * (wx - center) variance = min_var + tf.reduce_sum( vote_conf * noise, axis=1, keep_dims=True) / masses log_variance = tf.log(variance) p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True) log_2pi = tf.log(2 * math.pi) win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0)) logit = beta * (win - activation_biases * 5000) activation_update = tf.minimum( 0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit))) # return activation, center log_det_sigma = tf.reduce_sum(log_variance, axis=3, keep_dims=True) sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0 exp_update = tf.reduce_sum(noise / (2 * variance), axis=3, keep_dims=True) prior_update = activation_update - sigma_update - exp_update return (prior_update, logit, center) # activations = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # centers = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates.write(0, prior_update) for i in range(num_routing): update, out_activation, out_center = _body(i, update, out_activation, out_center) # for j in range(num_routing): # _, prior_update, out_activation, out_center = _body( # i, prior_update, start_activation, start_center) with tf.name_scope('out_activation'): utils.activation_summary(tf.sigmoid(out_activation)) with tf.name_scope('noise'): utils.variable_summaries((wx - out_center) * (wx - out_center)) with tf.name_scope('Wx'): utils.variable_summaries(wx) # for i in range(num_routing): # utils.activation_summary(activations.read(i)) # return activations.read(num_routing - 1), centers.read(num_routing - 1) return out_activation, out_center
def update_conv_routing(wx, input_activation, activation_biases, sigma_biases, logit_shape, num_out_atoms, input_dim, num_routing, output_dim, final_beta, min_var): """Convolutional Routing with EM for Mixture of Gaussians.""" # Wx: [batch, indim, outdim, outatom, height, width, k, k] # logit_shape: [indim, outdim, 1, height, width, k, k] # input_activations: [batch, indim, 1, 1, height, width, k, k] # activation_biases: [1, 1, outdim, 1, height, width] # prior = utils.bias_variable([1] + logit_shape, name='prior') post = tf.nn.softmax(tf.fill( tf.stack([ tf.shape(input_activation)[0], logit_shape[0], logit_shape[1], logit_shape[2], logit_shape[3], logit_shape[4], logit_shape[5], logit_shape[6] ]), 0.0), dim=2) out_activation = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) out_center = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, num_out_atoms, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) out_mass = tf.fill( tf.stack([ tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3], logit_shape[4], 1, 1 ]), 0.0) n = logit_shape[3] k = logit_shape[5] def _body(i, posterior, activation, center, masses): """Body of the EM while loop.""" del activation beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32))) # beta = final_beta # route: [outdim, height?, width?, batch, indim] vote_conf = posterior * input_activation # masses: [batch, 1, outdim, 1, height, width, 1, 1] masses = tf.reduce_sum(tf.reduce_sum(tf.reduce_sum( vote_conf, axis=1, keep_dims=True), axis=-1, keep_dims=True), axis=-2, keep_dims=True) + 0.0000001 preactivate_unrolled = vote_conf * wx # center: [batch, 1, outdim, outatom, height, width] center = .9 * tf.reduce_sum(tf.reduce_sum(tf.reduce_sum( preactivate_unrolled, axis=1, keep_dims=True), axis=-1, keep_dims=True), axis=-2, keep_dims=True) / masses + .1 * center noise = (wx - center) * (wx - center) variance = min_var + tf.reduce_sum(tf.reduce_sum(tf.reduce_sum( vote_conf * noise, axis=1, keep_dims=True), axis=-1, keep_dims=True), axis=-2, keep_dims=True) / masses log_variance = tf.log(variance) p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True) log_2pi = tf.log(2 * math.pi) win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0)) logit = beta * (win - activation_biases * 5000) activation_update = tf.minimum( 0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit))) # return activation, center log_det_sigma = -1 * p_i sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0 exp_update = tf.reduce_sum(noise / (2 * variance), axis=3, keep_dims=True) prior_update = activation_update - sigma_update - exp_update max_prior_update = tf.reduce_max(tf.reduce_max(tf.reduce_max( tf.reduce_max(prior_update, axis=-1, keep_dims=True), axis=-2, keep_dims=True), axis=-3, keep_dims=True), axis=-4, keep_dims=True) prior_normal = tf.add(prior_update, -1 * max_prior_update) prior_exp = tf.exp(prior_normal) t_prior = tf.transpose(prior_exp, [0, 1, 2, 3, 4, 6, 5, 7]) c_prior = tf.reshape(t_prior, [-1, n * k, n * k, 1]) pad_prior = tf.pad(c_prior, [[0, 0], [(k - 1) * (k - 1), (k - 1) * (k - 1)], [(k - 1) * (k - 1), (k - 1) * (k - 1)], [0, 0]], 'CONSTANT') patch_prior = tf.extract_image_patches(images=pad_prior, ksizes=[1, k, k, 1], strides=[1, k, k, 1], rates=[1, k - 1, k - 1, 1], padding='VALID') sum_prior = tf.reduce_sum(patch_prior, axis=-1, keep_dims=True) sum_prior_patch = tf.extract_image_patches(images=sum_prior, ksizes=[1, k, k, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID') sum_prior_reshape = tf.reshape( sum_prior_patch, [-1, input_dim, output_dim, 1, n, n, k, k]) + 0.0000001 posterior = prior_exp / sum_prior_reshape return (posterior, logit, center, masses) # activations = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # centers = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates = tf.TensorArray( # dtype=tf.float32, size=num_routing, clear_after_read=False) # updates.write(0, prior_update) for i in range(num_routing): post, out_activation, out_center, out_mass = _body( i, post, out_activation, out_center, out_mass) # for j in range(num_routing): # _, prior_update, out_activation, out_center = _body( # i, prior_update, start_activation, start_center) with tf.name_scope('out_activation'): utils.activation_summary(tf.sigmoid(out_activation)) with tf.name_scope('masses'): utils.activation_summary(tf.sigmoid(out_mass)) with tf.name_scope('posterior'): utils.activation_summary(post) with tf.name_scope('noise'): utils.variable_summaries((wx - out_center) * (wx - out_center)) with tf.name_scope('Wx'): utils.variable_summaries(wx) # for i in range(num_routing): # utils.activation_summary(activations.read(i)) # return activations.read(num_routing - 1), centers.read(num_routing - 1) return out_activation, out_center