def update(self, new_value): if self.logger is not None: self.logger({'score': new_value.score}) if new_value.score - 10 > self.start_score: self.logger({'extremely_strong': True}) if USE_COMPLEXITY_WIDTH: width = compute_complexity_width(new_value.td, DOMAIN_SIZES) approx_width = compute_complexity_width(new_value.td, DOMAIN_SIZES, approx=True) self.logger({'width': width, 'approx_width': approx_width}) self.value = new_value
def monitor_blip(filename, treewidth, logger: Callable, outfile="temp.res", timeout=10, seed=0, solver="kg", datfile=None, cwidth=0, onlyfilter=False, save_as="", debug=False): """ Run BLIP in monitoring mode, where each new score update is logged :param filename: path to jkl file :param treewidth: treewidth bound (ignored if in CWIDTH_MODE) :param logger: logging function to be used :param outfile: path to .res file containing learned network (volatile) :param timeout: total time limit on blip computation :param seed: random seed passed on to blip :param solver: blip sub-algo to use (for tw: [kg, ka, kmax], cwidth: [old, greedy, max]) :param datfile: path to data file (with header rows, for domain sizes) :param cwidth: cwidth bound (if positive, activates CWIDTH_MODE) :param onlyfilter: only use pset filtering algo (ignored if not in CWIDTH_MODE) :param save_as: filepath prefix to use for saving checkpoint solutions :param debug: enable debug mode """ CWIDTH_MODE = cwidth > 0 if CWIDTH_MODE: assert solver in ("old", "greedy", "max"), \ f"invalid solver({solver}) for monitor_blip in CWIDTH_MODE" basecmd = [ "java", "-jar", os.path.join(SOLVER_DIR, "blip-cw.jar"), f"solver.kg.adv", "-v", "1", "-src", f"cwidth-{solver}" ] args = [ "-j", filename, "-d", datfile, "-w", "0", "-cw", str(cwidth), "-r", outfile, "-t", str(timeout), "-seed", str(seed) ] if onlyfilter: args.append("-filter") else: basecmd = [ "java", "-jar", os.path.join(SOLVER_DIR, "blip.jar"), f"solver.{solver}", "-v", "1" ] args = [ "-j", filename, "-w", str(treewidth), "-r", outfile, "-t", str(timeout), "-seed", str(seed) ] cmd = basecmd + args if debug: print("monitoring blip, cmd:", " ".join(cmd)) if save_as: if CWIDTH_MODE: bnprovider = lambda: parse_res( filename, 0, outfile, cwidth=cwidth, datfile=datfile) else: bnprovider = lambda: parse_res(filename, treewidth, outfile) activate_checkpoints(bnprovider, save_as) domain_sizes = None if datfile is None else get_domain_sizes(datfile) with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True) as proc: for line in proc.stdout: if debug: print("got line:", line, end='') match = SCORE_PATN.match(line) if match: score = float(match['score']) logdata = {"score": score} if domain_sizes is not None: try: bn = bnprovider() except IndexError: # todo: not reached anymore, retries, rethink print("bn res file invalid (probably got overwritten)") cw = acw = -1 else: tw = bn.td.compute_width() cw = compute_complexity_width(bn.td, domain_sizes) logdata["tw"] = tw logdata["cw"] = cw logger(logdata) print(f"done returncode: {proc.returncode}")
def slim(filename: str, start_treewidth: int, budget: int = BUDGET, start_with_bn: TWBayesianNetwork = None, sat_timeout: int = TIMEOUT, max_passes=MAX_PASSES, max_time: int = MAX_TIME, heuristic=HEURISTIC, offset: int = OFFSET, seed=SEED, debug=False): global USING_COMPLEXITY_WIDTH, COMPLEXITY_BOUND, CW_TARGET_REACHED,\ CHECKPOINT_MILESTONES start = now() if SAVE_AS: activate_checkpoints(lambda: SOLUTION.value, SAVE_AS) def elapsed(): return f"(after {now()-start:.1f} s.)" heur_proc = outfile = None # placeholder if start_with_bn is not None: bn = start_with_bn elif START_WITH is not None: if not os.path.isfile(START_WITH): print(f"specified start-with file doesn't exist, quitting", file=sys.stderr) return if debug: print(f"starting with {START_WITH}, not running heuristic") # todo[safety]: handle case when no heuristic solution so far # todo[safety]: make add_extra_tuples a cli option add_extra_tuples = heuristic in ("hc", "hcp") bn = parse_res(filename, start_treewidth, START_WITH, add_extra_tuples=add_extra_tuples, augfile="augmented.jkl") else: if MIMIC: if debug: print("starting heuristic proc for mimicking") outfile = "temp-mimic.res" heur_proc = start_blip_proc(filename, start_treewidth, outfile=outfile, timeout=max_time, seed=seed, solver=heuristic, debug=False) if debug: print(f"waiting {offset}s") sleep(offset) # todo[safety]: make more robust by wrapping in try except (race condition) bn = parse_res(filename, start_treewidth, outfile) else: if debug: print(f"running initial heuristic for {offset}s") bn = run_blip(filename, start_treewidth, timeout=offset, seed=seed, solver=heuristic) if __debug__: bn.verify() # save checkpoint: milestone > start if CHECKPOINT_MILESTONES: write_res(bn, SAVE_AS.replace(".res", "-start.res"), write_elim_order=True) if USE_COMPLEXITY_WIDTH: start_cw = compute_complexity_width(bn.td, DOMAIN_SIZES) start_acw = compute_complexity_width(bn.td, DOMAIN_SIZES, approx=True) #complexity_bound = start_cw // 2 # todo[opt]: maybe use weight as bound? if FEASIBLE_CW: complexity_bound = FEASIBLE_CW_THRESHOLD if LOGGING: wandb.log({"infeasible": start_cw > complexity_bound}) else: complexity_bound = min(start_cw - 1, int(start_cw * CW_REDUCTION_FACTOR)) print(f"start cw: {start_cw}\tacw:{start_acw}") print( f"setting complexity bound: {complexity_bound}|{weight_from_domain_size(complexity_bound)}" ) COMPLEXITY_BOUND = complexity_bound SOLUTION.update(bn) if DOMAIN_SIZES: log_bag_metrics(bn.td, DOMAIN_SIZES) if LAZY_THRESHOLD > 0: print(f"lazy threshold: {LAZY_THRESHOLD} i.e. " f"minimum delta required: {bn.best_norm_score*LAZY_THRESHOLD}") prev_score = bn.score print(f"Starting score: {prev_score:.5f}") #if debug and DATFILE: print(f"Starting LL: {eval_ll(bn, DATFILE):.6f}") SOLUTION.start_score = prev_score if USE_COMPLEXITY_WIDTH: SOLUTION.start_width = start_cw history = Counter(dict.fromkeys(bn.td.decomp.nodes, 0)) if seed: random.seed(seed) cw_stop_looping = False while max_passes < 0 or SOLUTION.num_passes <= max_passes: # if USE_COMPLEXITY_WIDTH and cw_stop_looping: # if debug: print("*** initial bn score matched/surpassed ***\n") # # save checkpoint: milestone > finish # if CHECKPOINT_MILESTONES: # write_res(bn, SAVE_AS.replace(".res", "-finish.res"), write_elim_order=True) # CHECKPOINT_MILESTONES = False # break if USE_COMPLEXITY_WIDTH: USING_COMPLEXITY_WIDTH = SOLUTION.num_passes >= 10 or CW_TRAV_STRAT != "tw-max-rand" width_bound = complexity_bound if USING_COMPLEXITY_WIDTH else start_treewidth replaced = slimpass(bn, budget, sat_timeout, history, width_bound, debug=False) if replaced is None: # no change by slimpass # if debug: # print("failed slimpass (no subtree|lazy threshold|no maxsat soln)") continue # don't count this as a pass SOLUTION.num_passes += 1 new_score = bn.score if new_score > prev_score: print(f"*** New improvement! {new_score:.5f} {elapsed()} ***") prev_score = new_score SOLUTION.update(bn) SOLUTION.num_improvements += 1 if USE_COMPLEXITY_WIDTH and new_score >= SOLUTION.start_score: cw_stop_looping = True elif replaced: print("*** No improvement, but replacement performed ***") prev_score = new_score SOLUTION.update(bn) if MIMIC: heur_score = check_blip_proc(heur_proc, debug=False) if heur_score > bn.score: if debug: print( f"heuristic solution better {heur_score:.5f} > {bn.score:.5f}, mimicking" ) SOLUTION.restarts += 1 newbn = parse_res(filename, start_treewidth, outfile) new_score = newbn.score assert abs(new_score >= heur_score - 1e-5), \ f"score exaggerated, reported: {heur_score}\tactual score: {new_score}" bn = newbn prev_score = new_score SOLUTION.update(bn) # reset history because fresh tree decomposition history = Counter(dict.fromkeys(bn.td.decomp.nodes, 0)) if USE_COMPLEXITY_WIDTH: current_cw = compute_complexity_width(bn.td, DOMAIN_SIZES) if current_cw <= width_bound and not CW_TARGET_REACHED: if USING_COMPLEXITY_WIDTH and CW_TRAV_STRAT in [ "max-min", "max-rand", "tw-max-rand" ]: print("*** cw target reached, flipping strategy ***") CW_TARGET_REACHED = True # if bn.score >= prev_score: cw_stop_looping = True # save checkpoint: milestone > lowpoint if CHECKPOINT_MILESTONES: write_res(bn, SAVE_AS.replace(".res", "-lowpoint.res"), write_elim_order=True) if debug and USE_COMPLEXITY_WIDTH: print("current msss:", current_cw) if debug: print( f"* Iteration {SOLUTION.num_passes}:\t{bn.score:.5f} {elapsed()}\n" ) if now() - start > max_time: if debug: print("time limit exceeded, quitting") break else: if debug: print(f"{max_passes} passes completed, quitting") if MIMIC: if debug: print("stopping heur proc") stop_blip_proc(heur_proc) print(f"done {elapsed()}") if USE_COMPLEXITY_WIDTH and cw_stop_looping: return True
def slimpass(bn: TWBayesianNetwork, budget: int = BUDGET, timeout: int = TIMEOUT, history: Counter = None, width_bound: int = None, debug=False): td = bn.td if USING_COMPLEXITY_WIDTH: final_width_bound = weight_from_domain_size(width_bound) else: final_width_bound = width_bound selected, seen = find_subtree(td, budget, history, debug=False) history.update(selected) prep_tuple = prepare_subtree(bn, selected, seen, debug) if prep_tuple is None: return forced_arcs, forced_cliques, data, pset_acyc = prep_tuple # if debug: # print("filtered data:-") # pprint(data) old_score = bn.compute_score(seen) max_score = compute_max_score(data, bn) if RELAXED_PARENTS: # too strict # assert max_score + EPSILON >= old_score, "max score less than old score" assert round(max_score + EPSILON, 4) >= round( old_score, 4), "max score less than old score" if max_score < old_score: print("#### max score smaller than old score modulo epsilon") cur_offset = sum(bn.offsets[node] for node in seen) if debug: print( f"potential max: {(max_score - cur_offset)/bn.best_norm_score:.5f}", end="") if (max_score - cur_offset) / bn.best_norm_score <= LAZY_THRESHOLD: if debug: print(" skipping because lazy threshold not met") SOLUTION.skipped += 1 return pos = dict() # placeholder layout if not CLUSTER and debug: pos = pygraphviz_layout(bn.dag, prog='dot') nx.draw(bn.dag, pos, with_labels=True) plt.suptitle("entire dag") plt.show() nx.draw(bn.dag.subgraph(seen), pos, with_labels=True) plt.suptitle("subdag before improvement") plt.show() if debug: print("old parents:-") pprint({node: par for node, par in bn.parents.items() if node in seen}) domain_sizes = DOMAIN_SIZES if USING_COMPLEXITY_WIDTH else None try: replbn = solve_bn(data, final_width_bound, bn.input_file, forced_arcs, forced_cliques, pset_acyc, timeout, domain_sizes, debug) except NoSolutionException as err: SOLUTION.nosolution += 1 print(f"no solution found by maxsat, skipping (reason: {err})") return new_score = replbn.compute_score() if not CLUSTER and debug: nx.draw(replbn.dag, pos, with_labels=True) plt.suptitle("replacement subdag") plt.show() if debug: print("new parents:-") pprint(replbn.parents) if debug: print(f"score change: {old_score:.3f} -> {new_score:.3f}") if USE_COMPLEXITY_WIDTH: old_cw = compute_complexity_width(td, DOMAIN_SIZES, include=selected) new_cw = compute_complexity_width(replbn.td, DOMAIN_SIZES) old_acw = compute_complexity_width(td, DOMAIN_SIZES, include=selected, approx=True) new_acw = compute_complexity_width(replbn.td, DOMAIN_SIZES, approx=True) # print(f"old: {old_cw}|{old_acw:.3f}\tnew: {new_cw}|{new_acw:.3f}") print(f"msss of local part: {old_cw} -> {new_cw}") # replacement criterion if USING_COMPLEXITY_WIDTH and old_cw > width_bound: if new_cw > width_bound: return False elif USING_COMPLEXITY_WIDTH and new_score == old_score and new_cw > old_cw: return False elif new_score < old_score: # in case not using cw, then this is the only check return False print(f"score change: {old_score:.3f} -> {new_score:.3f}, replacing ...") td.replace(selected, forced_cliques, replbn.td) # update bn with new bn bn.replace(replbn) if __debug__: bn.verify(verify_treewidth=not USING_COMPLEXITY_WIDTH) return True
ll, maescore, maetime = eval_all(filepath, args.treewidth, DATFILE, startres, SEED) final_metrics = dict(final_ll=ll, final_maescore=maescore, final_maetime=maetime) if LOGGING: wandb.log(final_metrics) else: print(final_metrics) success_rate = SOLUTION.num_improvements / (SOLUTION.num_passes - SOLUTION.skipped) treewidths = dict(start_tw=args.treewidth, final_tw=SOLUTION.value.td.compute_width()) if LOGGING: wandb.log({"success_rate": success_rate}) wandb.log(treewidths) if SOLUTION.num_improvements > 0: wandb.log({"improved": True}) else: print("final metrics:") pprint(SOLUTION.data) print(f"success_rate: {success_rate:.2%}") print(f"final score: {SOLUTION.value.score:.5f}") print(treewidths) if DOMAIN_SIZES: #log_bag_metrics(SOLUTION.value.td, DOMAIN_SIZES, append=True) print( "complexity-width:", compute_complexity_width(SOLUTION.value.td, DOMAIN_SIZES))
print("seed:", SEED) random.seed(SEED) if len(sys.argv) >= 3: use_dd = bool(int(sys.argv[2])) else: use_dd = input("use dd? y/[n]:") == "y" SvEncodingWithComplexity.use_dd = use_dd # g = nx.Graph() # g.add_edges_from("ac af bc bh cd eg eh fg gh".split()) g = nx.fast_gnp_random_graph(20, p=0.3, seed=SEED) ds = {node: random.randint(2, 16) for node in g.nodes} print("ds:", ds) weights = weights_from_domain_sizes(ds) # g.remove_edge(0, 3) # g.add_edge(0, 4) # weights[4] = 2 # weights = {node: 1 for node in g.nodes} # weights['h'] = 2 # weights['a'] = 3 print("weights:", weights) nx.draw(g, with_labels=True) if DRAWING: plt.show() td = solve_graph(g, weights, complexity_width=15, timeout=30, debug=False) td.verify() print(td.elim_order) cw = compute_complexity_width(td, ds) print("final cw:", cw) # td.draw()