def log_z(self, parameter_type='regular', stop_gradient=False): if parameter_type == 'regular': pi = self.get_parameters('regular', stop_gradient=stop_gradient) eta = Stats.LogX(pi) else: eta = self.get_parameters('natural', stop_gradient=stop_gradient)[Stats.X] return T.logsumexp(eta, -1)
X, T.ones([batch_size]), T.ones([batch_size]), ]) x_stats = Gaussian.pack([ T.outer(X, X), X, ]) theta_cmessage = q_theta.expected_sufficient_statistics() num_batches = N / T.to_float(batch_size) nat_scale = 10.0 parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + parent_z q_z = Categorical(new_z - T.logsumexp(new_z, -1)[..., None], parameter_type='natural') p_z = Categorical(parent_z - T.logsumexp(parent_z, -1), parameter_type='natural') l_z = T.sum(kl_divergence(q_z, p_z)) z_pmessage = q_z.expected_sufficient_statistics() pi_stats = T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') current_pi = q_pi.get_parameters('natural') pi_gradient = nat_scale / N * (parent_pi + num_batches * pi_stats - current_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) theta_stats = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) parent_theta = p_theta.get_parameters('natural')[None] current_theta = q_theta.get_parameters('natural')
new_pi = p_pi.get_parameters('natural') + T.sum(z_pmessage, 0) parent_pi = p_pi.get_parameters('natural') pi_update = T.assign(q_pi.get_parameters('natural'), new_pi) l_pi = T.sum(kl_divergence(q_pi, p_pi)) new_theta = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage) + p_theta.get_parameters('natural')[None] parent_theta = p_theta.get_parameters('natural') theta_update = T.assign(q_theta.get_parameters('natural'), new_theta) l_theta = T.sum(kl_divergence(q_theta, p_theta)) parent_z = q_pi.expected_sufficient_statistics()[None] new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + q_pi.expected_sufficient_statistics()[None] new_z = new_z - T.logsumexp(new_z, -1)[..., None] z_update = T.assign(q_z.get_parameters('natural'), new_z) l_z = T.sum(kl_divergence(q_z, Categorical(parent_z, parameter_type='natural'))) x_param = T.einsum('ia,abc->ibc', q_z.expected_sufficient_statistics(), q_theta.expected_sufficient_statistics()) q_x = Gaussian(x_param, parameter_type='natural') l_x = T.sum(q_x.log_likelihood(X)) elbo = l_theta + l_pi + l_z + l_x elbos = [] sess = T.interactive_session() draw()