Exemple #1
0
  def testReconstructionMasking(self):
    """Tests layers.reconstruction masking mechanism.

    Masking enforces that only logit values at the reconstruction target affects
    the reconstruction. Therefore, since capsule_output1 and capsule_output3
    both has values at other digits they result in the same
    reconstruction. While capsule_output2 results in a different reconstruction.
    capsule_output2 is the only one with different logit values at the
    reconstruction target.
    """
    image = tf.zeros([1, 9], dtype=tf.float32)
    capsule_mask = tf.one_hot([2], 10)
    embedding1 = np.zeros((1, 10, 4), dtype=np.float32)
    embedding1[:, 1, :] = 1
    embedding2 = np.zeros((1, 10, 4), dtype=np.float32)
    embedding2[:, 2, :] = 1
    embedding3 = np.zeros((1, 10, 4), dtype=np.float32)
    embedding3[:, 3, :] = 10
    reconstruction1 = layers.reconstruction(
        capsule_mask=capsule_mask,
        num_atoms=4,
        capsule_embedding=embedding1,
        layer_sizes=(5, 10),
        num_pixels=9,
        reuse=False,
        image=image,
        balance_factor=0.1)
    reconstruction2 = layers.reconstruction(
        capsule_mask=capsule_mask,
        num_atoms=4,
        capsule_embedding=embedding2,
        layer_sizes=(5, 10),
        num_pixels=9,
        reuse=True,
        image=image,
        balance_factor=0.1)
    reconstruction3 = layers.reconstruction(
        capsule_mask=capsule_mask,
        num_atoms=4,
        capsule_embedding=embedding3,
        layer_sizes=(5, 10),
        num_pixels=9,
        reuse=True,
        image=image,
        balance_factor=0.1)
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    self.assertEqual(len(train_vars), 6)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      r_1, r_2, r_3 = sess.run(
          [reconstruction1, reconstruction2, reconstruction3])
    self.assertAlmostEqual(np.sum(r_1), np.sum(r_3))
    self.assertNotAlmostEqual(np.sum(r_2), np.sum(r_1))
    def _remake(self, features, capsule_embedding, route):
        """Adds the reconstruction subnetwork to build the remakes.

        This subnetwork shares the variables between different target remakes.
        It adds the subnetwork for the first target and reuses the weight
        variables for the second one.

        Args:
          features: A dictionary of input data containing the dimmension
            information and the input images and labels.
          capsule_embedding: A 3D tensor of shape
            [batch, num_latent_capsules, 16] containing network embeddings.
          route: The routing coefficients for each latent capsule
            [batch, num_latent_capsules]
        Returns:
          A list of network remakes of the targets.
        """
        num_pixels = features['depth'] * \
            features['height'] * features['height']
        remakes = []
        targets = [(features['recons_label'], features['recons_image'])]
        if features['num_targets'] == 2:
            targets.append((features['spare_label'], features['spare_image']))

        with tf.name_scope('recons'):
            for i in range(features['num_targets']):
                label, image = targets[i]

                # Reconstruction Mask
                capsule_mask = self._create_capsule_mask(
                    label, capsule_embedding, route)

                remakes.append(
                    layers.reconstruction(
                        capsule_mask=capsule_mask,
                        num_atoms=self._hparams.num_latent_atoms,
                        capsule_embedding=capsule_embedding,
                        layer_sizes=[512, 1024],
                        num_pixels=num_pixels,
                        reuse=(i > 0),
                        image=image,
                        balance_factor=self._hparams.balance_factor,
                        unsupervised=self._hparams.unsupervised))

        if self._hparams.verbose:
            self._summarize_remakes(features, remakes)

        return remakes
Exemple #3
0
    def _remake(self, features, capsule_embedding):
        """Adds the reconstruction subnetwork to build the remakes.

    This subnetwork shares the variables between different target remakes. It
    adds the subnetwork for the first target and reuses the weight variables
    for the second one.

    Args:
      features: A dictionary of input data containing the dimmension information
        and the input images and labels.
      capsule_embedding: A 3D tensor of shape [batch, 10, 16] containing
        network embeddings for each digit in the image if present.
    Returns:
      A list of network remakes of the targets.
    """
        num_pixels = features['depth'] * features['height'] * features['height']
        remakes = []
        targets = [(features['recons_label'], features['recons_image'])]
        if features['num_targets'] == 2:
            targets.append((features['spare_label'], features['spare_image']))

        with tf.name_scope('recons'):
            for i in xrange(features['num_targets']):
                label, image = targets[i]
                remakes.append(
                    layers.reconstruction(
                        capsule_mask=tf.one_hot(label,
                                                features['num_classes']),
                        num_atoms=self._hparams.digit_capsule_dim,
                        capsule_embedding=capsule_embedding,
                        layer_sizes=[512, 1024],
                        num_pixels=num_pixels,
                        reuse=(i > 0),
                        image=image,
                        balance_factor=0.0005))

        if self._hparams.verbose:
            self._summarize_remakes(features, remakes)

        return remakes
Exemple #4
0
  def _remake(self, features, capsule_embedding):
    """Adds the reconstruction subnetwork to build the remakes.

    This subnetwork shares the variables between different target remakes. It
    adds the subnetwork for the first target and reuses the weight variables
    for the second one.

    Args:
      features: A dictionary of input data containing the dimmension information
        and the input images and labels.
      capsule_embedding: A 3D tensor of shape [batch, 10, 16] containing
        network embeddings for each digit in the image if present.
    Returns:
      A list of network remakes of the targets.
    """
    num_pixels = features['depth'] * features['height'] * features['height']
    remakes = []
    targets = [(features['recons_label'], features['recons_image'])]
    if features['num_targets'] == 2:
      targets.append((features['spare_label'], features['spare_image']))

    with tf.name_scope('recons'):
      for i in xrange(features['num_targets']):
        label, image = targets[i]
        remakes.append(
            layers.reconstruction(
                capsule_mask=tf.one_hot(label, features['num_classes']),
                num_atoms=16,
                capsule_embedding=capsule_embedding,
                layer_sizes=[512, 1024],
                num_pixels=num_pixels,
                reuse=(i > 0),
                image=image,
                balance_factor=0.0005))

    if self._hparams.verbose:
      self._summarize_remakes(features, remakes)

    return remakes
Exemple #5
0
  def testReconstruction(self):
    """Tests layers.reconstruction output and variable declaration.

    Checks the correct number of variables are added to the trainable collection
    and the output size is the same as image.

    Reconstruction layer addes 3 fully connected layers therefore it should add
    6 (3 weights, 3 biases) to the trainable variable collection.
    """
    image = tf.random_uniform([2, 9])
    target = [2, 7]
    embedding = tf.random_uniform([2, 10, 4])
    reconstruction = layers.reconstruction(
        capsule_mask=tf.one_hot(target, 10),
        num_atoms=4,
        capsule_embedding=embedding,
        layer_sizes=(5, 10),
        num_pixels=9,
        reuse=False,
        image=image,
        balance_factor=0.1)
    self.assertListEqual([2, 9], reconstruction.get_shape().as_list())
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    self.assertEqual(len(train_vars), 6)