z_pmessage = q_z.expected_sufficient_statistics() x_tmessage = NIW.pack([ T.outer(X, X), X, T.ones(N), T.ones(N), ]) x_stats = Gaussian.pack([ T.outer(X, X), X, ]) theta_cmessage = q_theta.expected_sufficient_statistics() new_pi = p_pi.get_parameters('natural') + T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') pi_update = T.assign(q_pi.get_parameters('natural'), new_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) new_theta = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) + p_theta.get_parameters('natural')[None] parent_theta = p_theta.get_parameters('natural') theta_update = T.assign(q_theta.get_parameters('natural'), new_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + q_pi.expected_sufficient_statistics()[None] new_z = new_z - T.logsumexp(new_z, -1)[..., None] z_update = T.assign(q_z.get_parameters('natural'), new_z) l_z = T.sum(kl_divergence(q_z, Categorical(parent_z, parameter_type='natural')))
stats_net = GaussianLayer(D + 1, D) net_out = stats_net(T.concat([x, y[..., None]], -1)) stats = T.sum(net_out.get_parameters('natural'), 0)[None] natural_gradient = (p_w.get_parameters('natural') + num_batches * stats - q_w.get_parameters('natural')) / N next_w = Gaussian(q_w.get_parameters('natural') + lr * natural_gradient, parameter_type='natural') l_w = kl_divergence(q_w, p_w)[0] p_y = Bernoulli(T.sigmoid(T.einsum('jw,iw->ij', next_w.expected_value(), x))) l_y = T.sum(p_y.log_likelihood(y[..., None])) elbo = l_w + l_y nat_op = T.assign(q_w.get_parameters('natural'), next_w.get_parameters('natural')) grad_op = tf.train.RMSPropOptimizer(1e-4).minimize(-elbo) train_op = tf.group(nat_op, grad_op) sess = T.interactive_session() predictions = T.cast( T.sigmoid(T.einsum('jw,iw->i', q_w.expected_value(), T.to_float(X))) + 0.5, np.int32) accuracy = T.mean( T.to_float(T.equal(predictions, T.constant(Y.astype(np.int32))))) def iter(num_iter=1, b=100): for _ in range(num_iter): idx = np.random.permutation(N)[:b] sess.run(train_op, {x: X[idx], y: Y[idx]})