def main(): bnet = QuWetGrass.build_bnet() brute_eng = EnumerationEngine(bnet, is_quantum=True) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] node_list = brute_eng.bnet_ord_nodes brute_pot_list = brute_eng.get_unipot_list(node_list) bnet = QuWetGrass.build_bnet() monte_eng = MCMC_Engine(bnet, is_quantum=True) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] num_cycles = 1000 warmup = 200 node_list = monte_eng.bnet_ord_nodes monte_pot_list = monte_eng.get_unipot_list(node_list, num_cycles, warmup) bnet = QuWetGrass.build_bnet() jtree_eng = JoinTreeEngine(bnet, is_quantum=True) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] node_list = jtree_eng.bnet_ord_nodes jtree_pot_list = jtree_eng.get_unipot_list(node_list) for k in range(len(node_list)): print(node_list[k].name) print("brute engine:", brute_pot_list[k]) print("monte engine:", monte_pot_list[k]) print("jtree engine:", jtree_pot_list[k]) print("\n")
def main(): bnet = WetGrass.build_bnet() brute_eng = EnumerationEngine(bnet) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] node_list = brute_eng.bnet_ord_nodes brute_pot_list = brute_eng.get_unipot_list(node_list) bnet = WetGrass.build_bnet() monte_eng = MCMC_Engine(bnet) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] num_cycles = 1000 warmup = 200 node_list = monte_eng.bnet_ord_nodes monte_pot_list = monte_eng.get_unipot_list(node_list, num_cycles, warmup) bnet = WetGrass.build_bnet() jtree_eng = JoinTreeEngine(bnet) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] node_list = jtree_eng.bnet_ord_nodes jtree_pot_list = jtree_eng.get_unipot_list(node_list) for k in range(len(node_list)): print(node_list[k].name) print("brute engine:", brute_pot_list[k]) print("monte engine:", monte_pot_list[k]) print("jtree engine:", jtree_pot_list[k]) print("\n") bnet.write_bif('../examples_cbnets/WetGrass.bif', False) bnet.write_dot('../examples_cbnets/WetGrass.dot')
# introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] node_list = brute_eng.bnet_ord_nodes brute_pot_list = brute_eng.get_unipot_list(node_list) bnet = WetGrass.build_bnet() monte_eng = MCMC_Engine(bnet) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] num_cycles = 1000 warmup = 200 node_list = monte_eng.bnet_ord_nodes monte_pot_list = monte_eng.get_unipot_list(node_list, num_cycles, warmup) bnet = WetGrass.build_bnet() jtree_eng = JoinTreeEngine(bnet) # introduce some evidence bnet.get_node_named("WetGrass").active_states = [1] node_list = jtree_eng.bnet_ord_nodes jtree_pot_list = jtree_eng.get_unipot_list(node_list) for k in range(len(node_list)): print(node_list[k].name) print("brute engine:", brute_pot_list[k]) print("monte engine:", monte_pot_list[k]) print("jtree engine:", jtree_pot_list[k]) print("\n") bnet.write_bif('../examples_cbnets/WetGrass.bif', False) bnet.write_dot('../examples_cbnets/WetGrass.dot')
# print(node.name) # print(node.potential, "\n") monte_eng = MCMC_Engine(bnet) num_cycles = 1000 warmup = 200 monte_pot_list = monte_eng.get_unipot_list( node_list, num_cycles, warmup) # print("bnet pots after monte") # # check that bnet pots are not modified by engine # for node in node_list: # print(node.name) # print(node.potential, "\n") jtree_eng = JoinTreeEngine(bnet) jtree_pot_list = jtree_eng.get_unipot_list(node_list) # print("bnet pots after jtree") # # check that bnet pots are not modified by engine # for node in node_list: # print(node.name) # print(node.potential, "\n") for k in range(len(node_list)): print(node_list[k].name) print("brute engine:", brute_pot_list[k]) print("monte engine:", monte_pot_list[k]) print("jtree engine:", jtree_pot_list[k]) print("\n")
# for node in node_list: # print(node.name) # print(node.potential, "\n") monte_eng = MCMC_Engine(bnet, is_quantum=True) num_cycles = 1000 warmup = 200 monte_pot_list = monte_eng.get_unipot_list(node_list, num_cycles, warmup) # print("bnet pots after monte") # # check that bnet pots are not modified by engine # for node in node_list: # print(node.name) # print(node.potential, "\n") jtree_eng = JoinTreeEngine(bnet, is_quantum=True) jtree_pot_list = jtree_eng.get_unipot_list(node_list) # print("bnet pots after jtree") # # check that bnet pots are not modified by engine # for node in node_list: # print(node.name) # print(node.potential, "\n") for k in range(len(node_list)): print(node_list[k].name) print("brute engine:", brute_pot_list[k]) print("monte engine:", monte_pot_list[k]) print("jtree engine:", jtree_pot_list[k]) print("\n")
def run_gui(bnet): """ Generates and runs a widgets gui (graphical user interface) for doing inferences (more specifically, for displaying a probability bar plot for each node) conditioned on evidence entered by user into gui. Parameters ---------- bnet : BayesNet Returns ------- None """ engine = JoinTreeEngine(bnet) node_list = list(bnet.nodes) num_nds = len(node_list) nd_names = sorted([nd.name for nd in node_list]) # print(nd_names) display(widgets.Label(value="Active states for each node:")) active_wdg_list = [] for vtx in nd_names: st_names = bnet.get_node_named(vtx).state_names st_names1 = ['All States'] + st_names sel_wdg = widgets.SelectMultiple(options=dict( zip(st_names1, range(-1, len(st_names)))), value=(-1, ), description=vtx + ":") active_wdg_list.append(sel_wdg) display(sel_wdg) # print(active_wdg_list) display((widgets.Label(value="Desired Node Prob Plots:"))) plotted_nds_wdg = widgets.SelectMultiple( options=['All Nodes'] + nd_names, value=['All Nodes'], ) display(plotted_nds_wdg) run_wdg = widgets.Button(description='Refresh Node Prob Plots', ) run_wdg.layout.width = '40%' run_wdg.button_style = 'danger' display(run_wdg) # intialize each time run cell for nd in bnet.nodes: nd.active_states = range(nd.size) plotted_nds = engine.bnet_ord_nodes def active_wdg_do(title, change): nd = bnet.get_node_named(title) if -1 in change['new']: nd.active_states = range(nd.size) else: nd.active_states = list(change['new']) # print(title, change, nd.active_states) for active_wdg in active_wdg_list: title = active_wdg.description[:-1] # must store 'title' each time or all functions will use # value of 'title' at end of loop # Thanks to Jason Grout for pointing this out fun = (lambda x, title=title: active_wdg_do(title, x)) active_wdg.observe(fun, names='value') def plotted_nds_wdg_do(change): # print("inside_plotted_do") # make plotted_nds global so changes get outside function global plotted_nds if 'All Nodes' in change['new']: plotted_nds = engine.bnet_ord_nodes else: plotted_nds = [ bnet.get_node_named(name) for name in sorted(list(change['new'])) ] plotted_nds_wdg.observe(plotted_nds_wdg_do, names='value') def single_pd(ax, node_name, pd_df): plt.sca(ax) ax.invert_yaxis() y_pos = np.arange(len(pd_df.index)) + .5 plt.yticks(y_pos, pd_df.index) ax.set_xticks([0, .25, .5, .75, 1]) ax.set_xlim(0, 1) for row in range(len(y_pos)): val = pd_df.values[row] if isinstance(val, np.ndarray): val = val[0] ax.text(val, y_pos[row], '{:.3f}'.format(val)) ax.grid(True) ax.set_title(node_name) # new version of python/matplotlib has bug here. # The following used to work but no longer does. # ax.barh(y_pos, pd_df.values, align='center') # work around for b in range(len(y_pos)): ax.barh(y_pos[b], pd_df.values[b], align='center', color='blue') def run_wdg_do(b): # clear_output() plt.close('all') num_ax = len(plotted_nds) # h_scale is a height scale factor for # figure size to compensate for nodes having more than # 2 states. h_scale=1 if all plotted nodes have 2 states. h_scale = sum([len(plotted_nds[k].state_names) for k in range(num_ax)]) / (2 * num_ax) fig, ax_list = plt.subplots(nrows=num_ax, ncols=1, figsize=(4, num_ax * h_scale)) if num_ax == 1: ax_list = [ax_list] jtree_pot_list = engine.get_unipot_list(plotted_nds) for k in range(num_ax): vtx = plotted_nds[k].name print(vtx) print('Active States:', list(plotted_nds[k].active_states)) print(jtree_pot_list[k]) print("\n") df = pd.DataFrame(jtree_pot_list[k].pot_arr, index=plotted_nds[k].state_names) single_pd(ax_list[k], vtx, df) plt.tight_layout() plt.show() run_wdg.on_click(run_wdg_do)