Exemple #1
0
def compare_single_parameter(folders,
                             analyses,
                             parameter,
                             filter=False,
                             pareto=False):
    folders = [Path(f) for f in folders]
    analyses = [Path(a) for a in analyses]

    # Glob all network files in subfolders
    files = [sorted(f.rglob("*.net")) for f in folders]

    # Optional (Pareto) filter
    if filter:
        if pareto:
            filters = [np.load(a / "mask_pareto.npy") for a in analyses]
            files = [
                np.array(f)[fil].tolist() for f, fil in zip(files, filters)
            ]
        else:
            filters = [np.load(a / "mask.npy") for a in analyses]
            files = [
                np.array(f)[fil].tolist() for f, fil in zip(files, filters)
            ]

    # Build network placeholders
    networks = []
    for f in folders:
        with open(f / "config.yaml", "r") as cf:
            config = yaml.full_load(cf)
            networks.append(build_network(config))

    # Dicts to hold everything in an orderly manner
    values = {i: {} for i in range(len(files))}

    # Go over networks
    for i, file, network in zip(range(len(files)), files, networks):
        for f in file:
            # Load parameters
            network.load_state_dict(torch.load(f))
            network.reset_state()

            # Add values to dict
            for name, child in network.named_children():
                if name not in values[i] and hasattr(child, parameter):
                    values[i][name] = []
                if hasattr(child, parameter):
                    values[i][name].append(
                        getattr(child, parameter).detach().clone().view(1, -1))

    # Convert to single numpy arrays
    for i, case in values.items():
        for layer, val in case.items():
            values[i][layer] = torch.cat(val, 0).view(-1).numpy()
def plot_transient(folder, parameters):
    folder = Path(folder)
    individual_id = "_".join(
        [s.replace(".net", "") for s in parameters.split("/")[-2:]])
    save_folder = folder / ("transient_new+" + individual_id)
    if os.path.exists(save_folder):
        shutil.rmtree(save_folder)
    os.makedirs(save_folder)

    # Load config
    with open(folder / "config.yaml", "r") as cf:
        config = yaml.full_load(cf)

    # Build environment
    env = build_environment(config)

    # Load network
    network = build_network(config)
    network.load_state_dict(torch.load(parameters))
    if isinstance(network, SNNNetwork):
        network.reset_state()

    # 100 runs
    action_list = []
    obs_list = []
    for i in range(100):
        env = randomize_env(env, config)
        if isinstance(network, SNNNetwork):
            network.reset_state()
        obs = env.reset(h0=config["env"]["h0"][1])
        done = False

        # For plotting
        actions = []
        observations = []

        actions.append(np.clip(env.action, *config["env"]["g bounds"]))
        observations.append(obs.copy())

        while not done:
            # Step the environment
            obs = torch.from_numpy(obs)
            action = network.forward(obs.view(1, 1, -1))
            action = action.numpy()
            obs, _, done, _ = env.step(action)

            if env.t >= env.settle:
                actions.append(
                    np.clip(env.action[0], *config["env"]["g bounds"]))
                observations.append(obs.copy())

        action_list.append(actions[1:])
        obs_list.append(observations[1:])

    # Visualize
    all_x = []
    all_y = []
    fig, ax = plt.subplots(1, 1)
    for i, act, ob in zip(range(len(action_list)), action_list, obs_list):
        sort_idx = np.argsort(np.array(ob)[:, 0])
        ma = (pd.Series(np.array(act)[sort_idx]).rolling(
            window=40, min_periods=1).mean().values)
        ax.plot(np.array(ob)[sort_idx, 0], ma, "r", alpha=0.5)
        all_x.extend((np.array(ob)[:, 0]).tolist())
        all_y.extend(act)
        output = pd.DataFrame({"x": np.array(ob)[sort_idx, 0], "y": ma})
        output.to_csv(save_folder / f"run{i}.csv", index=False, sep=",")

    ax.scatter(all_x, all_y, c="b", alpha=0.5)
    ax.set_xlim([-10, 10])
    ax.set_ylim([-0.9, 0.9])
    scatter = pd.DataFrame({"x": all_x, "y": all_y})
    scatter.to_csv(save_folder / f"raw_points.csv", index=False, sep=",")
    fig.tight_layout()

    plt.show()
Exemple #3
0
def model_to_header(config, in_file, verbose=2):
    # Build network
    network = build_network(config)
    # Load network parameters
    network.load_state_dict(torch.load(in_file))

    if verbose:
        if network.neuron1 is not None:
            # Write in->hid connection header file
            # Get relevant data
            weights = network.fc1.weight.view(-1).tolist()
            post = network.fc1.weight.size(0)
            pre = network.fc1.weight.size(1)
            # Create string
            string = [
                "//Auto-generated",
                '#include "Connection.h"',
                f"float const w_inhid[] = {{{', '.join([str(w) for w in weights])}}};",
                f"ConnectionConf const conf_inhid = {{{post}, {pre}, w_inhid}};",
            ]
            # Write to file
            with open(f"{config['log location']}connection_conf_inhid.h",
                      "w") as f:
                for line in string:
                    f.write(f"{line}\n")

        if network.neuron1 is not None:
            # Write hid neuron header file
            # Get relevant data
            neuron_type = 1 if isinstance(network.neuron1,
                                          AdaptiveLIFNeuron) else 0
            a_v = network.neuron1.alpha_v.view(-1).tolist()
            a_th = (network.neuron1.alpha_thresh.view(-1).tolist()
                    if isinstance(network.neuron1, AdaptiveLIFNeuron) else
                    torch.zeros_like(
                        network.neuron1.alpha_v).view(-1).tolist())
            a_t = network.neuron1.alpha_t.view(-1).tolist()
            d_v = network.neuron1.tau_v.view(-1).tolist()
            d_th = (network.neuron1.tau_thresh.view(-1).tolist() if isinstance(
                network.neuron1, AdaptiveLIFNeuron) else torch.zeros_like(
                    network.neuron1.tau_v).view(-1).tolist())
            d_t = network.neuron1.tau_t.view(-1).tolist()
            v_rest = network.neuron1.v_rest.item()
            th_rest = ((torch.ones_like(network.neuron1.thresh) *
                        network.neuron1.thresh_center).view(-1).tolist()
                       if isinstance(network.neuron1, AdaptiveLIFNeuron) else
                       network.neuron1.thresh.view(-1).tolist())
            size = network.neuron1.spikes.size(-1)
            # Create string
            string = [
                "//Auto-generated",
                '#include "Neuron.h"',
                f"float const a_v_hid[] = {{{', '.join([str(a) for a in a_v])}}};",
                f"float const a_th_hid[] = {{{', '.join([str(a) for a in a_th])}}};",
                f"float const a_t_hid[] = {{{', '.join([str(a) for a in a_t])}}};",
                f"float const d_v_hid[] = {{{', '.join([str(d) for d in d_v])}}};",
                f"float const d_th_hid[] = {{{', '.join([str(d) for d in d_th])}}};",
                f"float const d_t_hid[] = {{{', '.join([str(d) for d in d_t])}}};",
                f"float const th_rest_hid[] = {{{', '.join([str(t) for t in th_rest])}}};",
                f"NeuronConf const conf_hid = {{{neuron_type}, {size}, a_v_hid, a_th_hid, a_t_hid, d_v_hid, d_th_hid, d_t_hid, {v_rest}, th_rest_hid}};",
            ]
            # Write to file
            with open(f"{config['log location']}neuron_conf_hid.h", "w") as f:
                for line in string:
                    f.write(f"{line}\n")

        if network.neuron1 is not None:
            # Write hid->out connection header file
            # Get relevant data
            weights = network.fc2.weight.view(-1).tolist()
            post = network.fc2.weight.size(0)
            pre = network.fc2.weight.size(1)
            # Create string
            string = [
                "//Auto-generated",
                '#include "Connection.h"',
                f"float const w_hidout[] = {{{', '.join([str(w) for w in weights])}}};",
                f"ConnectionConf const conf_hidout = {{{post}, {pre}, w_hidout}};",
            ]
            # Write to file
            with open(f"{config['log location']}connection_conf_hidout.h",
                      "w") as f:
                for line in string:
                    f.write(f"{line}\n")
        else:
            # Write in->out connection header file
            # Get relevant data
            weights = network.fc2.weight.view(-1).tolist()
            post = network.fc2.weight.size(0)
            pre = network.fc2.weight.size(1)
            # Create string
            string = [
                "//Auto-generated",
                '#include "Connection.h"',
                f"float const w_inout[] = {{{', '.join([str(w) for w in weights])}}};",
                f"ConnectionConf const conf_inout = {{{post}, {pre}, w_inout}};",
            ]
            # Write to file
            with open(f"{config['log location']}connection_conf_inout.h",
                      "w") as f:
                for line in string:
                    f.write(f"{line}\n")

        # Write out neuron header file
        # Get relevant data
        neuron_type = 1 if isinstance(network.neuron2,
                                      AdaptiveLIFNeuron) else 0
        a_v = network.neuron2.alpha_v.view(-1).tolist()
        a_th = (network.neuron2.alpha_thresh.view(-1).tolist() if isinstance(
            network.neuron2, AdaptiveLIFNeuron) else torch.zeros_like(
                network.neuron2.alpha_v).view(-1).tolist())
        a_t = network.neuron2.alpha_t.view(-1).tolist()
        d_v = network.neuron2.tau_v.view(-1).tolist()
        d_th = (network.neuron2.tau_thresh.view(-1).tolist() if isinstance(
            network.neuron2, AdaptiveLIFNeuron) else torch.zeros_like(
                network.neuron2.tau_v).view(-1).tolist())
        d_t = network.neuron2.tau_t.view(-1).tolist()
        v_rest = network.neuron2.v_rest.item()
        th_rest = ((torch.ones_like(network.neuron2.thresh) *
                    network.neuron2.thresh_center).view(-1).tolist()
                   if isinstance(network.neuron2, AdaptiveLIFNeuron) else
                   network.neuron2.thresh.view(-1).tolist())
        size = network.neuron2.spikes.size(-1)
        # Create string
        string = [
            "//Auto-generated",
            '#include "Neuron.h"',
            f"float const a_v_out[] = {{{', '.join([str(a) for a in a_v])}}};",
            f"float const a_th_out[] = {{{', '.join([str(a) for a in a_th])}}};",
            f"float const a_t_out[] = {{{', '.join([str(a) for a in a_t])}}};",
            f"float const d_v_out[] = {{{', '.join([str(d) for d in d_v])}}};",
            f"float const d_th_out[] = {{{', '.join([str(d) for d in d_th])}}};",
            f"float const d_t_out[] = {{{', '.join([str(d) for d in d_t])}}};",
            f"float const th_rest_out[] = {{{', '.join([str(t) for t in th_rest])}}};",
            f"NeuronConf const conf_out = {{{neuron_type}, {size}, a_v_out, a_th_out, a_t_out, d_v_out, d_th_out, d_t_out, {v_rest}, th_rest_out}};",
        ]
        # Write to file
        with open(f"{config['log location']}neuron_conf_out.h", "w") as f:
            for line in string:
                f.write(f"{line}\n")

        # Write network header file
        # Get data
        centers = network.in_centers.view(-1).tolist()
        encoding_type = 1 if "place" in network.encoding else 0
        decoding_scale = network.out_scale
        in_size = 2
        in_enc_size = network.neuron0.spikes.size(-1)
        hid_size = network.neuron1.spikes.size(
            -1) if network.neuron1 is not None else 0
        out_size = network.neuron2.spikes.size(-1)
        # Create string
        if network.neuron1 is not None:
            string = [
                "//Auto-generated",
                '#include "Network.h"',
                '#include "connection_conf_inhid.h"',
                '#include "connection_conf_hidout.h"',
                '#include "neuron_conf_hid.h"',
                '#include "neuron_conf_out.h"',
                f"float const centers[] = {{{', '.join([str(c) for c in centers])}}};",
                f"NetworkConf const conf = {{{encoding_type}, {decoding_scale}, centers, {in_size}, {in_enc_size}, {hid_size}, {out_size}, &conf_inhid, &conf_hid, &conf_hidout, &conf_out}};",
            ]
        else:
            string = [
                "//Auto-generated",
                '#include "Network2.h"',
                '#include "connection_conf_inout.h"',
                '#include "neuron_conf_out.h"',
                f"float const centers[] = {{{', '.join([str(c) for c in centers])}}};",
                f"NetworkConf const conf = {{{encoding_type}, {decoding_scale}, centers, {in_size}, {in_enc_size}, {out_size}, &conf_inout, &conf_out}};",
            ]

        # Write to file
        with open(f"{config['log location']}network_conf.h", "w") as f:
            for line in string:
                f.write(f"{line}\n")
Exemple #4
0
def plot_ss(folder, parameters, runs):
    folder = Path(folder)
    individual_id = "_".join(
        [s.replace(".net", "") for s in parameters.split("/")[-2:]])
    save_folder = folder / ("steadystate+" + individual_id)
    if os.path.exists(save_folder):
        shutil.rmtree(save_folder)
    os.makedirs(save_folder)

    # Get run filenames
    if runs is not None:
        runs = sorted(Path(runs).rglob("run*.csv"))

    # Load config
    with open(folder / "config.yaml", "r") as cf:
        config = yaml.full_load(cf)

    # Load network
    network = build_network(config)
    network.load_state_dict(torch.load(parameters))
    if isinstance(network, SNNNetwork):
        network.reset_state()

    # Input: D in [-10, 10], Ddot in [-20, 20]
    div_lim = [-10.0, 10.0]
    divdot_lim = [-20.0, 20.0]
    div = np.linspace(*div_lim, 101)
    divdot = np.linspace(*divdot_lim, 201)

    # Batch size is 1, so no parallel stuff
    # Show each input for 100 steps
    time = 100
    response = np.zeros((div.shape[0], divdot.shape[0], time))

    # Start loop
    for i in range(div.shape[0]):
        for j in range(divdot.shape[0]):
            if isinstance(network, SNNNetwork):
                network.reset_state()
            for k in range(time):
                obs = np.array([div[i], divdot[j]])
                obs = torch.from_numpy(obs).float()
                action = network.forward(obs.view(1, 1, -1))
                response[i, j, k] = action.item()

    # Interpolate
    x = np.linspace(*div_lim, 401)
    y = np.linspace(*divdot_lim, 801)
    xi, yi = np.meshgrid(x, y)
    response = response[:, :, -50:].mean(-1)
    interp = RectBivariateSpline(div, divdot, response)
    zi = interp.ev(xi, yi)

    # Save raw response
    divi, divdoti = np.meshgrid(div, divdot)
    data = pd.DataFrame({
        "x": divi.flatten(),
        "y": divdoti.flatten(),
        "z": response.T.flatten()
    })
    data.to_csv(str(save_folder) + f"/ss_raw.csv", index=False, sep=",")

    # Visualize
    fig, ax = plt.subplots(1, 1)

    # Bounded
    im = ax.imshow(
        zi,
        vmin=config["env"]["g bounds"][0],
        vmax=config["env"]["g bounds"][1],
        cmap="viridis",
        extent=[*div_lim, *divdot_lim],
        aspect=0.5,
        origin="lower",
    )
    if runs is not None:
        for run in runs[:1]:
            run = pd.read_csv(run, sep=",")
            ax.plot(run["div_gt"], run["divdot_gt"], "r")
    ax.set_title("steady state response (bounded)")
    ax.set_ylabel("divergence dot [1/s2]")
    ax.set_xlabel("divergence [1/s]")
    ax.set_xlim(div_lim)
    ax.set_ylim(divdot_lim)
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label("thrust command [g]")
    fig.tight_layout()

    plt.show()
def vis_sensitivity_complete_4m(config, parameters, verbose=2):
    # Expand to all parameter files
    # In order to combine multiple evolution runs: put them as subdirectories in one
    # big folder and use that as parameter argument
    parameters = sorted(Path(parameters).rglob("*.net"))
    ids = np.arange(0, len(parameters))
    # Save parameters with indices as DataFrame for later identification of good controllers
    pd.DataFrame([(i, p) for i, p in zip(ids, parameters)],
                 columns=["id", "location"
                          ]).to_csv(f"{config['log location']}ids.csv",
                                    index=False,
                                    sep=",")

    # Build environment
    env = build_environment(config)

    # Build network
    network = build_network(config)

    # Performance over 250 runs
    # Record time to land, final height, final velocity and spikes per second
    performance = np.zeros((len(parameters), 250, 4))

    # Go over runs
    for j in range(performance.shape[1]):
        # Randomize environment here, because we want all nets to be exposed to the same
        # conditions in a single run
        env = randomize_env(env, config)

        # Go over all individuals
        for i, param in enumerate(parameters):
            # Load network
            network.load_state_dict(torch.load(param))

            # Reset env and net (may be superfluous)
            # Also reseed env to make noise equal across runs!
            # Only test from 4m
            obs = env.reset(h0=config["env"]["h0"][1])
            env.seed(env.seeds)
            if isinstance(network, SNNNetwork):
                network.reset_state()

            # Start run
            done = False
            spikes = 0
            while not done:
                # Step environment
                obs = torch.from_numpy(obs)
                action = network.forward(obs.view(1, 1, -1))
                action = action.numpy()
                obs, _, done, _ = env.step(action)
                if isinstance(network, SNNNetwork):
                    spikes += (network.neuron1.spikes.sum().item() +
                               network.neuron2.spikes.sum().item()
                               if network.neuron1 is not None else
                               network.neuron2.spikes.sum().item())

            # Increment counters
            performance[i, j, :] = [
                env.t - config["env"]["settle"],
                env.state[0],
                abs(env.state[1]),
                spikes / (env.t - config["env"]["settle"]),
            ]

    # Process results: get median and 25th and 75th percentiles
    percentiles = np.percentile(performance, [25, 50, 75], 1)
    stds = np.std(performance, 1)
    print(
        f"ID: {config['log location'].split('/')[-2]}, mean sigmas for time: {stds.mean(0)[0]:.3f}, height: {stds.mean(0)[1]:.3f}, velocity: {stds.mean(0)[2]:.3f}, spikes: {stds.mean(0)[3]:.3f}"
    )

    # Save results
    # Before filtering!
    if verbose:
        pd.DataFrame(stds,
                     columns=[
                         "time to land", "final height", "final velocity",
                         "spikes"
                     ]).to_csv(f"{config['log location']}sensitivity_stds.csv",
                               index=False,
                               sep=",")
        pd.DataFrame(
            np.concatenate(
                [
                    percentiles[0, :, :],
                    percentiles[1, :, :],
                    percentiles[2, :, :],
                    ids[:, None],
                ],
                axis=1,
            ),
            columns=[
                "25th_ttl",
                "25th_fh",
                "25th_fv",
                "25th_s",
                "50th_ttl",
                "50th_fh",
                "50th_fv",
                "50th_s",
                "75th_ttl",
                "75th_fh",
                "75th_fv",
                "75th_s",
                "id",
            ],
        ).to_csv(f"{config['log location']}sensitivity.csv",
                 index=False,
                 sep=",")
        # Also save raw performance as npz
        np.save(f"{config['log location']}raw_performance", performance)

    # Filter results
    mask = (percentiles[1, :, 0] < 10.0) & (percentiles[1, :, 2] < 1.0)
    efficient = is_pareto_efficient(percentiles[1, :, :])
    mask_pareto = mask & efficient
    percentiles = percentiles[:, mask, :]
    ids = ids[mask]

    # Also save filters/masks as npy for later use
    if verbose:
        np.save(f"{config['log location']}mask", mask)
        np.save(f"{config['log location']}mask_pareto", mask_pareto)

    # Plot results
    fig1, ax1 = plt.subplots(1, 1, dpi=200)
    ax1.set_title("Performance sensitivity")
    ax1.set_xlabel(config["evo"]["objectives"][0])
    ax1.set_ylabel(config["evo"]["objectives"][2])
    ax1.set_xlim([0.0, 10.0])
    ax1.set_ylim([0.0, 1.0])
    ax1.grid()

    # Scatter plot with error bars for 25th and 75th
    ax1.errorbar(
        percentiles[1, :, 0],
        percentiles[1, :, 2],
        xerr=np.abs(percentiles[[0, 2], :, 0] - percentiles[1, :, 0]),
        yerr=np.abs(percentiles[[0, 2], :, 2] - percentiles[1, :, 2]),
        linestyle="",
        marker="",
        color="k",
        elinewidth=0.5,
        capsize=1,
        capthick=0.5,
        zorder=10,
    )
    cb = ax1.scatter(
        percentiles[1, :, 0],
        percentiles[1, :, 2],
        marker=".",
        c=percentiles[1, :, 3],
        cmap="coolwarm",
        s=np.abs(percentiles[2, :, 3] - percentiles[0, :, 3]),
        linewidths=0.5,
        edgecolors="k",
        vmin=None,
        vmax=None,
        zorder=100,
    )

    fig1.colorbar(cb, ax=ax1)
    fig1.tight_layout()

    # Also plot figure with IDs
    fig2, ax2 = plt.subplots(1, 1, dpi=200)
    ax2.set_title("Performance sensitivity")
    ax2.set_xlabel(config["evo"]["objectives"][0])
    ax2.set_ylabel(config["evo"]["objectives"][2])
    ax2.set_xlim([0.0, 10.0])
    ax2.set_ylim([0.0, 1.0])
    ax2.grid()

    # Scatter plot with error bars for 25th and 75th
    for i in range(percentiles.shape[1]):
        ax2.text(percentiles[1, i, 0],
                 percentiles[1, i, 2],
                 str(ids[i]),
                 fontsize=5)
    fig2.colorbar(cb, ax=ax2)
    fig2.tight_layout()

    # Save figure
    if verbose:
        fig1.savefig(f"{config['log location']}sensitivity.png")
        fig2.savefig(f"{config['log location']}ids.png")

    # Show figure
    if verbose > 1:
        plt.show()
def vis_disturbance(config, parameters, verbose=2):
    # Build environment
    env = build_environment(config)
    env = randomize_env(env, config)

    # Load network
    network = build_network(config)
    network.load_state_dict(torch.load(parameters))

    # Reset network and env
    if isinstance(network, SNNNetwork):
        network.reset_state()
    obs = env.reset(h0=config["env"]["h0"][1])
    done = False

    # Indicators whether disturbance already happened
    dist_1 = False
    dist_2 = False

    # For plotting
    state_list = []
    obs_gt_list = []
    obs_list = []
    time_list = []
    encoding_list = []

    # For neuron visualization
    neuron_dict = OrderedDict([(name, {
        "trace": [],
        "volt": [],
        "spike": [],
        "thresh": []
    }) for name, child in network.named_children()
                               if isinstance(child, BaseNeuron)])

    while not done:
        # Log performance
        state_list.append(env.state.copy())
        obs_gt_list.append(env.div_ph.copy())
        obs_list.append(obs.copy())
        time_list.append(env.t)

        # Log neurons
        for name, child in network.named_children():
            if name in neuron_dict:
                neuron_dict[name]["trace"].append(
                    child.trace.detach().clone().view(-1).numpy())
                neuron_dict[name]["volt"].append(
                    child.v_cell.detach().clone().view(-1).numpy()) if hasattr(
                        child, "v_cell") else None
                neuron_dict[name]["spike"].append(
                    child.spikes.detach().clone().view(-1).numpy()) if hasattr(
                        child, "spikes") else None
                neuron_dict[name]["thresh"].append(
                    child.thresh.detach().clone().view(-1).numpy()) if hasattr(
                        child, "thresh") else None

        # Step the environment
        obs = torch.from_numpy(obs)
        action = network.forward(obs.view(1, 1, -1))
        action = action.numpy()
        if env.t >= 1.5 and not dist_1:
            env.set_disturbance(200.0, 0.0)
            obs, _, done, _ = env.step(action)
            env.unset_disturbance()
            dist_1 = True
        elif env.t >= 2.5 and not dist_2:
            env.set_disturbance(0.0, -2000.0)
            obs, _, done, _ = env.step(action)
            env.unset_disturbance()
            dist_2 = True
        else:
            obs, _, done, _ = env.step(action)

        # Log encoding as well
        if isinstance(network, SNNNetwork):
            encoding_list.append(network.input.view(-1).numpy())

    # Plot
    fig_p, axs_p = plt.subplots(5, 1, sharex=True, figsize=(10, 10))
    # Height
    axs_p[0].plot(time_list, np.array(state_list)[:, 0], label="Height")
    axs_p[0].set_ylabel("height [m]")
    # Velocity
    axs_p[1].plot(time_list, np.array(state_list)[:, 1], label="Velocity")
    axs_p[1].set_ylabel("velocity [m/s]")
    # Acceleration/thrust
    axs_p[2].plot(time_list, np.array(state_list)[:, 2], label="Thrust")
    axs_p[2].set_ylabel("acceleration [m/s2]")
    # Divergence
    axs_p[3].plot(time_list,
                  np.array(obs_gt_list)[:, 0],
                  label="GT divergence")
    axs_p[3].plot(time_list, np.array(obs_list)[:, 0], label="Divergence")
    if isinstance(network, SNNNetwork) and network.encoding == "both":
        axs_p[3].plot(time_list,
                      np.array(encoding_list)[:, 0],
                      label="Encoded +D")
        axs_p[3].plot(time_list,
                      np.array(encoding_list)[:, 2],
                      label="Encoded -D")
    elif isinstance(network, SNNNetwork) and network.encoding == "divergence":
        axs_p[3].plot(time_list,
                      np.array(encoding_list)[:, 0],
                      label="Encoded +D")
        axs_p[3].plot(time_list,
                      np.array(encoding_list)[:, 1],
                      label="Encoded -D")
    axs_p[3].set_ylabel("divergence [1/s]")
    # Divergence dot
    axs_p[4].plot(time_list, np.array(obs_gt_list)[:, 1], label="GT div dot")
    axs_p[4].plot(time_list, np.array(obs_list)[:, 1], label="Div dot")
    if isinstance(network, SNNNetwork) and network.encoding == "both":
        axs_p[4].plot(time_list,
                      np.array(encoding_list)[:, 1],
                      label="Encoded +Ddot")
        axs_p[4].plot(time_list,
                      np.array(encoding_list)[:, 3],
                      label="Encoded -Ddot")
    axs_p[4].set_ylabel("divergence dot [1/s2]")
    axs_p[4].set_xlabel("time [s]")

    for ax in axs_p:
        ax.grid()
        ax.legend()
    plt.tight_layout()

    if verbose:
        plt.savefig(f"{config['log location']}disturbance+performance.png")

    # Plot neurons
    if isinstance(network, SNNNetwork):
        dpi = 50 if config["net"]["hidden size"] > 10 else 100
        fig, ax = plt.subplots(config["net"]["hidden size"],
                               2,
                               figsize=(20, 20),
                               dpi=dpi)
        for i, (name, recordings) in enumerate(neuron_dict.items()):
            for var, vals in recordings.items():
                if len(vals):
                    for j in range(np.array(vals).shape[1]):
                        ax[j, i].plot(time_list,
                                      np.array(vals)[:, j],
                                      label=var)
                        ax[j, i].grid(True)

        fig.tight_layout()

        if verbose:
            fig.savefig(f"{config['log location']}disturbance+neurons.png")

    if verbose > 1:
        plt.show()
Exemple #7
0
def vis_steadystate(config, parameters, verbose=2):
    # Build environment
    env = build_environment(config)
    env = randomize_env(env, config)

    # Load network
    network = build_network(config)
    network.load_state_dict(torch.load(parameters))

    # Do one run from 5m
    if isinstance(network, SNNNetwork):
        network.reset_state()
    obs = env.reset(h0=(config["env"]["h0"][0] + config["env"]["h0"][-1]) / 2)
    done = False

    # For plotting
    obs_gt_list = []

    while not done:
        # Log performance
        obs_gt_list.append(env.div_ph.copy())

        # Step the environment
        obs = torch.from_numpy(obs)
        action = network.forward(obs.view(1, 1, -1))
        action = action.numpy()
        obs, _, done, _ = env.step(action)

    # Convert to array
    obs_gt = np.array(obs_gt_list)

    # Input: D in [-10, 10], Ddot in [-100, 100]
    div_lim = [-10.0, 10.0]
    divdot_lim = [-20.0, 20.0]
    div = np.linspace(*div_lim, 101)
    divdot = np.linspace(*divdot_lim, 101)

    # Batch size is 1, so no parallel stuff
    # Show each input for 100 steps
    time = 100
    response = np.zeros((div.shape[0], divdot.shape[0], time))

    # Start loop
    for i in range(div.shape[0]):
        for j in range(divdot.shape[0]):
            if isinstance(network, SNNNetwork):
                network.reset_state()
            for k in range(time):
                obs = np.array([div[i], divdot[j]])
                obs = torch.from_numpy(obs).float()
                action = network.forward(obs.view(1, 1, -1))
                response[i, j, k] = action.item() * config["env"]["g"]

    # Steady-state corner plot
    colors = [
        "xkcd:neon red", "xkcd:neon blue", "xkcd:neon green",
        "xkcd:neon purple"
    ]
    c = 0
    fig_c, ax_c = plt.subplots()
    ax_c.set_xlabel("steps")
    ax_c.set_ylabel("action")
    ax_c.set_title("corner responses")
    for d in div_lim:
        for dd in divdot_lim:
            if isinstance(network, SNNNetwork):
                network.reset_state()
            corner = []
            corner_noise = []
            for _ in range(time):
                obs = np.array([d, dd])
                obs_noise = (obs + np.random.normal(0.0, env.noise_std) +
                             abs(obs) * np.random.normal(0.0, env.noise_p_std))
                obs = torch.from_numpy(obs).float()
                obs_noise = torch.from_numpy(obs_noise).float()
                action = network.forward(obs.view(1, 1, -1))
                action_noise = network.forward(obs_noise.view(1, 1, -1))
                corner.append(action.item() * config["env"]["g"])
                corner_noise.append(action_noise.item() * config["env"]["g"])
            ax_c.plot(corner, label=f"div: {d}, divdot: {dd}", color=colors[c])
            ax_c.plot(corner_noise, color=colors[c])
            c += 1

    ax_c.legend()
    ax_c.grid()
    fig_c.tight_layout()

    if verbose:
        fig_c.savefig(f"{config['log location']}ss_corners.png")

    # Interpolate
    x = np.linspace(*div_lim, 2100)
    y = np.linspace(*divdot_lim, 2100)
    xi, yi = np.meshgrid(x, y)
    response = response[:, :, -50:].mean(-1)
    interp = RectBivariateSpline(div, divdot, response)
    zi = interp.ev(xi, yi)

    # Visualize
    fig_ss, ax_ss = plt.subplots(1, 2, figsize=(10, 5))

    # Non-bounded
    im = ax_ss[0].imshow(
        zi,
        vmin=response.min(),
        vmax=response.max(),
        cmap=parula_map,
        extent=[*div_lim, *divdot_lim],
        aspect=0.5,
        origin="lower",
    )
    ax_ss[0].plot(obs_gt[:, 0], obs_gt[:, 1], "r")
    ax_ss[0].set_title("steady state response (non-bounded)")
    ax_ss[0].set_ylabel("divergence dot [1/s2]")
    ax_ss[0].set_xlabel("divergence [1/s]")
    ax_ss[0].set_xlim(div_lim)
    ax_ss[0].set_ylim(divdot_lim)
    cbar = fig_ss.colorbar(im, ax=ax_ss[0])
    cbar.set_label("thrust command [m/s2]")

    # Bounded
    im = ax_ss[1].imshow(
        zi,
        vmin=config["env"]["g bounds"][0] * config["env"]["g"],
        vmax=config["env"]["g bounds"][1] * config["env"]["g"],
        cmap=parula_map,
        extent=[*div_lim, *divdot_lim],
        aspect=0.5,
        origin="lower",
    )
    ax_ss[1].plot(obs_gt[:, 0], obs_gt[:, 1], "r")
    ax_ss[1].set_title("steady state response (bounded)")
    ax_ss[1].set_ylabel("divergence dot [1/s2]")
    ax_ss[1].set_xlabel("divergence [1/s]")
    ax_ss[1].set_xlim(div_lim)
    ax_ss[1].set_ylim(divdot_lim)
    cbar = fig_ss.colorbar(im, ax=ax_ss[1])
    cbar.set_label("thrust command [m/s2]")
    fig_ss.tight_layout()

    if verbose:
        fig_ss.savefig(f"{config['log location']}ss_response.png")

    if verbose > 1:
        plt.show()
Exemple #8
0
def compare_parameters(folder1,
                       folder2,
                       analysis1,
                       analysis2,
                       filter=False,
                       pareto=False):
    folder1 = Path(folder1)
    folder2 = Path(folder2)
    analysis1 = Path(analysis1)
    analysis2 = Path(analysis2)
    # Glob all network files in subfolders
    files1 = sorted(folder1.rglob("*.net"))
    files2 = sorted(folder2.rglob("*.net"))

    # Optional (Pareto) filter
    if filter:
        if pareto:
            filter1 = np.load(analysis1 / "mask_pareto.npy")
            filter2 = np.load(analysis2 / "mask_pareto.npy")
            files1 = np.array(files1)[filter1].tolist()
            files2 = np.array(files2)[filter2].tolist()
        else:
            filter1 = np.load(analysis1 / "mask.npy")
            filter2 = np.load(analysis2 / "mask.npy")
            files1 = np.array(files1)[filter1].tolist()
            files2 = np.array(files2)[filter2].tolist()

    # Genes we're going to compare
    genes = [
        "weight",
        "alpha_v",
        "alpha_t",
        "alpha_thresh",
        "tau_v",
        "tau_t",
        "tau_thresh",
        "thresh",
    ]

    # Build network placeholders
    with open(folder1 / "config.yaml", "r") as cf:
        config1 = yaml.full_load(cf)
    with open(folder2 / "config.yaml", "r") as cf:
        config2 = yaml.full_load(cf)
    network1 = build_network(config1)
    network2 = build_network(config2)

    # Dicts to hold everything in an orderly manner
    params1 = {gene: {} for gene in genes}
    params2 = {gene: {} for gene in genes}

    ### 1 ###
    # Go over networks
    for file in files1:
        # Load parameters
        network1.load_state_dict(torch.load(file))
        network1.reset_state()

        # Add values to dict
        for gene in genes:
            for name, child in network1.named_children():
                if name not in params1[gene] and hasattr(child, gene):
                    params1[gene][name] = []
                if hasattr(child, gene):
                    params1[gene][name].append(
                        getattr(child, gene).detach().clone().view(1, -1))

    # Remove unwanted ones, such as thresh for AdaptiveLIFNeuron
    for name, child in network1.named_children():
        if isinstance(child, AdaptiveLIFNeuron):
            params1["thresh"][name] = None

    # If no hidden layer, add empty neuron1
    if network1.neuron1 is None:
        for gene in genes:
            if gene != "weight":
                params1[gene]["neuron1"] = None
            else:
                params1[gene]["fc1"] = None

    # Convert to single numpy arrays
    for gene, layers in params1.items():
        for layer, values in layers.items():
            params1[gene][layer] = (torch.cat(values, 0).view(-1).numpy()
                                    if params1[gene][layer] is not None else
                                    np.array([]))

    ### 2 ###
    # Go over networks
    for file in files2:
        # Load parameters
        network2.load_state_dict(torch.load(file))
        network2.reset_state()

        # Add values to dict
        for gene in genes:
            for name, child in network2.named_children():
                if name not in params2[gene] and hasattr(child, gene):
                    params2[gene][name] = []
                if hasattr(child, gene):
                    params2[gene][name].append(
                        getattr(child, gene).detach().clone().view(1, -1))

    # Remove unwanted ones, such as thresh for AdaptiveLIFNeuron
    for name, child in network2.named_children():
        if isinstance(child, AdaptiveLIFNeuron):
            params2["thresh"][name] = None

    # If no hidden layer, add empty neuron1
    if network2.neuron1 is None:
        for gene in genes:
            if gene != "weight":
                params2[gene]["neuron1"] = None
            else:
                params2[gene]["fc1"] = None

    # Convert to single numpy arrays
    for gene, layers in params2.items():
        for layer, values in layers.items():
            params2[gene][layer] = (torch.cat(values, 0).view(-1).numpy()
                                    if params2[gene][layer] is not None else
                                    np.array([]))

    # Plot for each gene
    stats1 = pd.DataFrame(
        columns=["gene", "layer", "median", "MWU stat", "MWU p two-sided"])
    stats2 = pd.DataFrame(
        columns=["gene", "layer", "median", "MWU stat", "MWU p two-sided"])
    for gene in genes:
        fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)

        # Compute overall min/max for a certain gene
        min_gene = min([
            l.min().item() for params in [params1, params2]
            for g, layers in params.items() for l in layers.values()
            if g == gene and l.shape[0] > 0
        ])
        max_gene = max([
            l.max().item() for params in [params1, params2]
            for g, layers in params.items() for l in layers.values()
            if g == gene and l.shape[0] > 0
        ])
        bins = np.linspace(min_gene, max_gene, 15)

        if gene != "weight":
            layer_names = ["neuron0", "neuron1", "neuron2"]
        else:
            layer_names = ["fc1", "fc2"]

        for ax, layer in zip(axs, layer_names):
            print()
            print(gene, layer)
            if layer not in params1[gene] and layer not in params2[gene]:
                continue
            values1 = params1[gene][layer]
            values2 = params2[gene][layer]
            ax.set_title(f"{gene}: {layer}")
            ax.grid()
            ax.hist(
                values1,
                bins,
                density=True,
                edgecolor="k",
                alpha=0.5,
                label="1",
                zorder=10,
            )
            ax.hist(
                values2,
                bins,
                density=True,
                edgecolor="k",
                alpha=0.5,
                label="2",
                zorder=100,
            )
            data1 = pd.DataFrame({"values": values1})
            data2 = pd.DataFrame({"values": values2})
            data1.to_csv(analysis1 / f"parameters+{gene}+{layer}.csv",
                         index=False,
                         sep=",")
            data2.to_csv(analysis2 / f"parameters+{gene}+{layer}.csv",
                         index=False,
                         sep=",")
            stat, p = mannwhitneyu(values1, values2, alternative="two-sided")
            median1 = np.median(values1)
            median2 = np.median(values2)
            stats1 = stats1.append(
                {
                    "gene": gene,
                    "layer": layer,
                    "median": median1,
                    "MWU stat": stat,
                    "MWU p two-sided": p,
                },
                ignore_index=True,
            )
            stats2 = stats2.append(
                {
                    "gene": gene,
                    "layer": layer,
                    "median": median2,
                    "MWU stat": stat,
                    "MWU p two-sided": p,
                },
                ignore_index=True,
            )
            print(f"Median 1: {median1}; median 2: {median2}")
            print(f"MWU test two-sided {gene}-{layer}: {stat}, {p}")
        axs[0].legend()
        fig.tight_layout()

    stats1.to_csv(analysis1 / f"stats.csv", index=False, sep=",")
    stats2.to_csv(analysis2 / f"stats.csv", index=False, sep=",")
    plt.show()
Exemple #9
0
def plot_performance(folder, parameters):
    folder = Path(folder)
    individual_id = "_".join(
        [s.replace(".net", "") for s in parameters.split("/")[-2:]])
    save_folder = folder / ("test+" + individual_id)
    if os.path.exists(save_folder):
        shutil.rmtree(save_folder)
    os.makedirs(save_folder)

    # Load config
    with open(folder / "config.yaml", "r") as cf:
        config = yaml.full_load(cf)

    # Build environment
    env = build_environment(config)

    # Load network
    network = build_network(config)
    network.load_state_dict(torch.load(parameters))

    # Create plot for performance
    fig_p, axs_p = plt.subplots(6, 1, sharex=True, figsize=(10, 10))
    axs_p[0].set_ylabel("height [m]")
    axs_p[1].set_ylabel("velocity [m/s]")
    axs_p[2].set_ylabel("thrust setpoint [g]")
    axs_p[3].set_ylabel("divergence [1/s]")
    axs_p[4].set_ylabel("divergence dot [1/s2]")
    axs_p[5].set_ylabel("spikes [?]")
    axs_p[5].set_xlabel("time [s]")

    # Create plot for neurons
    fig_n, axs_n = plt.subplots(7, 3, sharex=True, figsize=(10, 10))
    axs_n = axs_n.flatten()

    # Create list to hold spike rates per neuron
    rates = []

    # 5 runs
    for i in range(5):
        # With different properties
        # Randomizing here means that another run of this file will get different envs,
        # but so be it. Not easy to change
        env = randomize_env(env, config)
        # Reset network and env
        if isinstance(network, SNNNetwork):
            network.reset_state()
        obs = env.reset(h0=config["env"]["h0"][1])
        done = False
        spikes = 0

        # For plotting
        action_list = []
        state_list = []
        obs_gt_list = []
        obs_list = []
        time_list = []
        spike_list = []

        # For neuron visualization
        neuron_dict = OrderedDict([(name, {
            "trace": [],
            "spike": []
        }) for name, child in network.named_children()
                                   if isinstance(child, BaseNeuron)])

        # Log performance
        action_list.append(np.clip(env.action, *config["env"]["g bounds"]))
        state_list.append(env.state.copy())
        obs_gt_list.append(env.div_ph.copy())
        obs_list.append(obs.copy())
        time_list.append(env.t)
        spike_list.append([0, 0])
        # Log neurons
        for name, child in network.named_children():
            if name in neuron_dict:
                neuron_dict[name]["trace"].append(
                    child.trace.detach().clone().view(-1).numpy())
                neuron_dict[name]["spike"].append(
                    child.spikes.detach().clone().view(-1).numpy()) if hasattr(
                        child, "spikes") else None

        while not done:
            # Step the environment
            obs = torch.from_numpy(obs)
            action = network.forward(obs.view(1, 1, -1))
            action = action.numpy()
            obs, _, done, _ = env.step(action)

            # Log performance
            action_list.append(
                np.clip(env.action[0], *config["env"]["g bounds"]))
            state_list.append(env.state.copy())
            obs_gt_list.append(env.div_ph.copy())
            obs_list.append(obs.copy())
            time_list.append(env.t)
            if isinstance(network, SNNNetwork):
                spikes += (network.neuron1.spikes.sum().item() +
                           network.neuron2.spikes.sum().item()
                           if network.neuron1 is not None else
                           network.neuron2.spikes.sum().item())
                spike_list.append([
                    spikes,
                    network.neuron1.spikes.sum().item() +
                    network.neuron2.spikes.sum().item() if network.neuron1
                    is not None else network.neuron2.spikes.sum().item(),
                ])

            # Log neurons
            for name, child in network.named_children():
                if name in neuron_dict:
                    neuron_dict[name]["trace"].append(
                        child.trace.detach().clone().view(-1).numpy())
                    neuron_dict[name]["spike"].append(
                        child.spikes.detach().clone().view(
                            -1).numpy()) if hasattr(child, "spikes") else None

        # Plot data
        # Height
        axs_p[0].plot(time_list,
                      np.array(state_list)[:, 0],
                      "C0",
                      label=f"run {i}")
        # Velocity
        axs_p[1].plot(time_list,
                      np.array(state_list)[:, 1],
                      "C0",
                      label=f"run {i}")
        # Thrust
        axs_p[2].plot(time_list, np.array(action_list), "C0", label=f"run {i}")
        # Divergence
        axs_p[3].plot(time_list,
                      np.array(obs_list)[:, 0],
                      "C0",
                      label=f"run {i}")
        axs_p[3].plot(time_list,
                      np.array(obs_gt_list)[:, 0],
                      "C1",
                      label=f"run {i} GT")
        # Divergence dot
        axs_p[4].plot(time_list,
                      np.array(obs_list)[:, 1],
                      "C0",
                      label=f"run {i}")
        axs_p[4].plot(time_list,
                      np.array(obs_gt_list)[:, 1],
                      "C1",
                      label=f"run {i} GT")
        # Spikes
        axs_p[5].plot(
            time_list,
            np.array(spike_list)[:, 0] / np.array(time_list),
            "C0",
            label=f"run {i}",
        )
        axs_p[5].plot(
            time_list,
            pd.Series(np.array(spike_list)[:, 1]).rolling(
                window=20, min_periods=1).mean().values,
            "C1",
            label=f"run {i}",
        )

        # Plot neurons
        neurons = OrderedDict()
        k = 0
        # Go over layers
        for recordings in neuron_dict.values():
            # Go over neurons in layer
            for j in range(np.array(recordings["spike"]).shape[1]):
                neurons[f"n{k}_spike"] = np.array(
                    recordings["spike"])[:, j].astype(float)
                neurons[f"n{k}_trace"] = np.array(recordings["trace"])[:, j]
                neurons[f"n{k}_ma"] = (pd.Series(
                    np.array(recordings["spike"])[:, j]).rolling(
                        window=20, min_periods=1).mean().values)
                axs_n[k].plot(time_list,
                              np.array(recordings["trace"])[:, j], "C0")
                axs_n[k].plot(time_list,
                              np.array(recordings["spike"])[:, j], "C1")
                axs_n[k].plot(
                    time_list,
                    pd.Series(np.array(recordings["spike"])[:, j]).rolling(
                        window=20, min_periods=1).mean().values,
                    "C2",
                )
                axs_n[k].set_title(f"{k}")
                k += 1

        # Save run
        rates.append([[
            v.sum() / (time_list[-1] - config["env"]["settle"]),
            v.sum() / (len(time_list) - config["env"]["settle"] // env.dt + 1),
        ] for k, v in neurons.items() if "spike" in k])
        data = pd.DataFrame({
            "time":
            time_list,
            "pos_z":
            np.array(state_list)[:, 0],
            "vel_z":
            np.array(state_list)[:, 1],
            "thrust":
            np.array(state_list)[:, 2],
            "tsp":
            np.array(action_list),
            "tsp_lp":
            pd.Series(action_list).rolling(window=20,
                                           min_periods=1).mean().values,
            "div":
            np.array(obs_list)[:, 0],
            "div_gt":
            np.array(obs_gt_list)[:, 0],
            "divdot":
            np.array(obs_list)[:, 1],
            "divdot_gt":
            np.array(obs_gt_list)[:, 1],
            "spike_count":
            np.array(spike_list)[:, 0],
            "spike_step":
            np.array(spike_list)[:, 1],
        })
        neurons = pd.DataFrame(neurons)
        data = pd.concat([data, neurons], 1)
        data.to_csv(str(save_folder) + f"/run{i}.csv", index=False, sep=",")

    # Compute rates
    rates = pd.DataFrame({
        "mean_time": np.array(rates).mean(0)[:, 0],
        "mean_steps": np.array(rates).mean(0)[:, 1],
        "std_time": np.array(rates).std(0)[:, 0],
        "std_steps": np.array(rates).std(0)[:, 1],
    })
    rates.to_csv(str(save_folder) + f"/rates.csv", index=False, sep=",")

    # Write network to tikz-network-compatible file
    # Edges
    if network.neuron1 is not None:
        # First layer
        edges_0 = pd.DataFrame(columns=["u", "v", "lw_raw", "color", "lw"])
        for i in range(network.fc1.weight.shape[0]):
            for j in range(-network.fc1.weight.shape[1], 0):
                edges_0 = edges_0.append({
                    "u": j,
                    "v": i,
                    "lw": 0.0
                },
                                         ignore_index=True)
        edges_0["u"] = edges_0["u"].astype(int)
        edges_0["v"] = edges_0["v"].astype(int)
        edges_0["lw_raw"] = network.fc1.weight.view(-1).tolist()
        edges_0["color"] = np.where(edges_0["lw_raw"] > 0, "magenta", "cyan")
        edges_0["lw"] = edges_0["lw_raw"].abs()
        # Second layer
        edges_1 = pd.DataFrame(columns=["u", "v", "lw_raw", "color", "lw"])
        for i in range(network.fc2.weight.shape[0]):
            for j in range(network.fc2.weight.shape[1]):
                edges_1 = edges_1.append(
                    {
                        "u": j,
                        "v": i + network.fc2.weight.shape[1],
                        "lw": 0.0
                    },
                    ignore_index=True,
                )
        edges_1["u"] = edges_1["u"].astype(int)
        edges_1["v"] = edges_1["v"].astype(int)
        edges_1["lw_raw"] = network.fc2.weight.view(-1).tolist()
        edges_1["color"] = np.where(edges_1["lw_raw"] > 0, "magenta", "cyan")
        edges_1["lw"] = edges_1["lw_raw"].abs()
        edges = pd.concat([edges_0, edges_1], 0)
        edges.to_csv(str(save_folder) + f"/network_edges.csv",
                     index=False,
                     sep=",")
    else:
        # Only layer
        edges = pd.DataFrame(columns=["u", "v", "lw_raw", "color", "lw"])
        for i in range(network.fc2.weight.shape[0]):
            for j in range(-network.fc2.weight.shape[1], 0):
                edges = edges.append({
                    "u": j,
                    "v": i,
                    "lw": 0.0
                },
                                     ignore_index=True)
        edges["u"] = edges["u"].astype(int)
        edges["v"] = edges["v"].astype(int)
        edges["lw_raw"] = network.fc2.weight.view(-1).tolist()
        edges["color"] = np.where(edges["lw_raw"] > 0, "magenta", "cyan")
        edges["lw"] = edges["lw_raw"].abs()
        edges.to_csv(str(save_folder) + f"/network_edges.csv",
                     index=False,
                     sep=",")

    # Vertices
    k = 0
    # Input layer
    if network.neuron1 is not None:
        vertices_0 = pd.DataFrame(columns=["id", "x", "y", "color"])
        for j in range(-network.fc1.weight.shape[1], 0):
            vertices_0 = vertices_0.append(
                {
                    "id":
                    j,
                    "x":
                    0.0,
                    "y":
                    network.fc1.weight.shape[1] / 4 - 0.25 - 0.5 *
                    (network.fc1.weight.shape[1] + j),
                    "color":
                    "cyan",
                },
                ignore_index=True,
            )
    else:
        vertices_0 = pd.DataFrame(columns=["id", "x", "y", "color"])
        for j in range(-network.fc2.weight.shape[1], 0):
            vertices_0 = vertices_0.append(
                {
                    "id":
                    j,
                    "x":
                    0.0,
                    "y":
                    network.fc2.weight.shape[1] / 4 - 0.25 - 0.5 *
                    (network.fc2.weight.shape[1] + j),
                    "color":
                    "cyan",
                },
                ignore_index=True,
            )
    # Hidden layer
    if network.neuron1 is not None:
        vertices_1 = pd.DataFrame(columns=["id", "x", "y", "color"])
        for j in range(network.fc2.weight.shape[1]):
            vertices_1 = vertices_1.append(
                {
                    "id": j,
                    "x": 2.0,
                    "y": network.fc2.weight.shape[1] / 4 - 0.25 - 0.5 * j,
                    "color":
                    f"cyan!{100 - 3.333 * rates['mean_time'][k]}!magenta",
                },
                ignore_index=True,
            )
            k += 1
    # Output layer
    vertices_2 = pd.DataFrame(columns=["id", "x", "y", "color"])
    vertices_2 = vertices_2.append(
        {
            "id": k,
            "x": 4.0 if network.neuron1 is not None else 2.0,
            "y": 0.0,
            "color": f"cyan!{100 - 3.333 * rates['mean_time'][k]}!magenta",
        },
        ignore_index=True,
    )
    if network.neuron1 is not None:
        vertices = pd.concat([vertices_0, vertices_1, vertices_2], 0)
    else:
        vertices = pd.concat([vertices_0, vertices_2], 0)
    vertices.to_csv(str(save_folder) + f"/network_vertices.csv",
                    index=False,
                    sep=",")

    # Add grid
    for ax in axs_p:
        ax.grid()
    fig_p.tight_layout()
    fig_n.tight_layout()

    plt.show()