def update_mixture(self, mixture, increase_tolerance=1e-6): # update mixture convergence_test = LlConvergenceTest(eps=1e-4, should_increase=True, use_absolute_difference=False) last_LL = None logging.info('Variational bound at start = %f', mixture.variational_bound()) for _i in xrange(self.options.max_iter): if None == last_LL: last_LL = mixture.variational_bound_piecewise() convergence_test(last_LL.sum()) mixture.q_z.update(mixture) last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Update z", tolerance=increase_tolerance, raise_error=True) assert np.isfinite(last_LL).all(), str(last_LL) mixture._check_shapes() mixture.reorder_components() last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Reorder components", tolerance=increase_tolerance, raise_error=True) assert np.isfinite(last_LL).all(), str(last_LL) mixture._check_shapes() if hasattr(mixture, 'q_pi'): mixture.q_pi.update(mixture) assert .99 < mixture.q_pi.E().sum() < 1.01 last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Update pi", tolerance=increase_tolerance, raise_error=True) assert np.isfinite(last_LL).all(), str(last_LL) mixture._check_shapes() mixture.q_eta.update(mixture) last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Update eta", tolerance=increase_tolerance, raise_error=False) assert np.isfinite(last_LL).all(), str(last_LL) mixture._check_shapes() logging.info('Iteration %d: variational bound = %f', _i+1, mixture.variational_bound()) if _i + 1 >= self.options.min_iter and convergence_test(last_LL.sum()): logging.info('Variational bound has converged : stopping.') break