def sufficient_statistics(self): A, Q = self.get_dynamics() Q_inv = T.matrix_inverse(Q) Q_inv_A = T.matrix_solve(Q, A) return [ -0.5 * Q_inv, Q_inv_A, -0.5 * T.einsum('hba,hbc->hac', A, Q_inv_A), -0.5 * T.logdet(Q) ]
def get_stat(x, name, feed_dict={}): node = get_current_graph().get_node(x) print(x, name) if node is not None: return node.get_stat(name, feed_dict=feed_dict) if name == 'x': return x elif name == 'xxT': return T.outer(x, x) elif name == '-0.5S^-1': return -0.5 * T.matrix_inverse(x) elif name == '-0.5log|S|': return -0.5 * T.logdet(x) raise Exception()
def log_z(self, parameter_type='regular', stop_gradient=False): if parameter_type == 'regular': sigma, mu = self.get_parameters('regular', stop_gradient=stop_gradient) d = T.to_float(self.shape()[-1]) hsi, hlds = Stats.HSI(sigma), Stats.HLDS(sigma) mmT = Stats.XXT(mu) return ( - T.sum(hsi * mmT, [-1, -2]) - hlds + d / 2. * np.log(2 * np.pi) ) else: natparam = self.get_parameters('natural', stop_gradient=stop_gradient) d = T.to_float(self.shape()[-1]) J, m = natparam[Stats.XXT], natparam[Stats.X] return ( - 0.25 * (m[..., None, :]@T.matrix_inverse(J)@m[..., None])[..., 0, 0] - 0.5 * T.logdet(-2 * J) + d / 2. * np.log(2 * np.pi) )
def compute(self, A): return -0.5 * T.logdet(A)