def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1]) precision = ops.cholesky_inverse(scale_tril.data) scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2), scale_tril.inputs) log_prob = -0.5 * scale_diag.shape[0] * math.log( 2 * math.pi) - ops.log(scale_diag).sum() inputs = scale_tril.inputs.copy() var = gensym('value') inputs[var] = Reals[scale_diag.shape[0]] gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def _log_det_tri(x): return ops.log(ops.diagonal(x, -1, -2)).sum(-1)