Exemplo n.º 1
0
 def get_inf_jtree(self):
     sn = self.get_factor_graph()
     prop = dai.PropertySet()
     prop["inference"] = "SUMPROD"
     prop["updates"] = "HUGIN"
     prop["verbose"] = "1"
     inf = dai.JTree(sn, prop)
     return inf
Exemplo n.º 2
0
def junction_tree(sg_model, verbose=False):
    '''
    Calculate the exact partition function of a spin glass model using the junction tree algorithm
    Inputs:
    - sg_model (SpinGlassModel)

    Outputs:
    - ln_Z: natural logarithm of the exact partition function
    '''
    sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel(
        sg_model, fixed_variables={})
    # sg_FactorGraph = build_graph_from_clique_ising_model(sg_model, fixed_variables={})

    # Write factorgraph to a file
    sg_FactorGraph.WriteToFile('sg_temp.fg')
    if verbose:
        print('spin glass factor graph written to sg_temp.fg')

    # Output some information about the factorgraph
    if verbose:
        print(sg_FactorGraph.nrVars(), 'variables')
        print(sg_FactorGraph.nrFactors(), 'factors')

    # Set some constants
    maxiter = 10000
    tol = 1e-9
    verb = 0
    # Store the constants in a PropertySet object
    opts = dai.PropertySet()
    opts["maxiter"] = str(maxiter)  # Maximum number of iterations
    opts["tol"] = str(tol)  # Tolerance for convergence
    opts["verbose"] = str(
        verb
    )  # Verbosity (amount of output generated)bpopts["updates"] = "SEQRND"

    ##################### Run Junction Tree Algorithm #####################
    # Construct a JTree (junction tree) object from the FactorGraph sg_FactorGraph
    # using the parameters specified by opts and an additional property
    # that specifies the type of updates the JTree algorithm should perform
    jtopts = opts
    jtopts["updates"] = "HUGIN"
    jt = dai.JTree(sg_FactorGraph, jtopts)
    # Initialize junction tree algorithm
    jt.init()
    # Run junction tree algorithm
    jt.run()

    # Construct another JTree (junction tree) object that is used to calculate
    # the joint configuration of variables that has maximum probability (MAP state)
    jtmapopts = opts
    jtmapopts["updates"] = "HUGIN"
    jtmapopts["inference"] = "MAXPROD"
    jtmap = dai.JTree(sg_FactorGraph, jtmapopts)
    # Initialize junction tree algorithm
    jtmap.init()
    # Run junction tree algorithm
    jtmap.run()
    # Calculate joint state of all variables that has maximum probability
    jtmapstate = jtmap.findMaximum()
    #doesn't work
    #     print(type(bp.belief(sg_FactorGraph.var(0))))
    #     print(jtmap.belief(sg_FactorGraph.var(0))[0])
    #     sleep(aslfkdj)

    ln_Z = jt.logZ()
    # Report log partition sum (normalizing constant) of sg_FactorGraph, calculated by the junction tree algorithm
    if verbose:
        print()
        print('-' * 80)
        print('Exact log partition sum:', ln_Z)
    return (ln_Z)
Exemplo n.º 3
0
def run_inference(sg_model):
    sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel(
        sg_model, fixed_variables={})
    # sg_FactorGraph = build_graph_from_clique_ising_model(sg_model, fixed_variables={})

    # Write factorgraph to a file
    sg_FactorGraph.WriteToFile('sg_temp.fg')
    print('spin glass factor graph written to sg_temp.fg')

    # Output some information about the factorgraph
    print(sg_FactorGraph.nrVars(), 'variables')
    print(sg_FactorGraph.nrFactors(), 'factors')

    # Set some constants
    maxiter = 10000
    tol = 1e-9
    verb = 1
    # Store the constants in a PropertySet object
    opts = dai.PropertySet()
    opts["maxiter"] = str(maxiter)  # Maximum number of iterations
    opts["tol"] = str(tol)  # Tolerance for convergence
    opts["verbose"] = str(
        verb
    )  # Verbosity (amount of output generated)bpopts["updates"] = "SEQRND"

    ##################### Run Loopy Belief Propagation #####################
    print()
    print('-' * 80)
    # Construct a BP (belief propagation) object from the FactorGraph sg_FactorGraph
    # using the parameters specified by opts and two additional properties,
    # specifying the type of updates the BP algorithm should perform and
    # whether they should be done in the real or in the logdomain
    bpopts = opts
    bpopts["updates"] = "SEQRND"
    bpopts["logdomain"] = "1"

    bp = dai.BP(sg_FactorGraph, bpopts)
    # Initialize belief propagation algorithm
    bp.init()
    # Run belief propagation algorithm
    bp.run()

    # Report log partition sum of sg_FactorGraph, approximated by the belief propagation algorithm
    print('Approximate (loopy belief propagation) log partition sum:',
          bp.logZ())

    ##################### Run Tree Re-weighted Belief Propagation #####################
    print()
    print('-' * 80)
    # Construct a BP (belief propagation) object from the FactorGraph sg_FactorGraph
    # using the parameters specified by opts and two additional properties,
    # specifying the type of updates the BP algorithm should perform and
    # whether they should be done in the real or in the logdomain
    trwbp_opts = opts
    trwbp_opts["updates"] = "SEQRND"
    trwbp_opts["nrtrees"] = "10"
    trwbp_opts["logdomain"] = "1"

    trwbp = dai.TRWBP(sg_FactorGraph, trwbp_opts)
    # trwbp = dai.FBP( sg_FactorGraph, trwbp_opts )

    # Initialize belief propagation algorithm
    trwbp.init()
    # Run belief propagation algorithm
    t0 = time.time()
    trwbp.run()
    t1 = time.time()

    # Report log partition sum of sg_FactorGraph, approximated by the belief propagation algorithm
    print(
        'Approximate (tree re-weighted belief propagation) log partition sum:',
        trwbp.logZ())
    print('time =', t1 - t0)

    ##################### Run Junction Tree Algorithm #####################
    print()
    print('-' * 80)
    # Construct a JTree (junction tree) object from the FactorGraph sg_FactorGraph
    # using the parameters specified by opts and an additional property
    # that specifies the type of updates the JTree algorithm should perform
    jtopts = opts
    jtopts["updates"] = "HUGIN"
    jt = dai.JTree(sg_FactorGraph, jtopts)
    # Initialize junction tree algorithm
    jt.init()
    # Run junction tree algorithm
    jt.run()

    # Construct another JTree (junction tree) object that is used to calculate
    # the joint configuration of variables that has maximum probability (MAP state)
    jtmapopts = opts
    jtmapopts["updates"] = "HUGIN"
    jtmapopts["inference"] = "MAXPROD"
    jtmap = dai.JTree(sg_FactorGraph, jtmapopts)
    # Initialize junction tree algorithm
    jtmap.init()
    # Run junction tree algorithm
    jtmap.run()
    # Calculate joint state of all variables that has maximum probability
    jtmapstate = jtmap.findMaximum()
    # Report log partition sum (normalizing constant) of sg_FactorGraph, calculated by the junction tree algorithm
    print()
    print('-' * 80)
    print('Exact log partition sum:', jt.logZ())
Exemplo n.º 4
0
    #    } catch( Exception &e ) {
    #        if( e.getCode() == Exception::OUT_OF_MEMORY ) {
    #            do_jt = false;
    #            cout << "Skipping junction tree (need more than " << maxstates << " states)." << endl;
    #        }
    #        else
    #            throw;
    #    }

    if do_jt:
        # Construct a JTree (junction tree) object from the FactorGraph fg
        # using the parameters specified by opts and an additional property
        # that specifies the type of updates the JTree algorithm should perform
        jtopts = opts
        jtopts["updates"] = "HUGIN"
        jt = dai.JTree( fg, jtopts )
        # Initialize junction tree algorithm
        jt.init()
        # Run junction tree algorithm
        jt.run()

        # Construct another JTree (junction tree) object that is used to calculate
        # the joint configuration of variables that has maximum probability (MAP state)
        jtmapopts = opts
        jtmapopts["updates"] = "HUGIN"
        jtmapopts["inference"] = "MAXPROD"
        jtmap = dai.JTree( fg, jtmapopts )
        # Initialize junction tree algorithm
        jtmap.init()
        # Run junction tree algorithm
        jtmap.run()