def project_mean_field( self, model_dist: MeanField, delta: float = 1., status: Optional[Status] = None, ) -> "FactorApprox": success, messages = Status() if status is None else status factor_dist = (model_dist / self.cavity_dist) if delta < 1: log_norm = factor_dist.log_norm factor_dist = (factor_dist**delta * self.factor_dist**(1 - delta)) factor_dist.log_norm = (delta * log_norm + (1 - delta) * self.factor_dist.log_norm) if not factor_dist.is_valid: success = False messages += (f"model projection for {self} is invalid", ) factor_dist = factor_dist.update_invalid(self.factor_dist) new_approx = FactorApproximation( self.factor, self.cavity_dist, factor_dist=factor_dist, model_dist=model_dist, ) return new_approx, Status(success, messages)
def update_factor_mean_field( self, cavity_dist: "MeanField", last_dist: Optional["MeanField"] = None, delta: float = 1.0, status: Status = Status(), ) -> Tuple["MeanField", Status]: success, messages, _, flag = status updated = False try: with LogWarnings(logger=_log_projection_warnings, action='always') as caught_warnings: factor_dist = self / cavity_dist if delta < 1: log_norm = factor_dist.log_norm factor_dist = factor_dist ** delta * last_dist ** (1 - delta) factor_dist.log_norm = ( delta * log_norm + (1 - delta) * last_dist.log_norm ) for m in caught_warnings.messages: messages += (f"project_mean_field warning: {m}",) if not factor_dist.is_valid: success = False messages += (f"model projection for {self} is invalid",) factor_dist = factor_dist.update_invalid(last_dist) # May want to check another way # e.g. factor_dist.check_valid().sum() / factor_dist.check_valid().size valid = factor_dist.check_valid() if valid.any(): updated = True n_valid = valid.sum() n_total = valid.size logger.debug( "meanfield with variables: %r ," "partially updated %d parameters " "out of %d total, %.0%%", tuple(self.variables), n_valid, n_total, n_valid / n_total, ) flag = StatusFlag.BAD_PROJECTION else: updated = True except exc.MessageException as e: logger.exception(e) factor_dist = last_dist return factor_dist, Status( success=success, messages=messages, updated=updated, flag=flag )
def __call__(self, factor: Factor, approx: EPMeanField, status: Status = Status()) -> bool: """ Add history for a given factor and determine whether optimisation should terminate. Parameters ---------- factor A factor in the optimisation approx A mean field produced by optimisation of the factor status A status describing whether the optimisation was successful Returns ------- A boolean indicating whether optimisation should terminate because divergence has dropped below a given tolerance or a callback evaluated to True. """ self[factor](approx, status) if status.success: if any([ callback(factor, approx, status) for callback in self._callbacks ]): return True return self.is_converged(factor) return False
def optimise(self, factor: Factor, model_approx: EPMeanField, status: Optional[Status] = Status(), **kwargs) -> Tuple[EPMeanField, Status]: factor_approx = model_approx.factor_approximation(factor) new_model_dist, status = self.optimise_approx(factor_approx, **kwargs) return self.update_model_approx(new_model_dist, factor_approx, model_approx, status)
def refine( self, factor: Factor, model_approx: EPMeanField, status: Optional[Status] = Status(), n_refine=None, ) -> Tuple[EPMeanField, Status]: factor_approx = model_approx.factor_approximation(factor) new_model_dist = self.refine_approx(factor_approx, n_refine=n_refine) return self.update_model_approx(new_model_dist, factor_approx, model_approx)
def __call__(self, factor: Factor, approx: EPMeanField, status: Status = Status()) -> bool: i = next(self.factor_count[factor]) self.history[i, factor] = approx self.statuses[i, factor] = status stop = any( [callback(factor, approx, status) for callback in self._callbacks]) if stop: return True elif i: last_approx = self.history[i - 1, factor] return self._check_convergence(approx, last_approx) return False
def project_mean_field( self, model_dist: MeanField, delta: float = 1.0, status: Status = Status(), ) -> Tuple["FactorApproximation", Status]: factor_dist, status = model_dist.update_factor_mean_field( self.cavity_dist, last_dist=self.factor_dist, delta=delta, status=status, ) new_approx = FactorApproximation( self.factor, self.cavity_dist, factor_dist=factor_dist, model_dist=model_dist, ) return new_approx, status
def project_mean_field( self, new_dist: MeanField, factor_approx: FactorApproximation, delta: float = 1.0, status: Status = Status(), ) -> Tuple["EPMeanField", Status]: new_factor_dist, status = new_dist.update_factor_mean_field( factor_approx.cavity_dist, factor_approx.factor_dist, delta=delta, status=status, ) factor_mean_field = self.factor_mean_field factor_mean_field[factor_approx.factor] = new_factor_dist new_approx = type(self)( factor_graph=self._factor_graph, factor_mean_field=factor_mean_field, ) return new_approx, status
def project_mean_field( self, new_dist: MeanField, factor_approx: FactorApproximation, delta: float = 1.0, status: Status = Status(), ) -> Tuple["EPMeanField", Status]: factor_mean_field = self.factor_mean_field # We're fitting the full factor_dist, not the subset factor_dist rescale1 = { v: 1 - scale for v, scale in self._factor_rescale[factor_approx.factor].items() } last_factor_dist = factor_mean_field.pop(factor_approx.factor) subset_cavity_dist = last_factor_dist.rescale(rescale1) new_factor_dist, status = new_dist.update_factor_mean_field( factor_approx.cavity_dist, factor_approx.factor_dist, delta=delta, status=status, ) factor_mean_field[ factor_approx.factor] = new_factor_dist * subset_cavity_dist new_approx = type(self)( factor_graph=self._factor_graph, factor_mean_field=factor_mean_field, factor_rescale=self._factor_rescale, factor_subset_factor=self._factor_subset_factor, ep_mean_field=self._ep_mean_field, plates_index=self._plates_index, ) return new_approx, status
def optimise_quasi_newton( state: OptimisationState, old_state: Optional[OptimisationState] = None, *, max_iter=100, search_direction=newton_direction, calc_line_search=line_search, quasi_newton_update=bfgs_update, stop_conditions=stop_conditions, search_direction_kws: Optional[Dict[str, Any]] = None, line_search_kws: Optional[Dict[str, Any]] = None, quasi_newton_kws: Optional[Dict[str, Any]] = None, stop_kws: Optional[Dict[str, Any]] = None, callback: Optional[_OPT_CALLBACK] = None, **kwargs, ) -> Tuple[OptimisationState, Status]: success = True updated = False messages = () message = "max iterations reached" stepsize = 0.0 for i in range(max_iter): stop = check_stop_conditions( stepsize, state, old_state, stop_conditions, **(stop_kws or {}) ) if stop: success, message = stop break with LogWarnings(logger=_log_projection_warnings, action='always') as caught_warnings: stepsize, state1 = take_quasi_newton_step( state, old_state, search_direction=search_direction, calc_line_search=calc_line_search, quasi_newton_update=quasi_newton_update, search_direction_kws=search_direction_kws, line_search_kws=line_search_kws, quasi_newton_kws=quasi_newton_kws, ) for m in caught_warnings.messages: messages += (f"optimise_quasi_newton warning: {m}",) if stepsize is None: success = False message = "Line search failed" break updated = True state, old_state = state1, state i += 1 if callback: callback(state, old_state) message += f", iter={i}" messages += (message,) status = Status( success, messages=messages, updated=updated, flag=StatusFlag.get_flag(success, i), ) return state, status
def project( self, factor_approx: FactorApproximation, status: Status = Status() ) -> Tuple[FactorApproximation, Status]: pass
def project_on_to_factor_approx( factor_approx: "FactorApproximation", model_dist: Dict[str, AbstractMessage], delta: float = 1., status: Optional[Status] = None ) -> Tuple["FactorApproximation", Status]: """ For a passed FactorApproximation this calculates the factor messages such that model_dist = factor_dist * cavity_dist """ success, messages = Status() if status is None else status assert 0 < delta <= 1 factor_projection = {} # log_norm = 0. for v, q_fit in model_dist.items(): q_cavity = factor_approx.cavity_dist.get(v) if isinstance(q_fit, FixedMessage): factor_projection[v] = q_fit elif q_fit.is_valid: if q_cavity: q_f0 = factor_approx.factor_dist[v] q_f1 = (q_fit / q_cavity) else: # In the case that q_cavity does not exist the model fit # equals the factor approximation q_f1 = q_fit # weighted update if delta != 1: q_f1 = (q_f1**delta).sum_natural_parameters(q_f0**(1 - delta)) if not q_f1.is_valid: # partial updating of values q_f1 = q_f1.update_invalid(q_f0) messages += ( f"factor projection for {v} with {factor_approx.factor} contained " "invalid values", ) if not q_f1.is_valid: success = False messages += ( f"factor projection for {v} with {factor_approx.factor} is invalid", ) factor_projection[v] = q_f1 else: success = False messages += ( f"model projection for {v} with {factor_approx.factor} is invalid", ) factor_projection[v] = factor_approx.factor_dist[v] q_model = (q_fit**delta).sum_natural_parameters( factor_approx.model_dist[v]**(1 - delta)) if q_model.is_valid: model_dist[v] = q_model projection = FactorApproximation( factor_approx.factor, factor_approx.cavity_dist, factor_dist=MeanField(factor_projection), model_dist=MeanField(model_dist), # log_norm=log_norm ) status = Status(success, messages) return projection, status