Esempio n. 1
0
 def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo):
     q_X_ = stats.LDS(q_X_natparam, 'natural')
     ess = q_X_.expected_sufficient_statistics()
     batch_size = T.shape(ess)[0]
     yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds]
     xxT = ess[..., :-1, :ds, :ds]
     yxT = ess[..., :-1, ds:2 * ds, :ds]
     x = ess[..., :-1, -1, :ds]
     y = ess[..., :-1, -1, ds:2 * ds]
     xaT = T.outer(x, a)
     yaT = T.outer(y, a)
     xaxaT = T.concatenate([
         T.concatenate([xxT, xaT], -1),
         T.concatenate([T.matrix_transpose(xaT), aaT], -1),
     ], -2)
     ess = [
         yyT,
         T.concatenate([yxT, yaT], -1), xaxaT,
         T.ones([batch_size, self.horizon - 1])
     ]
     q_dyn_natparam = [
         T.sum(a, [0]) * data_strength + b
         for a, b in zip(ess, initial_dyn_natparam)
     ]
     q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural')
     q_stats = q_dyn_.expected_sufficient_statistics()
     p_X = stats.LDS((q_stats, state_prior, None,
                      q_A.expected_value(), self.horizon))
     q_X_ = stats.LDS((q_stats, state_prior, q_X,
                       q_A.expected_value(), self.horizon))
     elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) +
             T.sum(stats.kl_divergence(q_dyn_, prior_dyn)))
     return i + 1, q_dyn_.get_parameters(
         'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo
Esempio n. 2
0
 def map_fn(data):
     data_shape = T.shape(data)
     leading = data_shape[:-1]
     dim_in = data_shape[-1]
     flattened = T.reshape(data, [-1, dim_in])
     net_out = network(flattened)
     if isinstance(net_out, stats.GaussianScaleDiag):
         scale_diag, mu = net_out.get_parameters('regular')
         dim_out = T.shape(mu)[-1]
         return stats.GaussianScaleDiag([
             T.reshape(scale_diag, T.concatenate([leading, [dim_out]])),
             T.reshape(mu, T.concatenate([leading, [dim_out]])),
         ])
     elif isinstance(net_out, stats.Gaussian):
         sigma, mu = net_out.get_parameters('regular')
         dim_out = T.shape(mu)[-1]
         return stats.Gaussian([
             T.reshape(sigma, T.concatenate([leading, [dim_out, dim_out]])),
             T.reshape(mu, T.concatenate([leading, [dim_out]])),
         ])
     elif isinstance(net_out, stats.Bernoulli):
         params = net_out.get_parameters('natural')
         dim_out = T.shape(params)[-1]
         return stats.Bernoulli(
             T.reshape(params, T.concatenate([leading, [dim_out]])),
             'natural')
     else:
         raise Exception("Unimplemented distribution")
Esempio n. 3
0
 def next_state(self, state, action, t):
     A, Q = self.get_dynamics()
     leading_dim = T.shape(state)[:-1]
     state_action = T.concatenate([state, action], -1)
     return stats.Gaussian([
         T.tile(Q[t][None], T.concatenate([leading_dim, [1, 1]])),
         T.einsum('ab,nb->na', A[t], state_action)
     ])
Esempio n. 4
0
    def get_statistics(self, q_Xt, q_At, q_Xt1):
        Xt1_Xt1T, Xt1 = stats.Gaussian.unpack(q_Xt1.expected_sufficient_statistics())

        Xt_XtT, Xt = stats.Gaussian.unpack(q_Xt.expected_sufficient_statistics())
        At_AtT, At = stats.Gaussian.unpack(q_At.expected_sufficient_statistics())

        XtAt = T.concatenate([Xt, At], -1)
        XtAt_XtAtT = T.concatenate([
            T.concatenate([Xt_XtT, T.outer(Xt, At)], -1),
            T.concatenate([T.outer(At, Xt), At_AtT], -1),
        ], -2)
        return (XtAt_XtAtT, XtAt), (Xt1_Xt1T, Xt1)
Esempio n. 5
0
 def initialize_objective(self):
     H, ds, da = self.horizon, self.ds, self.da
     if self.time_varying:
         A = T.concatenate([T.eye(ds), T.zeros([ds, da])], -1)
         self.A = T.variable(A[None] + 1e-2 * T.random_normal([H - 1, ds, ds + da]))
         self.Q_log_diag = T.variable(T.random_normal([H - 1, ds]) + 1)
         self.Q = T.matrix_diag(T.exp(self.Q_log_diag))
     else:
         A = T.concatenate([T.eye(ds), T.zeros([ds, da])], -1)
         self.A = T.variable(A + 1e-2 * T.random_normal([ds, ds + da]))
         self.Q_log_diag = T.variable(T.random_normal([ds]) + 1)
         self.Q = T.matrix_diag(T.exp(self.Q_log_diag))
Esempio n. 6
0
 def next_state(self, state, action, t):
     state_action = T.concatenate([state, action], -1)
     sigma, delta_mu = self.network(state_action).get_parameters('regular')
     return stats.Gaussian([
         sigma,
         delta_mu + state,
     ])
Esempio n. 7
0
 def forward(self, q_Xt, q_At):
     Xt, At = q_Xt.expected_value(), q_At.expected_value()
     batch_size = T.shape(Xt)[0]
     XAt = T.concatenate([Xt, At], -1)
     A, Q = self.get_dynamics()
     p_Xt1 = stats.Gaussian([
         T.tile(Q[None], [batch_size, 1, 1, 1]),
         T.einsum('nhs,hxs->nhx', XAt, A)
     ])
     return p_Xt1
Esempio n. 8
0
 def initialize_objective(self):
     H, ds, da = self.horizon, self.ds, self.da
     if self.time_varying:
         A = T.concatenate(
             [T.eye(ds, batch_shape=[H - 1]),
              T.zeros([H - 1, ds, da])], -1)
         self.A_prior = stats.MNIW([
             2 * T.eye(ds, batch_shape=[H - 1]), A,
             T.eye(ds + da, batch_shape=[H - 1]),
             T.to_float(ds + 2) * T.ones([H - 1])
         ],
                                   parameter_type='regular')
         self.A_variational = stats.MNIW(list(
             map(
                 T.variable,
                 stats.MNIW.regular_to_natural([
                     2 * T.eye(ds, batch_shape=[H - 1]),
                     A + 1e-2 * T.random_normal([H - 1, ds, ds + da]),
                     T.eye(ds + da, batch_shape=[H - 1]),
                     T.to_float(ds + 2) * T.ones([H - 1])
                 ]))),
                                         parameter_type='natural')
     else:
         A = T.concatenate([T.eye(ds), T.zeros([ds, da])], -1)
         self.A_prior = stats.MNIW(
             [2 * T.eye(ds), A,
              T.eye(ds + da),
              T.to_float(ds + 2)],
             parameter_type='regular')
         self.A_variational = stats.MNIW(list(
             map(
                 T.variable,
                 stats.MNIW.regular_to_natural([
                     2 * T.eye(ds),
                     A + 1e-2 * T.random_normal([ds, ds + da]),
                     T.eye(ds + da),
                     T.to_float(ds + 2)
                 ]))),
                                         parameter_type='natural')
Esempio n. 9
0
 def kl_gradients(self, q_X, q_A, _, num_data):
     if self.smooth:
         ds = self.ds
         ess = q_X.expected_sufficient_statistics()
         yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds]
         xxT = ess[..., :-1, :ds, :ds]
         yxT = ess[..., :-1, ds:2 * ds, :ds]
         aaT, a = stats.Gaussian.unpack(
             q_A.expected_sufficient_statistics())
         aaT, a = aaT[:, :-1], a[:, :-1]
         x = ess[..., :-1, -1, :ds]
         y = ess[..., :-1, -1, ds:2 * ds]
         xaT = T.outer(x, a)
         yaT = T.outer(y, a)
         xaxaT = T.concatenate([
             T.concatenate([xxT, xaT], -1),
             T.concatenate([T.matrix_transpose(xaT), aaT], -1),
         ], -2)
         batch_size = T.shape(ess)[0]
         num_batches = T.to_float(num_data) / T.to_float(batch_size)
         ess = [
             yyT,
             T.concatenate([yxT, yaT], -1), xaxaT,
             T.ones([batch_size, self.horizon - 1])
         ]
     else:
         q_Xt = q_X.__class__([
             q_X.get_parameters('regular')[0][:, :-1],
             q_X.get_parameters('regular')[1][:, :-1],
         ])
         q_At = q_A.__class__([
             q_A.get_parameters('regular')[0][:, :-1],
             q_A.get_parameters('regular')[1][:, :-1],
         ])
         q_Xt1 = q_X.__class__([
             q_X.get_parameters('regular')[0][:, 1:],
             q_X.get_parameters('regular')[1][:, 1:],
         ])
         (XtAt_XtAtT, XtAt), (Xt1_Xt1T,
                              Xt1) = self.get_statistics(q_Xt, q_At, q_Xt1)
         batch_size = T.shape(XtAt)[0]
         num_batches = T.to_float(num_data) / T.to_float(batch_size)
         ess = [
             Xt1_Xt1T,
             T.einsum('nha,nhb->nhba', XtAt, Xt1), XtAt_XtAtT,
             T.ones([batch_size, self.horizon - 1])
         ]
     if self.time_varying:
         ess = [
             T.sum(ess[0], [0]),
             T.sum(ess[1], [0]),
             T.sum(ess[2], [0]),
             T.sum(ess[3], [0]),
         ]
     else:
         ess = [
             T.sum(ess[0], [0, 1]),
             T.sum(ess[1], [0, 1]),
             T.sum(ess[2], [0, 1]),
             T.sum(ess[3], [0, 1]),
         ]
     return [
         -(a + num_batches * b - c) / T.to_float(num_data)
         for a, b, c in zip(
             self.A_prior.get_parameters('natural'),
             ess,
             self.A_variational.get_parameters('natural'),
         )
     ]
Esempio n. 10
0
 def forward(self, q_Xt, q_At):
     Xt, At = q_Xt.sample()[0], q_At.sample()[0]
     return util.map_network(self.network)(T.concatenate([Xt, At], -1))