示例#1
0
 def measure(valid, invalid, suffix):
     minlen=min(len(valid),len(invalid))
     
     valid_tmp   = valid  [:minlen]
     invalid_tmp = invalid[:minlen]
     
     tp = np.clip(combined_sd(valid_tmp  ,sae,cae,discriminator,batch_size=1000).round(), 0,1) # true positive
     fp = np.clip(combined_sd(invalid_tmp,sae,cae,discriminator,batch_size=1000).round(), 0,1) # false positive
     tn = 1-fp
     fn = 1-tp
 
     reg([suffix,"minlen"     ],minlen)
     recall      = np.mean(tp) # recall / sensitivity / power / true positive rate out of condition positive
     specificity = np.mean(tn) # specificity / true negative rate out of condition negative
     reg([suffix,"recall"     ],recall)
     reg([suffix,"specificity"],specificity)
     reg([suffix,"f"],(2*recall*specificity)/(recall+specificity))
     try:
         reg([suffix,"precision"  ],np.sum(tp)/(np.sum(tp)+np.sum(fp)))
     except ZeroDivisionError:
         reg([suffix,"precision"  ],float('nan'))
     try:
         reg([suffix,"accuracy"   ],(np.sum(tp)+np.sum(tn))/(2*minlen))
     except ZeroDivisionError:
         reg([suffix,"accuracy"   ],float('nan'))
     return
def decide_pruning_method():
    # Ad-hoc improvement: if the state discriminator type-1 error is very high
    # (which is not cheating because it can be verified from the training
    # dataset), don't include SD pruning. The threshold misclassification rate
    # is arbitrarily set as 0.25 .

    global pruning_methods
    print("verifying SD type-1 error")
    states_valid = np.loadtxt(sae.local("states.csv"), dtype=np.int8)
    type1_d = combined_sd(states_valid, sae, cae, sd3)
    type1_error = np.sum(1 - type1_d) / len(states_valid)
    if type1_error > 0.25:
        pruning_methods = [
            action_reconstruction_filtering,  # if applied, this should be the first method
            # state_reconstruction_from_aae_filtering,
            # inflate_actions,
            action_discriminator_filtering,
            state_reconstruction_filtering,
        ]
    else:
        pruning_methods = [
            action_reconstruction_filtering,  # if applied, this should be the first method
            # state_reconstruction_from_aae_filtering,
            # inflate_actions,
            action_discriminator_filtering,
            state_reconstruction_filtering,
            state_discriminator3_filtering,
        ]
示例#3
0
def test():
    valid = np.loadtxt(aae.local("valid_actions.csv"), dtype=np.int8)
    random.shuffle(valid)

    invalid = np.loadtxt(aae.local("invalid_actions.csv"), dtype=np.int8)
    random.shuffle(invalid)

    N = int(valid.shape[1] // 2)

    performance = {}

    def reg(names, value, d=performance):
        name = names[0]
        if len(names) > 1:
            try:
                tmp = d[name]
            except KeyError:
                tmp = {}
                d[name] = tmp
            reg(names[1:], value, tmp)
        else:
            d[name] = float(value)
            print(name, ": ", value)

    reg(["valid"], len(valid))
    reg(["invalid"], len(invalid))

    def measure(valid, invalid, suffix):
        minlen = min(len(valid), len(invalid))

        valid_tmp = valid[:minlen]
        invalid_tmp = invalid[:minlen]

        tp = np.clip(
            discriminator.discriminate(valid_tmp, batch_size=1000).round(), 0,
            1)  # true positive
        fp = np.clip(
            discriminator.discriminate(invalid_tmp, batch_size=1000).round(),
            0, 1)  # false positive
        tn = 1 - fp
        fn = 1 - tp

        reg([suffix, "minlen"], minlen)
        recall = np.mean(
            tp
        )  # recall / sensitivity / power / true positive rate out of condition positive
        specificity = np.mean(
            tn)  # specificity / true negative rate out of condition negative
        reg([suffix, "recall"], recall)
        reg([suffix, "specificity"], specificity)
        reg([suffix, "f"], (2 * recall * specificity) / (recall + specificity))
        try:
            reg([suffix, "precision"], np.sum(tp) / (np.sum(tp) + np.sum(fp)))
        except ZeroDivisionError:
            reg([suffix, "precision"], float('nan'))
        try:
            reg([suffix, "accuracy"], (np.sum(tp) + np.sum(tn)) / (2 * minlen))
        except ZeroDivisionError:
            reg([suffix, "accuracy"], float('nan'))
        return

    measure(valid, invalid, "raw")
    measure(
        valid, invalid[np.where(
            np.squeeze(
                combined_sd(invalid[:, N:], sae, cae, sd3, batch_size=1000)) >
            0.5)[0]], "sd")

    p = latplan.util.puzzle_module(sae.local(""))
    measure(
        valid, invalid[p.validate_states(sae.decode(invalid[:, N:],
                                                    batch_size=1000),
                                         verbose=False,
                                         batch_size=1000)], "validated")

    import json
    with open(discriminator.local('performance.json'), 'w') as f:
        json.dump(performance, f)
def main(network_dir,
         problem_dir,
         searcher,
         first_solution=True,
         heuristics="goalcount",
         _aae="_aae"):
    global sae, aae, ad, sd3, cae, available_actions

    def search(path):
        root, ext = os.path.splitext(path)
        return "{}_{}{}".format(searcher, root, ext)

    def heur(path):
        root, ext = os.path.splitext(path)
        return "{}_{}{}".format(heuristics, root, ext)

    p = latplan.util.puzzle_module(network_dir)
    log("loaded puzzle")

    sae = latplan.model.load(network_dir)
    aae = latplan.model.load(sae.local(_aae))
    ad = latplan.model.load(sae.local(_aae + "_ad/"))
    sd3 = latplan.model.load(sae.local("_sd3/"))
    cae = latplan.model.load(sae.local("_cae/"), allow_failure=True)
    setup_planner_utils(sae, problem_dir, network_dir, "ama2")
    log("loaded sae")

    decide_pruning_method()

    init, goal = init_goal_misc(p)
    log("loaded init/goal")

    known_transisitons = np.loadtxt(sae.local("actions.csv"), dtype=np.int8)
    actions = aae.encode_action(known_transisitons, batch_size=1000).round()
    histogram = np.squeeze(actions.sum(axis=0, dtype=int))
    print(histogram)
    print(np.count_nonzero(histogram), "actions valid")
    print("valid actions:")
    print(np.where(histogram > 0)[0])
    identified, total = np.squeeze(histogram.sum()), len(actions)
    if total != identified:
        print("network does not explain all actions: only {} out of {} ({}%)".
              format(identified, total, identified * 100 // total))
    available_actions = np.zeros(
        (np.count_nonzero(histogram), actions.shape[1], actions.shape[2]),
        dtype=int)

    for i, pos in enumerate(np.where(histogram > 0)[0]):
        available_actions[i][0][pos] = 1
    log("initialized actions")

    log("start planning")
    _searcher = eval(searcher)()
    _searcher.stats["aae"] = _aae
    _searcher.stats["heuristics"] = heuristics
    _searcher.stats["search"] = searcher
    _searcher.stats["network"] = network_dir
    _searcher.stats["problem"] = os.path.normpath(problem_dir).split("/")[-1]
    _searcher.stats["domain"] = os.path.normpath(problem_dir).split("/")[-2]
    _searcher.stats["noise"] = os.path.normpath(problem_dir).split("/")[-3]
    _searcher.stats["plan_count"] = 0
    try:
        for i, found_goal_state in enumerate(
                _searcher.search(init, goal, eval(heuristics))):
            log("plan found")
            _searcher.stats["found"] = True
            _searcher.stats["exhausted"] = False
            _searcher.stats["plan_count"] += 1
            plan = np.array(found_goal_state.path())
            _searcher.stats["statistics"]["cost"] = len(plan) - 1
            _searcher.stats["statistics"]["length"] = len(plan) - 1
            print(plan)
            if first_solution:
                plot_grid(sae.decode(plan),
                          path=problem(
                              ama(network(search(heur("problem.png"))))),
                          verbose=True)
            else:
                plot_grid(
                    sae.decode(plan),
                    path=problem(
                        ama(network(search(heur(
                            "problem_{}.png".format(i)))))),
                    verbose=True)
            log("plotted the plan")

            validation = p.validate_transitions(
                [sae.decode(plan[0:-1]),
                 sae.decode(plan[1:])])
            print(validation)
            print(
                ad.discriminate(np.concatenate((plan[0:-1], plan[1:]),
                                               axis=-1)).flatten())
            print(p.validate_states(sae.decode(plan)))
            print(combined_sd(plan, sae, cae, sd3).flatten())
            log("validated plan")
            if np.all(validation):
                _searcher.stats["valid"] = True
                return
            _searcher.stats["valid"] = False
            if first_solution:
                return
    except StopIteration:
        _searcher.stats["found"] = False
        _searcher.stats["exhausted"] = True
    finally:
        _searcher.stats["times"] = times
        _searcher.report(problem(ama(network(search(heur("problem.json"))))))
def state_discriminator3_filtering(y):
    N = y.shape[1] // 2
    return y[np.where(
        np.squeeze(combined_sd(y[:, N:], sae, cae, sd3)) > 0.5)[0]]