Ejemplo n.º 1
0
 def sufficient_statistics(self):
     A, Q = self.get_dynamics()
     Q_inv = T.matrix_inverse(Q)
     Q_inv_A = T.matrix_solve(Q, A)
     return [
         -0.5 * Q_inv,
         Q_inv_A,
         -0.5 * T.einsum('hba,hbc->hac', A, Q_inv_A),
         -0.5 * T.logdet(Q)
     ]
Ejemplo n.º 2
0
def get_stat(x, name, feed_dict={}):
    node = get_current_graph().get_node(x)
    print(x, name)
    if node is not None:
        return node.get_stat(name, feed_dict=feed_dict)
    if name == 'x':
        return x
    elif name == 'xxT':
        return T.outer(x, x)
    elif name == '-0.5S^-1':
        return -0.5 * T.matrix_inverse(x)
    elif name == '-0.5log|S|':
        return -0.5 * T.logdet(x)
    raise Exception()
Ejemplo n.º 3
0
 def log_z(self, parameter_type='regular', stop_gradient=False):
     if parameter_type == 'regular':
         sigma, mu = self.get_parameters('regular', stop_gradient=stop_gradient)
         d = T.to_float(self.shape()[-1])
         hsi, hlds = Stats.HSI(sigma), Stats.HLDS(sigma)
         mmT = Stats.XXT(mu)
         return (
             - T.sum(hsi * mmT, [-1, -2]) - hlds
             + d / 2. * np.log(2 * np.pi)
         )
     else:
         natparam = self.get_parameters('natural', stop_gradient=stop_gradient)
         d = T.to_float(self.shape()[-1])
         J, m = natparam[Stats.XXT], natparam[Stats.X]
         return (
             - 0.25 * (m[..., None, :]@T.matrix_inverse(J)@m[..., None])[..., 0, 0]
             - 0.5 * T.logdet(-2 * J)
             + d / 2. * np.log(2 * np.pi)
         )
Ejemplo n.º 4
0
 def natural_to_regular(cls, natural_parameters):
     J, m = natural_parameters[Stats.XXT], natural_parameters[Stats.X]
     sigma = -0.5 * T.matrix_inverse(J)
     mu = T.matmul(sigma, m[..., None])[..., 0]
     return [sigma, mu]
Ejemplo n.º 5
0
 def compute(self, A):
     return -0.5 * T.matrix_inverse(A)