Exemple #1
0
    def test_layer_to_device(self):
        for device in [None, T.CPU_DEVICE]:
            layer = ResBlock2d(3, 4, kernel_size=2, device=device)
            for param in tk.layers.iter_parameters(layer):
                self.assertEqual(T.get_device(param), device or T.current_device())

            for device2 in [None, T.CPU_DEVICE]:
                layer2 = tk.layers.layer_to_device(layer, device=device2)
                for param in tk.layers.iter_parameters(layer2):
                    self.assertEqual(T.get_device(param), device2 or T.current_device())
Exemple #2
0
def do_check_log_prob(given, batch_ndims, Z_log_prob_fn, np_log_prob):
    # test log_prob
    for group_ndims in range(0, batch_ndims + 1):
        out = Z_log_prob_fn(given, group_ndims=group_ndims)
        assert (T.get_device(out) == T.get_device(given))
        assert_allclose(
            out,
            np.sum(np_log_prob, axis=tuple(range(-group_ndims, 0))),
            rtol=1e-2,
            atol=1e-5,
        )
    with pytest.raises(Exception, match='`group_ndims` is too large'):
        _ = Z_log_prob_fn(given, group_ndims=batch_ndims + 1)
Exemple #3
0
    def test_randint(self):
        for low, high in [(0, 5), (-3, 4)]:
            for dtype, device in product(number_dtypes, [None, T.CPU_DEVICE]):
                # test sample dtype and shape
                t = T.random.randint(low=low,
                                     high=high,
                                     shape=[n_samples, 2, 3, 4],
                                     dtype=dtype,
                                     device=device)
                self.assertEqual(T.get_dtype(t), dtype)
                self.assertEqual(T.get_device(t), device or T.current_device())
                self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])
                x = T.to_numpy(t).astype(np.int32)

                # test sample value range
                r = list(range(low, high))
                self.assertTrue(
                    all((int(v) in r) for v in set(x.reshape([-1]).tolist())))

                # test the prob of each value
                p = 1. / len(r)
                size = 1. * np.size(x)
                for i in r:
                    self.assertLessEqual(
                        abs(np.sum(x == i) / size - p),
                        5. * np.sqrt(p * (1. - p)) / np.sqrt(size))

        with pytest.raises(Exception, match='`low` < `high` does not hold'):
            _ = T.random.randint(low=2, high=1, shape=[2, 3, 4])
Exemple #4
0
    def test_shuffle_and_random_permutation(self):
        x = np.arange(24).reshape([2, 3, 4])

        # shuffle
        for axis in range(-len(x.shape), len(x.shape)):
            equal_count = 0
            for k in range(100):
                t = T.random.shuffle(T.from_numpy(x), axis=axis)
                if np.all(np.equal(T.to_numpy(t), x)):
                    equal_count += 1
                assert_equal(np.sort(T.to_numpy(t), axis=axis), x)
            self.assertLess(equal_count, 100)

        # random_permutation
        for dtype, device in product(int_dtypes, [None, T.CPU_DEVICE]):
            for n in [0, 1, 5]:
                x = np.arange(n)
                equal_count = 0
                for k in range(100):
                    t = T.random.random_permutation(n,
                                                    dtype=dtype,
                                                    device=device)
                    self.assertEqual(T.get_dtype(t), dtype)
                    self.assertEqual(T.get_device(t), device
                                     or T.current_device())
                    if np.all(np.equal(T.to_numpy(t), x)):
                        equal_count += 1
                    assert_equal(np.sort(T.to_numpy(t)), x)
                if n > 1:
                    self.assertLess(equal_count, 100)
    def test_TensorStream(self):
        x = np.random.randn(17, 3, 4)
        y = np.random.randn(17, 5)
        source = mltk.DataStream.arrays(
            [x, y], batch_size=3, random_state=np.random.RandomState())

        # test tensor stream
        for device in [None, T.CPU_DEVICE]:
            stream = tk.utils.as_tensor_stream(source, device=device)
            self.assertIsInstance(stream, tk.utils.TensorStream)
            self.assertEqual(stream.device, device or T.current_device())

            for attr in ('batch_size', 'array_count', 'data_shapes',
                         'data_length', 'random_state'):
                self.assertEqual(getattr(stream, attr), getattr(source, attr))

            out_x, out_y = stream.get_arrays()
            assert_allclose(out_x, x, rtol=1e-4, atol=1e-6)
            assert_allclose(out_y, y, rtol=1e-4, atol=1e-6)

            for batch_x, batch_y in stream:
                self.assertIsInstance(batch_x, T.Tensor)
                self.assertEqual(T.get_device(batch_x), device or T.current_device())
                self.assertIsInstance(batch_y, T.Tensor)
                self.assertEqual(T.get_device(batch_y), device or T.current_device())

            # test copy
            for device2 in [None, T.CPU_DEVICE]:
                kwargs = {'device': device2} if device2 is not None else {}
                stream2 = stream.copy(**kwargs)
                self.assertIs(stream2.source, stream.source)
                self.assertEqual(stream2.device, device2 or stream.device)

        # test prefetch
        stream = tk.utils.as_tensor_stream(source, prefetch=3)
        self.assertIsInstance(stream.source, tk.utils.TensorStream)

        out_x, out_y = stream.get_arrays()
        assert_allclose(out_x, x, rtol=1e-4, atol=1e-6)
        assert_allclose(out_y, y, rtol=1e-4, atol=1e-6)
Exemple #6
0
    def test_rand(self):
        for dtype, device in product(float_dtypes, [None, T.CPU_DEVICE]):
            # test sample dtype and shape
            t = T.random.rand([n_samples, 2, 3, 4], dtype=dtype, device=device)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), device or T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(0.5 - x_mean),
                (3. * np.sqrt(1. / 12) / np.sqrt(n_samples) *
                 np.ones_like(x_mean)))
Exemple #7
0
        def do_test_sample(is_one_hot: bool, n_z: Optional[int],
                           dtype: Optional[str], float_dtype: str):
            probs_t = T.as_tensor(probs, dtype=float_dtype)
            logits_t = T.as_tensor(logits, dtype=float_dtype)
            value_shape = [n_classes] if is_one_hot else []

            if dtype is not None:
                expected_dtype = dtype
            else:
                expected_dtype = T.int32 if is_one_hot else T.categorical_dtype

            # sample
            sample_shape = [n_z] if n_z is not None else []
            kwargs = {'dtype': dtype} if dtype else {}
            t = (T.random.one_hot_categorical
                 if is_one_hot else T.random.categorical)(probs_t,
                                                          n_samples=n_z,
                                                          **kwargs)
            self.assertEqual(T.get_dtype(t), expected_dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t),
                             sample_shape + [2, 3, 4] + value_shape)

            # check values
            x = T.to_numpy(t)
            if is_one_hot:
                self.assertEqual(set(x.flatten().tolist()), {0, 1})
            else:
                if n_z is None:
                    self.assertTrue(
                        set(x.flatten().tolist()).issubset(
                            set(range(n_classes))))
                else:
                    self.assertEqual(set(x.flatten().tolist()),
                                     set(range(n_classes)))

            # check log_prob
            do_check_log_prob(
                given=t,
                batch_ndims=len(t.shape) - int(is_one_hot),
                Z_log_prob_fn=partial(
                    (T.random.one_hot_categorical_log_prob
                     if is_one_hot else T.random.categorical_log_prob),
                    logits=logits_t),
                np_log_prob=log_prob(x, probs, n_classes, is_one_hot))
Exemple #8
0
        def do_test_sample(n_z, sample_shape, float_dtype, dtype):
            probs_t = T.as_tensor(probs, dtype=float_dtype)
            logits_t = T.as_tensor(logits, dtype=float_dtype)
            t = T.random.bernoulli(probs=probs_t, n_samples=n_z, dtype=dtype)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), sample_shape + [2, 3, 4])

            # all values must be either 0 or 1
            x = T.to_numpy(t)
            self.assertEqual(set(x.flatten().tolist()), {0, 1})

            # check the log prob
            do_check_log_prob(
                given=t,
                batch_ndims=len(t.shape),
                Z_log_prob_fn=partial(T.random.bernoulli_log_prob,
                                      logits=logits_t),
                np_log_prob=log_prob(x))
Exemple #9
0
    def test_randn(self):
        for dtype, device in product(float_dtypes, [None, T.CPU_DEVICE]):
            # test sample dtype and shape
            t = T.random.randn([n_samples, 2, 3, 4],
                               dtype=dtype,
                               device=device)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), device or T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean), 3. / np.sqrt(n_samples) * np.ones_like(x_mean))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=T.random.randn_log_pdf,
                              np_log_prob=np.log(
                                  np.exp(-x**2 / 2.) / np.sqrt(2 * np.pi)))
Exemple #10
0
    def test_uniform(self):
        for low, high in [(-1., 2.), (3.5, 7.5)]:
            for dtype, device in product(float_dtypes, [None, T.CPU_DEVICE]):
                # test sample dtype and shape
                t = T.random.uniform([n_samples, 2, 3, 4],
                                     low=low,
                                     high=high,
                                     dtype=dtype,
                                     device=device)
                self.assertEqual(T.get_dtype(t), dtype)
                self.assertEqual(T.get_device(t), device or T.current_device())
                self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

                # test sample mean
                x = T.to_numpy(t)
                x_mean = np.mean(x, axis=0)
                np.testing.assert_array_less(
                    np.abs(x_mean - 0.5 * (high + low)),
                    (5. * np.sqrt((high - low)**2 / 12) / np.sqrt(n_samples) *
                     np.ones_like(x_mean)))

        with pytest.raises(Exception, match='`low` < `high` does not hold'):
            _ = T.random.uniform([2, 3, 4], low=2., high=1.)
Exemple #11
0
    def test_check_tensor_arg_types(self):
        for dtype in float_dtypes:
            # check ordinary usage: mixed floats, numbers, mutual groups
            for specified_dtype in [None, dtype]:
                e_orig = T.as_tensor([1., 2., 3.], dtype=dtype)
                f_orig = StochasticTensor(
                    T.as_tensor([4., 5., 6.], dtype=dtype), UnitNormal([]),
                    None, 0, True)
                a, [b, c], [d, e], f = check_tensor_arg_types(
                    ('a', 1.0), [('b', 2.0), ('c', None)], [('d', None),
                                                            ('e', e_orig)],
                    ('f', f_orig),
                    dtype=specified_dtype)
                for t, v in [(a, 1.0), (b, 2.0), (e, e_orig),
                             (f, f_orig.tensor)]:
                    self.assertIsInstance(t, T.Tensor)
                    self.assertEqual(T.get_dtype(t), dtype)
                    self.assertEqual(T.get_device(t), T.current_device())
                    if isinstance(v, float):
                        assert_equal(t, v)
                    else:
                        self.assertIs(t, v)

            # float dtype determined by `dtype` and `default_dtype`
            for arg_name in ('dtype', 'default_dtype'):
                [a] = check_tensor_arg_types(('a', 123.0), **{arg_name: dtype})
                self.assertIsInstance(a, T.Tensor)
                self.assertEqual(T.get_dtype(a), dtype)
                assert_equal(a, 123.0)

            # tensor dtype will ignore `default_dtype`, but checked against `dtype`.
            a_orig = T.as_tensor([1., 2., 3.], dtype=dtype)
            [a] = check_tensor_arg_types(('a', a_orig),
                                         default_dtype=T.float32)
            self.assertIs(a, a_orig)

            if dtype != T.float32:
                with pytest.raises(ValueError,
                                   match=f'`a.dtype` != `dtype`: {dtype} vs '
                                   f'{T.float32}'):
                    _ = check_tensor_arg_types(('a', a), dtype=T.float32)

            # check multiple tensors type mismatch
            if dtype != T.float32:
                a_orig = T.as_tensor([1., 2., 3.], dtype=dtype)
                b_orig = T.as_tensor([4., 5., 6.], dtype=T.float32)

                with pytest.raises(ValueError,
                                   match=f'`b.dtype` != `a.dtype`: '
                                   f'{T.float32} vs {dtype}'):
                    _ = check_tensor_arg_types(('a', a_orig), ('b', b_orig))

            # check `device` and `default_device`
            if T.current_device() != T.CPU_DEVICE:
                [a] = check_tensor_arg_types(('a', [1., 2., 3.]),
                                             device=T.CPU_DEVICE)
                self.assertEqual(T.get_device(a), T.CPU_DEVICE)

                [a] = check_tensor_arg_types(('a', [1., 2., 3.]),
                                             default_device=T.CPU_DEVICE)
                self.assertEqual(T.get_device(a), T.CPU_DEVICE)

                [a] = check_tensor_arg_types(('a', [1., 2., 3.]),
                                             device=T.CPU_DEVICE,
                                             default_device=T.current_device())
                self.assertEqual(T.get_device(a), T.CPU_DEVICE)

                a = T.as_tensor([1., 2., 3.], device=T.current_device())
                with pytest.raises(ValueError,
                                   match=f'`a.device` != `device`'):
                    _ = check_tensor_arg_types(('a', a), device=T.CPU_DEVICE)

                b = T.as_tensor([1., 2., 3.], device=T.CPU_DEVICE)
                with pytest.raises(ValueError,
                                   match=f'`b.device` != `a.device`'):
                    _ = check_tensor_arg_types(('a', a), ('b', b))

            # check tensor cannot be None
            with pytest.raises(ValueError, match='`a` must be specified.'):
                _ = check_tensor_arg_types(('a', None))

            # check mutual group must specify exactly one tensor
            for t in [None, T.as_tensor([1., 2., 3.], dtype=dtype)]:
                with pytest.raises(ValueError,
                                   match="Either `a` or `b` must be "
                                   "specified, but not both"):
                    _ = check_tensor_arg_types([('a', t), ('b', t)])
                with pytest.raises(ValueError,
                                   match="One and exactly one of `a`, `b` and "
                                   "`c` must be specified"):
                    _ = check_tensor_arg_types([('a', t), ('b', t), ('c', t)])
Exemple #12
0
        def do_test_sample(bin_size: float, min_val: Optional[float],
                           max_val: Optional[float], discretize_sample: bool,
                           discretize_given: bool, biased_edges: bool,
                           reparameterized: bool, n_samples: Optional[int],
                           validate_tensors: bool, dtype: str):
            mean_t = T.as_tensor(mean, dtype=dtype)
            log_scale_t = T.as_tensor(log_scale, dtype=dtype)
            value_shape = T.get_broadcast_shape(T.shape(mean_t),
                                                T.shape(log_scale_t))

            # sample
            sample_shape = [n_samples] if n_samples is not None else []
            u, t = get_samples(
                mean_t,
                log_scale_t,
                n_samples=n_samples,
                bin_size=bin_size,
                min_val=min_val,
                max_val=max_val,
                discretize=discretize_sample,
                reparameterized=reparameterized,
                epsilon=T.EPSILON,
                validate_tensors=validate_tensors,
            )
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), sample_shape + value_shape)

            # check values
            this_mean = mean.astype(dtype)
            this_log_scale = log_scale.astype(dtype)
            expected_t = naive_discretized_logistic_sample(
                u,
                this_mean,
                this_log_scale,
                bin_size,
                min_val,
                max_val,
                discretize_sample=discretize_sample,
            )
            assert_allclose(t, expected_t, rtol=1e-4, atol=1e-6)

            # check log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(t.shape),
                              Z_log_prob_fn=partial(
                                  T.random.discretized_logistic_log_prob,
                                  mean=mean_t,
                                  log_scale=log_scale_t,
                                  bin_size=bin_size,
                                  min_val=min_val,
                                  max_val=max_val,
                                  biased_edges=biased_edges,
                                  discretize=discretize_given,
                                  validate_tensors=validate_tensors,
                              ),
                              np_log_prob=naive_discretized_logistic_pdf(
                                  x=T.to_numpy(t),
                                  mean=this_mean,
                                  log_scale=this_log_scale,
                                  bin_size=bin_size,
                                  min_val=min_val,
                                  max_val=max_val,
                                  biased_edges=biased_edges,
                                  discretize_given=discretize_given,
                              ))
Exemple #13
0
        def do_test(low, high, dtype):
            # test(n_samples=n_samples)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t,
                                          std_t,
                                          n_samples=n_samples,
                                          low=low,
                                          high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test(n_samples=None)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t, std_t, low=low, high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test reparameterized
            w = np.random.randn(2, 3, 4)

            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

            # test not reparameterized
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(
                mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

            # given has lower rank than params, broadcasted to match param
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            assert_allclose(T.random.truncated_normal_log_pdf(
                T.float_scalar(0.),
                mean_t,
                std_t,
                logstd_t,
                low=low,
                high=high,
                log_zero=log_zero),
                            log_prob(0., low=low, high=high),
                            rtol=1e-4)

            # dtype mismatch
            with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
                _ = T.random.truncated_normal(T.as_tensor(mean, T.float32),
                                              T.as_tensor(std, T.float64),
                                              low=low,
                                              high=high)

            # check numerics
            mean_t = T.as_tensor(mean)
            std_t = T.zeros_like(mean_t)
            logstd_t = T.as_tensor(T.log(std_t))
            t = T.random.normal(mean_t, std_t)
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = T.random.truncated_normal_log_pdf(t,
                                                      mean_t,
                                                      std_t,
                                                      logstd_t,
                                                      validate_tensors=True)
Exemple #14
0
    def test_normal(self):
        mean = np.random.randn(2, 3, 4)
        logstd = np.random.randn(3, 4)
        std = np.exp(logstd)

        def log_prob(given):
            # np.log(np.exp(-(given - mean) ** 2 / (2. * std ** 2)) /
            #        (np.sqrt(2 * np.pi) * std))
            return (-(given - mean)**2 * (0.5 * np.exp(-2. * logstd)) -
                    np.log(np.sqrt(2 * np.pi)) - logstd)

        # test n_samples by manual expanding the param shape
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.cast(T.expand(T.as_tensor(mean), [n_samples, 2, 3, 4]),
                            dtype)
            std_t = T.cast(T.expand(T.as_tensor(std), [n_samples, 1, 3, 4]),
                           dtype)
            logstd_t = T.cast(
                T.expand(T.as_tensor(logstd), [n_samples, 1, 3, 4]), dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test with n_samples
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t, n_samples=n_samples)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test no n_samples
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test log_prob
            x = T.to_numpy(t)
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test reparameterized
        w = np.random.randn(2, 3, 4)

        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

        # test not reparameterized
        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

        # given has lower rank than params, broadcasted to match param
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            for val in (0., 1., -1.):
                assert_allclose(T.random.normal_log_pdf(
                    T.float_scalar(val), mean_t, logstd_t),
                                log_prob(val),
                                rtol=1e-4)

        # dtype mismatch
        with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
            _ = T.random.normal(T.as_tensor(mean, T.float32),
                                T.as_tensor(std, T.float64))

        # check numerics
        mean_t = T.as_tensor(mean)
        std_t = T.zeros_like(mean_t)
        logstd_t = T.as_tensor(T.log(std_t))
        t = T.random.normal(mean_t, std_t)
        with pytest.raises(Exception,
                           match='Infinity or NaN value encountered'):
            _ = T.random.normal_log_pdf(t,
                                        mean_t,
                                        logstd_t,
                                        validate_tensors=True)