def test_reuse_model(self, policy, is_training, classnet_type):
    config = saccader_classnet_config.get_config()
    num_times = 2
    image_shape = (100, 100, 3)
    num_classes = 10
    config.num_classes = num_classes
    config.num_times = num_times
    config.classnet_type = classnet_type
    batch_size = 3
    images = tf.constant(
        np.random.rand(*((batch_size,) + image_shape)), dtype=tf.float32)
    model = saccader_classnet.SaccaderClassNet(config)
    logits1 = model(images, images, num_times=num_times,
                    is_training_saccader=is_training,
                    is_training_classnet=is_training,
                    policy=policy)[0]
    num_params = len(tf.all_variables())
    l2_loss1 = tf.losses.get_regularization_loss()
    # Build twice with different num_times.
    logits2 = model(images, images, num_times=num_times+1,
                    is_training_saccader=is_training,
                    is_training_classnet=is_training,
                    policy=policy)[0]
    l2_loss2 = tf.losses.get_regularization_loss()

    # Ensure variables are reused.
    self.assertLen(tf.all_variables(), num_params)
    init = tf.global_variables_initializer()
    self.evaluate(init)
    logits1, logits2 = self.evaluate((logits1, logits2))
    l2_loss1, l2_loss2 = self.evaluate((l2_loss1, l2_loss2))
    np.testing.assert_almost_equal(l2_loss1, l2_loss2, decimal=5)
Example #2
0
 def test_build(self, policy, is_training, classnet_type):
     config = saccader_classnet_config.get_config()
     num_times = 2
     image_shape = (100, 100, 3)
     num_classes = 10
     config.num_classes = num_classes
     config.num_times = num_times
     config.classnet_type = classnet_type
     batch_size = 3
     images = tf.constant(np.random.rand(*((batch_size, ) + image_shape)),
                          dtype=tf.float32)
     model = saccader_classnet.SaccaderClassNet(config)
     logits = model(images,
                    images,
                    num_times=num_times,
                    is_training_saccader=is_training,
                    is_training_classnet=is_training,
                    policy=policy)[0]
     init_op = model.init_op
     self.evaluate(init_op)
     self.assertEqual((batch_size, num_classes),
                      self.evaluate(logits).shape)