def lstsq_laplace_factor_approx( model_approx: EPMeanField, factor: Factor, delta: float = 0.5, opt_kws: Optional[Dict[str, Any]] = None): """ """ factor_approx = model_approx.factor_approximation(factor) opt = LeastSquaresOpt( factor_approx, **({} if opt_kws is None else opt_kws)) mode, covar, result = opt.least_squares() message = ( "optimise.lsq_sq_laplace_factor_approx: " f"nfev={result.nfev}, njev={result.njev}, " f"optimality={result.optimality}, " f"cost={result.cost}, " f"status={result.status}, message={result.message}",) status = Status(result.success, message) model_dist = MeanField({ v: factor_approx.factor_dist[v].from_mode( mode[v], covar.get(v)) for v in mode }) projection, status = factor_approx.project( model_dist, delta=delta, status=status) return model_approx.project(projection, status=status)
def factor_approximation(self, factor: Factor) -> FactorApproximation: """ Create an approximation for one factor. This comprises: - The factor - The factor's variable distributions - The cavity distribution, which is the product of the distributions for each variable for all other factors - The model distribution, which is the product of the distributions for each variable for all factors Parameters ---------- factor Some factor Returns ------- An object comprising distributions with a specific distribution excluding that factor """ factor_mean_field = self._factor_mean_field.copy() factor_dist = factor_mean_field.pop(factor) cavity_dist = MeanField({v: 1.0 for v in factor_dist.all_variables }).prod(*factor_mean_field.values()) model_dist = factor_dist.prod(cavity_dist) return FactorApproximation(factor, cavity_dist, factor_dist, model_dist)
def project_factor_approx_sample( factor_approx: FactorApproximation, sample: SamplingResult) -> Dict[str, AbstractMessage]: # Calculate log_norm log_weights = sample.log_weights # Need to collapse the weights to match the shapes of the different # variables variable_log_weights = { v: factor_approx.factor.collapse(v, log_weights, agg_func=np.sum) for v in factor_approx.cavity_dist } log_weights = log_weights.sum(tuple(range(1, log_weights.ndim))) # subtract max log_weight for numerical stability log_w_max = np.max(log_weights) w = np.exp(log_weights - log_w_max) log_norm = np.log(w.mean(0)) + log_w_max model_dist = MeanField( { v: factor_approx.factor_dist[v].project( x, variable_log_weights.get(v)) for v, x in chain(sample.samples.items(), sample.det_variables.items()) }, log_norm=log_norm) return model_dist
def calc_exact_update(self, mean_field) -> "MeanField": if self._calc_exact_update: from autofit.graphical.mean_field import MeanField projection = self._calc_exact_update( *self.resolve_args_and_out(mean_field)) return MeanField( nested_filter(is_variable, self.args + (self.factor_out, ), projection)) else: raise NotImplementedError
def factor_approximation(self, factor: Factor) -> FactorApproximation: factor_mean_field = self._factor_mean_field.copy() factor_dist = factor_mean_field.pop(factor) cavity_dist = MeanField.prod( {v: 1. for v in factor_dist.all_variables}, *(dist for fac, dist in factor_mean_field.items())) # cavity_dist.log_norm = 0. model_dist = factor_dist.prod(cavity_dist) return FactorApproximation(factor, cavity_dist, factor_dist, model_dist)
def refine_approx( self, factor_approx: FactorApprox, mean_field: MeanField = None, params: VariableData = None, n_refine=None, ) -> Tuple[MeanField, Status]: mean_field = mean_field or factor_approx.model_dist state = self.prepare_state(factor_approx, mean_field, params) next_state = self.refine_state(state, mean_field.sample, n_refine=n_refine) return mean_field.from_opt_state(next_state)
def from_approx_dists( cls, factor_graph: FactorGraph, approx_dists: Dict[Variable, AbstractMessage], ) -> "EPMeanField": factor_mean_field = { factor: MeanField( {v: approx_dists[v].copy() for v in factor.all_variables}) for factor in factor_graph.factors } return cls(factor_graph, factor_mean_field)
def optimise_approx(self, factor_approx: FactorApprox, mean_field: MeanField = None, params: VariableData = None, **kwargs) -> Tuple[MeanField, Status]: mean_field = mean_field or factor_approx.model_dist state = self.prepare_state(factor_approx, mean_field, params) next_state, status = self.optimise_state(state, **kwargs) # if status.flag != StatusFlag.SUCCESS: next_state = max(state, next_state, key=lambda x: x.value) next_state = self.refine_state(next_state, mean_field.sample, n_refine=kwargs.get("n_refine")) projection = mean_field.from_opt_state(next_state) return projection, status
def project_mean_field( self, new_dist: MeanField, factor_approx: FactorApproximation, delta: float = 1.0, status: Status = Status(), ) -> Tuple["EPMeanField", Status]: new_factor_dist, status = new_dist.update_factor_mean_field( factor_approx.cavity_dist, factor_approx.factor_dist, delta=delta, status=status, ) factor_mean_field = self.factor_mean_field factor_mean_field[factor_approx.factor] = new_factor_dist new_approx = type(self)( factor_graph=self._factor_graph, factor_mean_field=factor_mean_field, ) return new_approx, status
def project_mean_field( self, new_dist: MeanField, factor_approx: FactorApproximation, delta: float = 1.0, status: Status = Status(), ) -> Tuple["EPMeanField", Status]: factor_mean_field = self.factor_mean_field # We're fitting the full factor_dist, not the subset factor_dist rescale1 = { v: 1 - scale for v, scale in self._factor_rescale[factor_approx.factor].items() } last_factor_dist = factor_mean_field.pop(factor_approx.factor) subset_cavity_dist = last_factor_dist.rescale(rescale1) new_factor_dist, status = new_dist.update_factor_mean_field( factor_approx.cavity_dist, factor_approx.factor_dist, delta=delta, status=status, ) factor_mean_field[ factor_approx.factor] = new_factor_dist * subset_cavity_dist new_approx = type(self)( factor_graph=self._factor_graph, factor_mean_field=factor_mean_field, factor_rescale=self._factor_rescale, factor_subset_factor=self._factor_subset_factor, ep_mean_field=self._ep_mean_field, plates_index=self._plates_index, ) return new_approx, status
def make_posdef_hessian(mean_field, variables): return MeanField.precision(mean_field, variables)
def mean_field(self) -> MeanField: return MeanField({v: 1.0 for v in self.all_variables }).prod(*self._factor_mean_field.values())