def test_ndims_equals_zero_and_negative(self): normal = Normal(mean=tf.zeros([3, 4]), logstd=0.) self.assertIs(normal.batch_ndims_to_value(0), normal) self.assertIs(normal.expand_value_ndims(0), normal) with pytest.raises(ValueError, match='`ndims` must be non-negative integers'): _ = normal.batch_ndims_to_value(-1) with pytest.raises(ValueError, match='`ndims` must be non-negative integers'): _ = normal.expand_value_ndims(-1)
def test_ndims_exceed_limit(self): normal = Normal(mean=tf.zeros([3, 4]), logstd=0.) with pytest.raises(ValueError, match='`distribution.batch_shape.ndims` ' 'is less then `ndims`'): _ = normal.expand_value_ndims(3)