예제 #1
0
def update_em_routing(wx, input_activation, activation_biases, sigma_biases,
                      logit_shape, num_out_atoms, num_routing, output_dim,
                      leaky, final_beta, min_var):
    """Fully connected routing with EM for Mixture of Gaussians."""
    # Wx: [batch, indim, outdim, outatom, height, width]
    # logit_shape: [indim, outdim, 1, height, width]
    # input_activations: [batch, indim, 1, 1, 1, 1]
    # activation_biases: [1, 1, outdim, 1, height, width]
    # prior = utils.bias_variable([1] + logit_shape, name='prior')
    update = tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], logit_shape[0], logit_shape[1],
            logit_shape[2], logit_shape[3], logit_shape[4]
        ]), 0.0)
    out_activation = tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3],
            logit_shape[4]
        ]), 0.0)
    out_center = tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], 1, output_dim, num_out_atoms,
            logit_shape[3], logit_shape[4]
        ]), 0.0)

    def _body(i, update, activation, center):
        """Body of the EM while loop."""
        del activation
        beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))
        # beta = final_beta
        # route: [outdim, height?, width?, batch, indim]
        if leaky:
            posterior = layers.leaky_routing(update, output_dim)
        else:
            posterior = tf.nn.softmax(update, dim=2)
        vote_conf = posterior * input_activation
        # masses: [batch, 1, outdim, 1, height, width]
        masses = tf.reduce_sum(vote_conf, axis=1, keep_dims=True) + 0.00001
        preactivate_unrolled = vote_conf * wx
        # center: [batch, 1, outdim, outatom, height, width]
        center = .9 * tf.reduce_sum(preactivate_unrolled,
                                    axis=1,
                                    keep_dims=True) / masses + .1 * center

        noise = (wx - center) * (wx - center)
        variance = min_var + tf.reduce_sum(
            vote_conf * noise, axis=1, keep_dims=True) / masses
        log_variance = tf.log(variance)
        p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True)
        log_2pi = tf.log(2 * math.pi)
        win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0))
        logit = beta * (win - activation_biases * 5000)
        activation_update = tf.minimum(
            0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit)))
        # return activation, center
        log_det_sigma = tf.reduce_sum(log_variance, axis=3, keep_dims=True)
        sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0
        exp_update = tf.reduce_sum(noise / (2 * variance),
                                   axis=3,
                                   keep_dims=True)
        prior_update = activation_update - sigma_update - exp_update
        return (prior_update, logit, center)

    # activations = tf.TensorArray(
    #     dtype=tf.float32, size=num_routing, clear_after_read=False)
    # centers = tf.TensorArray(
    #     dtype=tf.float32, size=num_routing, clear_after_read=False)
    # updates = tf.TensorArray(
    #     dtype=tf.float32, size=num_routing, clear_after_read=False)
    # updates.write(0, prior_update)
    for i in range(num_routing):
        update, out_activation, out_center = _body(i, update, out_activation,
                                                   out_center)
    # for j in range(num_routing):
    #   _, prior_update, out_activation, out_center = _body(
    #       i, prior_update, start_activation, start_center)
    with tf.name_scope('out_activation'):
        utils.activation_summary(tf.sigmoid(out_activation))
    with tf.name_scope('noise'):
        utils.variable_summaries((wx - out_center) * (wx - out_center))
    with tf.name_scope('Wx'):
        utils.variable_summaries(wx)

    # for i in range(num_routing):
    #   utils.activation_summary(activations.read(i))
    # return activations.read(num_routing - 1), centers.read(num_routing - 1)
    return out_activation, out_center
예제 #2
0
def update_conv_routing(wx, input_activation, activation_biases, sigma_biases,
                        logit_shape, num_out_atoms, input_dim, num_routing,
                        output_dim, final_beta, min_var):
    """Convolutional Routing with EM for Mixture of Gaussians."""
    # Wx: [batch, indim, outdim, outatom, height, width, k, k]
    # logit_shape: [indim, outdim, 1, height, width, k, k]
    # input_activations: [batch, indim, 1, 1, height, width, k, k]
    # activation_biases: [1, 1, outdim, 1, height, width]
    # prior = utils.bias_variable([1] + logit_shape, name='prior')
    post = tf.nn.softmax(tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], logit_shape[0], logit_shape[1],
            logit_shape[2], logit_shape[3], logit_shape[4], logit_shape[5],
            logit_shape[6]
        ]), 0.0),
                         dim=2)
    out_activation = tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3],
            logit_shape[4], 1, 1
        ]), 0.0)
    out_center = tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], 1, output_dim, num_out_atoms,
            logit_shape[3], logit_shape[4], 1, 1
        ]), 0.0)
    out_mass = tf.fill(
        tf.stack([
            tf.shape(input_activation)[0], 1, output_dim, 1, logit_shape[3],
            logit_shape[4], 1, 1
        ]), 0.0)
    n = logit_shape[3]
    k = logit_shape[5]

    def _body(i, posterior, activation, center, masses):
        """Body of the EM while loop."""
        del activation
        beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))
        # beta = final_beta
        # route: [outdim, height?, width?, batch, indim]
        vote_conf = posterior * input_activation
        # masses: [batch, 1, outdim, 1, height, width, 1, 1]
        masses = tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(
            vote_conf, axis=1, keep_dims=True),
                                             axis=-1,
                                             keep_dims=True),
                               axis=-2,
                               keep_dims=True) + 0.0000001
        preactivate_unrolled = vote_conf * wx
        # center: [batch, 1, outdim, outatom, height, width]
        center = .9 * tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(
            preactivate_unrolled, axis=1, keep_dims=True),
                                                  axis=-1,
                                                  keep_dims=True),
                                    axis=-2,
                                    keep_dims=True) / masses + .1 * center

        noise = (wx - center) * (wx - center)
        variance = min_var + tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(
            vote_conf * noise, axis=1, keep_dims=True),
                                                         axis=-1,
                                                         keep_dims=True),
                                           axis=-2,
                                           keep_dims=True) / masses
        log_variance = tf.log(variance)
        p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True)
        log_2pi = tf.log(2 * math.pi)
        win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0))
        logit = beta * (win - activation_biases * 5000)
        activation_update = tf.minimum(
            0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit)))
        # return activation, center
        log_det_sigma = -1 * p_i
        sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0
        exp_update = tf.reduce_sum(noise / (2 * variance),
                                   axis=3,
                                   keep_dims=True)
        prior_update = activation_update - sigma_update - exp_update
        max_prior_update = tf.reduce_max(tf.reduce_max(tf.reduce_max(
            tf.reduce_max(prior_update, axis=-1, keep_dims=True),
            axis=-2,
            keep_dims=True),
                                                       axis=-3,
                                                       keep_dims=True),
                                         axis=-4,
                                         keep_dims=True)
        prior_normal = tf.add(prior_update, -1 * max_prior_update)
        prior_exp = tf.exp(prior_normal)
        t_prior = tf.transpose(prior_exp, [0, 1, 2, 3, 4, 6, 5, 7])
        c_prior = tf.reshape(t_prior, [-1, n * k, n * k, 1])
        pad_prior = tf.pad(c_prior,
                           [[0, 0], [(k - 1) * (k - 1), (k - 1) * (k - 1)],
                            [(k - 1) * (k - 1),
                             (k - 1) * (k - 1)], [0, 0]], 'CONSTANT')
        patch_prior = tf.extract_image_patches(images=pad_prior,
                                               ksizes=[1, k, k, 1],
                                               strides=[1, k, k, 1],
                                               rates=[1, k - 1, k - 1, 1],
                                               padding='VALID')
        sum_prior = tf.reduce_sum(patch_prior, axis=-1, keep_dims=True)
        sum_prior_patch = tf.extract_image_patches(images=sum_prior,
                                                   ksizes=[1, k, k, 1],
                                                   strides=[1, 1, 1, 1],
                                                   rates=[1, 1, 1, 1],
                                                   padding='VALID')
        sum_prior_reshape = tf.reshape(
            sum_prior_patch,
            [-1, input_dim, output_dim, 1, n, n, k, k]) + 0.0000001
        posterior = prior_exp / sum_prior_reshape
        return (posterior, logit, center, masses)

    # activations = tf.TensorArray(
    #     dtype=tf.float32, size=num_routing, clear_after_read=False)
    # centers = tf.TensorArray(
    #     dtype=tf.float32, size=num_routing, clear_after_read=False)
    # updates = tf.TensorArray(
    #     dtype=tf.float32, size=num_routing, clear_after_read=False)
    # updates.write(0, prior_update)
    for i in range(num_routing):
        post, out_activation, out_center, out_mass = _body(
            i, post, out_activation, out_center, out_mass)
    # for j in range(num_routing):
    #   _, prior_update, out_activation, out_center = _body(
    #       i, prior_update, start_activation, start_center)
    with tf.name_scope('out_activation'):
        utils.activation_summary(tf.sigmoid(out_activation))
    with tf.name_scope('masses'):
        utils.activation_summary(tf.sigmoid(out_mass))
    with tf.name_scope('posterior'):
        utils.activation_summary(post)
    with tf.name_scope('noise'):
        utils.variable_summaries((wx - out_center) * (wx - out_center))
    with tf.name_scope('Wx'):
        utils.variable_summaries(wx)

    # for i in range(num_routing):
    #   utils.activation_summary(activations.read(i))
    # return activations.read(num_routing - 1), centers.read(num_routing - 1)
    return out_activation, out_center