def test_ndims_equals_zero_and_negative(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        self.assertIs(normal.batch_ndims_to_value(0), normal)
        self.assertIs(normal.expand_value_ndims(0), normal)

        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.batch_ndims_to_value(-1)
        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.expand_value_ndims(-1)
    def test_ndims_exceed_limit(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        with pytest.raises(ValueError,
                           match='`distribution.batch_shape.ndims` '
                           'is less then `ndims`'):
            _ = normal.expand_value_ndims(3)