def test_robust_max_multiclass_predict_log_density( num_classes, num_points, mock_prob, expected_prediction, tol, epsilon ): class MockRobustMax(RobustMax): def prob_is_largest(self, Y, Fmu, Fvar, gh_x, gh_w): return tf.ones((num_points, 1), dtype=default_float()) * mock_prob likelihood = MultiClass(num_classes, invlink=MockRobustMax(num_classes, epsilon)) F = tf.ones((num_points, num_classes)) rng = np.random.RandomState(1) Y = to_default_int(rng.randint(num_classes, size=(num_points, 1))) prediction = likelihood.predict_log_density(F, F, Y) assert_allclose(prediction, expected_prediction, tol, tol)
def test_softmax_bernoulli_equivalence(num, dimF, dimY): dF = np.vstack( (np.random.randn(num - 3, dimF), np.array([[-3.0, 0.0], [3, 0.0], [0.0, 0.0]]))) dY = np.vstack((np.random.randn(num - 3, dimY), np.ones((3, dimY)))) > 0 F = to_default_float(dF) Fvar = tf.exp( tf.stack([F[:, 1], -10.0 + tf.zeros(F.shape[0], dtype=F.dtype)], axis=1)) F = tf.stack([F[:, 0], tf.zeros(F.shape[0], dtype=F.dtype)], axis=1) Y = to_default_int(dY) Ylabel = 1 - Y softmax_likelihood = Softmax(dimF) bernoulli_likelihood = Bernoulli(invlink=tf.sigmoid) softmax_likelihood.num_monte_carlo_points = int( 0.3e7) # Minimum number of points to pass the test on CircleCI bernoulli_likelihood.num_gauss_hermite_points = 40 assert_allclose( softmax_likelihood.conditional_mean(F)[:, :1], bernoulli_likelihood.conditional_mean(F[:, :1]), ) assert_allclose( softmax_likelihood.conditional_variance(F)[:, :1], bernoulli_likelihood.conditional_variance(F[:, :1]), ) assert_allclose( softmax_likelihood.log_prob(F, Ylabel), bernoulli_likelihood.log_prob(F[:, :1], Y.numpy()), ) mean1, var1 = softmax_likelihood.predict_mean_and_var(F, Fvar) mean2, var2 = bernoulli_likelihood.predict_mean_and_var( F[:, :1], Fvar[:, :1]) assert_allclose(mean1[:, 0, None], mean2, rtol=2e-3) assert_allclose(var1[:, 0, None], var2, rtol=2e-3) ls_ve = softmax_likelihood.variational_expectations(F, Fvar, Ylabel) lb_ve = bernoulli_likelihood.variational_expectations( F[:, :1], Fvar[:, :1], Y.numpy()) assert_allclose(ls_ve, lb_ve, rtol=5e-3)