예제 #1
0
def test_config_gmean():
    def _check_config(obj, name):
        assert obj.name == name
        assert obj.dtype == tf.float32
        assert obj.stateful
        assert len(obj.variables) == 2

    name = "my_gmean"
    obj1 = GeometricMean(name=name)
    _check_config(obj1, name)

    obj2 = GeometricMean.from_config(obj1.get_config())
    _check_config(obj2, name)
예제 #2
0
def test_reset_states():
    obj = GeometricMean()
    obj.update_state([1, 2, 3, 4, 5])
    obj.reset_states()
    assert obj.total.numpy() == 0.0
    assert obj.count.numpy() == 0.0
예제 #3
0
def test_call_gmean(values, expected):
    obj = GeometricMean()
    result = obj(tf.constant(values, tf.float32))
    count = obj.count.numpy()
    assert_result(expected, result)
    np.testing.assert_equal(len(values), count)
예제 #4
0
def test_vector_update_state_gmean(values, expected):
    obj = GeometricMean()
    values = tf.constant(values, tf.float32)
    obj.update_state(values)
    check_result(obj, expected, len(values))
예제 #5
0
def test_scalar_update_state_gmean(values, expected):
    obj = GeometricMean()
    values = tf.constant(values, tf.float32)
    for v in values:
        obj.update_state(v)
    check_result(obj, expected, len(values))
예제 #6
0
def test_init_states_gmean():
    obj = GeometricMean()
    assert obj.total.numpy() == 0.0
    assert obj.count.numpy() == 0.0
    assert obj.total.dtype == tf.float32
    assert obj.count.dtype == tf.float32
예제 #7
0
def test_sample_weight_gmean(values, sample_weight, expected):
    obj = GeometricMean()
    obj.update_state(values, sample_weight=sample_weight)
    assert_result(expected, obj.result().numpy())