def kl_divergence(self, q_X, q_A, _): # q_Xt - [N, H, ds] # q_At - [N, H, da] if (q_X, q_A) not in self.cache: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) p_Xt1 = self.forward(q_Xt, q_At) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) rmse = T.sqrt( T.sum(T.square( q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1)) model_stdev = T.sqrt(p_Xt1.get_parameters('regular')[0]) encoding_stdev = T.sqrt(q_Xt1.get_parameters('regular')[0]) self.cache[(q_X, q_A)] = T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=-1), { 'rmse': rmse, 'encoding-stdev': encoding_stdev, 'model-stdev': model_stdev } return self.cache[(q_X, q_A)]
def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo): q_X_ = stats.LDS(q_X_natparam, 'natural') ess = q_X_.expected_sufficient_statistics() batch_size = T.shape(ess)[0] yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds] xxT = ess[..., :-1, :ds, :ds] yxT = ess[..., :-1, ds:2 * ds, :ds] x = ess[..., :-1, -1, :ds] y = ess[..., :-1, -1, ds:2 * ds] xaT = T.outer(x, a) yaT = T.outer(y, a) xaxaT = T.concatenate([ T.concatenate([xxT, xaT], -1), T.concatenate([T.matrix_transpose(xaT), aaT], -1), ], -2) ess = [ yyT, T.concatenate([yxT, yaT], -1), xaxaT, T.ones([batch_size, self.horizon - 1]) ] q_dyn_natparam = [ T.sum(a, [0]) * data_strength + b for a, b in zip(ess, initial_dyn_natparam) ] q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural') q_stats = q_dyn_.expected_sufficient_statistics() p_X = stats.LDS((q_stats, state_prior, None, q_A.expected_value(), self.horizon)) q_X_ = stats.LDS((q_stats, state_prior, q_X, q_A.expected_value(), self.horizon)) elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) + T.sum(stats.kl_divergence(q_dyn_, prior_dyn))) return i + 1, q_dyn_.get_parameters( 'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo
def log_z(self, parameter_type='regular', stop_gradient=False): if parameter_type == 'regular': alpha = self.get_parameters('regular', stop_gradient=stop_gradient) else: alpha = self.natural_to_regular( self.get_parameters('natural', stop_gradient=stop_gradient)) return T.sum(T.gammaln(alpha), -1) - T.gammaln(T.sum(alpha, -1))
def log_normal(x, mu, sigma, D, dim=1): if dim == 1: pre_term = -(D * 0.5 * np.log(2 * np.pi) + 0.5 * D * T.log(sigma)) delta = T.sum((x - mu) ** 2, axis=1) * 1.0 / sigma return pre_term + -0.5 * delta elif dim == 2: pre_term = -(D * 0.5 * np.log(2 * np.pi) + 0.5 * T.sum(T.log(sigma), axis=1)) delta = T.sum((x - mu) * 1.0 / sigma * (x - mu), axis=1) return pre_term + -0.5 * delta
def get_child_message(x, y, hidden={}, visible={}): with graph_context({**hidden, **visible}): data = context(y) log_likelihood = y.log_likelihood(data) stats = x.statistics() param = T.grad(T.sum(log_likelihood), [x.get_statistic(s) for s in stats]) return {s: param[i] for i, s in enumerate(stats)}
def message_passing(hidden, visible): elbo = 0.0 for var in top_sort(hidden)[::-1]: child_messages = [ get_child_message( var, c, hidden={k: v for k, v in hidden.items() if k != var}, visible=visible) for c in var.children() ] stats = var.statistics() parent_message = var.get_parameters('natural') e_p = var.__class__(parent_message, 'natural', graph=False) natparam = { s: parent_message[s] + sum([child_message[s] for child_message in child_messages]) for s in stats } q = var.__class__(natparam, 'natural', graph=False) elbo -= kl_divergence(q, e_p) hidden[var] = q for var in visible: with graph_context(hidden): elbo += T.sum(var.log_likelihood(visible[var])) return hidden, elbo
def log_likelihood(self, x): natparam = self.get_parameters('natural') stats = self.sufficient_statistics(x) return (sum( T.sum(stats[stat] * natparam[stat], list(range(-stat.out_dim(), 0))) for stat in self.statistics()) - self.log_z())
def kl_divergence(p, q): assert p.statistics() == q.statistics() param_dim = p.get_param_dim() dist = p.__class__ p_param, q_param = p.get_parameters('natural'), q.get_parameters('natural') stats = p.statistics() p_stats = p.expected_sufficient_statistics() return (sum([ T.sum((p_param[s] - q_param[s]) * p_stats[s], list(range(-param_dim[s], 0))) for s in stats ]) - p.log_z() + q.log_z())
def step(i, prev_elbo, elbo, qxt_param, qxt1_param): qxt, qxt1 = Gaussian(qxt_param, 'natural'), Gaussian(qxt1_param, 'natural') qxt_message = Gaussian.regular_to_natural(transition_net(qxt.sample()[0])) qxt1 = Gaussian(qxt_message + Yt1_message, 'natural') qxt1_message = Gaussian.regular_to_natural(rec_net(qxt1.sample()[0])) qxt = Gaussian(qxt1_message + Yt_message, 'natural') prev_elbo, elbo = elbo, T.sum( kl_divergence(qxt1, Gaussian(qxt_message, 'natural'))) return i + 1, prev_elbo, elbo, qxt.get_parameters( 'natural'), qxt1.get_parameters('natural')
def kl_divergence(self, q_X, q_A, _): # q_Xt - [N, H, ds] # q_At - [N, H, da] if (q_X, q_A) not in self.cache: info = {} if self.smooth: state_prior = stats.GaussianScaleDiag([ T.ones(self.ds), T.zeros(self.ds) ]) p_X = stats.LDS( (self.sufficient_statistics(), state_prior, None, q_A.expected_value(), self.horizon), 'internal') kl = T.mean(stats.kl_divergence(q_X, p_X), axis=0) Q = self.get_dynamics()[1] info['model-stdev'] = T.sqrt(T.matrix_diag_part(Q)) else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) p_Xt1 = self.forward(q_Xt, q_At) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) rmse = T.sqrt(T.sum(T.square(q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1)) kl = T.mean(T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=-1), axis=0) Q = self.get_dynamics()[1] model_stdev = T.sqrt(T.matrix_diag_part(Q)) info['rmse'] = rmse info['model-stdev'] = model_stdev self.cache[(q_X, q_A)] = kl, info return self.cache[(q_X, q_A)]
def log_z(self, parameter_type='regular', stop_gradient=False): if parameter_type == 'regular': sigma, mu = self.get_parameters('regular', stop_gradient=stop_gradient) d = T.to_float(self.shape()[-1]) hsi, hlds = Stats.HSI(sigma), Stats.HLDS(sigma) mmT = Stats.XXT(mu) return ( - T.sum(hsi * mmT, [-1, -2]) - hlds + d / 2. * np.log(2 * np.pi) ) else: natparam = self.get_parameters('natural', stop_gradient=stop_gradient) d = T.to_float(self.shape()[-1]) J, m = natparam[Stats.XXT], natparam[Stats.X] return ( - 0.25 * (m[..., None, :]@T.matrix_inverse(J)@m[..., None])[..., 0, 0] - 0.5 * T.logdet(-2 * J) + d / 2. * np.log(2 * np.pi) )
def kl_gradients(self, q_X, q_A, _, num_data): if self.smooth: ds = self.ds ess = q_X.expected_sufficient_statistics() yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds] xxT = ess[..., :-1, :ds, :ds] yxT = ess[..., :-1, ds:2 * ds, :ds] aaT, a = stats.Gaussian.unpack( q_A.expected_sufficient_statistics()) aaT, a = aaT[:, :-1], a[:, :-1] x = ess[..., :-1, -1, :ds] y = ess[..., :-1, -1, ds:2 * ds] xaT = T.outer(x, a) yaT = T.outer(y, a) xaxaT = T.concatenate([ T.concatenate([xxT, xaT], -1), T.concatenate([T.matrix_transpose(xaT), aaT], -1), ], -2) batch_size = T.shape(ess)[0] num_batches = T.to_float(num_data) / T.to_float(batch_size) ess = [ yyT, T.concatenate([yxT, yaT], -1), xaxaT, T.ones([batch_size, self.horizon - 1]) ] else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) (XtAt_XtAtT, XtAt), (Xt1_Xt1T, Xt1) = self.get_statistics(q_Xt, q_At, q_Xt1) batch_size = T.shape(XtAt)[0] num_batches = T.to_float(num_data) / T.to_float(batch_size) ess = [ Xt1_Xt1T, T.einsum('nha,nhb->nhba', XtAt, Xt1), XtAt_XtAtT, T.ones([batch_size, self.horizon - 1]) ] if self.time_varying: ess = [ T.sum(ess[0], [0]), T.sum(ess[1], [0]), T.sum(ess[2], [0]), T.sum(ess[3], [0]), ] else: ess = [ T.sum(ess[0], [0, 1]), T.sum(ess[1], [0, 1]), T.sum(ess[2], [0, 1]), T.sum(ess[3], [0, 1]), ] return [ -(a + num_batches * b - c) / T.to_float(num_data) for a, b, c in zip( self.A_prior.get_parameters('natural'), ess, self.A_variational.get_parameters('natural'), ) ]
q_w = make_variable( Gaussian([T.to_float(np.eye(D))[None], T.to_float(np.zeros(D))[None]])) x, y = T.matrix(), T.vector() lr = 1e-4 batch_size = T.shape(x)[0] num_batches = T.to_float(N / batch_size) with T.initialization('xavier'): # stats_net = Relu(D + 1, 20) >> Relu(20) >> GaussianLayer(D) stats_net = GaussianLayer(D + 1, D) net_out = stats_net(T.concat([x, y[..., None]], -1)) stats = T.sum(net_out.get_parameters('natural'), 0)[None] natural_gradient = (p_w.get_parameters('natural') + num_batches * stats - q_w.get_parameters('natural')) / N next_w = Gaussian(q_w.get_parameters('natural') + lr * natural_gradient, parameter_type='natural') l_w = kl_divergence(q_w, p_w)[0] p_y = Bernoulli(T.sigmoid(T.einsum('jw,iw->ij', next_w.expected_value(), x))) l_y = T.sum(p_y.log_likelihood(y[..., None])) elbo = l_w + l_y nat_op = T.assign(q_w.get_parameters('natural'), next_w.get_parameters('natural')) grad_op = tf.train.RMSPropOptimizer(1e-4).minimize(-elbo)
def initialize(self): self.graph = T.core.Graph() with self.graph.as_default(): prior_params = self.prior_params.copy() prior_type = prior_params.pop('prior_type') self.prior = PRIOR_MAP[prior_type](self.ds, self.da, self.horizon, **prior_params) cost_params = self.cost_params.copy() cost_type = cost_params.pop('cost_type') self.cost = COST_MAP[cost_type](self.ds, self.da, **cost_params) self.O = T.placeholder(T.floatx(), [None, None, self.do]) self.U = T.placeholder(T.floatx(), [None, None, self.du]) self.C = T.placeholder(T.floatx(), [None, None]) self.S = T.placeholder(T.floatx(), [None, None, self.ds]) self.A = T.placeholder(T.floatx(), [None, None, self.da]) self.t = T.placeholder(T.int32, []) self.state, self.action = T.placeholder(T.floatx(), [None, self.ds]), T.placeholder(T.floatx(), [None, self.da]) if self.prior.has_dynamics(): self.next_state = self.prior.next_state(self.state, self.action, self.t) self.prior_dynamics = self.prior.get_dynamics() self.num_data = T.scalar() self.beta = T.placeholder(T.floatx(), []) self.learning_rate = T.placeholder(T.floatx(), []) self.model_learning_rate = T.placeholder(T.floatx(), []) self.S_potentials = util.map_network(self.state_encoder)(self.O) self.A_potentials = util.map_network(self.action_encoder)(self.U) if self.prior.is_dynamics_prior(): self.data_strength = T.placeholder(T.floatx(), []) self.max_iter = T.placeholder(T.int32, []) posterior_dynamics, (encodings, actions) = \ self.prior.posterior_dynamics(self.S_potentials, self.A_potentials, data_strength=self.data_strength, max_iter=self.max_iter) self.posterior_dynamics_ = posterior_dynamics, (encodings.expected_value(), actions.expected_value()) if self.prior.is_filtering_prior(): self.prior_dynamics_stats = self.prior.sufficient_statistics() self.dynamics_stats = ( T.placeholder(T.floatx(), [None, self.ds, self.ds]), T.placeholder(T.floatx(), [None, self.ds, self.ds + self.da]), T.placeholder(T.floatx(), [None, self.ds + self.da, self.ds + self.da]), T.placeholder(T.floatx(), [None]), ) S_natparam = self.S_potentials.get_parameters('natural') num_steps = T.shape(S_natparam)[1] self.padded_S = stats.Gaussian(T.core.pad( self.S_potentials.get_parameters('natural'), [[0, 0], [0, self.horizon - num_steps], [0, 0], [0, 0]] ), 'natural') self.padded_A = stats.GaussianScaleDiag([ T.core.pad(self.A_potentials.get_parameters('regular')[0], [[0, 0], [0, self.horizon - num_steps], [0, 0]]), T.core.pad(self.A_potentials.get_parameters('regular')[1], [[0, 0], [0, self.horizon - num_steps], [0, 0]]) ], 'regular') self.q_S_padded, self.q_A_padded = self.prior.encode( self.padded_S, self.padded_A, dynamics_stats=self.dynamics_stats ) self.q_S_filter = self.q_S_padded.filter(max_steps=num_steps) self.q_A_filter = self.q_A_padded.__class__( self.q_A_padded.get_parameters('natural')[:, :num_steps] , 'natural') self.e_q_S_filter = self.q_S_filter.expected_value() self.e_q_A_filter = self.q_A_filter.expected_value() (self.q_S, self.q_A), self.prior_kl, self.kl_grads, self.info = self.prior.posterior_kl_grads( self.S_potentials, self.A_potentials, self.num_data ) self.q_S_sample = self.q_S.sample()[0] self.q_A_sample = self.q_A.sample()[0] self.q_O = util.map_network(self.state_decoder)(self.q_S_sample) self.q_U = util.map_network(self.action_decoder)(self.q_A_sample) self.q_O_sample = self.q_O.sample()[0] self.q_U_sample = self.q_U.sample()[0] self.q_O_ = util.map_network(self.state_decoder)(self.S) self.q_U_ = util.map_network(self.action_decoder)(self.A) self.q_O__sample = self.q_O_.sample()[0] self.q_U__sample = self.q_U_.sample()[0] self.cost_likelihood = self.cost.log_likelihood(self.q_S_sample, self.C) if self.cost.is_cost_function(): self.evaluated_cost = self.cost.evaluate(self.S) self.log_likelihood = T.sum(self.q_O.log_likelihood(self.O), axis=1) self.elbo = T.mean(self.log_likelihood + self.cost_likelihood - self.prior_kl) train_elbo = T.mean(self.log_likelihood + self.beta * (self.cost_likelihood - self.prior_kl)) T.core.summary.scalar("encoder-stdev", T.mean(self.S_potentials.get_parameters('regular')[0])) T.core.summary.scalar("log-likelihood", T.mean(self.log_likelihood)) T.core.summary.scalar("cost-likelihood", T.mean(self.cost_likelihood)) T.core.summary.scalar("prior-kl", T.mean(self.prior_kl)) T.core.summary.scalar("beta", self.beta) T.core.summary.scalar("elbo", self.elbo) T.core.summary.scalar("beta-elbo", train_elbo) for k, v in self.info.items(): T.core.summary.scalar(k, T.mean(v)) self.summary = T.core.summary.merge_all() neural_params = ( self.state_encoder.get_parameters() + self.state_decoder.get_parameters() + self.action_encoder.get_parameters() + self.action_decoder.get_parameters() ) cost_params = self.cost.get_parameters() if len(neural_params) > 0: optimizer = T.core.train.AdamOptimizer(self.learning_rate) gradients, variables = zip(*optimizer.compute_gradients(-train_elbo, var_list=neural_params)) gradients, _ = tf.clip_by_global_norm(gradients, 5.0) self.neural_op = optimizer.apply_gradients(zip(gradients, variables)) else: self.neural_op = T.core.no_op() if len(cost_params) > 0: self.cost_op = T.core.train.AdamOptimizer(self.learning_rate).minimize(-self.elbo, var_list=cost_params) else: self.cost_op = T.core.no_op() if len(self.kl_grads) > 0: if self.prior.is_dynamics_prior(): # opt = lambda x: T.core.train.MomentumOptimizer(x, 0.5) opt = lambda x: T.core.train.GradientDescentOptimizer(x) else: opt = T.core.train.AdamOptimizer self.dynamics_op = opt(self.model_learning_rate).apply_gradients([ (b, a) for a, b in self.kl_grads ]) else: self.dynamics_op = T.core.no_op() self.train_op = T.core.group(self.neural_op, self.dynamics_op, self.cost_op) self.session = T.interactive_session(graph=self.graph, allow_soft_placement=True, log_device_placement=False)
def kl_divergence(self, q_X, q_A, num_data): mu_shape = T.shape(q_X.get_parameters('regular')[1]) p_X = stats.GaussianScaleDiag([T.ones(mu_shape), T.zeros(mu_shape)]) return T.mean(T.sum(stats.kl_divergence(q_X, p_X), -1), 0), {}
def _sample(self, num_samples): alpha = self.get_parameters('regular') d = self.shape()[-1] gammas = T.random_gamma([num_samples], alpha, beta=1) return gammas / T.sum(gammas, -1)[..., None]
pi_cmessage = q_pi.expected_sufficient_statistics() z_pmessage = q_z.expected_sufficient_statistics() x_tmessage = NIW.pack([ T.outer(X, X), X, T.ones(N), T.ones(N), ]) x_stats = Gaussian.pack([ T.outer(X, X), X, ]) theta_cmessage = q_theta.expected_sufficient_statistics() new_pi = p_pi.get_parameters('natural') + T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') pi_update = T.assign(q_pi.get_parameters('natural'), new_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) new_theta = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) + p_theta.get_parameters('natural')[None] parent_theta = p_theta.get_parameters('natural') theta_update = T.assign(q_theta.get_parameters('natural'), new_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + q_pi.expected_sufficient_statistics()[None] new_z = new_z - T.logsumexp(new_z, -1)[..., None] z_update = T.assign(q_z.get_parameters('natural'), new_z)
np.set_printoptions(suppress=True) x_tmessage = net(X) # x_tmessage = NIW.pack([ # T.outer(X, X), # X, # T.ones([batch_size]), # T.ones([batch_size]), # ]) # q_theta = make_variable(NIW(map(lambda x: np.array(x).astype(T.floatx()), [np.eye(D), np.random.multivariate_normal(mean=np.zeros([D]), cov=np.eye(D) * 20), 1.0, 1.0]))) num_batches = N / T.to_float(batch_size) nat_scale = 1.0 theta_stats = T.sum(x_tmessage, 0) parent_theta = p_theta.get_parameters('natural') q_theta = NIW(parent_theta + theta_stats, parameter_type='natural') sigma, mu = Gaussian(q_theta.expected_sufficient_statistics(), parameter_type='natural').get_parameters('regular') # theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats - current_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) x_param = q_theta.expected_sufficient_statistics()[None] q_x = Gaussian(T.tile(x_param, [batch_size, 1, 1]), parameter_type='natural') l_x = T.sum(q_x.log_likelihood(X)) elbo = l_theta + l_x elbos = [] l_thetas = [] l_xs = []
def log_likelihood(self, states, costs): return T.sum(util.map_network(self.network)(states).log_likelihood( costs[..., None]), axis=-1)
def kl_divergence(self, q_X, q_A, num_data): if (q_X, q_A) not in self.cache: if self.smooth: state_prior = stats.GaussianScaleDiag( [T.ones(self.ds), T.zeros(self.ds)]) self.p_X = stats.LDS( (self.sufficient_statistics(), state_prior, None, q_A.expected_value(), self.horizon), 'internal') local_kl = stats.kl_divergence(q_X, self.p_X) if self.time_varying: global_kl = T.sum( stats.kl_divergence(self.A_variational, self.A_prior)) else: global_kl = stats.kl_divergence(self.A_variational, self.A_prior) prior_kl = T.mean(local_kl, axis=0) + global_kl / T.to_float(num_data) A, Q = self.get_dynamics() model_stdev = T.sqrt(T.matrix_diag_part(Q)) self.cache[(q_X, q_A)] = prior_kl, { 'local-kl': local_kl, 'global-kl': global_kl, 'model-stdev': model_stdev, } else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) p_Xt1 = self.forward(q_Xt, q_At) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) num_data = T.to_float(num_data) rmse = T.sqrt( T.sum(T.square( q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1)) A, Q = self.get_dynamics() model_stdev = T.sqrt(T.matrix_diag_part(Q)) local_kl = T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=1) if self.time_varying: global_kl = T.sum( stats.kl_divergence(self.A_variational, self.A_prior)) else: global_kl = stats.kl_divergence(self.A_variational, self.A_prior) self.cache[(q_X, q_A)] = (T.mean(local_kl, axis=0) + global_kl / T.to_float(num_data), { 'rmse': rmse, 'model-stdev': model_stdev, 'local-kl': local_kl, 'global-kl': global_kl }) return self.cache[(q_X, q_A)]
def posterior_dynamics(self, q_X, q_A, data_strength=1.0, max_iter=200, tol=1e-3): if self.smooth: if self.time_varying: prior_dyn = stats.MNIW( self.A_variational.get_parameters('natural'), 'natural') else: natparam = self.A_variational.get_parameters('natural') prior_dyn = stats.MNIW([ T.tile(natparam[0][None], [self.horizon - 1, 1, 1]), T.tile(natparam[1][None], [self.horizon - 1, 1, 1]), T.tile(natparam[2][None], [self.horizon - 1, 1, 1]), T.tile(natparam[3][None], [self.horizon - 1]), ], 'natural') state_prior = stats.Gaussian([T.eye(self.ds), T.zeros(self.ds)]) aaT, a = stats.Gaussian.unpack( q_A.expected_sufficient_statistics()) aaT, a = aaT[:, :-1], a[:, :-1] ds, da = self.ds, self.da initial_dyn_natparam = prior_dyn.get_parameters('natural') initial_X_natparam = stats.LDS( (self.sufficient_statistics(), state_prior, q_X, q_A.expected_value(), self.horizon), 'internal').get_parameters('natural') def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo): q_X_ = stats.LDS(q_X_natparam, 'natural') ess = q_X_.expected_sufficient_statistics() batch_size = T.shape(ess)[0] yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds] xxT = ess[..., :-1, :ds, :ds] yxT = ess[..., :-1, ds:2 * ds, :ds] x = ess[..., :-1, -1, :ds] y = ess[..., :-1, -1, ds:2 * ds] xaT = T.outer(x, a) yaT = T.outer(y, a) xaxaT = T.concatenate([ T.concatenate([xxT, xaT], -1), T.concatenate([T.matrix_transpose(xaT), aaT], -1), ], -2) ess = [ yyT, T.concatenate([yxT, yaT], -1), xaxaT, T.ones([batch_size, self.horizon - 1]) ] q_dyn_natparam = [ T.sum(a, [0]) * data_strength + b for a, b in zip(ess, initial_dyn_natparam) ] q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural') q_stats = q_dyn_.expected_sufficient_statistics() p_X = stats.LDS((q_stats, state_prior, None, q_A.expected_value(), self.horizon)) q_X_ = stats.LDS((q_stats, state_prior, q_X, q_A.expected_value(), self.horizon)) elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) + T.sum(stats.kl_divergence(q_dyn_, prior_dyn))) return i + 1, q_dyn_.get_parameters( 'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo def cond(i, _, __, prev_elbo, curr_elbo): with T.core.control_dependencies([T.core.print(curr_elbo)]): prev_elbo = T.core.identity(prev_elbo) return T.logical_and( T.abs(curr_elbo - prev_elbo) > tol, i < max_iter) result = T.while_loop( cond, em, [ 0, initial_dyn_natparam, initial_X_natparam, T.constant(-np.inf), T.constant(0.) ], back_prop=False) pd = stats.MNIW(result[1], 'natural') sigma, mu = pd.expected_value() q_X = stats.LDS(result[2], 'natural') return ((mu, sigma), pd.expected_sufficient_statistics()), (q_X, q_A) else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) (XtAt_XtAtT, XtAt), (Xt1_Xt1T, Xt1) = self.get_statistics(q_Xt, q_At, q_Xt1) batch_size = T.shape(XtAt)[0] ess = [ Xt1_Xt1T, T.einsum('nha,nhb->nhba', XtAt, Xt1), XtAt_XtAtT, T.ones([batch_size, self.horizon - 1]) ] if self.time_varying: posterior = stats.MNIW([ T.sum(a, [0]) * data_strength + b for a, b in zip( ess, self.A_variational.get_parameters('natural')) ], 'natural') else: posterior = stats.MNIW([ T.sum(a, [0]) * data_strength + b[None] for a, b in zip( ess, self.A_variational.get_parameters('natural')) ], 'natural') Q, A = posterior.expected_value() return (A, Q), q_X
x_stats = Gaussian.pack([ T.outer(X, X), X, ]) theta_cmessage = q_theta.expected_sufficient_statistics() num_batches = N / T.to_float(batch_size) nat_scale = 10.0 parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + parent_z q_z = Categorical(new_z - T.logsumexp(new_z, -1)[..., None], parameter_type='natural') p_z = Categorical(parent_z - T.logsumexp(parent_z, -1), parameter_type='natural') l_z = T.sum(kl_divergence(q_z, p_z)) z_pmessage = q_z.expected_sufficient_statistics() pi_stats = T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') current_pi = q_pi.get_parameters('natural') pi_gradient = nat_scale / N * (parent_pi + num_batches * pi_stats - current_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) theta_stats = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) parent_theta = p_theta.get_parameters('natural')[None] current_theta = q_theta.get_parameters('natural') theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats - current_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta))
def get_child_message(x, y, data={}): y_ = data[y] stats = x.statistics() log_likelihood = y.log_likelihood(y_) param = T.grad(T.sum(log_likelihood), [x.get_statistic(s) for s in stats]) return {s: param[i] for i, s in enumerate(stats)}
0.0, Gaussian.regular_to_natural( [T.eye(D, batch_shape=[batch_size]) * 1e-4, Yt]), Gaussian.regular_to_natural( [T.eye(D, batch_shape=[batch_size]) * 1e-4, Yt1]), ]) qxt = Gaussian(qxt_param, 'natural') qxt1 = Gaussian(qxt1_param, 'natural') pyt = Gaussian( [T.tile(T.eye(D)[None] * noise, [batch_size, 1, 1]), qxt.expected_value()]) pyt1 = Gaussian([ T.tile(T.eye(D)[None] * noise, [batch_size, 1, 1]), qxt1.expected_value() ]) log_likelihood = T.sum(pyt.log_likelihood(Yt) + pyt1.log_likelihood(Yt1)) elbo = (log_likelihood - kl) / T.to_float(batch_size) grads = T.gradients(-elbo, transition_net.get_parameters() + rec_net.get_parameters()) grads, _ = T.core.clip_by_global_norm(grads, 1) # gradient clipping grads_and_vars = list( zip(grads, transition_net.get_parameters() + rec_net.get_parameters())) train_op = T.core.train.AdamOptimizer(1e-4).apply_gradients(grads_and_vars) # train_op = T.core.train.AdamOptimizer(1e-5).minimize(-elbo, var_list=transition_net.get_parameters() + rec_net.get_parameters()) sess = T.interactive_session() plt.figure() plt.ion()