def test_hidden_layer_rsample(non_linearity, include_hidden_bias, B=2, D=3, H=4, N=900000): X = torch.randn(B, D) A_mean = torch.rand(D, H) A_scale = 0.3 * torch.exp(0.3 * torch.rand(D, H)) # test naive weight space sampling against sampling in pre-activation space dist1 = HiddenLayer(X=X, A_mean=A_mean, A_scale=A_scale, non_linearity=non_linearity, include_hidden_bias=include_hidden_bias, weight_space_sampling=True) dist2 = HiddenLayer(X=X, A_mean=A_mean, A_scale=A_scale, non_linearity=non_linearity, include_hidden_bias=include_hidden_bias, weight_space_sampling=False) out1 = dist1.rsample(sample_shape=(N, )) out1_mean, out1_var = out1.mean(0), out1.var(0) out2 = dist2.rsample(sample_shape=(N, )) out2_mean, out2_var = out2.mean(0), out2.var(0) assert_equal(out1_mean, out2_mean, prec=0.003) assert_equal(out1_var, out2_var, prec=0.003) return
def test_hidden_layer_log_prob(non_linearity, include_hidden_bias, B=2, D=3, H=2): X = torch.randn(B, D) A_mean = torch.rand(D, H) A_scale = 0.3 * torch.exp(0.3 * torch.rand(D, H)) dist = HiddenLayer(X=X, A_mean=A_mean, A_scale=A_scale, non_linearity=non_linearity, include_hidden_bias=include_hidden_bias) A_dist = Normal(A_mean, A_scale) A_prior = Normal(torch.zeros(D, H), torch.ones(D, H)) kl = torch.distributions.kl.kl_divergence(A_dist, A_prior).sum() assert_equal(kl, dist.KL, prec=0.01)