Ejemplo n.º 1
0
def test_robust_max_multiclass_symmetric(num_classes, num_points, tol, epsilon):
    """
    This test is based on the observation that for
    symmetric inputs the class predictions must have equal probability.
    """
    rng = np.random.RandomState(1)
    p = 1.0 / num_classes
    F = tf.ones((num_points, num_classes), dtype=default_float())
    Y = tf.convert_to_tensor(rng.randint(num_classes, size=(num_points, 1)), dtype=default_float())

    likelihood = MultiClass(num_classes)
    likelihood.invlink.epsilon = tf.convert_to_tensor(epsilon, dtype=default_float())

    mu, _ = likelihood.predict_mean_and_var(F, F)
    pred = likelihood.predict_log_density(F, F, Y)
    variational_expectations = likelihood.variational_expectations(F, F, Y)

    expected_mu = (p * (1.0 - epsilon) + (1.0 - p) * epsilon / (num_classes - 1)) * np.ones(
        (num_points, 1)
    )
    expected_log_density = np.log(expected_mu)

    # assert_allclose() would complain about shape mismatch
    assert np.allclose(mu, expected_mu, tol, tol)
    assert np.allclose(pred, expected_log_density, 1e-3, 1e-3)

    validation_variational_expectation = p * np.log(1.0 - epsilon) + (1.0 - p) * np.log(
        epsilon / (num_classes - 1)
    )
    assert_allclose(
        variational_expectations,
        np.ones((num_points,)) * validation_variational_expectation,
        tol,
        tol,
    )
Ejemplo n.º 2
0
def test_robust_max_multiclass_predict_log_density(
    num_classes, num_points, mock_prob, expected_prediction, tol, epsilon
):
    class MockRobustMax(RobustMax):
        def prob_is_largest(self, Y, Fmu, Fvar, gh_x, gh_w):
            return tf.ones((num_points, 1), dtype=default_float()) * mock_prob

    likelihood = MultiClass(num_classes, invlink=MockRobustMax(num_classes, epsilon))
    F = tf.ones((num_points, num_classes))
    rng = np.random.RandomState(1)
    Y = to_default_int(rng.randint(num_classes, size=(num_points, 1)))
    prediction = likelihood.predict_log_density(F, F, Y)

    assert_allclose(prediction, expected_prediction, tol, tol)