def pull_msg_from(self, node_from, node_to):
     """ pull a message at an edge (node_from, node_to) to the cluster node_to"""
     if not self.message_graph:
         raise MessageGraphError
     pulled_msg = factorSet()
     pulled_msg.add(self.message_graph.edge[node_from][node_to]['msg'])
     return pulled_msg if pulled_msg else None
    def add_factors(self, factor_list):
        """update factor_set, var_set, scopes from list of factors/valuations"""
        import copy
        factor_list = copy.deepcopy(factor_list)
        self.factors.update(
            factor_list
        )  # scope_list follows factor_list order // factors follows the same order

        for f in factor_list:
            for v in reversed(f.vars):
                if (v.label >= len(self.variables)):
                    self.variables.extend([
                        gm.Var(i, 0)
                        for i in range(len(self.variables), v.label + 1)
                    ])
                    self.factors_by_var.extend([
                        gm.factorSet()
                        for i in range(len(self.factors_by_var), v.label + 1)
                    ])
                if self.variables[v].states == 0: self.variables[v] = v
                assert self.variables[
                    v].states == v.states, '# states for a variable should match'
                self.factors_by_var[v].add(
                    f
                )  # add f to the v-th factor set in factors_by_var list; shared factor obj
            self.scope_list.append(f.vars)
    def combine_factors(self, node, factor_set, is_log=True):
        """ combine factors in current node and factor from factor_set """
        if factor_set:
            all_factors = factorSet(self.message_graph.node[node]['fs']
                                    ) | factor_set  # both are factorSet()
            # all_factors = self.message_graph.node[node]['fs'] | factor_set # both are factorSet()
        else:
            all_factors = factorSet(self.message_graph.node[node]['fs'])
            # all_factors = self.message_graph.node[node]['fs']
        if len(all_factors) == 0:
            return None

        combined_factor = all_factors[0].copy()  # must copy
        if is_log:
            for f in all_factors[1:]:
                combined_factor = combined_factor + f
        else:
            for f in all_factors[1:]:
                combined_factor = combined_factor * f
        return combined_factor
 def pull_msg(self, node, next_node):
     """ collect msg from edges incident to the current node except next_node
         this method can be used in CTE/IJGP algorithms, which use directed message passing graph"""
     if not self.message_graph:
         raise MessageGraphError
     pulled_msg = factorSet()
     for u, v, attr in self.message_graph.in_edges(
         [node],
             data=True) if self.is_directed else self.message_graph.edges(
                 [node], data=True):
         if u == next_node:
             continue
         if self.message_graph.edge[u][v]['msg']:
             pulled_msg.add(attr['msg'])
     return pulled_msg if pulled_msg else None
 def __init__(self,
              factor_list,
              weight_list,
              is_log=False,
              elim_order=None):
     self.variables = []
     self.factors = gm.factorSet()
     self.factors_by_var = []  # useful when building graph data structures
     self.scope_list = []
     self.weights = []
     self.add_factors(factor_list)
     self.set_weights(weight_list)
     self.elim_order = elim_order
     self.is_log = is_log
     if self.is_log:
         self.to_log()
def add_mg_attr_to_edges(MG, variables, is_log, is_valuation=False):
    """ adding additional attributes to the edges in a message graph
        msg, msg_shape, msg_len is for recording the shape and dimension of a single factor in a message
        to be used with scipy
    """
    for n1, n2 in MG.region_graph.edges_iter():
        sc = MG.region_graph.edge[n1][n2]['sc']
        f_temp = factorSet()
        for v in sc:
            if is_valuation:
                f_temp.add(get_const_valuation_with_v(v, variables, is_log))
            else:
                f_temp.add(get_const_factor_with_v(v, variables, is_log))
        combined_factor = f_temp[0]
        if is_log:
            for f in f_temp[1:]:
                combined_factor = combined_factor + f
        else:
            for f in f_temp[1:]:
                combined_factor = combined_factor * f
        MG.region_graph.edge[n1][n2]['msg'] = combined_factor
        MG.region_graph.edge[n1][n2]['msg_shape'] = combined_factor.vars.dims()
        MG.region_graph.edge[n1][n2]['msg_len'] = reduce(
            lambda x, y: x * y, combined_factor.vars.dims())
 def factors_with_any(self, vars):
     factors = gm.factorSet()
     for v in vars:
         factors.update(self.factors_by_var[v])
     return factors
 def set_factor_rg(self, node, f):
     self.region_graph.node[node]['fs'] = factorSet({f})
 def set_factor(self, node, f):
     """ set a factor/valuation f as a factor set at the node """
     self.message_graph.node[node]['fs'] = factorSet({f})