コード例 #1
0
ファイル: lds.py プロジェクト: yuchen8807/parasol
 def sufficient_statistics(self):
     A, Q = self.get_dynamics()
     Q_inv = T.matrix_inverse(Q)
     Q_inv_A = T.matrix_solve(Q, A)
     return [
         -0.5 * Q_inv,
         Q_inv_A,
         -0.5 * T.einsum('hba,hbc->hac', A, Q_inv_A),
         -0.5 * T.logdet(Q)
     ]
コード例 #2
0
ファイル: stats.py プロジェクト: sharadmv/nvmp
def get_stat(x, name, feed_dict={}):
    node = get_current_graph().get_node(x)
    print(x, name)
    if node is not None:
        return node.get_stat(name, feed_dict=feed_dict)
    if name == 'x':
        return x
    elif name == 'xxT':
        return T.outer(x, x)
    elif name == '-0.5S^-1':
        return -0.5 * T.matrix_inverse(x)
    elif name == '-0.5log|S|':
        return -0.5 * T.logdet(x)
    raise Exception()
コード例 #3
0
ファイル: gaussian.py プロジェクト: sharadmv/nvmp
 def log_z(self, parameter_type='regular', stop_gradient=False):
     if parameter_type == 'regular':
         sigma, mu = self.get_parameters('regular', stop_gradient=stop_gradient)
         d = T.to_float(self.shape()[-1])
         hsi, hlds = Stats.HSI(sigma), Stats.HLDS(sigma)
         mmT = Stats.XXT(mu)
         return (
             - T.sum(hsi * mmT, [-1, -2]) - hlds
             + d / 2. * np.log(2 * np.pi)
         )
     else:
         natparam = self.get_parameters('natural', stop_gradient=stop_gradient)
         d = T.to_float(self.shape()[-1])
         J, m = natparam[Stats.XXT], natparam[Stats.X]
         return (
             - 0.25 * (m[..., None, :]@T.matrix_inverse(J)@m[..., None])[..., 0, 0]
             - 0.5 * T.logdet(-2 * J)
             + d / 2. * np.log(2 * np.pi)
         )
コード例 #4
0
 def compute(self, A):
     return -0.5 * T.logdet(A)