def testReconstructionMasking(self): """Tests layers.reconstruction masking mechanism. Masking enforces that only logit values at the reconstruction target affects the reconstruction. Therefore, since capsule_output1 and capsule_output3 both has values at other digits they result in the same reconstruction. While capsule_output2 results in a different reconstruction. capsule_output2 is the only one with different logit values at the reconstruction target. """ image = tf.zeros([1, 9], dtype=tf.float32) capsule_mask = tf.one_hot([2], 10) embedding1 = np.zeros((1, 10, 4), dtype=np.float32) embedding1[:, 1, :] = 1 embedding2 = np.zeros((1, 10, 4), dtype=np.float32) embedding2[:, 2, :] = 1 embedding3 = np.zeros((1, 10, 4), dtype=np.float32) embedding3[:, 3, :] = 10 reconstruction1 = layers.reconstruction( capsule_mask=capsule_mask, num_atoms=4, capsule_embedding=embedding1, layer_sizes=(5, 10), num_pixels=9, reuse=False, image=image, balance_factor=0.1) reconstruction2 = layers.reconstruction( capsule_mask=capsule_mask, num_atoms=4, capsule_embedding=embedding2, layer_sizes=(5, 10), num_pixels=9, reuse=True, image=image, balance_factor=0.1) reconstruction3 = layers.reconstruction( capsule_mask=capsule_mask, num_atoms=4, capsule_embedding=embedding3, layer_sizes=(5, 10), num_pixels=9, reuse=True, image=image, balance_factor=0.1) train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(train_vars), 6) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) r_1, r_2, r_3 = sess.run( [reconstruction1, reconstruction2, reconstruction3]) self.assertAlmostEqual(np.sum(r_1), np.sum(r_3)) self.assertNotAlmostEqual(np.sum(r_2), np.sum(r_1))
def _remake(self, features, capsule_embedding, route): """Adds the reconstruction subnetwork to build the remakes. This subnetwork shares the variables between different target remakes. It adds the subnetwork for the first target and reuses the weight variables for the second one. Args: features: A dictionary of input data containing the dimmension information and the input images and labels. capsule_embedding: A 3D tensor of shape [batch, num_latent_capsules, 16] containing network embeddings. route: The routing coefficients for each latent capsule [batch, num_latent_capsules] Returns: A list of network remakes of the targets. """ num_pixels = features['depth'] * \ features['height'] * features['height'] remakes = [] targets = [(features['recons_label'], features['recons_image'])] if features['num_targets'] == 2: targets.append((features['spare_label'], features['spare_image'])) with tf.name_scope('recons'): for i in range(features['num_targets']): label, image = targets[i] # Reconstruction Mask capsule_mask = self._create_capsule_mask( label, capsule_embedding, route) remakes.append( layers.reconstruction( capsule_mask=capsule_mask, num_atoms=self._hparams.num_latent_atoms, capsule_embedding=capsule_embedding, layer_sizes=[512, 1024], num_pixels=num_pixels, reuse=(i > 0), image=image, balance_factor=self._hparams.balance_factor, unsupervised=self._hparams.unsupervised)) if self._hparams.verbose: self._summarize_remakes(features, remakes) return remakes
def _remake(self, features, capsule_embedding): """Adds the reconstruction subnetwork to build the remakes. This subnetwork shares the variables between different target remakes. It adds the subnetwork for the first target and reuses the weight variables for the second one. Args: features: A dictionary of input data containing the dimmension information and the input images and labels. capsule_embedding: A 3D tensor of shape [batch, 10, 16] containing network embeddings for each digit in the image if present. Returns: A list of network remakes of the targets. """ num_pixels = features['depth'] * features['height'] * features['height'] remakes = [] targets = [(features['recons_label'], features['recons_image'])] if features['num_targets'] == 2: targets.append((features['spare_label'], features['spare_image'])) with tf.name_scope('recons'): for i in xrange(features['num_targets']): label, image = targets[i] remakes.append( layers.reconstruction( capsule_mask=tf.one_hot(label, features['num_classes']), num_atoms=self._hparams.digit_capsule_dim, capsule_embedding=capsule_embedding, layer_sizes=[512, 1024], num_pixels=num_pixels, reuse=(i > 0), image=image, balance_factor=0.0005)) if self._hparams.verbose: self._summarize_remakes(features, remakes) return remakes
def _remake(self, features, capsule_embedding): """Adds the reconstruction subnetwork to build the remakes. This subnetwork shares the variables between different target remakes. It adds the subnetwork for the first target and reuses the weight variables for the second one. Args: features: A dictionary of input data containing the dimmension information and the input images and labels. capsule_embedding: A 3D tensor of shape [batch, 10, 16] containing network embeddings for each digit in the image if present. Returns: A list of network remakes of the targets. """ num_pixels = features['depth'] * features['height'] * features['height'] remakes = [] targets = [(features['recons_label'], features['recons_image'])] if features['num_targets'] == 2: targets.append((features['spare_label'], features['spare_image'])) with tf.name_scope('recons'): for i in xrange(features['num_targets']): label, image = targets[i] remakes.append( layers.reconstruction( capsule_mask=tf.one_hot(label, features['num_classes']), num_atoms=16, capsule_embedding=capsule_embedding, layer_sizes=[512, 1024], num_pixels=num_pixels, reuse=(i > 0), image=image, balance_factor=0.0005)) if self._hparams.verbose: self._summarize_remakes(features, remakes) return remakes
def testReconstruction(self): """Tests layers.reconstruction output and variable declaration. Checks the correct number of variables are added to the trainable collection and the output size is the same as image. Reconstruction layer addes 3 fully connected layers therefore it should add 6 (3 weights, 3 biases) to the trainable variable collection. """ image = tf.random_uniform([2, 9]) target = [2, 7] embedding = tf.random_uniform([2, 10, 4]) reconstruction = layers.reconstruction( capsule_mask=tf.one_hot(target, 10), num_atoms=4, capsule_embedding=embedding, layer_sizes=(5, 10), num_pixels=9, reuse=False, image=image, balance_factor=0.1) self.assertListEqual([2, 9], reconstruction.get_shape().as_list()) train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(train_vars), 6)