Esempio n. 1
0
 def next_state(self, state, action, t):
     A, Q = self.get_dynamics()
     leading_dim = T.shape(state)[:-1]
     state_action = T.concatenate([state, action], -1)
     return stats.Gaussian([
         T.tile(Q[t][None], T.concatenate([leading_dim, [1, 1]])),
         T.einsum('ab,nb->na', A[t], state_action)
     ])
Esempio n. 2
0
 def forward(self, q_Xt, q_At):
     Xt, At = q_Xt.expected_value(), q_At.expected_value()
     batch_size = T.shape(Xt)[0]
     XAt = T.concatenate([Xt, At], -1)
     A, Q = self.get_dynamics()
     p_Xt1 = stats.Gaussian([
         T.tile(Q[None], [batch_size, 1, 1, 1]),
         T.einsum('nhs,hxs->nhx', XAt, A)
     ])
     return p_Xt1
Esempio n. 3
0
 def sufficient_statistics(self):
     A, Q = self.get_dynamics()
     Q_inv = T.matrix_inverse(Q)
     Q_inv_A = T.matrix_solve(Q, A)
     return [
         -0.5 * Q_inv,
         Q_inv_A,
         -0.5 * T.einsum('hba,hbc->hac', A, Q_inv_A),
         -0.5 * T.logdet(Q)
     ]
Esempio n. 4
0
    def posterior_dynamics(self,
                           q_X,
                           q_A,
                           data_strength=1.0,
                           max_iter=200,
                           tol=1e-3):
        if self.smooth:
            if self.time_varying:
                prior_dyn = stats.MNIW(
                    self.A_variational.get_parameters('natural'), 'natural')
            else:
                natparam = self.A_variational.get_parameters('natural')
                prior_dyn = stats.MNIW([
                    T.tile(natparam[0][None], [self.horizon - 1, 1, 1]),
                    T.tile(natparam[1][None], [self.horizon - 1, 1, 1]),
                    T.tile(natparam[2][None], [self.horizon - 1, 1, 1]),
                    T.tile(natparam[3][None], [self.horizon - 1]),
                ], 'natural')
            state_prior = stats.Gaussian([T.eye(self.ds), T.zeros(self.ds)])
            aaT, a = stats.Gaussian.unpack(
                q_A.expected_sufficient_statistics())
            aaT, a = aaT[:, :-1], a[:, :-1]
            ds, da = self.ds, self.da

            initial_dyn_natparam = prior_dyn.get_parameters('natural')
            initial_X_natparam = stats.LDS(
                (self.sufficient_statistics(), state_prior, q_X,
                 q_A.expected_value(), self.horizon),
                'internal').get_parameters('natural')

            def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo):
                q_X_ = stats.LDS(q_X_natparam, 'natural')
                ess = q_X_.expected_sufficient_statistics()
                batch_size = T.shape(ess)[0]
                yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds]
                xxT = ess[..., :-1, :ds, :ds]
                yxT = ess[..., :-1, ds:2 * ds, :ds]
                x = ess[..., :-1, -1, :ds]
                y = ess[..., :-1, -1, ds:2 * ds]
                xaT = T.outer(x, a)
                yaT = T.outer(y, a)
                xaxaT = T.concatenate([
                    T.concatenate([xxT, xaT], -1),
                    T.concatenate([T.matrix_transpose(xaT), aaT], -1),
                ], -2)
                ess = [
                    yyT,
                    T.concatenate([yxT, yaT], -1), xaxaT,
                    T.ones([batch_size, self.horizon - 1])
                ]
                q_dyn_natparam = [
                    T.sum(a, [0]) * data_strength + b
                    for a, b in zip(ess, initial_dyn_natparam)
                ]
                q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural')
                q_stats = q_dyn_.expected_sufficient_statistics()
                p_X = stats.LDS((q_stats, state_prior, None,
                                 q_A.expected_value(), self.horizon))
                q_X_ = stats.LDS((q_stats, state_prior, q_X,
                                  q_A.expected_value(), self.horizon))
                elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) +
                        T.sum(stats.kl_divergence(q_dyn_, prior_dyn)))
                return i + 1, q_dyn_.get_parameters(
                    'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo

            def cond(i, _, __, prev_elbo, curr_elbo):
                with T.core.control_dependencies([T.core.print(curr_elbo)]):
                    prev_elbo = T.core.identity(prev_elbo)
                return T.logical_and(
                    T.abs(curr_elbo - prev_elbo) > tol, i < max_iter)

            result = T.while_loop(
                cond,
                em, [
                    0, initial_dyn_natparam, initial_X_natparam,
                    T.constant(-np.inf),
                    T.constant(0.)
                ],
                back_prop=False)
            pd = stats.MNIW(result[1], 'natural')
            sigma, mu = pd.expected_value()
            q_X = stats.LDS(result[2], 'natural')
            return ((mu, sigma), pd.expected_sufficient_statistics()), (q_X,
                                                                        q_A)
        else:
            q_Xt = q_X.__class__([
                q_X.get_parameters('regular')[0][:, :-1],
                q_X.get_parameters('regular')[1][:, :-1],
            ])
            q_At = q_A.__class__([
                q_A.get_parameters('regular')[0][:, :-1],
                q_A.get_parameters('regular')[1][:, :-1],
            ])
            q_Xt1 = q_X.__class__([
                q_X.get_parameters('regular')[0][:, 1:],
                q_X.get_parameters('regular')[1][:, 1:],
            ])
            (XtAt_XtAtT, XtAt), (Xt1_Xt1T,
                                 Xt1) = self.get_statistics(q_Xt, q_At, q_Xt1)
            batch_size = T.shape(XtAt)[0]
            ess = [
                Xt1_Xt1T,
                T.einsum('nha,nhb->nhba', XtAt, Xt1), XtAt_XtAtT,
                T.ones([batch_size, self.horizon - 1])
            ]
            if self.time_varying:
                posterior = stats.MNIW([
                    T.sum(a, [0]) * data_strength + b for a, b in zip(
                        ess, self.A_variational.get_parameters('natural'))
                ], 'natural')
            else:
                posterior = stats.MNIW([
                    T.sum(a, [0]) * data_strength + b[None] for a, b in zip(
                        ess, self.A_variational.get_parameters('natural'))
                ], 'natural')
            Q, A = posterior.expected_value()
            return (A, Q), q_X
Esempio n. 5
0
 def kl_gradients(self, q_X, q_A, _, num_data):
     if self.smooth:
         ds = self.ds
         ess = q_X.expected_sufficient_statistics()
         yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds]
         xxT = ess[..., :-1, :ds, :ds]
         yxT = ess[..., :-1, ds:2 * ds, :ds]
         aaT, a = stats.Gaussian.unpack(
             q_A.expected_sufficient_statistics())
         aaT, a = aaT[:, :-1], a[:, :-1]
         x = ess[..., :-1, -1, :ds]
         y = ess[..., :-1, -1, ds:2 * ds]
         xaT = T.outer(x, a)
         yaT = T.outer(y, a)
         xaxaT = T.concatenate([
             T.concatenate([xxT, xaT], -1),
             T.concatenate([T.matrix_transpose(xaT), aaT], -1),
         ], -2)
         batch_size = T.shape(ess)[0]
         num_batches = T.to_float(num_data) / T.to_float(batch_size)
         ess = [
             yyT,
             T.concatenate([yxT, yaT], -1), xaxaT,
             T.ones([batch_size, self.horizon - 1])
         ]
     else:
         q_Xt = q_X.__class__([
             q_X.get_parameters('regular')[0][:, :-1],
             q_X.get_parameters('regular')[1][:, :-1],
         ])
         q_At = q_A.__class__([
             q_A.get_parameters('regular')[0][:, :-1],
             q_A.get_parameters('regular')[1][:, :-1],
         ])
         q_Xt1 = q_X.__class__([
             q_X.get_parameters('regular')[0][:, 1:],
             q_X.get_parameters('regular')[1][:, 1:],
         ])
         (XtAt_XtAtT, XtAt), (Xt1_Xt1T,
                              Xt1) = self.get_statistics(q_Xt, q_At, q_Xt1)
         batch_size = T.shape(XtAt)[0]
         num_batches = T.to_float(num_data) / T.to_float(batch_size)
         ess = [
             Xt1_Xt1T,
             T.einsum('nha,nhb->nhba', XtAt, Xt1), XtAt_XtAtT,
             T.ones([batch_size, self.horizon - 1])
         ]
     if self.time_varying:
         ess = [
             T.sum(ess[0], [0]),
             T.sum(ess[1], [0]),
             T.sum(ess[2], [0]),
             T.sum(ess[3], [0]),
         ]
     else:
         ess = [
             T.sum(ess[0], [0, 1]),
             T.sum(ess[1], [0, 1]),
             T.sum(ess[2], [0, 1]),
             T.sum(ess[3], [0, 1]),
         ]
     return [
         -(a + num_batches * b - c) / T.to_float(num_data)
         for a, b, c in zip(
             self.A_prior.get_parameters('natural'),
             ess,
             self.A_variational.get_parameters('natural'),
         )
     ]
Esempio n. 6
0
 def evaluate(self, states):
     return (
         0.5 * T.einsum('nia,ab,nib->ni', states, self.C, states)
         + T.einsum('nia,a->ni', states, self.c)
     )
Esempio n. 7
0
    T.outer(X, X),
    X,
    T.ones([batch_size]),
    T.ones([batch_size]),
])
x_stats = Gaussian.pack([
    T.outer(X, X),
    X,
])
theta_cmessage = q_theta.expected_sufficient_statistics()

num_batches = N / T.to_float(batch_size)
nat_scale = 10.0

parent_z = q_pi.expected_sufficient_statistics()[None]
new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + parent_z
q_z = Categorical(new_z - T.logsumexp(new_z, -1)[..., None],
                  parameter_type='natural')
p_z = Categorical(parent_z - T.logsumexp(parent_z, -1),
                  parameter_type='natural')
l_z = T.sum(kl_divergence(q_z, p_z))
z_pmessage = q_z.expected_sufficient_statistics()

pi_stats = T.sum(z_pmessage, 0)
parent_pi = p_pi.get_parameters('natural')
current_pi = q_pi.get_parameters('natural')
pi_gradient = nat_scale / N * (parent_pi + num_batches * pi_stats - current_pi)
l_pi = T.sum(kl_divergence(q_pi, p_pi))

theta_stats = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage)
parent_theta = p_theta.get_parameters('natural')[None]
Esempio n. 8
0
    X,
    T.ones(N),
    T.ones(N),
])
x_stats = Gaussian.pack([
    T.outer(X, X),
    X,
])
theta_cmessage = q_theta.expected_sufficient_statistics()

new_pi = p_pi.get_parameters('natural') + T.sum(z_pmessage, 0)
parent_pi = p_pi.get_parameters('natural')
pi_update = T.assign(q_pi.get_parameters('natural'), new_pi)
l_pi = T.sum(kl_divergence(q_pi, p_pi))

new_theta = T.einsum('ia,ibc->abc', z_pmessage,
                     x_tmessage) + p_theta.get_parameters('natural')[None]
parent_theta = p_theta.get_parameters('natural')
theta_update = T.assign(q_theta.get_parameters('natural'), new_theta)
l_theta = T.sum(kl_divergence(q_theta, p_theta))

parent_z = q_pi.expected_sufficient_statistics()[None]
new_z = T.einsum('iab,jab->ij', x_tmessage,
                 theta_cmessage) + q_pi.expected_sufficient_statistics()[None]
new_z = new_z - T.logsumexp(new_z, -1)[..., None]
z_update = T.assign(q_z.get_parameters('natural'), new_z)
l_z = T.sum(kl_divergence(q_z, Categorical(parent_z,
                                           parameter_type='natural')))

x_param = T.einsum('ia,abc->ibc', q_z.expected_sufficient_statistics(),
                   q_theta.expected_sufficient_statistics())
q_x = Gaussian(x_param, parameter_type='natural')
Esempio n. 9
0
yt, yt1 = yt.reshape([-1, D]), yt1.reshape([-1, D])

transition_net = Tanh(D, 500) >> Tanh(500) >> nn.Gaussian(D)
transition_net.initialize()

rec_net = Tanh(D, 500) >> Tanh(500) >> nn.Gaussian(D)
rec_net.initialize()

Yt = T.placeholder(T.floatx(), [None, D])
Yt1 = T.placeholder(T.floatx(), [None, D])
batch_size = T.shape(Yt)[0]
num_batches = N / T.to_float(batch_size)

Yt_message = Gaussian.pack([
    T.tile(T.eye(D)[None] * noise, [batch_size, 1, 1]),
    T.einsum('ab,ib->ia',
             T.eye(D) * noise, Yt)
])
Yt1_message = Gaussian.pack([
    T.tile(T.eye(D)[None] * noise, [batch_size, 1, 1]),
    T.einsum('ab,ib->ia',
             T.eye(D) * noise, Yt1)
])
transition = Gaussian(transition_net(Yt)).expected_value()

max_iter = 1000
tol = 1e-5


def cond(i, prev_elbo, elbo, qxt, qxt1):
    return T.logical_and(i < max_iter, abs(prev_elbo - elbo) >= tol)
Esempio n. 10
0
 def _sample(self, num_samples):
     tensor_samples = self.tensor.sample(num_samples)
     weight_samples = self.weights.sample(num_samples)
     return T.einsum('ia,iab->ib', weight_samples, tensor_samples)
Esempio n. 11
0
num_batches = T.to_float(N / batch_size)

with T.initialization('xavier'):
    # stats_net = Relu(D + 1, 20) >> Relu(20) >> GaussianLayer(D)
    stats_net = GaussianLayer(D + 1, D)
net_out = stats_net(T.concat([x, y[..., None]], -1))
stats = T.sum(net_out.get_parameters('natural'), 0)[None]

natural_gradient = (p_w.get_parameters('natural') + num_batches * stats -
                    q_w.get_parameters('natural')) / N
next_w = Gaussian(q_w.get_parameters('natural') + lr * natural_gradient,
                  parameter_type='natural')

l_w = kl_divergence(q_w, p_w)[0]

p_y = Bernoulli(T.sigmoid(T.einsum('jw,iw->ij', next_w.expected_value(), x)))
l_y = T.sum(p_y.log_likelihood(y[..., None]))
elbo = l_w + l_y

nat_op = T.assign(q_w.get_parameters('natural'),
                  next_w.get_parameters('natural'))
grad_op = tf.train.RMSPropOptimizer(1e-4).minimize(-elbo)
train_op = tf.group(nat_op, grad_op)
sess = T.interactive_session()

predictions = T.cast(
    T.sigmoid(T.einsum('jw,iw->i', q_w.expected_value(), T.to_float(X))) + 0.5,
    np.int32)
accuracy = T.mean(
    T.to_float(T.equal(predictions, T.constant(Y.astype(np.int32)))))