Пример #1
0
    def test_iwae(self):
        assert_allclose_ = functools.partial(assert_allclose,
                                             rtol=1e-5,
                                             atol=1e-6)

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True)
        cost = iwae_estimator(log_f, axis=[0])
        assert_allclose_(-cost, iwae_estimator(log_f, axis=[0], negative=True))
        assert_allclose_(T.reduce_mean(cost),
                         iwae_estimator(log_f, axis=[0], reduction='mean'))
        assert_allclose_(T.reduce_sum(cost),
                         iwae_estimator(log_f, axis=[0], reduction='sum'))
        cost_shape = T.shape(cost)
        assert_allclose_(
            T.grad([T.reduce_sum(cost)], [y])[0],
            T.reduce_sum(wk_hat * (2 * x * y), axis=[0]))

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True)
        cost_k = iwae_estimator(log_f, axis=[0], keepdims=True)
        assert_allclose_(
            T.reduce_mean(cost_k),
            iwae_estimator(log_f, axis=[0], keepdims=True, reduction='mean'))
        assert_allclose_(
            T.reduce_sum(cost_k),
            iwae_estimator(log_f, axis=[0], keepdims=True, reduction='sum'))
        assert_allclose_(
            -cost_k,
            T.to_numpy(
                iwae_estimator(log_f, axis=[0], keepdims=True, negative=True)))
        self.assertListEqual([1] + cost_shape, T.shape(cost_k))
        assert_allclose_(
            T.grad([T.reduce_sum(cost_k)], [y])[0],
            T.reduce_sum(wk_hat * (2 * x * y), axis=[0]))
Пример #2
0
    def test_invertible_matrices(self):
        for cls in (LooseInvertibleMatrix, StrictInvertibleMatrix):
            for n in [1, 3, 5]:
                m = cls(np.random.randn(n, n))
                self.assertEqual(repr(m), f'{cls.__qualname__}(size={n})')
                self.assertEqual(m.size, n)

                m = tk.layers.jit_compile(m)

                # check the initial value is an orthogonal matrix
                matrix, _ = m(inverse=False, compute_log_det=False)
                inv_matrix, _ = m(inverse=True, compute_log_det=False)
                assert_allclose(np.eye(n), T.matmul(matrix, inv_matrix),
                                rtol=1e-4, atol=1e-6)
                assert_allclose(np.eye(n), T.matmul(inv_matrix, matrix),
                                rtol=1e-4, atol=1e-6)

                # check the invertibility
                check_invertible_matrix(self, m, n)

                # check the gradient
                matrix, log_det = m(inverse=False, compute_log_det=True)
                params = list(tk.layers.iter_parameters(m))
                grads = T.grad(
                    [T.reduce_sum(matrix), T.reduce_sum(log_det)], params)

                # update with gradient, then check the invertibility
                if cls is StrictInvertibleMatrix:
                    for param, grad in zip(params, grads):
                        with T.no_grad():
                            T.assign(param, param + 0.001 * grad)
                    check_invertible_matrix(self, m, n)
Пример #3
0
    def test_grad(self):
        def f():
            indices = T.as_tensor([[0, 1], [1, 0]])
            values = T.requires_grad(T.as_tensor([0.1, 0.2]))
            shape = [2, 2]
            x = T.sparse.make_sparse(indices, values, shape=shape)
            return values, x

        values, x = f()
        y = T.sparse.reduce_sum(x * x)
        [grad] = T.grad([y], [values])
        assert_allclose(grad, 2 * values, atol=1e-4, rtol=1e-6)

        values, x = f()
        y = T.sparse.reduce_sum(T.sparse.stop_grad(x) * x)
        [grad] = T.grad([y], [values])
        assert_allclose(grad, values, atol=1e-4, rtol=1e-6)
Пример #4
0
    def test_sgvb(self):
        assert_allclose_ = functools.partial(assert_allclose,
                                             rtol=1e-5,
                                             atol=1e-6)

        # default
        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        cost = sgvb_estimator(f)
        assert_allclose_(-cost, sgvb_estimator(f, negative=True))
        assert_allclose_(T.reduce_mean(cost),
                         sgvb_estimator(f, reduction='mean'))
        assert_allclose_(T.reduce_sum(cost), sgvb_estimator(f,
                                                            reduction='sum'))
        cost_shape = T.shape(cost)
        assert_allclose_(
            T.grad([T.reduce_sum(cost)], [y])[0],
            T.reduce_sum(2 * x * y * f, axis=[0]))

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        cost_r = sgvb_estimator(f, axis=[0])
        assert_allclose_(-cost_r, sgvb_estimator(f, axis=[0], negative=True))
        self.assertListEqual(cost_shape[1:], T.shape(cost_r))
        assert_allclose_(
            T.grad([T.reduce_sum(cost_r)], [y])[0],
            T.reduce_sum(2 * x * y * f, axis=[0]) / 7)

        x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True)
        cost_rk = sgvb_estimator(f, axis=[0], keepdims=True)
        assert_allclose_(T.reduce_mean(cost_rk),
                         sgvb_estimator(f, axis=[0], reduction='mean'))
        assert_allclose_(T.reduce_sum(cost_rk),
                         sgvb_estimator(f, axis=[0], reduction='sum'))
        assert_allclose_(
            -cost_rk, sgvb_estimator(f, axis=[0], keepdims=True,
                                     negative=True))
        self.assertListEqual([1] + cost_shape[1:], T.shape(cost_rk))
        assert_allclose_(
            T.grad([T.reduce_sum(cost_rk)], [y])[0],
            T.reduce_sum(2 * x * y * f, axis=[0]) / 7)
Пример #5
0
    def test_sample_and_log_prob(self):
        array_low = np.random.randn(2, 1)
        array_high = np.exp(np.random.randn(1, 3)) + 1.
        log_zero = -1e6

        def log_prob(x, low, high, group_ndims=0):
            if low is None and high is None:
                low, high = 0., 1.
            log_pdf = -np.log(np.ones_like(x) * (high - low))
            log_pdf = np.where(np.logical_and(low <= x, x <= high), log_pdf,
                               log_zero)
            log_pdf = np.sum(log_pdf, axis=tuple(range(-group_ndims, 0)))
            return log_pdf

        for shape, dtype, (low, high), event_ndims in product(
            [None, [], [5, 4]], float_dtypes, [(None, None), (-1., 2.),
                                               (array_low, array_high)],
                range(5)):
            low_rank = len(np.shape(low)) if low is not None else 0
            if event_ndims > len(shape or []) + low_rank:
                continue

            if isinstance(low, np.ndarray):
                low_t = T.as_tensor(low, dtype=dtype)
                high_t = T.as_tensor(high, dtype=dtype)
                uniform = Uniform(shape=shape,
                                  low=low_t,
                                  high=high_t,
                                  event_ndims=event_ndims,
                                  log_zero=log_zero)
                value_shape = (shape or []) + [2, 3]
                self.assertIs(uniform.low, low_t)
                self.assertIs(uniform.high, high_t)
            else:
                uniform = Uniform(shape=shape,
                                  low=low,
                                  high=high,
                                  dtype=dtype,
                                  event_ndims=event_ndims,
                                  log_zero=log_zero)
                value_shape = shape or []
                self.assertEqual(uniform.low, low)
                self.assertEqual(uniform.high, high)
            self.assertEqual(uniform.log_zero, log_zero)
            self.assertEqual(uniform.value_shape, value_shape)

            # sample(n_samples=None)
            t = uniform.sample()
            x = T.to_numpy(t.tensor)
            sample_shape = value_shape
            self.assertIsInstance(t, StochasticTensor)
            self.assertIs(t.distribution, uniform)
            self.assertEqual(T.get_dtype(t.tensor), dtype)
            self.assertEqual(t.n_samples, None)
            self.assertEqual(t.group_ndims, 0)
            self.assertEqual(t.reparameterized, True)
            self.assertIsInstance(t.tensor, T.Tensor)
            self.assertEqual(T.shape(t.tensor), sample_shape)

            for log_pdf in [t.log_prob(), uniform.log_prob(t)]:
                self.assertEqual(T.get_dtype(log_pdf), dtype)
                assert_allclose(log_pdf,
                                log_prob(x, low, high, event_ndims),
                                rtol=1e-4)

            # test log-prob on out-of-range values
            assert_allclose(
                uniform.log_prob(t.tensor * 10.),
                log_prob(x * 10., low, high, event_ndims),
                rtol=1e-4,
            )

            # sample(n_samples=7)
            if event_ndims >= 1:
                t = uniform.sample(n_samples=7,
                                   group_ndims=-1,
                                   reparameterized=False)
                x = T.to_numpy(t.tensor)
                sample_shape = [7] + value_shape
                self.assertIsInstance(t, StochasticTensor)
                self.assertIs(t.distribution, uniform)
                self.assertEqual(T.get_dtype(t.tensor), dtype)
                self.assertEqual(t.n_samples, 7)
                self.assertEqual(t.group_ndims, -1)
                self.assertEqual(t.reparameterized, False)
                self.assertIsInstance(t.tensor, T.Tensor)
                self.assertEqual(T.shape(t.tensor), sample_shape)
                reduce_ndims = event_ndims - 1

                for log_pdf in [
                        t.log_prob(),
                        uniform.log_prob(t, group_ndims=-1)
                ]:
                    self.assertEqual(T.get_dtype(log_pdf), dtype)
                    assert_allclose(log_pdf,
                                    log_prob(x, low, high, reduce_ndims),
                                    rtol=1e-4)

        # test reparameterized
        low_t = T.requires_grad(T.as_tensor(array_low))
        high_t = T.requires_grad(T.as_tensor(array_high))
        uniform = Uniform(low=low_t, high=high_t)

        t = uniform.sample()
        self.assertTrue(t.reparameterized)
        u = (T.to_numpy(t.tensor) - array_low) / (array_high - array_low)
        [low_grad, high_grad] = T.grad([T.reduce_sum(t.tensor)],
                                       [low_t, high_t])
        assert_allclose(low_grad,
                        np.sum(1. - u, axis=-1, keepdims=True),
                        rtol=1e-4)
        assert_allclose(high_grad, np.sum(u, axis=0, keepdims=True), rtol=1e-4)

        t = uniform.sample(reparameterized=False)
        w_t = T.requires_grad(T.as_tensor(np.random.randn(2, 3)))
        self.assertFalse(t.reparameterized)
        [low_grad, high_grad] = T.grad([T.reduce_sum(w_t * t.tensor)],
                                       [low_t, high_t],
                                       allow_unused=True)
        self.assertTrue(T.is_null_grad(low_t, low_grad))
        self.assertTrue(T.is_null_grad(high_t, high_grad))
Пример #6
0
        def do_test(low, high, dtype):
            # test(n_samples=n_samples)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t,
                                          std_t,
                                          n_samples=n_samples,
                                          low=low,
                                          high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test(n_samples=None)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t, std_t, low=low, high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test reparameterized
            w = np.random.randn(2, 3, 4)

            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

            # test not reparameterized
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(
                mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

            # given has lower rank than params, broadcasted to match param
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            assert_allclose(T.random.truncated_normal_log_pdf(
                T.float_scalar(0.),
                mean_t,
                std_t,
                logstd_t,
                low=low,
                high=high,
                log_zero=log_zero),
                            log_prob(0., low=low, high=high),
                            rtol=1e-4)

            # dtype mismatch
            with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
                _ = T.random.truncated_normal(T.as_tensor(mean, T.float32),
                                              T.as_tensor(std, T.float64),
                                              low=low,
                                              high=high)

            # check numerics
            mean_t = T.as_tensor(mean)
            std_t = T.zeros_like(mean_t)
            logstd_t = T.as_tensor(T.log(std_t))
            t = T.random.normal(mean_t, std_t)
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = T.random.truncated_normal_log_pdf(t,
                                                      mean_t,
                                                      std_t,
                                                      logstd_t,
                                                      validate_tensors=True)
Пример #7
0
    def test_normal(self):
        mean = np.random.randn(2, 3, 4)
        logstd = np.random.randn(3, 4)
        std = np.exp(logstd)

        def log_prob(given):
            # np.log(np.exp(-(given - mean) ** 2 / (2. * std ** 2)) /
            #        (np.sqrt(2 * np.pi) * std))
            return (-(given - mean)**2 * (0.5 * np.exp(-2. * logstd)) -
                    np.log(np.sqrt(2 * np.pi)) - logstd)

        # test n_samples by manual expanding the param shape
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.cast(T.expand(T.as_tensor(mean), [n_samples, 2, 3, 4]),
                            dtype)
            std_t = T.cast(T.expand(T.as_tensor(std), [n_samples, 1, 3, 4]),
                           dtype)
            logstd_t = T.cast(
                T.expand(T.as_tensor(logstd), [n_samples, 1, 3, 4]), dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test with n_samples
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t, n_samples=n_samples)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test no n_samples
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test log_prob
            x = T.to_numpy(t)
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test reparameterized
        w = np.random.randn(2, 3, 4)

        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

        # test not reparameterized
        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

        # given has lower rank than params, broadcasted to match param
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            for val in (0., 1., -1.):
                assert_allclose(T.random.normal_log_pdf(
                    T.float_scalar(val), mean_t, logstd_t),
                                log_prob(val),
                                rtol=1e-4)

        # dtype mismatch
        with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
            _ = T.random.normal(T.as_tensor(mean, T.float32),
                                T.as_tensor(std, T.float64))

        # check numerics
        mean_t = T.as_tensor(mean)
        std_t = T.zeros_like(mean_t)
        logstd_t = T.as_tensor(T.log(std_t))
        t = T.random.normal(mean_t, std_t)
        with pytest.raises(Exception,
                           match='Infinity or NaN value encountered'):
            _ = T.random.normal_log_pdf(t,
                                        mean_t,
                                        logstd_t,
                                        validate_tensors=True)
Пример #8
0
    def test_embedding(self):
        n_channels = 3
        n_embeddings = n_samples

        for spatial_ndims in (0, 1, 2, 3):
            w_shape = make_conv_shape([], n_channels, [4, 5,
                                                       6][:spatial_ndims])
            w_size = int(np.prod(w_shape))

            layer = getattr(tk.layers,
                            (f'Embedding{spatial_ndims}d' if spatial_ndims > 0
                             else 'Embedding'))(n_embeddings, w_shape)
            weight = layer.weight

            # check the weight
            self.assertEqual(T.shape(weight), [n_embeddings] + w_shape)
            reduce_axis = list(range(len(w_shape) + 1))
            reduce_axis.pop(-1 if T.IS_CHANNEL_LAST else 1)
            w_mean = np.average(T.to_numpy(weight), axis=tuple(reduce_axis))
            np.testing.assert_array_less(
                w_mean, 3. / np.sqrt(n_samples * w_size / n_channels))

            # check the output
            layer = jit_compile(layer)
            weight_array = T.to_numpy(weight)

            for in_shape in ([7], [7, 8]):
                indices = T.random.randint(0, n_samples, in_shape)
                indices = T.concat([indices, indices[:3]], axis=0)

                # check the output
                output = layer(indices)
                assert_allclose(output, T.embedding(weight, indices))

                # check the grad
                if spatial_ndims in (0, 1):
                    out_sum = T.reduce_sum(output**2)
                    [grad] = T.grad([out_sum], [weight])
                    expected_grad = np.zeros(T.shape(weight))
                    for idx in T.to_numpy(indices).reshape([-1]):
                        expected_grad[idx] += 2. * weight_array[idx]
                    assert_allclose(grad, expected_grad)

            # test the constructor error
            if spatial_ndims > 0:
                with pytest.raises(
                        ValueError,
                        match=f'`embedding_size` must be a int list '
                        f'with {spatial_ndims + 1} elements'):
                    _ = getattr(tk.layers,
                                f'Embedding{spatial_ndims}d')(n_embeddings,
                                                              w_shape[:-1])

        # test no grad
        layer = Embedding(n_embeddings, n_channels, freeze=True)
        weight = layer.weight
        self.assertEqual(T.shape(weight), [n_embeddings, n_channels])

        layer = jit_compile(layer)
        indices = T.random.randint(0, n_samples, [7, 8])
        output = layer(indices)
        assert_allclose(output, T.embedding(weight, indices))

        out_sum = T.reduce_sum(output**2)
        try:
            [grad] = T.grad([out_sum], [weight])
        except Exception:
            pass
        else:
            self.assertTrue(T.is_null_grad(weight, grad))

        # test errors
        with pytest.raises(ValueError,
                           match='`embedding_size` must not be empty'):
            _ = Embedding(n_embeddings, [])