def get_dynamics(self): if self.time_varying: return self.A, self.Q else: return ( T.tile(self.A[None], [self.horizon - 1, 1, 1]), T.tile(self.Q[None], [self.horizon - 1, 1, 1]) )
def sufficient_statistics(self): if self.time_varying: return self.A_variational.expected_sufficient_statistics() else: stats = self.A_variational.expected_sufficient_statistics() return [ T.tile(stats[0][None], [self.horizon - 1, 1, 1]), T.tile(stats[1][None], [self.horizon - 1, 1, 1]), T.tile(stats[2][None], [self.horizon - 1, 1, 1]), T.tile(stats[3][None], [self.horizon - 1]), ]
def get_dynamics(self): if self.time_varying: Q, A = self.A_variational.expected_value() return ( A, Q, ) return A, Q else: Q, A = self.A_variational.expected_value() return ( T.tile(A[None], [self.horizon - 1, 1, 1]), T.tile(Q[None], [self.horizon - 1, 1, 1]), )
def _sample(self, num_samples): sigma, mu = self.natural_to_regular(self.regular_to_natural(self.get_parameters('regular'))) L = T.cholesky(sigma) sample_shape = T.concat([[num_samples], T.shape(mu)], 0) noise = T.random_normal(sample_shape) L = T.tile(L[None], T.concat([[num_samples], T.ones([T.rank(sigma)], dtype=np.int32)])) return mu[None] + T.matmul(L, noise[..., None])[..., 0]
def next_state(self, state, action, t): A, Q = self.get_dynamics() leading_dim = T.shape(state)[:-1] state_action = T.concatenate([state, action], -1) return stats.Gaussian([ T.tile(Q[t][None], T.concatenate([leading_dim, [1, 1]])), T.einsum('ab,nb->na', A[t], state_action) ])
def forward(self, q_Xt, q_At): Xt, At = q_Xt.expected_value(), q_At.expected_value() batch_size = T.shape(Xt)[0] XAt = T.concatenate([Xt, At], -1) A, Q = self.get_dynamics() p_Xt1 = stats.Gaussian([ T.tile(Q[None], [batch_size, 1, 1, 1]), T.einsum('nhs,hxs->nhx', XAt, A) ]) return p_Xt1
def posterior_dynamics(self, q_X, q_A, data_strength=1.0, max_iter=200, tol=1e-3): if self.smooth: if self.time_varying: prior_dyn = stats.MNIW( self.A_variational.get_parameters('natural'), 'natural') else: natparam = self.A_variational.get_parameters('natural') prior_dyn = stats.MNIW([ T.tile(natparam[0][None], [self.horizon - 1, 1, 1]), T.tile(natparam[1][None], [self.horizon - 1, 1, 1]), T.tile(natparam[2][None], [self.horizon - 1, 1, 1]), T.tile(natparam[3][None], [self.horizon - 1]), ], 'natural') state_prior = stats.Gaussian([T.eye(self.ds), T.zeros(self.ds)]) aaT, a = stats.Gaussian.unpack( q_A.expected_sufficient_statistics()) aaT, a = aaT[:, :-1], a[:, :-1] ds, da = self.ds, self.da initial_dyn_natparam = prior_dyn.get_parameters('natural') initial_X_natparam = stats.LDS( (self.sufficient_statistics(), state_prior, q_X, q_A.expected_value(), self.horizon), 'internal').get_parameters('natural') def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo): q_X_ = stats.LDS(q_X_natparam, 'natural') ess = q_X_.expected_sufficient_statistics() batch_size = T.shape(ess)[0] yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds] xxT = ess[..., :-1, :ds, :ds] yxT = ess[..., :-1, ds:2 * ds, :ds] x = ess[..., :-1, -1, :ds] y = ess[..., :-1, -1, ds:2 * ds] xaT = T.outer(x, a) yaT = T.outer(y, a) xaxaT = T.concatenate([ T.concatenate([xxT, xaT], -1), T.concatenate([T.matrix_transpose(xaT), aaT], -1), ], -2) ess = [ yyT, T.concatenate([yxT, yaT], -1), xaxaT, T.ones([batch_size, self.horizon - 1]) ] q_dyn_natparam = [ T.sum(a, [0]) * data_strength + b for a, b in zip(ess, initial_dyn_natparam) ] q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural') q_stats = q_dyn_.expected_sufficient_statistics() p_X = stats.LDS((q_stats, state_prior, None, q_A.expected_value(), self.horizon)) q_X_ = stats.LDS((q_stats, state_prior, q_X, q_A.expected_value(), self.horizon)) elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) + T.sum(stats.kl_divergence(q_dyn_, prior_dyn))) return i + 1, q_dyn_.get_parameters( 'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo def cond(i, _, __, prev_elbo, curr_elbo): with T.core.control_dependencies([T.core.print(curr_elbo)]): prev_elbo = T.core.identity(prev_elbo) return T.logical_and( T.abs(curr_elbo - prev_elbo) > tol, i < max_iter) result = T.while_loop( cond, em, [ 0, initial_dyn_natparam, initial_X_natparam, T.constant(-np.inf), T.constant(0.) ], back_prop=False) pd = stats.MNIW(result[1], 'natural') sigma, mu = pd.expected_value() q_X = stats.LDS(result[2], 'natural') return ((mu, sigma), pd.expected_sufficient_statistics()), (q_X, q_A) else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) (XtAt_XtAtT, XtAt), (Xt1_Xt1T, Xt1) = self.get_statistics(q_Xt, q_At, q_Xt1) batch_size = T.shape(XtAt)[0] ess = [ Xt1_Xt1T, T.einsum('nha,nhb->nhba', XtAt, Xt1), XtAt_XtAtT, T.ones([batch_size, self.horizon - 1]) ] if self.time_varying: posterior = stats.MNIW([ T.sum(a, [0]) * data_strength + b for a, b in zip( ess, self.A_variational.get_parameters('natural')) ], 'natural') else: posterior = stats.MNIW([ T.sum(a, [0]) * data_strength + b[None] for a, b in zip( ess, self.A_variational.get_parameters('natural')) ], 'natural') Q, A = posterior.expected_value() return (A, Q), q_X
# q_theta = make_variable(NIW(map(lambda x: np.array(x).astype(T.floatx()), [np.eye(D), np.random.multivariate_normal(mean=np.zeros([D]), cov=np.eye(D) * 20), 1.0, 1.0]))) num_batches = N / T.to_float(batch_size) nat_scale = 1.0 theta_stats = T.sum(x_tmessage, 0) parent_theta = p_theta.get_parameters('natural') q_theta = NIW(parent_theta + theta_stats, parameter_type='natural') sigma, mu = Gaussian(q_theta.expected_sufficient_statistics(), parameter_type='natural').get_parameters('regular') # theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats - current_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) x_param = q_theta.expected_sufficient_statistics()[None] q_x = Gaussian(T.tile(x_param, [batch_size, 1, 1]), parameter_type='natural') l_x = T.sum(q_x.log_likelihood(X)) elbo = l_theta + l_x elbos = [] l_thetas = [] l_xs = [] # natgrads = [(theta_gradient, q_theta.get_parameters('natural'))] # nat_op = tf.group(*[T.assign(b, a + b) for a, b in natgrads]) # nat_opt = tf.train.GradientDescentOptimizer(1e-2) # nat_op = nat_opt.apply_gradients([(-a, b) for a, b in natgrads]) grad_op = tf.train.AdamOptimizer(1e-2).minimize(-l_x, var_list=net.get_parameters())
yt, yt1 = data[:, :-1], data[:, 1:] yt, yt1 = yt.reshape([-1, D]), yt1.reshape([-1, D]) transition_net = Tanh(D, 500) >> Tanh(500) >> nn.Gaussian(D) transition_net.initialize() rec_net = Tanh(D, 500) >> Tanh(500) >> nn.Gaussian(D) rec_net.initialize() Yt = T.placeholder(T.floatx(), [None, D]) Yt1 = T.placeholder(T.floatx(), [None, D]) batch_size = T.shape(Yt)[0] num_batches = N / T.to_float(batch_size) Yt_message = Gaussian.pack([ T.tile(T.eye(D)[None] * noise, [batch_size, 1, 1]), T.einsum('ab,ib->ia', T.eye(D) * noise, Yt) ]) Yt1_message = Gaussian.pack([ T.tile(T.eye(D)[None] * noise, [batch_size, 1, 1]), T.einsum('ab,ib->ia', T.eye(D) * noise, Yt1) ]) transition = Gaussian(transition_net(Yt)).expected_value() max_iter = 1000 tol = 1e-5 def cond(i, prev_elbo, elbo, qxt, qxt1):