def kl_divergence(cur_dist, old_dist): assert isinstance(cur_dist, type(old_dist)) if isinstance(cur_dist, Normal): vr = (cur_dist.sigma / old_dist.sigma)**2 t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2 return 0.5*(vr+t1-1-jt.safe_log(vr)) if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical): t = cur_dist.probs * (cur_dist.logits-old_dist.logits) return t.sum(-1) if isinstance(cur_dist, Uniform): res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low)) if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high: res = math.inf return res if isinstance(cur_dist, Geometric): return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
def __init__(self, p=None, logits=None): assert (p is not None) or (logits is not None) assert 0 < p and p < 1 if p is None: self.prob = jt.sigmoid(logits) self.logits = logits elif logits is None: self.prob = p self.logits = -jt.safe_log(1. / p - 1)
def __init__(self, probs=None, logits=None): assert not (probs is None and logits is None) if probs is None: # cannot align to pytorch probs = jt.sigmoid(logits) probs = probs / probs.sum(-1, True) if logits is None: logits = jt.safe_log(probs) with jt.no_grad(): self.probs = probs self.logits = logits self.cum_probs = simple_presum(self.probs) self.cum_probs_l = self.cum_probs[..., :-1] self.cum_probs_r = self.cum_probs[..., 1:]
def log_prob(self, x): var = self.sigma**2 log_scale = jt.safe_log(self.sigma) return -((x - self.mu)**2) / (2 * var) - log_scale - np.log( np.sqrt(2 * np.pi))
def log_prob(self, x): a = self.probs.ndim b = x.ndim indexes = tuple(f'i{i}' for i in range(b - a + 1, b)) indexes = indexes + (x, ) return jt.safe_log(self.probs).getitem(indexes)
def log_prob(self, x): return x * jt.safe_log(-self.prob + 1) + jt.safe_log(self.prob)
def sample(self, sample_shape): u = jt.rand(sample_shape) return (jt.safe_log(u) / (jt.safe_log(-self.probs + 1))).floor()
def entropy(self): return jt.safe_log(self.high - self.low)
def log_prob(self, x): if x < self.low or x >= self.high: return math.inf return -jt.safe_log(self.high - self.low)
def entropy(self): return 0.5 + 0.5 * np.log(2 * np.pi) + jt.safe_log(self.sigma)