def test_iaf(input_dim, hidden_dims): arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1]) rng = random.PRNGKey(0) batch_size = 4 input_shape = (batch_size, input_dim) _, init_params = arn.init_fun(rng, input_shape) iaf = InverseAutoregressiveTransform(arn, init_params) # test inverse is correct x = onp.random.rand(*input_shape) y = iaf(x) inv = iaf.inv(y) assert_allclose(x, inv, atol=1e-5) # test jacobian x = onp.random.rand(*input_shape[-1:]) jac = jacfwd(iaf)(x) # permute jacobian as necessary permuted_jac = onp.zeros(jac.shape) perm = arn.permutation for j in range(input_dim): for k in range(input_dim): permuted_jac[..., j, k] = jac[..., perm[j], perm[k]] # make sure jacobian is triangular assert onp.sum(onp.abs(onp.triu(permuted_jac, 1))) == 0.00 # make sure iaf.log_abs_det_jacobian is correct ldj = iaf.log_abs_det_jacobian(x, y) assert_allclose(ldj, onp.sum(onp.log(onp.diag(permuted_jac))), atol=1e-5)
def _make_iaf_args(input_dim, hidden_dims): _, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.shuffle(rng_key_perm, onp.arange(input_dim)) arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=perm) _, init_params = arn_init(random.PRNGKey(0), (input_dim, )) return partial(arn, init_params),
def _make_iaf_args(input_dim, hidden_dims): _, rng_perm = random.split(random.PRNGKey(0)) perm = random.shuffle(rng_perm, onp.arange(input_dim)) # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values, # which in turn create some zero entries in the lower triangular part of Jacobian. arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=perm, nonlinearity=stax.Elu) _, init_params = arn_init(random.PRNGKey(0), (input_dim, )) return partial(arn, init_params),
def test_auto_reg_nn(input_dim, hidden_dims, param_dims, skip_connections): rng_key, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.permutation(rng_key_perm, onp.arange(input_dim)) arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=param_dims, skip_connections=skip_connections, permutation=perm) batch_size = 4 input_shape = (batch_size, input_dim) _, init_params = arn_init(rng_key, input_shape) output = arn(init_params, onp.random.rand(*input_shape)) if param_dims == [1]: assert output.shape == (batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim)) elif param_dims == [1, 1]: assert output[0].shape == (batch_size, input_dim) assert output[1].shape == (batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x)[0])( onp.random.rand(input_dim)) elif param_dims == [2]: assert output.shape == (2, batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim)) elif param_dims == [2, 3]: assert output[0].shape == (2, batch_size, input_dim) assert output[1].shape == (3, batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x)[0])( onp.random.rand(input_dim)) # permute jacobian as necessary permuted_jac = onp.zeros(jac.shape) for j in range(input_dim): for k in range(input_dim): permuted_jac[..., j, k] = jac[..., perm[j], perm[k]] # make sure jacobians are triangular assert onp.sum(onp.abs(onp.triu(permuted_jac))) == 0.0
def _make_iaf(input_dim, hidden_dims, rng): arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1]) _, init_params = arn_init(rng, (input_dim, )) return InverseAutoregressiveTransform(partial(arn, init_params))
def _make_iaf_args(input_dim, hidden_dims): arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1]) _, init_params = arn_init(random.PRNGKey(0), (input_dim, )) return arn, init_params