def test_config_gmean(): def _check_config(obj, name): assert obj.name == name assert obj.dtype == tf.float32 assert obj.stateful assert len(obj.variables) == 2 name = "my_gmean" obj1 = GeometricMean(name=name) _check_config(obj1, name) obj2 = GeometricMean.from_config(obj1.get_config()) _check_config(obj2, name)
def test_reset_states(): obj = GeometricMean() obj.update_state([1, 2, 3, 4, 5]) obj.reset_states() assert obj.total.numpy() == 0.0 assert obj.count.numpy() == 0.0
def test_call_gmean(values, expected): obj = GeometricMean() result = obj(tf.constant(values, tf.float32)) count = obj.count.numpy() assert_result(expected, result) np.testing.assert_equal(len(values), count)
def test_vector_update_state_gmean(values, expected): obj = GeometricMean() values = tf.constant(values, tf.float32) obj.update_state(values) check_result(obj, expected, len(values))
def test_scalar_update_state_gmean(values, expected): obj = GeometricMean() values = tf.constant(values, tf.float32) for v in values: obj.update_state(v) check_result(obj, expected, len(values))
def test_init_states_gmean(): obj = GeometricMean() assert obj.total.numpy() == 0.0 assert obj.count.numpy() == 0.0 assert obj.total.dtype == tf.float32 assert obj.count.dtype == tf.float32
def test_sample_weight_gmean(values, sample_weight, expected): obj = GeometricMean() obj.update_state(values, sample_weight=sample_weight) assert_result(expected, obj.result().numpy())