Example #1
0
 def log_prob_fn(t):
     log_px = distribution.log_prob(t.transform_origin.tensor,
                                    group_ndims=0)
     y, log_det = flow(t.transform_origin.tensor)  # y and log |dy/dx|
     assert_allclose(y, t.tensor, atol=1e-4, rtol=1e-6)
     ctx.assertEqual(
         T.rank(log_det),
         T.rank(log_px) -
         (flow.get_x_event_ndims() - distribution.event_ndims))
     return -log_det + T.reduce_sum(
         log_px,
         T.int_range(
             -(flow.get_x_event_ndims() - distribution.event_ndims), 0))
Example #2
0
 def check_x(layer):
     y = layer(x)
     y_mean, y_var = T.calculate_mean_and_var(
         y, axis=[-T.rank(x)] + get_spatial_axis(spatial_ndims))
     if use_bias:
         assert_allclose(y_mean,
                         T.zeros_like(y_mean),
                         atol=1e-6,
                         rtol=1e-4)
     assert_allclose(y_var,
                     T.ones_like(y_var),
                     atol=1e-6,
                     rtol=1e-4)
Example #3
0
 def check_x(layer):
     y = layer(x)
     y_mean, y_var = T.calculate_mean_and_var(y,
                                              axis=T.int_range(
                                                  -T.rank(x),
                                                  -1))
     if use_bias:
         assert_allclose(y_mean,
                         T.zeros_like(y_mean),
                         atol=1e-6,
                         rtol=1e-4)
     assert_allclose(y_var,
                     T.ones_like(y_var),
                     atol=1e-6,
                     rtol=1e-4)
Example #4
0
def check_scale(ctx,
                scale: Scale,
                x,
                pre_scale,
                expected_y,
                expected_log_det):
    assert(T.shape(x) == T.shape(expected_log_det))

    # dimension error
    with pytest.raises(Exception,
                       match=r'`rank\(input\) >= event_ndims` does not hold'):
        _ = scale(T.random.randn([1]), T.random.randn([1]), event_ndims=2)

    with pytest.raises(Exception,
                       match=r'`rank\(input\) >= rank\(pre_scale\)` does not hold'):
        _ = scale(T.random.randn([1]), T.random.randn([1, 2]), event_ndims=1)

    with pytest.raises(Exception,
                       match=r'The shape of `input_log_det` is not expected'):
        _ = scale(T.random.randn([2, 3]),
                  T.random.randn([2, 3]),
                  event_ndims=1,
                  input_log_det=T.random.randn([3]))

    with pytest.raises(Exception,
                       match=r'The shape of `input_log_det` is not expected'):
        _ = scale(T.random.randn([2, 3]),
                  T.random.randn([2, 3]),
                  event_ndims=2,
                  input_log_det=T.random.randn([2]))

    # check call
    for event_ndims in range(T.rank(pre_scale), T.rank(x)):
        this_expected_log_det = T.reduce_sum(
            expected_log_det, axis=T.int_range(-event_ndims, 0))
        input_log_det = T.random.randn(T.shape(this_expected_log_det))

        # check no compute log_det
        y, log_det = scale(x, pre_scale, event_ndims=event_ndims,
                           compute_log_det=False)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        ctx.assertIsNone(log_det)

        # check compute log_det
        y, log_det = scale(x, pre_scale, event_ndims=event_ndims)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, this_expected_log_det, rtol=1e-4, atol=1e-6)

        # check compute log_det with input_log_det
        y, log_det = scale(
            x, pre_scale, event_ndims=event_ndims, input_log_det=input_log_det)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, input_log_det + this_expected_log_det,
                        rtol=1e-4, atol=1e-6)

        # check inverse, no compute log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True, compute_log_det=False)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        ctx.assertIsNone(log_det)

        # check inverse, compute log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, -this_expected_log_det, rtol=1e-4, atol=1e-6)

        # check inverse, compute log_det with input_log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True, input_log_det=input_log_det)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, input_log_det - this_expected_log_det,
                        rtol=1e-4, atol=1e-6)
Example #5
0
def check_distribution_instance(ctx,
                                d,
                                event_ndims,
                                batch_shape,
                                min_event_ndims,
                                max_event_ndims,
                                log_prob_fn,
                                transform_origin_distribution=None,
                                transform_origin_group_ndims=None,
                                **expected_attrs):
    ctx.assertLessEqual(max_event_ndims - event_ndims, d.batch_ndims)

    event_shape = expected_attrs.get('event_shape', None)
    ctx.assertEqual(d.min_event_ndims, min_event_ndims)
    ctx.assertEqual(d.value_ndims, len(batch_shape) + event_ndims)
    if event_shape is not None:
        ctx.assertEqual(d.value_shape, batch_shape + event_shape)
    ctx.assertEqual(d.batch_shape, batch_shape)
    ctx.assertEqual(d.batch_ndims, len(batch_shape))
    ctx.assertEqual(d.event_ndims, event_ndims)
    ctx.assertEqual(d.event_shape, event_shape)

    for attr, val in expected_attrs.items():
        ctx.assertEqual(getattr(d, attr), val)
    ctx.assertEqual(
        d.validate_tensors,
        expected_attrs.get('validate_tensors', settings.validate_tensors))

    # check sample
    for n_samples in (None, 5):
        for group_ndims in (None, 0, -(event_ndims - min_event_ndims),
                            max_event_ndims - event_ndims):
            for reparameterized2 in (None, True, False):
                if reparameterized2 and not d.reparameterized:
                    continue

                # sample()
                sample_kwargs = {}
                if n_samples is not None:
                    sample_kwargs['n_samples'] = n_samples
                    sample_shape = [n_samples]
                else:
                    sample_shape = []

                if group_ndims is not None:
                    sample_kwargs['group_ndims'] = group_ndims
                else:
                    group_ndims = 0

                if reparameterized2 is not None:
                    sample_kwargs['reparameterized'] = reparameterized2
                else:
                    reparameterized2 = d.reparameterized

                t = d.sample(**sample_kwargs)
                ctx.assertEqual(t.group_ndims, group_ndims)
                ctx.assertEqual(t.reparameterized, reparameterized2)
                ctx.assertEqual(T.rank(t.tensor),
                                d.value_ndims + len(sample_shape))
                ctx.assertEqual(
                    T.shape(t.tensor)[:(d.batch_ndims + len(sample_shape))],
                    sample_shape + d.batch_shape)

                if transform_origin_distribution is not None:
                    ctx.assertIsInstance(t.transform_origin, StochasticTensor)
                    ctx.assertIs(t.transform_origin.distribution,
                                 transform_origin_distribution)
                    ctx.assertIs(t.transform_origin.group_ndims,
                                 transform_origin_group_ndims)

                # log_prob()
                expected_log_prob = log_prob_fn(t)
                for group_ndims2 in (None, 0, -(event_ndims - min_event_ndims),
                                     max_event_ndims - event_ndims):
                    if group_ndims2 is not None:
                        log_prob_kwargs = {'group_ndims': group_ndims2}
                    else:
                        log_prob_kwargs = {}
                        group_ndims2 = group_ndims

                    log_prob = t.log_prob(**log_prob_kwargs)
                    ctx.assertEqual(
                        T.shape(log_prob),
                        T.shape(t.tensor)[:T.rank(t.tensor) -
                                          (group_ndims2 + event_ndims)])

                    assert_allclose(
                        log_prob,
                        T.reduce_sum(
                            expected_log_prob,
                            T.int_range(
                                -(group_ndims2 +
                                  (event_ndims - min_event_ndims)), 0)),
                        rtol=1e-4,
                        atol=1e-6,
                    )

                    prob = t.prob(**log_prob_kwargs)
                    assert_allclose(prob,
                                    T.exp(log_prob),
                                    rtol=1e-4,
                                    atol=1e-6)

                    if transform_origin_distribution is not None:
                        for p in (log_prob, prob):
                            ctx.assertIsInstance(p.transform_origin,
                                                 StochasticTensor)
                            ctx.assertIs(p.transform_origin.distribution,
                                         transform_origin_distribution)
                            ctx.assertIs(p.transform_origin.group_ndims,
                                         transform_origin_group_ndims)
Example #6
0
            def f(modules, self_module, self_weight, use_bias, normalizer,
                  activation, merge_mode):
                if n_partitions == 1:
                    adj = in_adj[0]
                else:
                    adj = in_adj[:len(modules)]

                out_shape = list(add_out_shape)
                if merge_mode == 'concat':
                    out_shape[feature_axis] *= len(modules) + int(
                        self_module is not None)

                bias_store = (SimpleParamStore(out_shape,
                                               initializer=tk.init.normal)
                              if use_bias else None)
                layer_kwargs = dict(self_module=self_module,
                                    self_weight=self_weight,
                                    bias_store=bias_store,
                                    normalizer=normalizer,
                                    activation=activation,
                                    merge_mode=merge_mode)

                layer = jit_compile(
                    cls(module=modules[0], **layer_kwargs) if n_partitions ==
                    1 else cls(modules=modules, **layer_kwargs))
                if isinstance(activation, type):
                    activation_layer = activation()
                else:
                    activation_layer = activation

                for x in inputs:
                    # test errors
                    if len(modules) > 1:
                        with pytest.raises(
                                Exception,
                                match=r'`adj` is expected to have .* element'
                                r'\(s\), but got .*'):
                            _ = layer(x, in_adj[:len(modules) - 1])

                    if T.rank(x) == value_ndims + 1:
                        with pytest.raises(
                                Exception,
                                match='`input` is expected to be at least .*d'
                        ):
                            _ = layer(x[0], adj)

                    # obtain valid output
                    y = layer(x, adj)
                    self.assertEqual(T.shape(y),
                                     T.shape(x)[:-value_ndims] + out_shape)

                    # compute the expected output
                    def g(m, x):
                        m_out, m_front = T.flatten_to_ndims(x, value_ndims + 1)
                        m_out = m(m_out)
                        m_out = T.unflatten_from_ndims(m_out, m_front)
                        return m_out

                    outs = []
                    for m, a in zip(modules, in_adj):
                        m_out = T.as_tensor(
                            np.reshape(
                                np.dot(T.sparse.to_numpy(a),
                                       T.to_numpy(x).reshape([50, -1])),
                                x.shape))
                        outs.append(g(m, m_out))

                    if self_module is not None:
                        outs.append(g(self_module, x))

                    if merge_mode == 'add':
                        out = T.add_n(outs)
                    elif merge_mode == 'concat':
                        out = T.concat(outs, axis=feature_axis)

                    if bias_store is not None:
                        out = out + bias_store()
                    if normalizer is not None:
                        out = normalizer(out)
                    if activation is not None:
                        out = activation_layer(out)

                    # assert the output is expected
                    assert_allclose_(y, out)
Example #7
0
    def do_check(batch_shape, scale_type, initialized, dtype):
        x = T.random.randn(make_conv_shape(batch_shape, num_features,
                                           [6, 7, 8][:spatial_ndims]),
                           dtype=dtype)

        # check construct
        flow = cls(num_features,
                   scale=scale_type,
                   initialized=initialized,
                   dtype=dtype)
        ctx.assertIn(f'num_features={num_features}', repr(flow))
        ctx.assertIn(f'axis={-(spatial_ndims + 1)}', repr(flow))
        ctx.assertIn(f'scale_type={scale_type!r}', repr(flow))
        flow = tk.layers.jit_compile(flow)

        # check initialize
        if not initialized:
            # must initialize with sufficient data
            with pytest.raises(Exception, match='with at least .* dimensions'):
                _ = flow(
                    T.random.randn(make_conv_shape([], num_features,
                                                   [6, 7, 8][:spatial_ndims]),
                                   dtype=dtype))

            # must initialize with inverse = Fale
            with pytest.raises(Exception,
                               match='`ActNorm` must be initialized with '
                               '`inverse = False`'):
                _ = flow(x, inverse=True)

            # do initialize
            y, _ = flow(x, compute_log_det=False)
            y_mean, y_var = T.calculate_mean_and_var(
                y, axis=[a for a in range(-T.rank(y), 0) if a != channel_axis])
            assert_allclose(y_mean,
                            T.zeros([num_features]),
                            rtol=1e-4,
                            atol=1e-6)
            assert_allclose(y_var,
                            T.ones([num_features]),
                            rtol=1e-4,
                            atol=1e-6)
        else:
            y, _ = flow(x, compute_log_det=False)
            assert_allclose(y, x, rtol=1e-4, atol=1e-6)

        # prepare for the expected result
        scale_obj = ExpScale() if scale_type == 'exp' else LinearScale()

        if T.IS_CHANNEL_LAST:
            aligned_shape = [num_features]
        else:
            aligned_shape = [num_features] + [1] * spatial_ndims
        bias = T.reshape(flow.bias, aligned_shape)
        pre_scale = T.reshape(flow.pre_scale, aligned_shape)

        expected_y, expected_log_det = scale_obj(x + bias,
                                                 pre_scale,
                                                 event_ndims=(spatial_ndims +
                                                              1),
                                                 compute_log_det=True)

        flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                            T.random.randn(T.shape(expected_log_det)))