def test_no_arguments(): """ No args passed in during initialization. """ n = nn.Normal() assert isinstance(n, nn.Distribution) b = nn.Bernoulli() assert isinstance(b, nn.Distribution)
def test_with_arguments(): """ Args passed in during initialization. """ n = nn.Normal([3.0], [4.0], dtype=dtype.float32) assert isinstance(n, nn.Distribution) b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) assert isinstance(b, nn.Distribution)
def test_basics(): """ Test mean/standard deviation and probs. """ basics = Net3() mean, sd = basics() expect_mean = [0.5, 0.5] assert (mean.asnumpy() == expect_mean).all() assert (sd.asnumpy() == expect_mean).all() b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) probs = b.probs() expect_probs = [0.7, 0.5] tol = 1e-6 assert (np.abs(probs.asnumpy() - expect_probs) < tol).all()
def __init__(self, shape, seed=0): super(Net4, self).__init__() self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) self.shape = shape
def __init__(self): super(Net3, self).__init__() self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32)
def __init__(self): super(Net2, self).__init__() self.b = nn.Bernoulli(0.7, dtype=dtype.int32)
def __init__(self): super(NormalBernoulli, self).__init__() self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) self.b = nn.Bernoulli(0.5, dtype=dtype.int32)
def __init__(self): super(BernoulliKlNoArgs, self).__init__() self.b = nn.Bernoulli(dtype=dtype.int32)
def __init__(self): super(BernoulliLogProb2, self).__init__() self.bernoulli = nn.Bernoulli(0.5)
def __init__(self): super(BernoulliProb1, self).__init__() self.bernoulli = nn.Bernoulli()
def __init__(self): super(BernoulliLogProb, self).__init__() self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32)