def test_invalid_properties(self, bij_params): bij_params.update({ 'conditioner': lambda x: x, 'bijector': lambda _: lambda x: x }) with self.assertRaises(ValueError): split_coupling.SplitCoupling(**bij_params)
def test_raises_on_bijector_with_different_event_ndims(self): inner_bij = lambda _: DummyBijector(1, 0, False, False) bij_params = { 'split_index': 0, 'event_ndims': 1, 'conditioner': lambda x: x, 'bijector': inner_bij } bij = split_coupling.SplitCoupling(**bij_params) with self.assertRaises(ValueError): bij.forward_and_log_det(jnp.zeros((4, 3)))
def _create_split_coupling_bijector(split_index, split_axis=-1, swap=False, event_ndims=2): return split_coupling.SplitCoupling( split_index=split_index, split_axis=split_axis, event_ndims=event_ndims, swap=swap, conditioner=lambda x: x**2, bijector=lambda _: lambda x: 2. * x + 3.)
def test_raises_on_invalid_input_shape(self): event_shape = (2, 3) bij = split_coupling.SplitCoupling(split_index=event_shape[-1] // 2, event_ndims=len(event_shape), conditioner=lambda x: x, bijector=lambda _: lambda x: x) for fn in [ bij.forward, bij.inverse, bij.forward_log_det_jacobian, bij.inverse_log_det_jacobian, bij.forward_and_log_det, bij.inverse_and_log_det ]: with self.assertRaises(ValueError): fn(jnp.zeros((3, )))