def measure(valid, invalid, suffix): minlen=min(len(valid),len(invalid)) valid_tmp = valid [:minlen] invalid_tmp = invalid[:minlen] tp = np.clip(combined_sd(valid_tmp ,sae,cae,discriminator,batch_size=1000).round(), 0,1) # true positive fp = np.clip(combined_sd(invalid_tmp,sae,cae,discriminator,batch_size=1000).round(), 0,1) # false positive tn = 1-fp fn = 1-tp reg([suffix,"minlen" ],minlen) recall = np.mean(tp) # recall / sensitivity / power / true positive rate out of condition positive specificity = np.mean(tn) # specificity / true negative rate out of condition negative reg([suffix,"recall" ],recall) reg([suffix,"specificity"],specificity) reg([suffix,"f"],(2*recall*specificity)/(recall+specificity)) try: reg([suffix,"precision" ],np.sum(tp)/(np.sum(tp)+np.sum(fp))) except ZeroDivisionError: reg([suffix,"precision" ],float('nan')) try: reg([suffix,"accuracy" ],(np.sum(tp)+np.sum(tn))/(2*minlen)) except ZeroDivisionError: reg([suffix,"accuracy" ],float('nan')) return
def decide_pruning_method(): # Ad-hoc improvement: if the state discriminator type-1 error is very high # (which is not cheating because it can be verified from the training # dataset), don't include SD pruning. The threshold misclassification rate # is arbitrarily set as 0.25 . global pruning_methods print("verifying SD type-1 error") states_valid = np.loadtxt(sae.local("states.csv"), dtype=np.int8) type1_d = combined_sd(states_valid, sae, cae, sd3) type1_error = np.sum(1 - type1_d) / len(states_valid) if type1_error > 0.25: pruning_methods = [ action_reconstruction_filtering, # if applied, this should be the first method # state_reconstruction_from_aae_filtering, # inflate_actions, action_discriminator_filtering, state_reconstruction_filtering, ] else: pruning_methods = [ action_reconstruction_filtering, # if applied, this should be the first method # state_reconstruction_from_aae_filtering, # inflate_actions, action_discriminator_filtering, state_reconstruction_filtering, state_discriminator3_filtering, ]
def test(): valid = np.loadtxt(aae.local("valid_actions.csv"), dtype=np.int8) random.shuffle(valid) invalid = np.loadtxt(aae.local("invalid_actions.csv"), dtype=np.int8) random.shuffle(invalid) N = int(valid.shape[1] // 2) performance = {} def reg(names, value, d=performance): name = names[0] if len(names) > 1: try: tmp = d[name] except KeyError: tmp = {} d[name] = tmp reg(names[1:], value, tmp) else: d[name] = float(value) print(name, ": ", value) reg(["valid"], len(valid)) reg(["invalid"], len(invalid)) def measure(valid, invalid, suffix): minlen = min(len(valid), len(invalid)) valid_tmp = valid[:minlen] invalid_tmp = invalid[:minlen] tp = np.clip( discriminator.discriminate(valid_tmp, batch_size=1000).round(), 0, 1) # true positive fp = np.clip( discriminator.discriminate(invalid_tmp, batch_size=1000).round(), 0, 1) # false positive tn = 1 - fp fn = 1 - tp reg([suffix, "minlen"], minlen) recall = np.mean( tp ) # recall / sensitivity / power / true positive rate out of condition positive specificity = np.mean( tn) # specificity / true negative rate out of condition negative reg([suffix, "recall"], recall) reg([suffix, "specificity"], specificity) reg([suffix, "f"], (2 * recall * specificity) / (recall + specificity)) try: reg([suffix, "precision"], np.sum(tp) / (np.sum(tp) + np.sum(fp))) except ZeroDivisionError: reg([suffix, "precision"], float('nan')) try: reg([suffix, "accuracy"], (np.sum(tp) + np.sum(tn)) / (2 * minlen)) except ZeroDivisionError: reg([suffix, "accuracy"], float('nan')) return measure(valid, invalid, "raw") measure( valid, invalid[np.where( np.squeeze( combined_sd(invalid[:, N:], sae, cae, sd3, batch_size=1000)) > 0.5)[0]], "sd") p = latplan.util.puzzle_module(sae.local("")) measure( valid, invalid[p.validate_states(sae.decode(invalid[:, N:], batch_size=1000), verbose=False, batch_size=1000)], "validated") import json with open(discriminator.local('performance.json'), 'w') as f: json.dump(performance, f)
def main(network_dir, problem_dir, searcher, first_solution=True, heuristics="goalcount", _aae="_aae"): global sae, aae, ad, sd3, cae, available_actions def search(path): root, ext = os.path.splitext(path) return "{}_{}{}".format(searcher, root, ext) def heur(path): root, ext = os.path.splitext(path) return "{}_{}{}".format(heuristics, root, ext) p = latplan.util.puzzle_module(network_dir) log("loaded puzzle") sae = latplan.model.load(network_dir) aae = latplan.model.load(sae.local(_aae)) ad = latplan.model.load(sae.local(_aae + "_ad/")) sd3 = latplan.model.load(sae.local("_sd3/")) cae = latplan.model.load(sae.local("_cae/"), allow_failure=True) setup_planner_utils(sae, problem_dir, network_dir, "ama2") log("loaded sae") decide_pruning_method() init, goal = init_goal_misc(p) log("loaded init/goal") known_transisitons = np.loadtxt(sae.local("actions.csv"), dtype=np.int8) actions = aae.encode_action(known_transisitons, batch_size=1000).round() histogram = np.squeeze(actions.sum(axis=0, dtype=int)) print(histogram) print(np.count_nonzero(histogram), "actions valid") print("valid actions:") print(np.where(histogram > 0)[0]) identified, total = np.squeeze(histogram.sum()), len(actions) if total != identified: print("network does not explain all actions: only {} out of {} ({}%)". format(identified, total, identified * 100 // total)) available_actions = np.zeros( (np.count_nonzero(histogram), actions.shape[1], actions.shape[2]), dtype=int) for i, pos in enumerate(np.where(histogram > 0)[0]): available_actions[i][0][pos] = 1 log("initialized actions") log("start planning") _searcher = eval(searcher)() _searcher.stats["aae"] = _aae _searcher.stats["heuristics"] = heuristics _searcher.stats["search"] = searcher _searcher.stats["network"] = network_dir _searcher.stats["problem"] = os.path.normpath(problem_dir).split("/")[-1] _searcher.stats["domain"] = os.path.normpath(problem_dir).split("/")[-2] _searcher.stats["noise"] = os.path.normpath(problem_dir).split("/")[-3] _searcher.stats["plan_count"] = 0 try: for i, found_goal_state in enumerate( _searcher.search(init, goal, eval(heuristics))): log("plan found") _searcher.stats["found"] = True _searcher.stats["exhausted"] = False _searcher.stats["plan_count"] += 1 plan = np.array(found_goal_state.path()) _searcher.stats["statistics"]["cost"] = len(plan) - 1 _searcher.stats["statistics"]["length"] = len(plan) - 1 print(plan) if first_solution: plot_grid(sae.decode(plan), path=problem( ama(network(search(heur("problem.png"))))), verbose=True) else: plot_grid( sae.decode(plan), path=problem( ama(network(search(heur( "problem_{}.png".format(i)))))), verbose=True) log("plotted the plan") validation = p.validate_transitions( [sae.decode(plan[0:-1]), sae.decode(plan[1:])]) print(validation) print( ad.discriminate(np.concatenate((plan[0:-1], plan[1:]), axis=-1)).flatten()) print(p.validate_states(sae.decode(plan))) print(combined_sd(plan, sae, cae, sd3).flatten()) log("validated plan") if np.all(validation): _searcher.stats["valid"] = True return _searcher.stats["valid"] = False if first_solution: return except StopIteration: _searcher.stats["found"] = False _searcher.stats["exhausted"] = True finally: _searcher.stats["times"] = times _searcher.report(problem(ama(network(search(heur("problem.json"))))))
def state_discriminator3_filtering(y): N = y.shape[1] // 2 return y[np.where( np.squeeze(combined_sd(y[:, N:], sae, cae, sd3)) > 0.5)[0]]