def test_monto_carlo_objective(self):
        log_p, log_q = prepare_test_payload()

        obj = monte_carlo_objective(log_p, log_q, axis=[0])
        assert_allclose(
            T.reduce_mean(obj),
            monte_carlo_objective(log_p, log_q, axis=[0], reduction='mean'),
            rtol=1e-4, atol=1e-6
        )
        assert_allclose(
            T.reduce_sum(obj),
            monte_carlo_objective(log_p, log_q, axis=[0], reduction='sum'),
            rtol=1e-4, atol=1e-6
        )
        obj_shape = T.shape(obj)
        assert_allclose(obj, T.log_mean_exp(log_p - log_q, axis=[0]))

        obj_k = monte_carlo_objective(log_p, log_q, axis=[0], keepdims=True)
        assert_allclose(
            T.reduce_mean(obj_k),
            monte_carlo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='mean')
        )
        assert_allclose(
            T.reduce_sum(obj_k),
            monte_carlo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='sum')
        )
        self.assertListEqual([1] + obj_shape, T.shape(obj_k))
        assert_allclose(
            obj_k,
            T.log_mean_exp(log_p - log_q, axis=[0], keepdims=True)
        )
Exemple #2
0
    def test_invertible_matrices(self):
        for cls in (LooseInvertibleMatrix, StrictInvertibleMatrix):
            for n in [1, 3, 5]:
                m = cls(np.random.randn(n, n))
                self.assertEqual(repr(m), f'{cls.__qualname__}(size={n})')
                self.assertEqual(m.size, n)

                m = tk.layers.jit_compile(m)

                # check the initial value is an orthogonal matrix
                matrix, _ = m(inverse=False, compute_log_det=False)
                inv_matrix, _ = m(inverse=True, compute_log_det=False)
                assert_allclose(np.eye(n), T.matmul(matrix, inv_matrix),
                                rtol=1e-4, atol=1e-6)
                assert_allclose(np.eye(n), T.matmul(inv_matrix, matrix),
                                rtol=1e-4, atol=1e-6)

                # check the invertibility
                check_invertible_matrix(self, m, n)

                # check the gradient
                matrix, log_det = m(inverse=False, compute_log_det=True)
                params = list(tk.layers.iter_parameters(m))
                grads = T.grad(
                    [T.reduce_sum(matrix), T.reduce_sum(log_det)], params)

                # update with gradient, then check the invertibility
                if cls is StrictInvertibleMatrix:
                    for param, grad in zip(params, grads):
                        with T.no_grad():
                            T.assign(param, param + 0.001 * grad)
                    check_invertible_matrix(self, m, n)
Exemple #3
0
    def test_elbo(self):
        log_p, log_q = prepare_test_payload()

        obj = elbo_objective(log_p, log_q)
        assert_allclose(
            T.reduce_mean(obj),
            elbo_objective(log_p, log_q, reduction='mean')
        )
        assert_allclose(
            T.reduce_sum(obj),
            elbo_objective(log_p, log_q, reduction='sum')
        )
        obj_shape = T.shape(obj)
        assert_allclose(obj, log_p - log_q)

        obj_r = elbo_objective(log_p, log_q, axis=[0])
        self.assertListEqual(obj_shape[1:], T.shape(obj_r))
        assert_allclose(obj_r, T.reduce_mean(log_p - log_q, axis=[0]))

        obj_rk = elbo_objective(log_p, log_q, axis=[0], keepdims=True)
        assert_allclose(
            T.reduce_mean(obj_rk),
            elbo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='mean')
        )
        assert_allclose(
            T.reduce_sum(obj_rk),
            elbo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='sum')
        )
        self.assertListEqual([1] + obj_shape[1:], T.shape(obj_rk))
        assert_allclose(
            obj_rk,
            T.reduce_mean(log_p - log_q, axis=[0], keepdims=True)
        )
Exemple #4
0
    def test_iwae(self):
        assert_allclose_ = functools.partial(assert_allclose,
                                             rtol=1e-5,
                                             atol=1e-6)

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True)
        cost = iwae_estimator(log_f, axis=[0])
        assert_allclose_(-cost, iwae_estimator(log_f, axis=[0], negative=True))
        assert_allclose_(T.reduce_mean(cost),
                         iwae_estimator(log_f, axis=[0], reduction='mean'))
        assert_allclose_(T.reduce_sum(cost),
                         iwae_estimator(log_f, axis=[0], reduction='sum'))
        cost_shape = T.shape(cost)
        assert_allclose_(
            T.grad([T.reduce_sum(cost)], [y])[0],
            T.reduce_sum(wk_hat * (2 * x * y), axis=[0]))

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True)
        cost_k = iwae_estimator(log_f, axis=[0], keepdims=True)
        assert_allclose_(
            T.reduce_mean(cost_k),
            iwae_estimator(log_f, axis=[0], keepdims=True, reduction='mean'))
        assert_allclose_(
            T.reduce_sum(cost_k),
            iwae_estimator(log_f, axis=[0], keepdims=True, reduction='sum'))
        assert_allclose_(
            -cost_k,
            T.to_numpy(
                iwae_estimator(log_f, axis=[0], keepdims=True, negative=True)))
        self.assertListEqual([1] + cost_shape, T.shape(cost_k))
        assert_allclose_(
            T.grad([T.reduce_sum(cost_k)], [y])[0],
            T.reduce_sum(wk_hat * (2 * x * y), axis=[0]))
Exemple #5
0
    def test_monto_carlo_objective(self):
        log_p, log_q = prepare_test_payload()

        ll = importance_sampling_log_likelihood(log_p, log_q, axis=[0])
        ll_shape = T.shape(ll)
        assert_allclose_(ll, T.log_mean_exp(log_p - log_q, axis=[0]))
        assert_allclose_(
            T.reduce_mean(ll),
            importance_sampling_log_likelihood(log_p,
                                               log_q,
                                               axis=[0],
                                               reduction='mean'))
        assert_allclose_(
            T.reduce_sum(ll),
            importance_sampling_log_likelihood(log_p,
                                               log_q,
                                               axis=[0],
                                               reduction='sum'))

        ll_k = importance_sampling_log_likelihood(log_p,
                                                  log_q,
                                                  axis=[0],
                                                  keepdims=True)
        self.assertListEqual([1] + ll_shape, T.shape(ll_k))
        assert_allclose_(
            ll_k, T.log_mean_exp(log_p - log_q, axis=[0], keepdims=True))
Exemple #6
0
 def log_prob_fn(t):
     log_px = distribution.log_prob(t.transform_origin.tensor,
                                    group_ndims=0)
     y, log_det = flow(t.transform_origin.tensor)  # y and log |dy/dx|
     assert_allclose(y, t.tensor, atol=1e-4, rtol=1e-6)
     ctx.assertEqual(
         T.rank(log_det),
         T.rank(log_px) -
         (flow.get_x_event_ndims() - distribution.event_ndims))
     return -log_det + T.reduce_sum(
         log_px,
         T.int_range(
             -(flow.get_x_event_ndims() - distribution.event_ndims), 0))
Exemple #7
0
    def test_sgvb(self):
        assert_allclose_ = functools.partial(assert_allclose,
                                             rtol=1e-5,
                                             atol=1e-6)

        # default
        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        cost = sgvb_estimator(f)
        assert_allclose_(-cost, sgvb_estimator(f, negative=True))
        assert_allclose_(T.reduce_mean(cost),
                         sgvb_estimator(f, reduction='mean'))
        assert_allclose_(T.reduce_sum(cost), sgvb_estimator(f,
                                                            reduction='sum'))
        cost_shape = T.shape(cost)
        assert_allclose_(
            T.grad([T.reduce_sum(cost)], [y])[0],
            T.reduce_sum(2 * x * y * f, axis=[0]))

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        cost_r = sgvb_estimator(f, axis=[0])
        assert_allclose_(-cost_r, sgvb_estimator(f, axis=[0], negative=True))
        self.assertListEqual(cost_shape[1:], T.shape(cost_r))
        assert_allclose_(
            T.grad([T.reduce_sum(cost_r)], [y])[0],
            T.reduce_sum(2 * x * y * f, axis=[0]) / 7)

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        cost_rk = sgvb_estimator(f, axis=[0], keepdims=True)
        assert_allclose_(T.reduce_mean(cost_rk),
                         sgvb_estimator(f, axis=[0], reduction='mean'))
        assert_allclose_(T.reduce_sum(cost_rk),
                         sgvb_estimator(f, axis=[0], reduction='sum'))
        assert_allclose_(
            -cost_rk, sgvb_estimator(f, axis=[0], keepdims=True,
                                     negative=True))
        self.assertListEqual([1] + cost_shape[1:], T.shape(cost_rk))
        assert_allclose_(
            T.grad([T.reduce_sum(cost_rk)], [y])[0],
            T.reduce_sum(2 * x * y * f, axis=[0]) / 7)
Exemple #8
0
def check_invertible_linear(ctx,
                            spatial_ndims: int,
                            invertible_linear_factory,
                            linear_factory,
                            strict: bool,):
    batch_shape = [2]
    num_features = 4
    spatial_shape = [5, 6, 7][:spatial_ndims]
    x = T.random.randn(make_conv_shape(
        batch_shape, num_features, spatial_shape))

    # construct the layer
    flow = invertible_linear_factory(num_features, strict=strict)
    ctx.assertIn(f'num_features={num_features}', repr(flow))
    flow = tk.layers.jit_compile(flow)

    # derive the expected answer
    weight, log_det = flow.invertible_matrix(
        inverse=False, compute_log_det=True)
    linear_kwargs = {}
    if spatial_ndims > 0:
        linear_kwargs['kernel_size'] = 1
    linear = linear_factory(
        num_features, num_features,
        weight_init=T.reshape(weight, T.shape(weight) + [1] * spatial_ndims),
        use_bias=False,
        **linear_kwargs
    )
    x_flatten, front_shape = T.flatten_to_ndims(x, spatial_ndims + 2)
    expected_y = T.unflatten_from_ndims(linear(x_flatten), front_shape)
    expected_log_det = T.expand(
        T.reduce_sum(T.expand(log_det, spatial_shape)), batch_shape)

    # check the invertible layer
    flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                        T.random.randn(batch_shape))
Exemple #9
0
def check_scale(ctx,
                scale: Scale,
                x,
                pre_scale,
                expected_y,
                expected_log_det):
    assert(T.shape(x) == T.shape(expected_log_det))

    # dimension error
    with pytest.raises(Exception,
                       match=r'`rank\(input\) >= event_ndims` does not hold'):
        _ = scale(T.random.randn([1]), T.random.randn([1]), event_ndims=2)

    with pytest.raises(Exception,
                       match=r'`rank\(input\) >= rank\(pre_scale\)` does not hold'):
        _ = scale(T.random.randn([1]), T.random.randn([1, 2]), event_ndims=1)

    with pytest.raises(Exception,
                       match=r'The shape of `input_log_det` is not expected'):
        _ = scale(T.random.randn([2, 3]),
                  T.random.randn([2, 3]),
                  event_ndims=1,
                  input_log_det=T.random.randn([3]))

    with pytest.raises(Exception,
                       match=r'The shape of `input_log_det` is not expected'):
        _ = scale(T.random.randn([2, 3]),
                  T.random.randn([2, 3]),
                  event_ndims=2,
                  input_log_det=T.random.randn([2]))

    # check call
    for event_ndims in range(T.rank(pre_scale), T.rank(x)):
        this_expected_log_det = T.reduce_sum(
            expected_log_det, axis=T.int_range(-event_ndims, 0))
        input_log_det = T.random.randn(T.shape(this_expected_log_det))

        # check no compute log_det
        y, log_det = scale(x, pre_scale, event_ndims=event_ndims,
                           compute_log_det=False)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        ctx.assertIsNone(log_det)

        # check compute log_det
        y, log_det = scale(x, pre_scale, event_ndims=event_ndims)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, this_expected_log_det, rtol=1e-4, atol=1e-6)

        # check compute log_det with input_log_det
        y, log_det = scale(
            x, pre_scale, event_ndims=event_ndims, input_log_det=input_log_det)
        assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, input_log_det + this_expected_log_det,
                        rtol=1e-4, atol=1e-6)

        # check inverse, no compute log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True, compute_log_det=False)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        ctx.assertIsNone(log_det)

        # check inverse, compute log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, -this_expected_log_det, rtol=1e-4, atol=1e-6)

        # check inverse, compute log_det with input_log_det
        inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims,
                               inverse=True, input_log_det=input_log_det)
        assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(log_det, input_log_det - this_expected_log_det,
                        rtol=1e-4, atol=1e-6)
Exemple #10
0
    def test_sample_and_log_prob(self):
        array_low = np.random.randn(2, 1)
        array_high = np.exp(np.random.randn(1, 3)) + 1.
        log_zero = -1e6

        def log_prob(x, low, high, group_ndims=0):
            if low is None and high is None:
                low, high = 0., 1.
            log_pdf = -np.log(np.ones_like(x) * (high - low))
            log_pdf = np.where(np.logical_and(low <= x, x <= high), log_pdf,
                               log_zero)
            log_pdf = np.sum(log_pdf, axis=tuple(range(-group_ndims, 0)))
            return log_pdf

        for shape, dtype, (low, high), event_ndims in product(
            [None, [], [5, 4]], float_dtypes, [(None, None), (-1., 2.),
                                               (array_low, array_high)],
                range(5)):
            low_rank = len(np.shape(low)) if low is not None else 0
            if event_ndims > len(shape or []) + low_rank:
                continue

            if isinstance(low, np.ndarray):
                low_t = T.as_tensor(low, dtype=dtype)
                high_t = T.as_tensor(high, dtype=dtype)
                uniform = Uniform(shape=shape,
                                  low=low_t,
                                  high=high_t,
                                  event_ndims=event_ndims,
                                  log_zero=log_zero)
                value_shape = (shape or []) + [2, 3]
                self.assertIs(uniform.low, low_t)
                self.assertIs(uniform.high, high_t)
            else:
                uniform = Uniform(shape=shape,
                                  low=low,
                                  high=high,
                                  dtype=dtype,
                                  event_ndims=event_ndims,
                                  log_zero=log_zero)
                value_shape = shape or []
                self.assertEqual(uniform.low, low)
                self.assertEqual(uniform.high, high)
            self.assertEqual(uniform.log_zero, log_zero)
            self.assertEqual(uniform.value_shape, value_shape)

            # sample(n_samples=None)
            t = uniform.sample()
            x = T.to_numpy(t.tensor)
            sample_shape = value_shape
            self.assertIsInstance(t, StochasticTensor)
            self.assertIs(t.distribution, uniform)
            self.assertEqual(T.get_dtype(t.tensor), dtype)
            self.assertEqual(t.n_samples, None)
            self.assertEqual(t.group_ndims, 0)
            self.assertEqual(t.reparameterized, True)
            self.assertIsInstance(t.tensor, T.Tensor)
            self.assertEqual(T.shape(t.tensor), sample_shape)

            for log_pdf in [t.log_prob(), uniform.log_prob(t)]:
                self.assertEqual(T.get_dtype(log_pdf), dtype)
                assert_allclose(log_pdf,
                                log_prob(x, low, high, event_ndims),
                                rtol=1e-4)

            # test log-prob on out-of-range values
            assert_allclose(
                uniform.log_prob(t.tensor * 10.),
                log_prob(x * 10., low, high, event_ndims),
                rtol=1e-4,
            )

            # sample(n_samples=7)
            if event_ndims >= 1:
                t = uniform.sample(n_samples=7,
                                   group_ndims=-1,
                                   reparameterized=False)
                x = T.to_numpy(t.tensor)
                sample_shape = [7] + value_shape
                self.assertIsInstance(t, StochasticTensor)
                self.assertIs(t.distribution, uniform)
                self.assertEqual(T.get_dtype(t.tensor), dtype)
                self.assertEqual(t.n_samples, 7)
                self.assertEqual(t.group_ndims, -1)
                self.assertEqual(t.reparameterized, False)
                self.assertIsInstance(t.tensor, T.Tensor)
                self.assertEqual(T.shape(t.tensor), sample_shape)
                reduce_ndims = event_ndims - 1

                for log_pdf in [
                        t.log_prob(),
                        uniform.log_prob(t, group_ndims=-1)
                ]:
                    self.assertEqual(T.get_dtype(log_pdf), dtype)
                    assert_allclose(log_pdf,
                                    log_prob(x, low, high, reduce_ndims),
                                    rtol=1e-4)

        # test reparameterized
        low_t = T.requires_grad(T.as_tensor(array_low))
        high_t = T.requires_grad(T.as_tensor(array_high))
        uniform = Uniform(low=low_t, high=high_t)

        t = uniform.sample()
        self.assertTrue(t.reparameterized)
        u = (T.to_numpy(t.tensor) - array_low) / (array_high - array_low)
        [low_grad, high_grad] = T.grad([T.reduce_sum(t.tensor)],
                                       [low_t, high_t])
        assert_allclose(low_grad,
                        np.sum(1. - u, axis=-1, keepdims=True),
                        rtol=1e-4)
        assert_allclose(high_grad, np.sum(u, axis=0, keepdims=True), rtol=1e-4)

        t = uniform.sample(reparameterized=False)
        w_t = T.requires_grad(T.as_tensor(np.random.randn(2, 3)))
        self.assertFalse(t.reparameterized)
        [low_grad, high_grad] = T.grad([T.reduce_sum(w_t * t.tensor)],
                                       [low_t, high_t],
                                       allow_unused=True)
        self.assertTrue(T.is_null_grad(low_t, low_grad))
        self.assertTrue(T.is_null_grad(high_t, high_grad))
Exemple #11
0
def check_distribution_instance(ctx,
                                d,
                                event_ndims,
                                batch_shape,
                                min_event_ndims,
                                max_event_ndims,
                                log_prob_fn,
                                transform_origin_distribution=None,
                                transform_origin_group_ndims=None,
                                **expected_attrs):
    ctx.assertLessEqual(max_event_ndims - event_ndims, d.batch_ndims)

    event_shape = expected_attrs.get('event_shape', None)
    ctx.assertEqual(d.min_event_ndims, min_event_ndims)
    ctx.assertEqual(d.value_ndims, len(batch_shape) + event_ndims)
    if event_shape is not None:
        ctx.assertEqual(d.value_shape, batch_shape + event_shape)
    ctx.assertEqual(d.batch_shape, batch_shape)
    ctx.assertEqual(d.batch_ndims, len(batch_shape))
    ctx.assertEqual(d.event_ndims, event_ndims)
    ctx.assertEqual(d.event_shape, event_shape)

    for attr, val in expected_attrs.items():
        ctx.assertEqual(getattr(d, attr), val)
    ctx.assertEqual(
        d.validate_tensors,
        expected_attrs.get('validate_tensors', settings.validate_tensors))

    # check sample
    for n_samples in (None, 5):
        for group_ndims in (None, 0, -(event_ndims - min_event_ndims),
                            max_event_ndims - event_ndims):
            for reparameterized2 in (None, True, False):
                if reparameterized2 and not d.reparameterized:
                    continue

                # sample()
                sample_kwargs = {}
                if n_samples is not None:
                    sample_kwargs['n_samples'] = n_samples
                    sample_shape = [n_samples]
                else:
                    sample_shape = []

                if group_ndims is not None:
                    sample_kwargs['group_ndims'] = group_ndims
                else:
                    group_ndims = 0

                if reparameterized2 is not None:
                    sample_kwargs['reparameterized'] = reparameterized2
                else:
                    reparameterized2 = d.reparameterized

                t = d.sample(**sample_kwargs)
                ctx.assertEqual(t.group_ndims, group_ndims)
                ctx.assertEqual(t.reparameterized, reparameterized2)
                ctx.assertEqual(T.rank(t.tensor),
                                d.value_ndims + len(sample_shape))
                ctx.assertEqual(
                    T.shape(t.tensor)[:(d.batch_ndims + len(sample_shape))],
                    sample_shape + d.batch_shape)

                if transform_origin_distribution is not None:
                    ctx.assertIsInstance(t.transform_origin, StochasticTensor)
                    ctx.assertIs(t.transform_origin.distribution,
                                 transform_origin_distribution)
                    ctx.assertIs(t.transform_origin.group_ndims,
                                 transform_origin_group_ndims)

                # log_prob()
                expected_log_prob = log_prob_fn(t)
                for group_ndims2 in (None, 0, -(event_ndims - min_event_ndims),
                                     max_event_ndims - event_ndims):
                    if group_ndims2 is not None:
                        log_prob_kwargs = {'group_ndims': group_ndims2}
                    else:
                        log_prob_kwargs = {}
                        group_ndims2 = group_ndims

                    log_prob = t.log_prob(**log_prob_kwargs)
                    ctx.assertEqual(
                        T.shape(log_prob),
                        T.shape(t.tensor)[:T.rank(t.tensor) -
                                          (group_ndims2 + event_ndims)])

                    assert_allclose(
                        log_prob,
                        T.reduce_sum(
                            expected_log_prob,
                            T.int_range(
                                -(group_ndims2 +
                                  (event_ndims - min_event_ndims)), 0)),
                        rtol=1e-4,
                        atol=1e-6,
                    )

                    prob = t.prob(**log_prob_kwargs)
                    assert_allclose(prob,
                                    T.exp(log_prob),
                                    rtol=1e-4,
                                    atol=1e-6)

                    if transform_origin_distribution is not None:
                        for p in (log_prob, prob):
                            ctx.assertIsInstance(p.transform_origin,
                                                 StochasticTensor)
                            ctx.assertIs(p.transform_origin.distribution,
                                         transform_origin_distribution)
                            ctx.assertIs(p.transform_origin.group_ndims,
                                         transform_origin_group_ndims)
Exemple #12
0
    def test_embedding(self):
        n_channels = 3
        n_embeddings = n_samples

        for spatial_ndims in (0, 1, 2, 3):
            w_shape = make_conv_shape([], n_channels, [4, 5,
                                                       6][:spatial_ndims])
            w_size = int(np.prod(w_shape))

            layer = getattr(tk.layers,
                            (f'Embedding{spatial_ndims}d' if spatial_ndims > 0
                             else 'Embedding'))(n_embeddings, w_shape)
            weight = layer.weight

            # check the weight
            self.assertEqual(T.shape(weight), [n_embeddings] + w_shape)
            reduce_axis = list(range(len(w_shape) + 1))
            reduce_axis.pop(-1 if T.IS_CHANNEL_LAST else 1)
            w_mean = np.average(T.to_numpy(weight), axis=tuple(reduce_axis))
            np.testing.assert_array_less(
                w_mean, 3. / np.sqrt(n_samples * w_size / n_channels))

            # check the output
            layer = jit_compile(layer)
            weight_array = T.to_numpy(weight)

            for in_shape in ([7], [7, 8]):
                indices = T.random.randint(0, n_samples, in_shape)
                indices = T.concat([indices, indices[:3]], axis=0)

                # check the output
                output = layer(indices)
                assert_allclose(output, T.embedding(weight, indices))

                # check the grad
                if spatial_ndims in (0, 1):
                    out_sum = T.reduce_sum(output**2)
                    [grad] = T.grad([out_sum], [weight])
                    expected_grad = np.zeros(T.shape(weight))
                    for idx in T.to_numpy(indices).reshape([-1]):
                        expected_grad[idx] += 2. * weight_array[idx]
                    assert_allclose(grad, expected_grad)

            # test the constructor error
            if spatial_ndims > 0:
                with pytest.raises(
                        ValueError,
                        match=f'`embedding_size` must be a int list '
                        f'with {spatial_ndims + 1} elements'):
                    _ = getattr(tk.layers,
                                f'Embedding{spatial_ndims}d')(n_embeddings,
                                                              w_shape[:-1])

        # test no grad
        layer = Embedding(n_embeddings, n_channels, freeze=True)
        weight = layer.weight
        self.assertEqual(T.shape(weight), [n_embeddings, n_channels])

        layer = jit_compile(layer)
        indices = T.random.randint(0, n_samples, [7, 8])
        output = layer(indices)
        assert_allclose(output, T.embedding(weight, indices))

        out_sum = T.reduce_sum(output**2)
        try:
            [grad] = T.grad([out_sum], [weight])
        except Exception:
            pass
        else:
            self.assertTrue(T.is_null_grad(weight, grad))

        # test errors
        with pytest.raises(ValueError,
                           match='`embedding_size` must not be empty'):
            _ = Embedding(n_embeddings, [])
Exemple #13
0
    def test_basic_interface(self):
        normal = UnitNormal(shape=[2, 3])
        samples = normal.sample(n_samples=5)
        samples_0 = samples.tensor[0]
        samples_no_grad = T.stop_grad(samples.tensor)
        log_prob = normal.log_prob(samples.tensor, group_ndims=0)
        log_prob_reduce_1 = T.reduce_sum(log_prob, axis=[-1])

        ##
        t = StochasticTensor(
            tensor=samples_no_grad,
            distribution=normal,
            n_samples=5,
            group_ndims=0,
            reparameterized=False,
        )

        self.assertIs(t.tensor, samples_no_grad)
        self.assertIs(t.distribution, normal)
        self.assertEqual(t.n_samples, 5)
        self.assertEqual(t.group_ndims, 0)
        self.assertEqual(t.reparameterized, False)
        self.assertIsNone(t.transform_origin)
        self.assertIsNone(t._cached_log_prob)
        self.assertIsNone(t._cached_prob)

        self.assertEqual(repr(t), f'StochasticTensor({t.tensor!r})')
        self.assertEqual(hash(t), hash(t))
        self.assertEqual(t, t)
        self.assertNotEqual(
            t,
            StochasticTensor(
                tensor=samples_0,
                distribution=normal,
                n_samples=5,
                group_ndims=0,
                reparameterized=False,
            ))
        self.assertEqual(t.continuous, True)

        # log_prob()
        this_log_prob = t.log_prob()
        self.assertIs(t._cached_log_prob, this_log_prob)
        self.assertIs(t.log_prob(), t._cached_log_prob)
        self.assertIs(t.log_prob(group_ndims=0), t._cached_log_prob)
        assert_allclose(this_log_prob, log_prob, rtol=1e-4)

        this_log_prob = t.log_prob(group_ndims=1)
        self.assertIsNot(this_log_prob, t._cached_log_prob)
        assert_allclose(this_log_prob, log_prob_reduce_1, rtol=1e-4)

        # prob()
        this_prob = t.prob()
        self.assertIs(t._cached_prob, this_prob)
        self.assertIs(t.prob(), t._cached_prob)
        self.assertIs(t.prob(group_ndims=0), t._cached_prob)
        assert_allclose(this_prob, np.exp(T.to_numpy(log_prob)), rtol=1e-4)

        this_prob = t.prob(group_ndims=1)
        self.assertIsNot(this_prob, t._cached_prob)
        assert_allclose(this_prob,
                        np.exp(T.to_numpy(log_prob_reduce_1)),
                        rtol=1e-4)

        ##
        normal.continuous = False
        t = StochasticTensor(
            tensor=samples_0,
            distribution=normal,
            n_samples=None,
            group_ndims=1,
            reparameterized=True,
            log_prob=log_prob_reduce_1,
            transform_origin=samples,
        )
        self.assertEqual(t.continuous, False)

        self.assertIs(t.tensor, samples_0)
        self.assertIs(t.distribution, normal)
        self.assertEqual(t.n_samples, None)
        self.assertEqual(t.group_ndims, 1)
        self.assertEqual(t.reparameterized, True)
        self.assertIs(t.transform_origin, samples)
        self.assertIs(t._cached_log_prob, log_prob_reduce_1)
        self.assertIsNone(t._cached_prob)