Exemplo n.º 1
0
def test_tensorflow_model(num_classes):
    bounds = (0, 255)
    channels = num_classes

    def mean_brightness_net(images):
        logits = tf.reduce_mean(images, axis=(1, 2))
        return logits

    g = tf.Graph()
    with g.as_default():
        images = tf.placeholder(tf.float32, (None, 5, 5, channels))
        logits = mean_brightness_net(images)

    with tf.Session(graph=g):
        model = TensorFlowModel(
            images,
            logits,
            bounds=bounds)

        assert model.session is not None

        test_images = np.random.rand(2, 5, 5, channels).astype(np.float32)
        test_label = 7

        assert model.batch_predictions(test_images).shape \
            == (2, num_classes)

        test_logits = model.predictions(test_images[0])
        assert test_logits.shape == (num_classes,)

        test_gradient = model.gradient(test_images[0], test_label)
        assert test_gradient.shape == test_images[0].shape

        np.testing.assert_almost_equal(
            model.predictions_and_gradient(test_images[0], test_label)[0],
            test_logits)
        np.testing.assert_almost_equal(
            model.predictions_and_gradient(test_images[0], test_label)[1],
            test_gradient)

        assert model.num_classes() == num_classes
graph, saver, images, logits = adv_model_resnet()

sess = tf.Session(graph=graph)
#sess.run(tf.global_variables_initializer())
#model.tf_load(sess, "./resnet18/checkpoints/model/")

path = os.path.join('resnet18', 'checkpoints', 'model')
saver.restore(sess, tf.train.latest_checkpoint(path))

data = dataset('../Defense_Model/tiny-imagenet-200/', normalize=False)
batch_size = 256
x_test, y_test = data.next_test_batch(batch_size)

with sess.as_default():
    model = TensorFlowModel(images, logits, bounds=(0, 255))
    y_logits = model.batch_predictions(x_test)
    #y_prob=np.softmax(y_logits)
    y_pred = np.argmax(y_logits, axis=1)
    #print(y_pred)
    y_label = np.argmax(y_test, axis=1)
    print(np.sum(np.equal(y_pred, y_label)) / x_test.shape[0])

# In[ ]:
"""from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=tf.train.latest_checkpoint(path), tensor_name='', all_tensors=True)"""

# In[ ]: