예제 #1
0
def make_factor_graph(cpts, evidence):
    G = nx.Graph()

    names2factors = dict((tuple(cpt.names), cpt) for cpt in cpts)
    G.add_nodes_from(names2factors.keys())
    for (name, factor) in names2factors.items():
        for axnames in factor.names:
            G.add_edge(name, axnames)

    names2factors.update(
        dict((name,
              DataArray(
                  np.ones(size) if name not in
                  evidence else one_hot(size, evidence[name]), [name]))
             for cpt in cpts for (name, size) in zip(cpt.names, cpt.shape)))

    return G, names2factors
예제 #2
0
def multiply_potentials(*DAs):
    """
    Multiply DataArrays in the way that we multiply functions, 
    e.g. h(i,j,k,l) = f(i,j,k) g(k,l)
    
    Parameters
    -------------
    DA1,DA2,... : DataArrays with variable names as axis labels
    
    Returns
    ---------
    product
    
    example
    ---------
    >>> f_of_a = DataArray([1, 2],"a")
    >>> g_of_b = DataArray([1,-1],"b")
    >>> multiply_potentials(f_of_a, g_of_b)
    DataArray([[ 1, -1],
           [ 2, -2]])
    ('a', 'b')
    >>> multiply_potentials(f_of_a, f_of_a)
    DataArray([1, 4])
    ('a',)

    
    """
    if len(DAs) == 0: return 1

    full_names, full_shape = [], []
    for axis, size in zip(_sum(list(DA.axes) for DA in DAs),
                          _sum(DA.shape for DA in DAs)):
        if axis.name not in full_names:
            full_names.append(axis.name)
            full_shape.append(size)

    return DataArray(_prod(
        match_shape(DA.copy(), full_shape,
                    [full_names.index(axis.name) for axis in DA.axes])
        for DA in DAs),
                     axes=full_names)
예제 #3
0
def test_pearl_network():
    """ From Russell and Norvig, "Artificial Intelligence, A Modern Approach,"
    Section 15.1 originally from Pearl.

    "Consider the following situation. You have a new burglar alarm installed
    at home. It is fairly reliable at detecting a burglary, but also responds
    on occasion to minor earthquakes. You also have two neighbors, John and
    Mary, who have promised to call you at work when they hear the alarm. John
    always calls when he hears the alarm, but sometimes confuses the telephone
    ringing with the alarm and calls then, too. Mary on the other hand, likes
    rather loud music and sometimes misses the alarm altogether. Given the
    evidence of who has or has not called, we would like to estimate the
    probability of a burglary.

                    Burglary         Earthquake

                           \         /
                           _\|     |/_

                              Alarm

                            /     \  
                          |/_     _\|

                    Johncalls        Marycalls

    This test function uses four different algorithms to calculate 

        P(burglary | johncalls = 1, marycalls = 1) 

    In increasing order of sophistication: 
        1. Simple (calculate joint distribution and marginalize) 
        2. Elimination (strategically marginalize over one variable at a time) 
        3. Sum-product algorithm on factor graph 
        4. Junction tree algorithm
    """
    burglary = DataArray([.999, .001], axes=["burglary"])
    earthquake = DataArray([.998, .002], axes=["earthquake"])
    alarm = DataArray([[[.05, .95], [.06, .94]], [[.71, .29], [.999, .001]]],
                      ["burglary", "earthquake", "alarm"])

    johncalls = DataArray([[.10, .90], [.95, .05]], ["alarm", "johncalls"])
    marycalls = DataArray([[.30, .70], [.01, .99]], ["alarm", "marycalls"])

    cpts = [burglary, earthquake, alarm, johncalls, marycalls]

    evidence = {"johncalls": 0, "marycalls": 0}

    margs1, lik1 = calc_marginals_simple(cpts, evidence)
    p_burglary, lik2 = digraph_eliminate(cpts, evidence, ["burglary"])
    margs3, lik3 = calc_marginals_sumproduct(cpts, evidence, 'burglary')

    # TODO: This version is disabled until I can dig up the reference to figure
    # out how it works. -jt
    # margs4,lik4 = calc_marginals_jtree(cpts,evidence)

    # Check that all four calculations give the same p(burglary) and
    # likelihood, up to numerical error
    for (marg,lik) in \
            [(p_burglary, lik2), (margs3["burglary"], lik3)]: # , (margs4["burglary"],lik4)]:
        assert_almost_equal(marg, margs1["burglary"])
        assert_almost_equal(lik, lik1)

    print("p(burglary) = %s" % margs1["burglary"].__array__())
    print("likelihood of observations = %.3f" % lik1)