Ejemplo n.º 1
0
 def update_mixture(self, mixture, increase_tolerance=1e-6):
     # update mixture
     convergence_test = LlConvergenceTest(eps=1e-4, should_increase=True, use_absolute_difference=False)
     last_LL = None
     logging.info('Variational bound at start = %f', mixture.variational_bound())
     for _i in xrange(self.options.max_iter):
         if None == last_LL:
             last_LL = mixture.variational_bound_piecewise()
             convergence_test(last_LL.sum())
         mixture.q_z.update(mixture)
         last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Update z", tolerance=increase_tolerance, raise_error=True)
         assert np.isfinite(last_LL).all(), str(last_LL)
         mixture._check_shapes()
         mixture.reorder_components()
         last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Reorder components", tolerance=increase_tolerance, raise_error=True)
         assert np.isfinite(last_LL).all(), str(last_LL)
         mixture._check_shapes()
         if hasattr(mixture, 'q_pi'):
             mixture.q_pi.update(mixture)
             assert .99 < mixture.q_pi.E().sum() < 1.01
             last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Update pi", tolerance=increase_tolerance, raise_error=True)
             assert np.isfinite(last_LL).all(), str(last_LL)
             mixture._check_shapes()
         mixture.q_eta.update(mixture)
         last_LL = check_LL_increased(last_LL, mixture.variational_bound_piecewise(), tag="Update eta", tolerance=increase_tolerance, raise_error=False)
         assert np.isfinite(last_LL).all(), str(last_LL)
         mixture._check_shapes()
         logging.info('Iteration %d: variational bound = %f', _i+1, mixture.variational_bound())
         if _i + 1 >= self.options.min_iter and convergence_test(last_LL.sum()):
             logging.info('Variational bound has converged : stopping.')
             break