def __init__(self, shape, seed=0): super(Sampling, self).__init__() self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0, 3.0]), dtype=dtype.float32, seed=seed) self.shape = shape
def __init__(self): super(KL, self).__init__() self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0]), dtype=dtype.float32)
def __init__(self): super(LogProb, self).__init__() self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def __init__(self): super(CrossEntropy, self).__init__() self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def __init__(self): super(GumbelProb, self).__init__() self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32)
def __init__(self): super(KL, self).__init__() self.gumbel = msd.Gumbel(3.0, 4.0)
def test_scale(): with pytest.raises(ValueError): msd.Gumbel(0., 0.) with pytest.raises(ValueError): msd.Gumbel(0., -1.)
def test_arguments(): """ args passing during initialization. """ l = msd.Gumbel([3.0], [4.0], dtype=dtype.float32) assert isinstance(l, msd.Distribution)
def test_seed(): with pytest.raises(TypeError): msd.Gumbel(0., 1., seed='seed')
def test_name(): with pytest.raises(TypeError): msd.Gumbel(0., 1., name=1.0)
def test_type(): with pytest.raises(TypeError): msd.Gumbel(0., 1., dtype=dtype.int32)
def test_gumbel_shape_errpr(): """ Invalid shapes. """ with pytest.raises(ValueError): msd.Gumbel([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def __init__(self): super(GumbelConstruct, self).__init__() self.gumbel = msd.Gumbel(3.0, 4.0)