def test_invalid_value_shape(self):
     m = metrics.MeanTensor(dtype=tf.float64)
     m([1])
     with self.assertRaisesRegex(
             ValueError,
             'MeanTensor input values must always have the same shape'):
         m([1, 5])
Beispiel #2
0
 def __init__(self):
     super().__init__()
     self.dense1 = layers.Dense(3,
                                activation="relu",
                                kernel_initializer="ones")
     self.dense2 = layers.Dense(1,
                                activation="sigmoid",
                                kernel_initializer="ones")
     self.mean_tensor = metrics.MeanTensor()
 def __init__(self):
     super(ModelWithMetric, self).__init__()
     self.dense1 = layers.Dense(3,
                                activation='relu',
                                kernel_initializer='ones')
     self.dense2 = layers.Dense(1,
                                activation='sigmoid',
                                kernel_initializer='ones')
     self.mean_tensor = metrics.MeanTensor()
Beispiel #4
0
    def test_build_in_tf_function(self):
        """Ensure that variables are created correctly in a tf function."""
        m = metrics.MeanTensor(dtype=tf.float64)

        @tf.function
        def call_metric(x):
            return m(x)

        with self.test_session():
            self.assertAllClose(self.evaluate(call_metric([100, 40])),
                                [100, 40])
            self.assertAllClose(self.evaluate(m.total), [100, 40])
            self.assertAllClose(self.evaluate(m.count), [1, 1])
            self.assertAllClose(self.evaluate(call_metric([20, 2])), [60, 21])
Beispiel #5
0
    def test_weighted(self):
        with self.test_session():
            m = metrics.MeanTensor(dtype=tf.float64)
            self.assertEqual(m.dtype, tf.float64)

            # check scalar weight
            result_t = m([100, 30], sample_weight=0.5)
            self.assertAllClose(self.evaluate(result_t), [100, 30])
            self.assertAllClose(self.evaluate(m.total), [50, 15])
            self.assertAllClose(self.evaluate(m.count), [0.5, 0.5])

            # check weights not scalar and weights rank matches values rank
            result_t = m([1, 5], sample_weight=[1, 0.2])
            result = self.evaluate(result_t)
            self.assertAllClose(result, [51 / 1.5, 16 / 0.7], 2)
            self.assertAllClose(self.evaluate(m.total), [51, 16])
            self.assertAllClose(self.evaluate(m.count), [1.5, 0.7])

            # check weights broadcast
            result_t = m([1, 2], sample_weight=0.5)
            self.assertAllClose(self.evaluate(result_t), [51.5 / 2, 17 / 1.2])
            self.assertAllClose(self.evaluate(m.total), [51.5, 17])
            self.assertAllClose(self.evaluate(m.count), [2, 1.2])

            # check weights squeeze
            result_t = m([1, 5], sample_weight=[[1], [0.2]])
            self.assertAllClose(self.evaluate(result_t), [52.5 / 3, 18 / 1.4])
            self.assertAllClose(self.evaluate(m.total), [52.5, 18])
            self.assertAllClose(self.evaluate(m.count), [3, 1.4])

            # check weights expand
            m = metrics.MeanTensor(dtype=tf.float64)
            self.evaluate(tf.compat.v1.variables_initializer(m.variables))
            result_t = m([[1], [5]], sample_weight=[1, 0.2])
            self.assertAllClose(self.evaluate(result_t), [[1], [5]])
            self.assertAllClose(self.evaluate(m.total), [[1], [1]])
            self.assertAllClose(self.evaluate(m.count), [[1], [0.2]])
Beispiel #6
0
    def test_config(self):
        with self.test_session():
            m = metrics.MeanTensor(name="mean_by_element")

            # check config
            self.assertEqual(m.name, "mean_by_element")
            self.assertTrue(m.stateful)
            self.assertEqual(m.dtype, tf.float32)
            self.assertEmpty(m.variables)

            with self.assertRaisesRegex(ValueError,
                                        "does not have any value yet"):
                m.result()

            self.evaluate(m([[3], [5], [3]]))
            self.assertAllEqual(m._shape, [3, 1])

            m2 = metrics.MeanTensor.from_config(m.get_config())
            self.assertEqual(m2.name, "mean_by_element")
            self.assertTrue(m2.stateful)
            self.assertEqual(m2.dtype, tf.float32)
            self.assertEmpty(m2.variables)
    def test_unweighted(self):
        with self.test_session():
            m = metrics.MeanTensor(dtype=tf.float64)

            # check __call__()
            self.assertAllClose(self.evaluate(m([100, 40])), [100, 40])
            self.assertAllClose(self.evaluate(m.total), [100, 40])
            self.assertAllClose(self.evaluate(m.count), [1, 1])

            # check update_state() and result() + state accumulation + tensor input
            update_op = m.update_state(
                [tf.convert_to_tensor(1),
                 tf.convert_to_tensor(5)])
            self.evaluate(update_op)
            self.assertAllClose(self.evaluate(m.result()), [50.5, 22.5])
            self.assertAllClose(self.evaluate(m.total), [101, 45])
            self.assertAllClose(self.evaluate(m.count), [2, 2])

            # check reset_state()
            m.reset_state()
            self.assertAllClose(self.evaluate(m.total), [0, 0])
            self.assertAllClose(self.evaluate(m.count), [0, 0])