def testIntegrity(self): """Checks a multi_gpu call on CapsuleModel builds the desired graph. With the correct inference graph, multi_gpu is able to call inference multiple times without any increase in number of trainable variables or a duplication error. """ with tf.Graph().as_default(): test_model = capsule_model.CapsuleModel(self.hparams) toy_image = np.reshape(np.arange(32 * 32), (1, 1, 32, 32)) input_image = tf.constant(toy_image, dtype=tf.float32) features = { 'height': 32, 'depth': 1, 'images': input_image, 'labels': tf.one_hot([2], 10), 'num_classes': 10, 'num_targets': 1, } _, tower_output = test_model.multi_gpu( [features, features, features], 3) trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(trainable_vars), 6) _, classes = tower_output[0].logits.get_shape() self.assertEqual(10, classes.value)
def testInferenceWithRemake(self): """Checks the correct shape of remakes and total number of variables. The reconstruction should have same shape as input. Each remake network should declare 6 sets of variables (weight and bias) and different targets should share the variables. """ with tf.Graph().as_default(): self.hparams.parse('remake=True,verbose=True') test_model = capsule_model.CapsuleModel(self.hparams) toy_image = np.reshape(np.arange(32 * 32), (1, 1, 32, 32)) input_image = tf.constant(toy_image, dtype=tf.float32) features = { 'height': 32, 'depth': 1, 'images': input_image, 'recons_image': input_image, 'spare_image': input_image, 'recons_label': tf.constant([2]), 'spare_label': tf.constant([2]), 'num_targets': 2, 'num_classes': 10, } output = test_model.inference(features) trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(trainable_vars), 12) remake_1, remake_2 = output.remakes self.assertEqual(32 * 32, remake_1.get_shape()[1].value) self.assertEqual(32 * 32, remake_2.get_shape()[1].value)
def testInference(self): """ Checks the correct shape of capsule output and total number of variables. The output logit shape should be [batch, 10]. Also each layer should declare 2 sets of variables (weight and bias), therefore single call to inference declares 6 variables for a total of 3 layers. """ with tf.Graph().as_default(): test_model = capsule_model.CapsuleModel(self.hparams) toy_image = np.reshape(np.arange(32 * 32), (1, 1, 32, 32)) input_image = tf.constant(toy_image, dtype=tf.float32) features = { 'height': 32, 'depth': 1, 'num_classes': 10, 'images': input_image } output = test_model.inference(features) trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(trainable_vars), 6) _, classes = output.logits.get_shape() self.assertEqual(10, classes.value)
def testBuildCapsule(self): """Checks the correct shape of capsule output and total number of variables. The output shape should be [batch, 10, 16]. Also each capsule layer should declare 2 sets of variables (weight and bias), therefore single call to _build_capsule declares 4 variables for a total of 2 capsule layers. """ with tf.Graph().as_default(): test_model = capsule_model.CapsuleModel(self.hparams) toy_input = np.reshape(np.arange(256 * 14 * 14), (1, 1, 256, 14, 14)) input_tensor = tf.constant(toy_input, dtype=tf.float32) output = test_model._build_capsule(input_tensor, 10) trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(trainable_vars), 4) _, capsules, atoms = output.get_shape() self.assertListEqual([10, 16], [capsules.value, atoms.value])