Beispiel #1
0
def _add_sufficient_recall(cid: CID, dec1: str, dec2: str,
                           utility_node: str) -> None:
    """Add edges to a cid until `dec2` has sufficient recall of `dec1` (to optimize utility)

    this is done by adding edges from non-collider nodes until recall is adequate
    """

    if dec2 in cid._get_ancestors_of(dec1):
        raise ValueError('{} is an ancestor of {}'.format(dec2, dec1))

    cid2 = cid.copy()
    cid2.add_edge('pi', dec1)

    while cid2.is_active_trail('pi',
                               utility_node,
                               observed=cid.get_parents(dec2) + [dec2]):
        path = find_active_path(cid2, 'pi', utility_node,
                                cid.get_parents(dec2) + [dec2])
        if path is None:
            raise Exception(
                "couldn't find path even though there should be an active trail"
            )
        while True:
            i = random.randrange(1, len(path) - 1)
            # print('consider {}--{}--{}'.format(path[i-1], path[i], path[i+1]),end='')
            collider = ((path[i - 1], path[i]) in cid2.edges) and (
                (path[i + 1], path[i]) in cid2.edges)
            if not collider:
                if dec2 not in cid2._get_ancestors_of(path[i]):
                    # print('add {}->{}'.format(path[i], dec2), end=' ')
                    cid.add_edge(path[i], dec2)
                    cid2.add_edge(path[i], dec2)
                    break
Beispiel #2
0
def add_sufficient_recalls(cid: CID) -> None:
    """add edges to a cid until all decisions have sufficient recall of all prior decisions"""
    for utility_node in cid.all_utility_nodes:
        # decisions = cid._get_valid_order(cid.decision_nodes)  # cannot be trusted...
        for i, dec1 in enumerate(cid.all_decision_nodes):
            for dec2 in cid.all_decision_nodes[i + 1:]:
                if dec1 in cid._get_ancestors_of(dec2):
                    _add_sufficient_recall(cid, dec1, dec2, utility_node)
                else:
                    _add_sufficient_recall(cid, dec2, dec1, utility_node)
Beispiel #3
0
def random_cid(n_all: int,
               n_decisions: int,
               n_utilities: int,
               edge_density: float = 0.4,
               add_sr_edges: bool = True,
               add_cpds: bool = True,
               seed: int = None) -> CID:
    """Generates a random Cid with the specified number of nodes and edges"""

    all_names, decision_names, utility_names = get_node_names(
        n_all, n_decisions, n_utilities)
    edges = get_edges(all_names,
                      utility_names,
                      edge_density,
                      seed=seed,
                      allow_u_edges=False)
    cid = CID(edges, decision_names, utility_names)

    for uname in utility_names:
        for edge in edges:
            assert uname != edge[0]

    for i, d1 in enumerate(decision_names):
        for j, d2 in enumerate(decision_names[i + 1:]):
            assert d2 not in cid._get_ancestors_of(d1)

    if add_sr_edges:
        add_sufficient_recalls(cid)

    if add_cpds:
        for node in cid.nodes:
            if node in cid.all_decision_nodes:
                cid.add_cpds(DecisionDomain(node, [0, 1]))
            elif not cid.get_parents(node):  # node is a root node
                cid.add_cpds(UniformRandomCPD(node, [0, 1]))
            else:
                cid.add_cpds(
                    RandomlySampledFunctionCPD(node, cid.get_parents(node)))
    return cid