예제 #1
0
 def __init__(self, shape, seed=0):
     super(Sampling, self).__init__()
     self.gum = msd.Gumbel(np.array([0.0]),
                           np.array([1.0, 2.0, 3.0]),
                           dtype=dtype.float32,
                           seed=seed)
     self.shape = shape
예제 #2
0
 def __init__(self):
     super(KL, self).__init__()
     self.gum = msd.Gumbel(np.array([0.0]),
                           np.array([1.0, 2.0]),
                           dtype=dtype.float32)
예제 #3
0
 def __init__(self):
     super(LogProb, self).__init__()
     self.gum = msd.Gumbel(np.array([0.0]),
                           np.array([[1.0], [2.0]]),
                           dtype=dtype.float32)
예제 #4
0
 def __init__(self):
     super(CrossEntropy, self).__init__()
     self.gum = msd.Gumbel(np.array([0.0]),
                           np.array([[1.0], [2.0]]),
                           dtype=dtype.float32)
예제 #5
0
 def __init__(self):
     super(GumbelProb, self).__init__()
     self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32)
예제 #6
0
 def __init__(self):
     super(KL, self).__init__()
     self.gumbel = msd.Gumbel(3.0, 4.0)
예제 #7
0
def test_scale():
    with pytest.raises(ValueError):
        msd.Gumbel(0., 0.)
    with pytest.raises(ValueError):
        msd.Gumbel(0., -1.)
예제 #8
0
def test_arguments():
    """
    args passing during initialization.
    """
    l = msd.Gumbel([3.0], [4.0], dtype=dtype.float32)
    assert isinstance(l, msd.Distribution)
예제 #9
0
def test_seed():
    with pytest.raises(TypeError):
        msd.Gumbel(0., 1., seed='seed')
예제 #10
0
def test_name():
    with pytest.raises(TypeError):
        msd.Gumbel(0., 1., name=1.0)
예제 #11
0
def test_type():
    with pytest.raises(TypeError):
        msd.Gumbel(0., 1., dtype=dtype.int32)
예제 #12
0
def test_gumbel_shape_errpr():
    """
    Invalid shapes.
    """
    with pytest.raises(ValueError):
        msd.Gumbel([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
예제 #13
0
 def __init__(self):
     super(GumbelConstruct, self).__init__()
     self.gumbel = msd.Gumbel(3.0, 4.0)