Exemple #1
0
def kl_divergence(cur_dist, old_dist):
    assert isinstance(cur_dist, type(old_dist))
    if isinstance(cur_dist, Normal):
        vr = (cur_dist.sigma / old_dist.sigma)**2
        t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
        return 0.5*(vr+t1-1-jt.safe_log(vr))
    if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
        t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
        return t.sum(-1)
    if isinstance(cur_dist, Uniform):
        res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
        if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
            res = math.inf
        return res
    if isinstance(cur_dist, Geometric):
        return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
Exemple #2
0
 def __init__(self, p=None, logits=None):
     assert (p is not None) or (logits is not None)
     assert 0 < p and p < 1
     if p is None:
         self.prob = jt.sigmoid(logits)
         self.logits = logits
     elif logits is None:
         self.prob = p
         self.logits = -jt.safe_log(1. / p - 1)
Exemple #3
0
 def __init__(self, probs=None, logits=None):
     assert not (probs is None and logits is None)
     if probs is None:
         # cannot align to pytorch
         probs = jt.sigmoid(logits)
     probs = probs / probs.sum(-1, True)
     if logits is None:
         logits = jt.safe_log(probs)
     with jt.no_grad():
         self.probs = probs
         self.logits = logits
         self.cum_probs = simple_presum(self.probs)
         self.cum_probs_l = self.cum_probs[..., :-1]
         self.cum_probs_r = self.cum_probs[..., 1:]
Exemple #4
0
 def log_prob(self, x):
     var = self.sigma**2
     log_scale = jt.safe_log(self.sigma)
     return -((x - self.mu)**2) / (2 * var) - log_scale - np.log(
         np.sqrt(2 * np.pi))
Exemple #5
0
 def log_prob(self, x):
     a = self.probs.ndim
     b = x.ndim
     indexes = tuple(f'i{i}' for i in range(b - a + 1, b))
     indexes = indexes + (x, )
     return jt.safe_log(self.probs).getitem(indexes)
Exemple #6
0
 def log_prob(self, x):
     return x * jt.safe_log(-self.prob + 1) + jt.safe_log(self.prob)
Exemple #7
0
 def sample(self, sample_shape):
     u = jt.rand(sample_shape)
     return (jt.safe_log(u) / (jt.safe_log(-self.probs + 1))).floor()
Exemple #8
0
 def entropy(self):
     return jt.safe_log(self.high - self.low)
Exemple #9
0
 def log_prob(self, x):
     if x < self.low or x >= self.high:
         return math.inf
     return -jt.safe_log(self.high - self.low)
Exemple #10
0
 def entropy(self):
     return 0.5 + 0.5 * np.log(2 * np.pi) + jt.safe_log(self.sigma)