Ejemplo n.º 1
0
def infer_message_passing(model: PairWiseFiniteModel,
                          max_iter=None) -> InferenceResult:
    """Inference with Message Passing.

    For acyclic graph returns exact partition function and marginal
        probabilities. For graph with loops may return good approximation to
        the true marginal probabilities, but partition function will be a
        useless number.
    This is an iterative algorithm which terminates when it converged or when
        `max_iter` iterations were made.

    :param model: Pairwise model for which to perform inference.
    :param max_iter: How many iterations without convergence should happen for
        algorithm to terminate. Defaults to maximal diameter of connected
        component.
    :return: InferenceResult object.

    Reference
        [1] Wainwright, Jordan. Graphical Models, Exponential Families, and
        Variational Inference. 2008. Section 2.5.1 (p. 26).
    """
    if max_iter is None:
        graph = networkx.Graph()
        graph.add_edges_from(model.get_edges_connected())
        max_iter = networkx.diameter(graph)

    # Build list of directed edges.
    edges = model.get_edges_connected()
    dir_edges = np.concatenate([edges, np.flip(edges, axis=1)])

    # Sort edges by end vertex. This ensures that edges ending with the same
    # vertex are sequential, which allows for efficient lookup.
    dir_edges.view('i4,i4').sort(order=['f1'], axis=0)

    # Compact representation of interactions.
    intrn = model.get_interactions_for_edges(dir_edges)

    # Main algorithm.
    lmu = _message_passing(dir_edges, intrn, model.field, max_iter)

    # Restore partition function for fixed values in nodes.
    log_marg_pf = np.array(model.field)
    for edge_id in range(len(dir_edges)):
        log_marg_pf[dir_edges[edge_id][1], :] += lmu[edge_id]

    log_pf = scipy.special.logsumexp(log_marg_pf, axis=-1)
    marg_prob = scipy.special.softmax(log_marg_pf, axis=-1)
    marg_prob /= np.sum(marg_prob, axis=-1).reshape(-1, 1)
    return InferenceResult(np.min(log_pf), marg_prob)
Ejemplo n.º 2
0
def max_likelihood_tree_dp(model: PairWiseFiniteModel):
    """Max Likelihood for the pairwise model.

    Performs dynamic programming on tree.

    Applicable only if the interaction graph is a tree or a forest. Otherwise
    throws exception.

    :param model: Model for which to find most likely state.
    :return: Most likely state. np.array of ints.
    """
    assert not model.get_dfs_result().had_cycles, "Graph has cycles."

    field = model.field.astype(np.float64, copy=False)
    dfs_edges = model.get_dfs_result().dfs_edges
    ints = model.get_interactions_for_edges(dfs_edges)

    return _max_likelihood_internal(field, dfs_edges, ints)
Ejemplo n.º 3
0
def infer_tree_dp(model: PairWiseFiniteModel,
                  subtree_mp=False) -> InferenceResult:
    """Inference using DP on tree.

    Performs dynamic programming on tree.

    Applicable only if the interaction graph is a tree or a forest. Otherwise
    throws exception.

    :param model: Model for which to perform inference.
    :param subtree_mp: If true, will return marginal probabilities for
        subtrees, i.e. for each node will return probability of it having
        different values if we leave only it and its subtree.
    :return: InferenceResult object.
    """
    assert not model.get_dfs_result().had_cycles, "Graph has cycles."

    dfs_edges = model.get_dfs_result().dfs_edges
    dfs_j = model.get_interactions_for_edges(dfs_edges)

    lz = model.field.astype(dtype=np.float64, copy=True)  # log(z)
    lzc = np.zeros_like(lz)  # log(zc)
    # Log(z_r). z_r  is partition function for all tree except subtree of given
    # vertex, when value of given vertex is fixed.
    lzr = np.zeros((model.gr_size, model.al_size))

    _dfs1(lz, lzc, dfs_edges, dfs_j)
    log_pf = logsumexp(lz[0, :])

    if subtree_mp:
        return InferenceResult(log_pf, lz)

    _dfs2(lz, lzc, lzr, dfs_edges, dfs_j)

    marg_proba = np.exp(lz + lzr - log_pf)
    return InferenceResult(log_pf, marg_proba)