コード例 #1
0
 def test_shape_2(self):
     with self.test_session():
         s = helpers.broadcast(
             tf.convert_to_tensor(np.random.rand(10, 1, 15, 1, 15)),
             tf.convert_to_tensor(np.random.rand(1, 10, 1, 15, 3)),
             axis=[0, 1],
             broadcast_b=False)
         a, b = s[0].eval(), s[1].eval()
         self.assertAllCloseAccordingToType(a.shape, [10, 10, 15, 1, 15])
         self.assertAllCloseAccordingToType(b.shape, [1, 10, 1, 15, 3])
コード例 #2
0
def _routing(predictions, coupling_logits, routing_iterations):
    """
    Performs capsule routing.

    Parameters
    ----------
    predictions: tf.Tensor
        Lower level capsule predictions `(batch, num_caps[l], num_caps[l-1], 1, units[l])`
    coupling_logits: tf.Tensor
        Starting coupling logits `b` `(batch, num_caps[l], num_caps[l-1], 1)`
    routing_iterations: int
        Number of routing iterations

    Returns
    -------
    tf.Tensor
        Resulting capsule values. `(batch, num_caps[l], units[l])`

    """

    with tf.name_scope("routing"):
        for i in range(routing_iterations):
            coupling_coeffs = _coupling_coefficients(coupling_logits)
            out = tf.reduce_sum(tf.multiply(
                predictions, tf.expand_dims(coupling_coeffs, -1)),
                                axis=-3)
            out = helpers.squash(out)  # (batch, num_caps[l], 1, units[l])
            if i < routing_iterations - 1:
                # expand to (batch, num_caps[l], 1, units[l], 1)
                out = tf.expand_dims(out, -1)
                # broadcast to (batch, num_caps[l], num_caps[l-1], units[l], 1)
                out, _ = helpers.broadcast(out,
                                           predictions,
                                           axis=-3,
                                           broadcast_b=False)
                # get logit update (batch, num_caps[l], num_caps[l-1], 1, 1)
                logits_update = tf.matmul(predictions, out)
                logits_update = tf.squeeze(
                    logits_update,
                    [-1])  # (batch, num_caps[l], num_caps[l-1], 1)
                # update coupling_logits
                coupling_logits = tf.add(coupling_logits, logits_update)

        return tf.squeeze(
            out, axis=-2)  # squeeze to get (batch, num_caps[l], units[l])
コード例 #3
0
def _prediction_vectors(inputs, weights):
    """
    Computes prediction vectors u_hat.

    Inputs should be of shape `(batch, num_caps[l-1], units[l-1])`.
    Weights should be of shape `(num_caps[l], num_caps[l-1], units[l-1], units[l])`

    Per each capsule[l] computes prediction vector for each lower level capsule.

    Outputs tensor of predictions `(batch, num_caps[l], num_caps[l-1], 1, units[l])`.

    Parameters
    ----------
    inputs: tf.Tensor
        Lower level capsules tensor.
    weights: tf.Tensor
        Weights used to compute predictions.

    Returns
    -------
    tf.Tensor
        Predictions u_hat.

    """
    with tf.name_scope("prediction_vectors"):
        # take inputs (batch,num_caps[l-1],units[l-1])
        # transform into (batch, 1, num_caps[l-1], 1, units[l-1])
        inputs = tf.expand_dims(inputs, -2)  # add dim before units
        inputs = tf.expand_dims(inputs, 1)  # add dim after batch

        # take weights (num_caps[l], num_caps[l-1], units[l-1], units[l])
        # transform into (1, num_caps[l], num_caps[l-1], units[l-1], units[l])
        weights = tf.expand_dims(weights, 0)

        # inputs (batch, num_caps[l], num_caps[l-1], 1, units[l-1])
        # weights (batch, num_caps[l], num_caps[l-1], units[l-1], units[l])
        inputs, weights = helpers.broadcast(inputs, weights, axis=[0, 1])

        return tf.matmul(inputs, weights)
コード例 #4
0
 def test_shape_legacy(self):
     with self.test_session():
         s = helpers.broadcast(np.random.rand(10, 1, 15), np.random.rand(1, 15, 3), use_legacy=True)
         a, b = s[0].eval(), s[1].eval()
         self.assertAllCloseAccordingToType(a.shape, [10, 1, 15])
         self.assertAllCloseAccordingToType(b.shape, [10, 15, 3])