Exemple #1
0
def test_spaun():
    pytest.importorskip("_spaun")

    dimensions = 2

    net = benchmarks.spaun(dimensions=dimensions)
    assert net.mem.mb1_net.output.size_in == dimensions
def compare_optimizations(ctx, dimensions, unroll):
    load = ctx.obj["load"]
    reps = ctx.obj["reps"]
    device = ctx.obj["device"]
    save = ctx.obj["save"]

    # optimizations to apply (simplifications, merging, sorting, unroll)
    params = [
        (False, False, False, False),
        (False, False, False, True),
        (False, True, False, True),
        (False, True, True, True),
        (True, True, True, True),
    ]
    # params = list(itertools.product((False, True), repeat=4))

    if load:
        with open("compare_optimizations_%d_data_saved.pkl" % dimensions,
                  "rb") as f:
            results = pickle.load(f)
    else:
        results = [{
            "times": [],
            "simplifications": simp,
            "planner": plan,
            "sorting": sort,
            "unroll": unro
        } for simp, plan, sort, unro in params]

    if reps > 0:
        with benchmarks.spaun(dimensions) as net:
            nengo_dl.configure_settings(inference_only=True)
        model = nengo.builder.Model(dt=0.001,
                                    builder=nengo_dl.builder.NengoBuilder())
        model.build(net)

        print("neurons", net.n_neurons)
        print("ensembles", len(net.all_ensembles))
        print("connections", len(net.all_connections))

        for i, (simp, plan, sort, unro) in enumerate(params):
            print("%d/%d: %s %s %s %s" %
                  (i + 1, len(params), simp, plan, sort, unro))
            with net:
                config = dict()
                config["simplifications"] = ([
                    graph_optimizer.remove_constant_copies,
                    graph_optimizer.remove_unmodified_resets,
                    # graph_optimizer.remove_zero_incs,
                    graph_optimizer.remove_identity_muls
                ] if simp else [])

                config["planner"] = (graph_optimizer.tree_planner if plan else
                                     graph_optimizer.greedy_planner)

                config["sorter"] = (graph_optimizer.order_signals if sort else
                                    graph_optimizer.noop_order_signals)

                nengo_dl.configure_settings(**config)

            with nengo_dl.Simulator(None,
                                    model=model,
                                    unroll_simulation=unroll if unro else 1,
                                    device=device) as sim:
                sim.run(0.1)

                sim_time = 1.0

                for _ in range(reps):
                    start = time.time()
                    sim.run(sim_time)
                    results[i]["times"].append(
                        (time.time() - start) / sim_time)

            print("   ", min(results[i]["times"]), max(results[i]["times"]),
                  np.mean(results[i]["times"]))

        with open("compare_optimizations_%d_data.pkl" % dimensions, "wb") as f:
            pickle.dump(results, f)

    data = np.asarray([bootstrap_ci(x) for x in filter_results(results)])

    plt.figure()

    alphas = np.linspace(0.5, 1, len(results))
    colour = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
    for i in range(len(results)):
        plt.bar([i], [data[i, 0]],
                yerr=abs(data[i, 1:] - data[i, [0]])[:, None],
                log=True,
                alpha=alphas[i],
                color=colour)

    labels = []
    for r in results:
        lab = "merging\n"
        if r["unroll"]:
            lab += "unrolling\n"
        if r["planner"]:
            lab += "planning\n"
        if r["sorting"]:
            lab += "sorting\n"
        if r["simplifications"]:
            lab += "simplifications\n"

        labels.append(lab[:-1])
    plt.xticks(np.arange(len(results)), labels, rotation="vertical")
    plt.ylabel("real time / simulated time")

    plt.tight_layout()

    if save:
        plt.savefig("compare_optimizations_%d.%s" % (dimensions, save))