示例#1
0
    def _transform(self, input: Tensor, input_log_det: Optional[Tensor],
                   inverse: bool,
                   compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]:
        if inverse:
            output = input * 2.0 + 1
            event_ndims = self.x_event_ndims
        else:
            output = (input - 1.0) * 0.5
            event_ndims = self.y_event_ndims

        if compute_log_det:
            if inverse:
                output_log_det = float_scalar_like(-math.log(2.), output)
            else:
                output_log_det = float_scalar_like(math.log(2.), output)

            for axis in int_range(-event_ndims, 0):
                output_log_det = output_log_det * output.shape[axis]

            if input_log_det is not None:
                output_log_det = output_log_det + input_log_det
        else:
            output_log_det: Optional[Tensor] = None

        return output, output_log_det
示例#2
0
 def log_prob_fn(t):
     log_px = distribution.log_prob(t.transform_origin.tensor,
                                    group_ndims=0)
     y, log_det = flow(t.transform_origin.tensor)  # y and log |dy/dx|
     assert_allclose(y, t.tensor, atol=1e-4, rtol=1e-6)
     ctx.assertEqual(
         T.rank(log_det),
         T.rank(log_px) -
         (flow.get_x_event_ndims() - distribution.event_ndims))
     return -log_det + T.reduce_sum(
         log_px,
         T.int_range(
             -(flow.get_x_event_ndims() - distribution.event_ndims), 0))
示例#3
0
 def check_x(layer):
     y = layer(x)
     y_mean, y_var = T.calculate_mean_and_var(y,
                                              axis=T.int_range(
                                                  -T.rank(x),
                                                  -1))
     if use_bias:
         assert_allclose(y_mean,
                         T.zeros_like(y_mean),
                         atol=1e-6,
                         rtol=1e-4)
     assert_allclose(y_var,
                     T.ones_like(y_var),
                     atol=1e-6,
                     rtol=1e-4)
示例#4
0
def check_scale(ctx,
                scale: Scale,
                x,
                pre_scale,
                expected_y,
                expected_log_det):
    assert(T.shape(x) == T.shape(expected_log_det))

    # dimension error
    with pytest.raises(Exception,
                       match=r'`rank\(input\) >= event_ndims` does not hold'):
        _ = scale(T.random.randn([1]), T.random.randn([1]), event_ndims=2)

    with pytest.raises(Exception,
                       match=r'`rank\(input\) >= rank\(pre_scale\)` does not hold'):
        _ = scale(T.random.randn([1]), T.random.randn([1, 2]), event_ndims=1)

    with pytest.raises(Exception,
                       match=r'The shape of `input_log_det` is not expected'):
        _ = scale(T.random.randn([2, 3]),
                  T.random.randn([2, 3]),
                  event_ndims=1,
                  input_log_det=T.random.randn([3]))

    with pytest.raises(Exception,
                       match=r'The shape of `input_log_det` is not expected'):
        _ = scale(T.random.randn([2, 3]),
                  T.random.randn([2, 3]),
                  event_ndims=2,
                  input_log_det=T.random.randn([2]))

    # check call
    for event_ndims in range(T.rank(pre_scale), T.rank(x)):
        this_expected_log_det = T.reduce_sum(
            expected_log_det, axis=T.int_range(-event_ndims, 0))
        input_log_det = T.random.randn(T.shape(this_expected_log_det))

        # check no compute log_det
        y, log_det = scale(x, pre_scale, event_ndims=event_ndims,
                           compute_log_det=False)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        ctx.assertIsNone(log_det)

        # check compute log_det
        y, log_det = scale(x, pre_scale, event_ndims=event_ndims)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, this_expected_log_det, rtol=1e-4, atol=1e-6)

        # check compute log_det with input_log_det
        y, log_det = scale(
            x, pre_scale, event_ndims=event_ndims, input_log_det=input_log_det)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, input_log_det + this_expected_log_det,
                        rtol=1e-4, atol=1e-6)

        # check inverse, no compute log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True, compute_log_det=False)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        ctx.assertIsNone(log_det)

        # check inverse, compute log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, -this_expected_log_det, rtol=1e-4, atol=1e-6)

        # check inverse, compute log_det with input_log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True, input_log_det=input_log_det)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, input_log_det - this_expected_log_det,
                        rtol=1e-4, atol=1e-6)
示例#5
0
def check_distribution_instance(ctx,
                                d,
                                event_ndims,
                                batch_shape,
                                min_event_ndims,
                                max_event_ndims,
                                log_prob_fn,
                                transform_origin_distribution=None,
                                transform_origin_group_ndims=None,
                                **expected_attrs):
    ctx.assertLessEqual(max_event_ndims - event_ndims, d.batch_ndims)

    event_shape = expected_attrs.get('event_shape', None)
    ctx.assertEqual(d.min_event_ndims, min_event_ndims)
    ctx.assertEqual(d.value_ndims, len(batch_shape) + event_ndims)
    if event_shape is not None:
        ctx.assertEqual(d.value_shape, batch_shape + event_shape)
    ctx.assertEqual(d.batch_shape, batch_shape)
    ctx.assertEqual(d.batch_ndims, len(batch_shape))
    ctx.assertEqual(d.event_ndims, event_ndims)
    ctx.assertEqual(d.event_shape, event_shape)

    for attr, val in expected_attrs.items():
        ctx.assertEqual(getattr(d, attr), val)
    ctx.assertEqual(
        d.validate_tensors,
        expected_attrs.get('validate_tensors', settings.validate_tensors))

    # check sample
    for n_samples in (None, 5):
        for group_ndims in (None, 0, -(event_ndims - min_event_ndims),
                            max_event_ndims - event_ndims):
            for reparameterized2 in (None, True, False):
                if reparameterized2 and not d.reparameterized:
                    continue

                # sample()
                sample_kwargs = {}
                if n_samples is not None:
                    sample_kwargs['n_samples'] = n_samples
                    sample_shape = [n_samples]
                else:
                    sample_shape = []

                if group_ndims is not None:
                    sample_kwargs['group_ndims'] = group_ndims
                else:
                    group_ndims = 0

                if reparameterized2 is not None:
                    sample_kwargs['reparameterized'] = reparameterized2
                else:
                    reparameterized2 = d.reparameterized

                t = d.sample(**sample_kwargs)
                ctx.assertEqual(t.group_ndims, group_ndims)
                ctx.assertEqual(t.reparameterized, reparameterized2)
                ctx.assertEqual(T.rank(t.tensor),
                                d.value_ndims + len(sample_shape))
                ctx.assertEqual(
                    T.shape(t.tensor)[:(d.batch_ndims + len(sample_shape))],
                    sample_shape + d.batch_shape)

                if transform_origin_distribution is not None:
                    ctx.assertIsInstance(t.transform_origin, StochasticTensor)
                    ctx.assertIs(t.transform_origin.distribution,
                                 transform_origin_distribution)
                    ctx.assertIs(t.transform_origin.group_ndims,
                                 transform_origin_group_ndims)

                # log_prob()
                expected_log_prob = log_prob_fn(t)
                for group_ndims2 in (None, 0, -(event_ndims - min_event_ndims),
                                     max_event_ndims - event_ndims):
                    if group_ndims2 is not None:
                        log_prob_kwargs = {'group_ndims': group_ndims2}
                    else:
                        log_prob_kwargs = {}
                        group_ndims2 = group_ndims

                    log_prob = t.log_prob(**log_prob_kwargs)
                    ctx.assertEqual(
                        T.shape(log_prob),
                        T.shape(t.tensor)[:T.rank(t.tensor) -
                                          (group_ndims2 + event_ndims)])

                    assert_allclose(
                        log_prob,
                        T.reduce_sum(
                            expected_log_prob,
                            T.int_range(
                                -(group_ndims2 +
                                  (event_ndims - min_event_ndims)), 0)),
                        rtol=1e-4,
                        atol=1e-6,
                    )

                    prob = t.prob(**log_prob_kwargs)
                    assert_allclose(prob,
                                    T.exp(log_prob),
                                    rtol=1e-4,
                                    atol=1e-6)

                    if transform_origin_distribution is not None:
                        for p in (log_prob, prob):
                            ctx.assertIsInstance(p.transform_origin,
                                                 StochasticTensor)
                            ctx.assertIs(p.transform_origin.distribution,
                                         transform_origin_distribution)
                            ctx.assertIs(p.transform_origin.group_ndims,
                                         transform_origin_group_ndims)