Esempio n. 1
0
    def project_mean_field(
        self,
        model_dist: MeanField,
        delta: float = 1.,
        status: Optional[Status] = None,
    ) -> "FactorApprox":
        success, messages = Status() if status is None else status

        factor_dist = (model_dist / self.cavity_dist)
        if delta < 1:
            log_norm = factor_dist.log_norm
            factor_dist = (factor_dist**delta * self.factor_dist**(1 - delta))
            factor_dist.log_norm = (delta * log_norm +
                                    (1 - delta) * self.factor_dist.log_norm)

        if not factor_dist.is_valid:
            success = False
            messages += (f"model projection for {self} is invalid", )
            factor_dist = factor_dist.update_invalid(self.factor_dist)

        new_approx = FactorApproximation(
            self.factor,
            self.cavity_dist,
            factor_dist=factor_dist,
            model_dist=model_dist,
        )
        return new_approx, Status(success, messages)
Esempio n. 2
0
    def update_factor_mean_field(
        self,
        cavity_dist: "MeanField",
        last_dist: Optional["MeanField"] = None,
        delta: float = 1.0,
        status: Status = Status(),
    ) -> Tuple["MeanField", Status]:

        success, messages, _, flag = status
        updated = False
        try:
            with LogWarnings(logger=_log_projection_warnings, action='always') as caught_warnings:
                factor_dist = self / cavity_dist
                if delta < 1:
                    log_norm = factor_dist.log_norm
                    factor_dist = factor_dist ** delta * last_dist ** (1 - delta)
                    factor_dist.log_norm = (
                        delta * log_norm + (1 - delta) * last_dist.log_norm
                    )

            for m in caught_warnings.messages:
                messages += (f"project_mean_field warning: {m}",)

            if not factor_dist.is_valid:
                success = False
                messages += (f"model projection for {self} is invalid",)
                factor_dist = factor_dist.update_invalid(last_dist)
                # May want to check another way
                # e.g. factor_dist.check_valid().sum() / factor_dist.check_valid().size
                valid = factor_dist.check_valid()
                if valid.any():
                    updated = True
                    n_valid = valid.sum()
                    n_total = valid.size
                    logger.debug(
                        "meanfield with variables: %r ,"
                        "partially updated %d parameters "
                        "out of %d total, %.0%%",
                        tuple(self.variables),
                        n_valid,
                        n_total,
                        n_valid / n_total,
                    )

                flag = StatusFlag.BAD_PROJECTION
            else:
                updated = True

        except exc.MessageException as e:
            logger.exception(e)
            factor_dist = last_dist

        return factor_dist, Status(
            success=success, messages=messages, updated=updated, flag=flag
        )
Esempio n. 3
0
    def __call__(self,
                 factor: Factor,
                 approx: EPMeanField,
                 status: Status = Status()) -> bool:
        """
        Add history for a given factor and determine whether optimisation
        should terminate.

        Parameters
        ----------
        factor
            A factor in the optimisation
        approx
            A mean field produced by optimisation of the factor
        status
            A status describing whether the optimisation was successful

        Returns
        -------
        A boolean indicating whether optimisation should terminate because
        divergence has dropped below a given tolerance or a callback evaluated
        to True.
        """
        self[factor](approx, status)
        if status.success:
            if any([
                    callback(factor, approx, status)
                    for callback in self._callbacks
            ]):
                return True

            return self.is_converged(factor)

        return False
Esempio n. 4
0
    def optimise(self,
                 factor: Factor,
                 model_approx: EPMeanField,
                 status: Optional[Status] = Status(),
                 **kwargs) -> Tuple[EPMeanField, Status]:

        factor_approx = model_approx.factor_approximation(factor)
        new_model_dist, status = self.optimise_approx(factor_approx, **kwargs)
        return self.update_model_approx(new_model_dist, factor_approx,
                                        model_approx, status)
Esempio n. 5
0
 def refine(
     self,
     factor: Factor,
     model_approx: EPMeanField,
     status: Optional[Status] = Status(),
     n_refine=None,
 ) -> Tuple[EPMeanField, Status]:
     factor_approx = model_approx.factor_approximation(factor)
     new_model_dist = self.refine_approx(factor_approx, n_refine=n_refine)
     return self.update_model_approx(new_model_dist, factor_approx,
                                     model_approx)
Esempio n. 6
0
    def __call__(self,
                 factor: Factor,
                 approx: EPMeanField,
                 status: Status = Status()) -> bool:
        i = next(self.factor_count[factor])
        self.history[i, factor] = approx
        self.statuses[i, factor] = status

        stop = any(
            [callback(factor, approx, status) for callback in self._callbacks])
        if stop:
            return True
        elif i:
            last_approx = self.history[i - 1, factor]
            return self._check_convergence(approx, last_approx)

        return False
Esempio n. 7
0
    def project_mean_field(
        self,
        model_dist: MeanField,
        delta: float = 1.0,
        status: Status = Status(),
    ) -> Tuple["FactorApproximation", Status]:

        factor_dist, status = model_dist.update_factor_mean_field(
            self.cavity_dist,
            last_dist=self.factor_dist,
            delta=delta,
            status=status,
        )
        new_approx = FactorApproximation(
            self.factor,
            self.cavity_dist,
            factor_dist=factor_dist,
            model_dist=model_dist,
        )
        return new_approx, status
Esempio n. 8
0
    def project_mean_field(
            self,
            new_dist: MeanField,
            factor_approx: FactorApproximation,
            delta: float = 1.0,
            status: Status = Status(),
    ) -> Tuple["EPMeanField", Status]:

        new_factor_dist, status = new_dist.update_factor_mean_field(
            factor_approx.cavity_dist,
            factor_approx.factor_dist,
            delta=delta,
            status=status,
        )
        factor_mean_field = self.factor_mean_field
        factor_mean_field[factor_approx.factor] = new_factor_dist

        new_approx = type(self)(
            factor_graph=self._factor_graph,
            factor_mean_field=factor_mean_field,
        )
        return new_approx, status
Esempio n. 9
0
    def project_mean_field(
            self,
            new_dist: MeanField,
            factor_approx: FactorApproximation,
            delta: float = 1.0,
            status: Status = Status(),
    ) -> Tuple["EPMeanField", Status]:

        factor_mean_field = self.factor_mean_field

        # We're fitting the full factor_dist, not the subset factor_dist
        rescale1 = {
            v: 1 - scale
            for v, scale in self._factor_rescale[factor_approx.factor].items()
        }
        last_factor_dist = factor_mean_field.pop(factor_approx.factor)
        subset_cavity_dist = last_factor_dist.rescale(rescale1)

        new_factor_dist, status = new_dist.update_factor_mean_field(
            factor_approx.cavity_dist,
            factor_approx.factor_dist,
            delta=delta,
            status=status,
        )
        factor_mean_field[
            factor_approx.factor] = new_factor_dist * subset_cavity_dist

        new_approx = type(self)(
            factor_graph=self._factor_graph,
            factor_mean_field=factor_mean_field,
            factor_rescale=self._factor_rescale,
            factor_subset_factor=self._factor_subset_factor,
            ep_mean_field=self._ep_mean_field,
            plates_index=self._plates_index,
        )
        return new_approx, status
Esempio n. 10
0
def optimise_quasi_newton(
        state: OptimisationState,
        old_state: Optional[OptimisationState] = None,
        *,
        max_iter=100,
        search_direction=newton_direction,
        calc_line_search=line_search,
        quasi_newton_update=bfgs_update,
        stop_conditions=stop_conditions,
        search_direction_kws: Optional[Dict[str, Any]] = None,
        line_search_kws: Optional[Dict[str, Any]] = None,
        quasi_newton_kws: Optional[Dict[str, Any]] = None,
        stop_kws: Optional[Dict[str, Any]] = None,
        callback: Optional[_OPT_CALLBACK] = None,
        **kwargs,
) -> Tuple[OptimisationState, Status]:
    success = True
    updated = False
    messages = ()
    message = "max iterations reached"
    stepsize = 0.0
    for i in range(max_iter):
        stop = check_stop_conditions(
            stepsize, state, old_state, stop_conditions, **(stop_kws or {})
        )
        if stop:
            success, message = stop
            break

        with LogWarnings(logger=_log_projection_warnings, action='always') as caught_warnings:
            stepsize, state1 = take_quasi_newton_step(
                state,
                old_state,
                search_direction=search_direction,
                calc_line_search=calc_line_search,
                quasi_newton_update=quasi_newton_update,
                search_direction_kws=search_direction_kws,
                line_search_kws=line_search_kws,
                quasi_newton_kws=quasi_newton_kws,
            )
        for m in caught_warnings.messages:
            messages += (f"optimise_quasi_newton warning: {m}",)

        if stepsize is None:
            success = False
            message = "Line search failed"
            break

        updated = True
        state, old_state = state1, state
        i += 1
        if callback:
            callback(state, old_state)

    message += f", iter={i}"
    messages += (message,)
    status = Status(
        success,
        messages=messages,
        updated=updated,
        flag=StatusFlag.get_flag(success, i),
    )
    return state, status
Esempio n. 11
0
 def project(
     self, factor_approx: FactorApproximation, status: Status = Status()
 ) -> Tuple[FactorApproximation, Status]:
     pass
Esempio n. 12
0
def project_on_to_factor_approx(
        factor_approx: "FactorApproximation",
        model_dist: Dict[str, AbstractMessage],
        delta: float = 1.,
        status: Optional[Status] = None
) -> Tuple["FactorApproximation", Status]:
    """
    For a passed FactorApproximation this calculates the 
    factor messages such that 
    
    model_dist = factor_dist * cavity_dist
    """
    success, messages = Status() if status is None else status
    assert 0 < delta <= 1

    factor_projection = {}
    #     log_norm = 0.
    for v, q_fit in model_dist.items():
        q_cavity = factor_approx.cavity_dist.get(v)
        if isinstance(q_fit, FixedMessage):
            factor_projection[v] = q_fit
        elif q_fit.is_valid:
            if q_cavity:
                q_f0 = factor_approx.factor_dist[v]
                q_f1 = (q_fit / q_cavity)
            else:
                # In the case that q_cavity does not exist the model fit
                # equals the factor approximation
                q_f1 = q_fit

            # weighted update
            if delta != 1:
                q_f1 = (q_f1**delta).sum_natural_parameters(q_f0**(1 - delta))

            if not q_f1.is_valid:
                # partial updating of values
                q_f1 = q_f1.update_invalid(q_f0)
                messages += (
                    f"factor projection for {v} with {factor_approx.factor} contained "
                    "invalid values", )

            if not q_f1.is_valid:
                success = False
                messages += (
                    f"factor projection for {v} with {factor_approx.factor} is invalid",
                )

            factor_projection[v] = q_f1
        else:
            success = False
            messages += (
                f"model projection for {v} with {factor_approx.factor} is invalid",
            )

            factor_projection[v] = factor_approx.factor_dist[v]
            q_model = (q_fit**delta).sum_natural_parameters(
                factor_approx.model_dist[v]**(1 - delta))
            if q_model.is_valid:
                model_dist[v] = q_model

    projection = FactorApproximation(
        factor_approx.factor,
        factor_approx.cavity_dist,
        factor_dist=MeanField(factor_projection),
        model_dist=MeanField(model_dist),
        #         log_norm=log_norm
    )
    status = Status(success, messages)

    return projection, status