def get_inf_jtree(self): sn = self.get_factor_graph() prop = dai.PropertySet() prop["inference"] = "SUMPROD" prop["updates"] = "HUGIN" prop["verbose"] = "1" inf = dai.JTree(sn, prop) return inf
def junction_tree(sg_model, verbose=False): ''' Calculate the exact partition function of a spin glass model using the junction tree algorithm Inputs: - sg_model (SpinGlassModel) Outputs: - ln_Z: natural logarithm of the exact partition function ''' sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel( sg_model, fixed_variables={}) # sg_FactorGraph = build_graph_from_clique_ising_model(sg_model, fixed_variables={}) # Write factorgraph to a file sg_FactorGraph.WriteToFile('sg_temp.fg') if verbose: print('spin glass factor graph written to sg_temp.fg') # Output some information about the factorgraph if verbose: print(sg_FactorGraph.nrVars(), 'variables') print(sg_FactorGraph.nrFactors(), 'factors') # Set some constants maxiter = 10000 tol = 1e-9 verb = 0 # Store the constants in a PropertySet object opts = dai.PropertySet() opts["maxiter"] = str(maxiter) # Maximum number of iterations opts["tol"] = str(tol) # Tolerance for convergence opts["verbose"] = str( verb ) # Verbosity (amount of output generated)bpopts["updates"] = "SEQRND" ##################### Run Junction Tree Algorithm ##################### # Construct a JTree (junction tree) object from the FactorGraph sg_FactorGraph # using the parameters specified by opts and an additional property # that specifies the type of updates the JTree algorithm should perform jtopts = opts jtopts["updates"] = "HUGIN" jt = dai.JTree(sg_FactorGraph, jtopts) # Initialize junction tree algorithm jt.init() # Run junction tree algorithm jt.run() # Construct another JTree (junction tree) object that is used to calculate # the joint configuration of variables that has maximum probability (MAP state) jtmapopts = opts jtmapopts["updates"] = "HUGIN" jtmapopts["inference"] = "MAXPROD" jtmap = dai.JTree(sg_FactorGraph, jtmapopts) # Initialize junction tree algorithm jtmap.init() # Run junction tree algorithm jtmap.run() # Calculate joint state of all variables that has maximum probability jtmapstate = jtmap.findMaximum() #doesn't work # print(type(bp.belief(sg_FactorGraph.var(0)))) # print(jtmap.belief(sg_FactorGraph.var(0))[0]) # sleep(aslfkdj) ln_Z = jt.logZ() # Report log partition sum (normalizing constant) of sg_FactorGraph, calculated by the junction tree algorithm if verbose: print() print('-' * 80) print('Exact log partition sum:', ln_Z) return (ln_Z)
def run_inference(sg_model): sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel( sg_model, fixed_variables={}) # sg_FactorGraph = build_graph_from_clique_ising_model(sg_model, fixed_variables={}) # Write factorgraph to a file sg_FactorGraph.WriteToFile('sg_temp.fg') print('spin glass factor graph written to sg_temp.fg') # Output some information about the factorgraph print(sg_FactorGraph.nrVars(), 'variables') print(sg_FactorGraph.nrFactors(), 'factors') # Set some constants maxiter = 10000 tol = 1e-9 verb = 1 # Store the constants in a PropertySet object opts = dai.PropertySet() opts["maxiter"] = str(maxiter) # Maximum number of iterations opts["tol"] = str(tol) # Tolerance for convergence opts["verbose"] = str( verb ) # Verbosity (amount of output generated)bpopts["updates"] = "SEQRND" ##################### Run Loopy Belief Propagation ##################### print() print('-' * 80) # Construct a BP (belief propagation) object from the FactorGraph sg_FactorGraph # using the parameters specified by opts and two additional properties, # specifying the type of updates the BP algorithm should perform and # whether they should be done in the real or in the logdomain bpopts = opts bpopts["updates"] = "SEQRND" bpopts["logdomain"] = "1" bp = dai.BP(sg_FactorGraph, bpopts) # Initialize belief propagation algorithm bp.init() # Run belief propagation algorithm bp.run() # Report log partition sum of sg_FactorGraph, approximated by the belief propagation algorithm print('Approximate (loopy belief propagation) log partition sum:', bp.logZ()) ##################### Run Tree Re-weighted Belief Propagation ##################### print() print('-' * 80) # Construct a BP (belief propagation) object from the FactorGraph sg_FactorGraph # using the parameters specified by opts and two additional properties, # specifying the type of updates the BP algorithm should perform and # whether they should be done in the real or in the logdomain trwbp_opts = opts trwbp_opts["updates"] = "SEQRND" trwbp_opts["nrtrees"] = "10" trwbp_opts["logdomain"] = "1" trwbp = dai.TRWBP(sg_FactorGraph, trwbp_opts) # trwbp = dai.FBP( sg_FactorGraph, trwbp_opts ) # Initialize belief propagation algorithm trwbp.init() # Run belief propagation algorithm t0 = time.time() trwbp.run() t1 = time.time() # Report log partition sum of sg_FactorGraph, approximated by the belief propagation algorithm print( 'Approximate (tree re-weighted belief propagation) log partition sum:', trwbp.logZ()) print('time =', t1 - t0) ##################### Run Junction Tree Algorithm ##################### print() print('-' * 80) # Construct a JTree (junction tree) object from the FactorGraph sg_FactorGraph # using the parameters specified by opts and an additional property # that specifies the type of updates the JTree algorithm should perform jtopts = opts jtopts["updates"] = "HUGIN" jt = dai.JTree(sg_FactorGraph, jtopts) # Initialize junction tree algorithm jt.init() # Run junction tree algorithm jt.run() # Construct another JTree (junction tree) object that is used to calculate # the joint configuration of variables that has maximum probability (MAP state) jtmapopts = opts jtmapopts["updates"] = "HUGIN" jtmapopts["inference"] = "MAXPROD" jtmap = dai.JTree(sg_FactorGraph, jtmapopts) # Initialize junction tree algorithm jtmap.init() # Run junction tree algorithm jtmap.run() # Calculate joint state of all variables that has maximum probability jtmapstate = jtmap.findMaximum() # Report log partition sum (normalizing constant) of sg_FactorGraph, calculated by the junction tree algorithm print() print('-' * 80) print('Exact log partition sum:', jt.logZ())
# } catch( Exception &e ) { # if( e.getCode() == Exception::OUT_OF_MEMORY ) { # do_jt = false; # cout << "Skipping junction tree (need more than " << maxstates << " states)." << endl; # } # else # throw; # } if do_jt: # Construct a JTree (junction tree) object from the FactorGraph fg # using the parameters specified by opts and an additional property # that specifies the type of updates the JTree algorithm should perform jtopts = opts jtopts["updates"] = "HUGIN" jt = dai.JTree( fg, jtopts ) # Initialize junction tree algorithm jt.init() # Run junction tree algorithm jt.run() # Construct another JTree (junction tree) object that is used to calculate # the joint configuration of variables that has maximum probability (MAP state) jtmapopts = opts jtmapopts["updates"] = "HUGIN" jtmapopts["inference"] = "MAXPROD" jtmap = dai.JTree( fg, jtmapopts ) # Initialize junction tree algorithm jtmap.init() # Run junction tree algorithm jtmap.run()