def test_reuse_model(self, rank, is_training, special_type):
        batch_size = 5
        num_classes = 10
        num_channels = 3
        images = _construct_images(batch_size)
        config = simple_model_config.get_config()
        config.num_classes = num_classes
        config.num_channels = num_channels
        config.batch_norm = False
        config.kernel_size_list = [3, 3, 3]
        config.num_filters_list = [64, 64, 64]
        config.strides_list = [2, 2, 1]
        config.layer_types = ['conv2d', special_type, 'conv2d']
        config.rank = rank

        model = simple_model.SimpleNetwork(config)

        # Build once.
        logits1, _ = model(images, is_training)
        num_params = len(tf.all_variables())
        # Build twice.
        logits2, _ = model(images, is_training)
        # Ensure variables are reused.
        self.assertLen(tf.all_variables(), num_params)
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            # Ensure operations are the same after reuse.
            err_logits = (np.abs(sess.run(logits1 - logits2))).sum()
            self.assertAlmostEqual(err_logits, 0, 9)
    def test_build_model(self, rank, is_training, special_type):
        batch_size = 5
        num_classes = 10
        num_channels = 3
        images = _construct_images(batch_size)
        config = simple_model_config.get_config()
        config.num_classes = num_classes
        config.num_channels = num_channels
        config.rank = rank
        config.batch_norm = True
        config.kernel_size_list = [3, 3, 3]
        config.num_filters_list = [64, 64, 64]
        config.strides_list = [2, 2, 1]
        config.layer_types = ['conv2d', special_type, 'conv2d']

        model = simple_model.SimpleNetwork(config)

        logits, _ = model(images, is_training)

        final_shape = (batch_size, num_classes)
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            self.assertEqual(final_shape, sess.run(logits).shape)