def test_invalid_value_shape(self): m = metrics.MeanTensor(dtype=tf.float64) m([1]) with self.assertRaisesRegex( ValueError, 'MeanTensor input values must always have the same shape'): m([1, 5])
def __init__(self): super().__init__() self.dense1 = layers.Dense(3, activation="relu", kernel_initializer="ones") self.dense2 = layers.Dense(1, activation="sigmoid", kernel_initializer="ones") self.mean_tensor = metrics.MeanTensor()
def __init__(self): super(ModelWithMetric, self).__init__() self.dense1 = layers.Dense(3, activation='relu', kernel_initializer='ones') self.dense2 = layers.Dense(1, activation='sigmoid', kernel_initializer='ones') self.mean_tensor = metrics.MeanTensor()
def test_build_in_tf_function(self): """Ensure that variables are created correctly in a tf function.""" m = metrics.MeanTensor(dtype=tf.float64) @tf.function def call_metric(x): return m(x) with self.test_session(): self.assertAllClose(self.evaluate(call_metric([100, 40])), [100, 40]) self.assertAllClose(self.evaluate(m.total), [100, 40]) self.assertAllClose(self.evaluate(m.count), [1, 1]) self.assertAllClose(self.evaluate(call_metric([20, 2])), [60, 21])
def test_weighted(self): with self.test_session(): m = metrics.MeanTensor(dtype=tf.float64) self.assertEqual(m.dtype, tf.float64) # check scalar weight result_t = m([100, 30], sample_weight=0.5) self.assertAllClose(self.evaluate(result_t), [100, 30]) self.assertAllClose(self.evaluate(m.total), [50, 15]) self.assertAllClose(self.evaluate(m.count), [0.5, 0.5]) # check weights not scalar and weights rank matches values rank result_t = m([1, 5], sample_weight=[1, 0.2]) result = self.evaluate(result_t) self.assertAllClose(result, [51 / 1.5, 16 / 0.7], 2) self.assertAllClose(self.evaluate(m.total), [51, 16]) self.assertAllClose(self.evaluate(m.count), [1.5, 0.7]) # check weights broadcast result_t = m([1, 2], sample_weight=0.5) self.assertAllClose(self.evaluate(result_t), [51.5 / 2, 17 / 1.2]) self.assertAllClose(self.evaluate(m.total), [51.5, 17]) self.assertAllClose(self.evaluate(m.count), [2, 1.2]) # check weights squeeze result_t = m([1, 5], sample_weight=[[1], [0.2]]) self.assertAllClose(self.evaluate(result_t), [52.5 / 3, 18 / 1.4]) self.assertAllClose(self.evaluate(m.total), [52.5, 18]) self.assertAllClose(self.evaluate(m.count), [3, 1.4]) # check weights expand m = metrics.MeanTensor(dtype=tf.float64) self.evaluate(tf.compat.v1.variables_initializer(m.variables)) result_t = m([[1], [5]], sample_weight=[1, 0.2]) self.assertAllClose(self.evaluate(result_t), [[1], [5]]) self.assertAllClose(self.evaluate(m.total), [[1], [1]]) self.assertAllClose(self.evaluate(m.count), [[1], [0.2]])
def test_config(self): with self.test_session(): m = metrics.MeanTensor(name="mean_by_element") # check config self.assertEqual(m.name, "mean_by_element") self.assertTrue(m.stateful) self.assertEqual(m.dtype, tf.float32) self.assertEmpty(m.variables) with self.assertRaisesRegex(ValueError, "does not have any value yet"): m.result() self.evaluate(m([[3], [5], [3]])) self.assertAllEqual(m._shape, [3, 1]) m2 = metrics.MeanTensor.from_config(m.get_config()) self.assertEqual(m2.name, "mean_by_element") self.assertTrue(m2.stateful) self.assertEqual(m2.dtype, tf.float32) self.assertEmpty(m2.variables)
def test_unweighted(self): with self.test_session(): m = metrics.MeanTensor(dtype=tf.float64) # check __call__() self.assertAllClose(self.evaluate(m([100, 40])), [100, 40]) self.assertAllClose(self.evaluate(m.total), [100, 40]) self.assertAllClose(self.evaluate(m.count), [1, 1]) # check update_state() and result() + state accumulation + tensor input update_op = m.update_state( [tf.convert_to_tensor(1), tf.convert_to_tensor(5)]) self.evaluate(update_op) self.assertAllClose(self.evaluate(m.result()), [50.5, 22.5]) self.assertAllClose(self.evaluate(m.total), [101, 45]) self.assertAllClose(self.evaluate(m.count), [2, 2]) # check reset_state() m.reset_state() self.assertAllClose(self.evaluate(m.total), [0, 0]) self.assertAllClose(self.evaluate(m.count), [0, 0])