Пример #1
0
 def test_invalid_properties(self, bij_params):
     bij_params.update({
         'conditioner': lambda x: x,
         'bijector': lambda _: lambda x: x
     })
     with self.assertRaises(ValueError):
         split_coupling.SplitCoupling(**bij_params)
Пример #2
0
 def test_raises_on_bijector_with_different_event_ndims(self):
     inner_bij = lambda _: DummyBijector(1, 0, False, False)
     bij_params = {
         'split_index': 0,
         'event_ndims': 1,
         'conditioner': lambda x: x,
         'bijector': inner_bij
     }
     bij = split_coupling.SplitCoupling(**bij_params)
     with self.assertRaises(ValueError):
         bij.forward_and_log_det(jnp.zeros((4, 3)))
Пример #3
0
def _create_split_coupling_bijector(split_index,
                                    split_axis=-1,
                                    swap=False,
                                    event_ndims=2):
    return split_coupling.SplitCoupling(
        split_index=split_index,
        split_axis=split_axis,
        event_ndims=event_ndims,
        swap=swap,
        conditioner=lambda x: x**2,
        bijector=lambda _: lambda x: 2. * x + 3.)
Пример #4
0
 def test_raises_on_invalid_input_shape(self):
     event_shape = (2, 3)
     bij = split_coupling.SplitCoupling(split_index=event_shape[-1] // 2,
                                        event_ndims=len(event_shape),
                                        conditioner=lambda x: x,
                                        bijector=lambda _: lambda x: x)
     for fn in [
             bij.forward, bij.inverse, bij.forward_log_det_jacobian,
             bij.inverse_log_det_jacobian, bij.forward_and_log_det,
             bij.inverse_and_log_det
     ]:
         with self.assertRaises(ValueError):
             fn(jnp.zeros((3, )))