예제 #1
0
def test_metrics_equivalence():
    import numpy as np

    a = np.random.random((10, 16, 64, 64, 3))
    b = np.random.random((10, 16, 64, 64, 3))
    metrics = [
        mean_squared_error, peak_signal_to_noise_ratio, structural_similarity,
        vgg_cosine_distance
    ]
    metrics_np = [
        mean_squared_error_np, peak_signal_to_noise_ratio_np,
        structural_similarity_np, vgg_cosine_distance_np
    ]
    sess = tf.Session()
    # this initializes the vgg network for the pure tf (i.e. non-np) metrics
    with tf.variable_scope('vgg'):
        vgg_network.vgg16(tf.placeholder(tf.float32, shape=[None] * 4))
    vgg_network.vgg_assign_from_values_fn(var_name_prefix='vgg/')(sess)

    for keep_axis in (None, 0, 1, (0, 1)):
        for metric, metric_np in zip(metrics, metrics_np):
            m = metric(a, b, keep_axis=keep_axis)
            m_np = metric_np(a, b, keep_axis=keep_axis)
            assert np.allclose(sess.run(m), m_np, atol=1e-7)
    print('The test metrics returned the same values.')
예제 #2
0
 def initialize_graph(self):
     self._image_placeholder = tf.placeholder(dtype=tf.float32)
     with tf.variable_scope('vgg', reuse=tf.AUTO_REUSE):
         _, self._feature_op = _with_flat_batch(vgg_network.vgg16)(
             self._image_placeholder)
     self._assign_from_values_fn = vgg_network.vgg_assign_from_values_fn(
         var_name_prefix='vgg/')