def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo): q_X_ = stats.LDS(q_X_natparam, 'natural') ess = q_X_.expected_sufficient_statistics() batch_size = T.shape(ess)[0] yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds] xxT = ess[..., :-1, :ds, :ds] yxT = ess[..., :-1, ds:2 * ds, :ds] x = ess[..., :-1, -1, :ds] y = ess[..., :-1, -1, ds:2 * ds] xaT = T.outer(x, a) yaT = T.outer(y, a) xaxaT = T.concatenate([ T.concatenate([xxT, xaT], -1), T.concatenate([T.matrix_transpose(xaT), aaT], -1), ], -2) ess = [ yyT, T.concatenate([yxT, yaT], -1), xaxaT, T.ones([batch_size, self.horizon - 1]) ] q_dyn_natparam = [ T.sum(a, [0]) * data_strength + b for a, b in zip(ess, initial_dyn_natparam) ] q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural') q_stats = q_dyn_.expected_sufficient_statistics() p_X = stats.LDS((q_stats, state_prior, None, q_A.expected_value(), self.horizon)) q_X_ = stats.LDS((q_stats, state_prior, q_X, q_A.expected_value(), self.horizon)) elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) + T.sum(stats.kl_divergence(q_dyn_, prior_dyn))) return i + 1, q_dyn_.get_parameters( 'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo
def kl_divergence(self, q_X, q_A, _): # q_Xt - [N, H, ds] # q_At - [N, H, da] if (q_X, q_A) not in self.cache: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) p_Xt1 = self.forward(q_Xt, q_At) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) rmse = T.sqrt( T.sum(T.square( q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1)) model_stdev = T.sqrt(p_Xt1.get_parameters('regular')[0]) encoding_stdev = T.sqrt(q_Xt1.get_parameters('regular')[0]) self.cache[(q_X, q_A)] = T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=-1), { 'rmse': rmse, 'encoding-stdev': encoding_stdev, 'model-stdev': model_stdev } return self.cache[(q_X, q_A)]
def kl_divergence(self, q_X, q_A, _): # q_Xt - [N, H, ds] # q_At - [N, H, da] if (q_X, q_A) not in self.cache: info = {} if self.smooth: state_prior = stats.GaussianScaleDiag([ T.ones(self.ds), T.zeros(self.ds) ]) p_X = stats.LDS( (self.sufficient_statistics(), state_prior, None, q_A.expected_value(), self.horizon), 'internal') kl = T.mean(stats.kl_divergence(q_X, p_X), axis=0) Q = self.get_dynamics()[1] info['model-stdev'] = T.sqrt(T.matrix_diag_part(Q)) else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) p_Xt1 = self.forward(q_Xt, q_At) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) rmse = T.sqrt(T.sum(T.square(q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1)) kl = T.mean(T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=-1), axis=0) Q = self.get_dynamics()[1] model_stdev = T.sqrt(T.matrix_diag_part(Q)) info['rmse'] = rmse info['model-stdev'] = model_stdev self.cache[(q_X, q_A)] = kl, info return self.cache[(q_X, q_A)]
def kl_divergence(self, q_X, q_A, num_data): if (q_X, q_A) not in self.cache: if self.smooth: state_prior = stats.GaussianScaleDiag( [T.ones(self.ds), T.zeros(self.ds)]) self.p_X = stats.LDS( (self.sufficient_statistics(), state_prior, None, q_A.expected_value(), self.horizon), 'internal') local_kl = stats.kl_divergence(q_X, self.p_X) if self.time_varying: global_kl = T.sum( stats.kl_divergence(self.A_variational, self.A_prior)) else: global_kl = stats.kl_divergence(self.A_variational, self.A_prior) prior_kl = T.mean(local_kl, axis=0) + global_kl / T.to_float(num_data) A, Q = self.get_dynamics() model_stdev = T.sqrt(T.matrix_diag_part(Q)) self.cache[(q_X, q_A)] = prior_kl, { 'local-kl': local_kl, 'global-kl': global_kl, 'model-stdev': model_stdev, } else: q_Xt = q_X.__class__([ q_X.get_parameters('regular')[0][:, :-1], q_X.get_parameters('regular')[1][:, :-1], ]) q_At = q_A.__class__([ q_A.get_parameters('regular')[0][:, :-1], q_A.get_parameters('regular')[1][:, :-1], ]) p_Xt1 = self.forward(q_Xt, q_At) q_Xt1 = q_X.__class__([ q_X.get_parameters('regular')[0][:, 1:], q_X.get_parameters('regular')[1][:, 1:], ]) num_data = T.to_float(num_data) rmse = T.sqrt( T.sum(T.square( q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1)) A, Q = self.get_dynamics() model_stdev = T.sqrt(T.matrix_diag_part(Q)) local_kl = T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=1) if self.time_varying: global_kl = T.sum( stats.kl_divergence(self.A_variational, self.A_prior)) else: global_kl = stats.kl_divergence(self.A_variational, self.A_prior) self.cache[(q_X, q_A)] = (T.mean(local_kl, axis=0) + global_kl / T.to_float(num_data), { 'rmse': rmse, 'model-stdev': model_stdev, 'local-kl': local_kl, 'global-kl': global_kl }) return self.cache[(q_X, q_A)]
def kl_divergence(self, q_X, q_A, num_data): mu_shape = T.shape(q_X.get_parameters('regular')[1]) p_X = stats.GaussianScaleDiag([T.ones(mu_shape), T.zeros(mu_shape)]) return T.mean(T.sum(stats.kl_divergence(q_X, p_X), -1), 0), {}
x_stats = Gaussian.pack([ T.outer(X, X), X, ]) theta_cmessage = q_theta.expected_sufficient_statistics() num_batches = N / T.to_float(batch_size) nat_scale = 10.0 parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + parent_z q_z = Categorical(new_z - T.logsumexp(new_z, -1)[..., None], parameter_type='natural') p_z = Categorical(parent_z - T.logsumexp(parent_z, -1), parameter_type='natural') l_z = T.sum(kl_divergence(q_z, p_z)) z_pmessage = q_z.expected_sufficient_statistics() pi_stats = T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') current_pi = q_pi.get_parameters('natural') pi_gradient = nat_scale / N * (parent_pi + num_batches * pi_stats - current_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) theta_stats = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) parent_theta = p_theta.get_parameters('natural')[None] current_theta = q_theta.get_parameters('natural') theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats - current_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta))
x_tmessage = NIW.pack([ T.outer(X, X), X, T.ones(N), T.ones(N), ]) x_stats = Gaussian.pack([ T.outer(X, X), X, ]) theta_cmessage = q_theta.expected_sufficient_statistics() new_pi = p_pi.get_parameters('natural') + T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') pi_update = T.assign(q_pi.get_parameters('natural'), new_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) new_theta = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) + p_theta.get_parameters('natural')[None] parent_theta = p_theta.get_parameters('natural') theta_update = T.assign(q_theta.get_parameters('natural'), new_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + q_pi.expected_sufficient_statistics()[None] new_z = new_z - T.logsumexp(new_z, -1)[..., None] z_update = T.assign(q_z.get_parameters('natural'), new_z) l_z = T.sum(kl_divergence(q_z, Categorical(parent_z, parameter_type='natural')))
# T.ones([batch_size]), # T.ones([batch_size]), # ]) # q_theta = make_variable(NIW(map(lambda x: np.array(x).astype(T.floatx()), [np.eye(D), np.random.multivariate_normal(mean=np.zeros([D]), cov=np.eye(D) * 20), 1.0, 1.0]))) num_batches = N / T.to_float(batch_size) nat_scale = 1.0 theta_stats = T.sum(x_tmessage, 0) parent_theta = p_theta.get_parameters('natural') q_theta = NIW(parent_theta + theta_stats, parameter_type='natural') sigma, mu = Gaussian(q_theta.expected_sufficient_statistics(), parameter_type='natural').get_parameters('regular') # theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats - current_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) x_param = q_theta.expected_sufficient_statistics()[None] q_x = Gaussian(T.tile(x_param, [batch_size, 1, 1]), parameter_type='natural') l_x = T.sum(q_x.log_likelihood(X)) elbo = l_theta + l_x elbos = [] l_thetas = [] l_xs = [] # natgrads = [(theta_gradient, q_theta.get_parameters('natural'))] # nat_op = tf.group(*[T.assign(b, a + b) for a, b in natgrads]) # nat_opt = tf.train.GradientDescentOptimizer(1e-2) # nat_op = nat_opt.apply_gradients([(-a, b) for a, b in natgrads])
lr = 1e-4 batch_size = T.shape(x)[0] num_batches = T.to_float(N / batch_size) with T.initialization('xavier'): # stats_net = Relu(D + 1, 20) >> Relu(20) >> GaussianLayer(D) stats_net = GaussianLayer(D + 1, D) net_out = stats_net(T.concat([x, y[..., None]], -1)) stats = T.sum(net_out.get_parameters('natural'), 0)[None] natural_gradient = (p_w.get_parameters('natural') + num_batches * stats - q_w.get_parameters('natural')) / N next_w = Gaussian(q_w.get_parameters('natural') + lr * natural_gradient, parameter_type='natural') l_w = kl_divergence(q_w, p_w)[0] p_y = Bernoulli(T.sigmoid(T.einsum('jw,iw->ij', next_w.expected_value(), x))) l_y = T.sum(p_y.log_likelihood(y[..., None])) elbo = l_w + l_y nat_op = T.assign(q_w.get_parameters('natural'), next_w.get_parameters('natural')) grad_op = tf.train.RMSPropOptimizer(1e-4).minimize(-elbo) train_op = tf.group(nat_op, grad_op) sess = T.interactive_session() predictions = T.cast( T.sigmoid(T.einsum('jw,iw->i', q_w.expected_value(), T.to_float(X))) + 0.5, np.int32) accuracy = T.mean(