Exemple #1
0
error_at_max_1 = np.abs(y1[max_sig_idx] - y2[max_sig_idx]) / np.abs(y1[max_sig_idx])
error_at_max_2 = np.abs(y1[max_sig_idx] - y3[max_sig_idx]) / np.abs(y1[max_sig_idx])
error_at_max_3 = np.abs(y1[max_sig_idx] - y4[max_sig_idx]) / np.abs(y1[max_sig_idx])

max_error_1 = np.max(np.abs(y1[t0_plot_idx:t1_plot_idx] - y2[t0_plot_idx:t1_plot_idx]) /
                     np.max(np.abs(y1[t0_plot_idx:t1_plot_idx])))
max_error_2 = np.max(np.abs(y1[t0_plot_idx:t1_plot_idx] - y3[t0_plot_idx:t1_plot_idx]) /
                     np.max(np.abs(y1[t0_plot_idx:t1_plot_idx])))
max_error_3 = np.max(np.abs(y1[t0_plot_idx:t1_plot_idx] - y4[t0_plot_idx:t1_plot_idx]) /
                     np.max(np.abs(y1[t0_plot_idx:t1_plot_idx])))

print("Single dipole: error at sig max (t={:1.3f} ms): {:1.4f}. Max relative error: {:1.4f}".format(
    tvec[max_sig_idx], error_at_max_1, max_error_1))
print("Single dipole opt pos: error at sig max (t={:1.3f} ms): {:1.4f}. Max relative error: {:1.4f}".format(
    tvec[max_sig_idx], error_at_max_2, max_error_2))
print("Pop dipoles: error at sig max (t={:1.3f} ms): {:1.4f}. Max relative error: {:1.4f}".format(
    tvec[max_sig_idx], error_at_max_3, max_error_3))


ax1.plot(tvec, y1, c="k", lw=2., label="Sum")

ax1.plot(tvec, y2, ":", c="gray", lw=2., label="Single combined dipole")

ax1.plot(tvec, y4, "--", c="r", lw=1., label="Population dipoles")

simplify_axes(ax1)
fig.legend(frameon=False, ncol=3, fontsize=8)
plt.savefig(join(sim_folder, "Figure_combined_EEG.png"))
plt.savefig(join(sim_folder, "Figure_combined_EEG.pdf"))

Exemple #2
0
phi_sphere = phi_sphere.reshape(180, 180)[:, 0][0:90]
theta = params.theta.reshape(180, 180)[:, 0][0:90]

fig = plt.figure(figsize=[5, 4])
fig.subplots_adjust(left=0.2, bottom=0.14)
plt.subplot(111)
plt.xlim([0, 80])
plt.ylim([-15, 500.])

# plt.ylim([-0.5, 1.+1e-10]) # , xlim=[0, 50], ylim=[-0.2, 5])
# plt.plot(theta, phi_20, 'k', label='Nunez 20')

plt.plot(theta, phi_sphere, 'r', label='Homogeneous sphere')
plt.plot(theta, phi_nunsri06, 'k+', label='Nunez & Srinivasan (2006)')
plt.plot(theta, phi_sri98, 'g*', label='Srinivasan (1998)')
plt.plot(theta, phi_correct, 'b.', label='Present results - analytical')

plt.xlabel('Polar angle (degrees)', fontsize=14)
plt.ylabel(r'Potential ($\mathrm{\mu V}$)', fontsize=14.)
plt.tick_params(labelsize=15.)
plt.title(
    '$\sigma_{\mathrm{skull}} = \sigma_{\mathrm{brain}} = \sigma_{\mathrm{csf}} = \sigma_{\mathrm{scalp}}$',
    fontsize=19)
plt.legend(frameon=False, bbox_to_anchor=(1, 0.9), fontsize=11)
simplify_axes(fig.axes)

# plt.savefig(os.path.join(args.results, 'figure4_scaled.png'), dpi=150)
plt.savefig(os.path.join(args.results, 'figure4.eps'), dpi=150)
# plt.show()
def make_results_figure():
    x = np.linspace(x0, x1, nx)
    z = np.linspace(z0, z1, nz)

    mea_x_plot_pos = np.array([np.argmin(np.abs(x - x_))
                               for x_ in mea_x_positions])

    # print(mea_x_plot_pos, x[mea_x_plot_pos])

    mea_analytic = np.zeros((len(mea_x_plot_pos), num_tsteps))
    mea_fem = np.zeros((len(mea_x_plot_pos), num_tsteps))
    phi_plane_xz = np.zeros((len(x), len(z), len(tvec)))
    for t_idx in range(num_tsteps):
        phi_plane_xz[:, :, t_idx] = 1000 * np.load(join(out_folder,
                                        "phi_xz_t_vec_{}.npy".format(t_idx)))
        # phi_plane_xy_ = np.load(join(out_folder, "phi_xy_t_vec_{}.npy".format(t_idx)))
        mea_fem[:, t_idx] = 1000 * np.load(join(out_folder,
                    "phi_mea_t_vec_{}.npy".format(t_idx)))[mea_x_plot_pos]
        mea_analytic[:, t_idx] = 1000 * np.load(join(out_folder,
             "phi_mea_analytic_t_vec_{}.npy".format(t_idx)))[mea_x_plot_pos]

    noise_level = 10  # uV
    soma_height = 65
    soma_diam = 8
    soma_xpos = -200
    z_idx = np.argmin(np.abs(z - soma_height))
    x_idxs = (-150 + soma_diam > x) & (x > soma_xpos + soma_diam / 2)

    eap_amp = np.zeros(len(x))
    for x_idx in range(len(x)):
        eap_amp[x_idx] = np.max(np.abs(phi_plane_xz[x_idx, z_idx]))

    plt.close("all")
    fig = plt.figure(figsize=[117 * 0.03937, 48 * 0.03937])
    # fig = plt.figure(figsize=[117 * 0.03937 * 5, 48 * 0.03937 * 5])
    fig.subplots_adjust(hspace=0.45, bottom=0.17, top=0.99,
                        left=0.16, wspace=1.0, right=0.96)

    ax_h = 0.25
    ax_w = 0.11
    ax_left = 0.10

    ax_setup = fig.add_axes([0.28, 0.13, 0.69, 0.85], #aspect=1,
                            xlim=[-240, 145], ylim=[-15, 117])

    ax_vmem = fig.add_axes([ax_left, 0.13 + 2 * (ax_h + 0.04), ax_w, ax_h],
                           xlim=[0, tvec[-1]], ylim=[-110, 40])

    ax_mea_free = fig.add_axes([ax_left, 0.13 + ax_h + 0.04, ax_w, ax_h],
                               xlim=[0, tvec[-1]], ylim=[-8, 4])
    ax_mea_tunnel = fig.add_axes([ax_left, 0.13, ax_w, ax_h], xlim=[0, tvec[-1]],
                                    ylim=[-600, 400])
    ax_EAP_decay = fig.add_axes([0.57, 0.70, 0.08, 0.17],
                                xlim=[0, 50], facecolor='none')


    ax_setup.set_xlabel('X (µm)', labelpad=0)
    ax_setup.set_ylabel('Z (µm)', labelpad=-1)

    ax_vmem.set_ylabel('Membrane\npotential (mV)', labelpad=0)
    ax_vmem.set_xticklabels(["", ""])
    # ax_vmem.set_xlabel('Time (ms)', labelpad=-0.1)

    ax_mea_free.set_ylabel('OME (µV)', labelpad=9)
    ax_mea_free.set_xticklabels(["", ""])

    # ax_mea_free.set_xlabel('Time (ms)', labelpad=-0.1)

    ax_mea_tunnel.set_ylabel('CME (µV)', labelpad=2)
    ax_mea_tunnel.set_xlabel('Time (ms)', labelpad=-0.1)

    ax_EAP_decay.set_ylabel('Peak\namplitude (µV)', labelpad=1)
    ax_EAP_decay.set_xlabel('Distance from\nsoma (µm)', labelpad=-0.0)

    # ax_mea_free.set_title("OµE", pad=-5)
    # ax_mea_tunnel.set_title("CµE", pad=-5)
    ax_setup.text(139, -7, "Substrate", ha='right', va='top')

    dist_from_soma = x[x_idxs] - soma_xpos

    ax_EAP_decay.plot(dist_from_soma, eap_amp[x_idxs], lw=1, c='b')

    ax_setup.plot(x[x_idxs], np.ones(len(x[x_idxs])) * soma_height,
                  ls=':', c='blue', lw=1)

    noise_level_dist_idx = np.argmin(np.abs(eap_amp[x_idxs] - noise_level))
    noise_level_dist = dist_from_soma[noise_level_dist_idx]
    ax_EAP_decay.axhline(noise_level, ls='--', c='gray', lw=0.5)
    ax_EAP_decay.axvline(noise_level_dist, ls='--', c='gray', lw=0.5)
    # ax_EAP_decay.text(noise_level_dist + 5, noise_level + 5,
    #                   "{:1.1f}\n µm".format(noise_level_dist), fontsize=6)

    #  Draw set up with tunnel and axon
    rect = mpatches.Rectangle([-structure_radius - 2, tunnel_radius],
                              2 * structure_radius + 4, structure_height,
                              ec="k", fc='0.8', linewidth=0.3)
    ax_setup.add_patch(rect)

    rect_bottom = mpatches.Rectangle([-1000, 0],
                              2000, -1000, ec="k", fc='0.7', linewidth=0.3)
    ax_setup.add_patch(rect_bottom)

    for source_idx in range(len(source_line_pos)):
        xstart, xend = source_line_pos[source_idx, :, 0]
        zstart, zend = source_line_pos[source_idx, :, 2]

        ax_setup.plot([xstart, xend], [zstart, zend], c='#18e10c', lw=1,
                      solid_capstyle="round", solid_joinstyle="round")
    ax_setup.plot(source_pos[0, 0], source_pos[0, 2], c='#18e10c', marker='o',
                  ms=14)

    for counter, p_idx in enumerate(cell_plot_idxs):
        ax_setup.plot(source_pos[p_idx, 0], source_pos[p_idx, 2],
                      c=plot_idx_clrs[counter], marker='o', ms=5)
        if p_idx > 0:

            rect_elec = mpatches.Rectangle([source_pos[p_idx, 0] - 5, 0],
                                              10, -5, ec="k", fc='k',
                                           linewidth=0.5)
            ax_setup.add_patch(rect_elec)
            ax_setup.text(source_pos[p_idx, 0], - 6,
                          ["OME", "CME"][counter - 1], va='top', ha="center")

    for counter, p_idx in enumerate(cell_plot_idxs):
        ax_vmem.plot(tvec, vmem[p_idx, :], c=plot_idx_clrs[counter], lw=1)

    num = 11
    levels = np.logspace(-2., 0, num=num)
    scale_max = 500

    levels_norm = scale_max * np.concatenate((-levels[::-1], levels))
    bwr_cmap = plt.cm.get_cmap('bwr')  # rainbow, spectral, RdYlBu

    colors_from_map = [bwr_cmap(i * np.int(255 / (len(levels_norm) - 2)))
                       for i in range(len(levels_norm) - 1)]
    colors_from_map[num - 1] = (1.0, 1.0, 1.0, 1.0)

    xz_masked = np.ma.masked_where(np.isnan(phi_plane_xz), phi_plane_xz)

    ep_intervals = ax_setup.contourf(x, z, xz_masked[:, :, 0],
                                zorder=-2, colors=colors_from_map,
                                levels=levels_norm, extend='both')

    ep_intervals_ = ax_setup.contour(x, z, xz_masked[:, :, 0].T, colors='k',
                                     linewidths=(0.3), zorder=-2,
                                     levels=levels_norm)

    cax = fig.add_axes([0.85, 0.5, 0.01, 0.35])

    cbar = plt.colorbar(ep_intervals, cax=cax, label="$\phi$ (µV)")
    cbar.set_ticks(np.array([-1, -0.1, -0.02, 0.02, 0.1, 1]) * scale_max)

    #ax_mea_free.plot(tvec, mea_analytic[0], lw=1, c='gray')
    l, = ax_mea_free.plot(tvec, mea_fem[0],  lw=1, c='k')
    la, = ax_mea_tunnel.plot(tvec, mea_fem[1],  lw=1, c='k', ls="-")

    rel_error = np.max(np.abs((mea_analytic[0] - mea_fem[0])) / np.max(
        np.abs(mea_fem[0])))
    print("Relative error between FEM and MoI (free elec): {:1.4f}".format(
        rel_error))

    t1 = ax_vmem.axvline(tvec[0], c='gray', ls="--")
    t2 = ax_mea_free.axvline(tvec[0], c='gray', ls="--")
    t3 = ax_mea_tunnel.axvline(tvec[0], c='gray', ls="--")

    simplify_axes([ax_setup, ax_mea_free, ax_mea_tunnel, ax_vmem, ax_EAP_decay])
    #mark_subplots([ax_vmem, ax_mea_free, ax_mea_tunnel, ax_EAP_decay], "BCDE", xpos=-0.05, ypos=1.07)

    ## This is to make animation. Can be commented out to save time
    for t_idx in range(num_tsteps):

        for tp in ep_intervals.collections:
            tp.remove()
        ep_intervals = ax_setup.contourf(x, z, xz_masked[:, :, t_idx].T,
                                         zorder=-2, colors=colors_from_map,
                                         levels=levels_norm, extend='both')
        for tp in ep_intervals_.collections:
            tp.remove()

        ep_intervals_ = ax_setup.contour(x, z, xz_masked[:, :, t_idx].T,
                                         colors='k', linewidths=(1), zorder=-2,
                         levels=levels_norm)

        t1.set_xdata(tvec[t_idx])
        t2.set_xdata(tvec[t_idx])
        t3.set_xdata(tvec[t_idx])

        plt.savefig(join(fem_fig_folder, 'anim_results_{}_t_idx_{:04d}.png'.format(
            sim_name, t_idx)), dpi=300)

    cax.clear()
    t1.set_xdata(100)
    t2.set_xdata(100)
    t3.set_xdata(100)

    # num = 11
    # levels = np.logspace(-2.0, 0, num=num)
    # scale_max = 500
    # levels_norm = scale_max * levels
    # bwr_cmap = plt.cm.get_cmap('Reds')  # rainbow, spectral, RdYlBu
    # colors_from_map = [bwr_cmap(i * np.int(255 / (len(levels_norm) - 2)))
    #                    for i in range(len(levels_norm) - 1)]
    #
    #

    levels_norm = [0, 10.0, 1e9]#scale_max * levels
    # bwr_cmap = plt.cm.get_cmap('Reds')  # rainbow, spectral, RdYlBu
    colors_from_map = ['0.95', '#ffbbbb', (0.5, 0.5, 0.5, 1)]

    #colors_from_map[0] = (1.0, 1.0, 1.0, 1.0)

    xz_crossmax = np.array(np.max(np.abs(xz_masked[:, :, :]), axis=-1).T)

    for tp in ep_intervals.collections:
        tp.remove()
    ep_intervals = ax_setup.contourf(x, z, xz_crossmax,
                                     zorder=-2, colors=colors_from_map,
                                     levels=levels_norm, extend='both')

    for tp in ep_intervals_.collections:
        tp.remove()

    # ep_intervals_ = ax_setup.contour(x, z, xz_crossmax, colors='k',
    #                                  linewidths=(0.3), zorder=-2,
    #                                  levels=levels_norm)

    #img1 = ax_setup.imshow(np.max(np.abs(xz_masked), axis=-1).T,
    #           interpolation='nearest', origin='lower', cmap='Reds',
    #           extent=(x[0], x[-1], z[0], z[-1]), norm=LogNorm(0.002, vmax=1))

    cbar = plt.colorbar(ep_intervals, cax=cax)
    # cbar.set_ticks(np.array([0.01, 0.1, 1]) * scale_max)
    # print(cbar.get_ticks())
    cbar.set_ticks(np.array([5, 1e9/2]))
    cbar.set_ticklabels(np.array(["<10 µV", ">10 µV"]))
    #cax.set_xticklabels(np.array(np.array([-1, -0.1, -0.01, 0, 0.01, 0.1, 1])
                    # * scale_max, dtype=int), fontsize=7, rotation=0)
    plt.savefig(join(root_folder, 'Fig_{}_4.png'.format(sim_name)), dpi=300)
    plt.savefig(join(root_folder, 'Fig_{}_4.pdf'.format(sim_name)), dpi=300)


    plot_data_folder = join(root_folder, 'figure_data')
    os.makedirs(plot_data_folder, exist_ok=True)
    np.save(join(plot_data_folder, "xz_max_amp.npy"), xz_crossmax)
    np.save(join(plot_data_folder, "xz_x_values.npy"), x)
    np.save(join(plot_data_folder, "xz_z_values.npy"), z)
    np.save(join(plot_data_folder, "mea_phi_values.npy"), mea_fem)
    np.save(join(plot_data_folder, "source_line_pos.npy"), source_line_pos)
    np.save(join(plot_data_folder, "t_vec.npy"), tvec)
    np.save(join(plot_data_folder, "memb_pot.npy"), vmem[cell_plot_idxs, :])
    np.save(join(plot_data_folder, "eap_amp.npy"), eap_amp[x_idxs])
    np.save(join(plot_data_folder, "eap_amp_dist.npy"), dist_from_soma)
Exemple #4
0
def simulate_laminar_LFP():
    # DEPRECATED, and not updated for LFPy2.2 and newer
    dt = 2**-5
    cell_name = 'hay'

    elec_z = np.linspace(-200, 1200, 15)
    elec_x = np.ones(len(elec_z)) * 50
    elec_y = np.zeros(len(elec_z))

    h = np.abs(elec_z[1] - elec_z[0])

    electrode_parameters = {
        'sigma': 0.3,  # extracellular conductivity
        'x': elec_x,  # x,y,z-coordinates of contact points
        'y': elec_y,
        'z': elec_z
    }

    elec_clr = lambda idx: plt.cm.viridis(idx / len(elec_z))

    num_sims = 10

    cells = []

    summed_LFP = []
    summed_cdm = []

    for sim in range(num_sims):
        print(sim + 1, "/", num_sims)
        cell_wca, idx_clr, plot_idxs = run_cell_simulation_distributed_input(
            dt=dt, cell_name=cell_name)
        cells.append([cell_wca, idx_clr, plot_idxs])
    for sim in range(num_sims):
        cell_wca, idx_clr, plot_idxs = cells[sim]
        fig = plt.figure(figsize=[7, 7])
        fig.subplots_adjust(hspace=0.5,
                            left=0.0,
                            wspace=0.4,
                            right=0.96,
                            top=0.97,
                            bottom=0.1)

        ax_m = fig.add_axes([-0.01, 0.05, 0.3, 0.97],
                            aspect=1,
                            frameon=False,
                            xlim=[-350, 350],
                            xticks=[],
                            yticks=[])
        [
            ax_m.plot([cell_wca.xstart[idx], cell_wca.xend[idx]],
                      [cell_wca.zstart[idx], cell_wca.zend[idx]],
                      c='k') for idx in range(cell_wca.totnsegs)
        ]
        [
            ax_m.plot(cell_wca.xmid[idx],
                      cell_wca.zmid[idx],
                      'o',
                      c=idx_clr[idx],
                      ms=13) for idx in plot_idxs
        ]

        [
            ax_m.plot(cell_wca.xmid[idx], cell_wca.zmid[idx], 'rd')
            for idx in cell_wca.synidx
        ]

        ax_top = 0.98
        ax_h = 0.25
        h_space = 0.1
        ax_w = 0.17
        ax_left = 0.4
        cell = cell_wca

        elec = LFPy.RecExtElectrode(cell, **electrode_parameters)
        elec.calc_lfp()

        ax_vm = fig.add_axes([ax_left, ax_top - ax_h, ax_w, ax_h],
                             ylim=[-80, 50],
                             xlim=[0, 80],
                             xlabel="Time (ms)")

        ax_eap = fig.add_axes([ax_left + 0.3, 0.1, ax_w, 0.8],
                              xlim=[0, 80],
                              xlabel="Time (ms)")
        ax_cdm = fig.add_axes(
            [ax_left, 0.2, ax_w, ax_h],
            xlabel="Time (ms)",
            ylim=[-0.5, 1],
            xlim=[0, 80],
        )

        ax_vm.set_ylabel("Membrane\npotential\n(mV)", labelpad=-3)
        ax_eap.set_ylabel("Extracellular potential ($\mu$V)", labelpad=-3)
        ax_cdm.set_ylabel("Curent dipole\nmoment\n($\mu$A$\cdot \mu$m)",
                          labelpad=-3)

        [
            ax_vm.plot(cell.tvec, cell.vmem[idx], c=idx_clr[idx])
            for idx in plot_idxs
        ]

        elec.LFP -= elec.LFP[:, 0, None]
        cell.current_dipole_moment -= cell.current_dipole_moment[0, :]
        summed_LFP.append(elec.LFP)
        summed_cdm.append(cell.current_dipole_moment)

        normalize = np.max(np.abs(elec.LFP))
        for idx in range(len(elec_x)):
            ax_eap.plot(cell.tvec,
                        elec.LFP[idx] / normalize * h + elec_z[idx],
                        c=elec_clr(idx))
            ax_m.plot(elec_x[idx], elec_z[idx], c=elec_clr(idx), marker='D')

        ax_cdm.plot(cell.tvec, 1e-3 * cell.current_dipole_moment[:, 2], c='k')

        mark_subplots([ax_m], xpos=0.1, ypos=0.95)
        simplify_axes(fig.axes)

        plt.savefig(
            join("figures",
                 'laminar_LFP_ca_spike_{}_{}.png'.format(cell_name, sim)))
        # plt.savefig(join("figures", 'hay_ca_spike.pdf'))
        plt.close("all")

    summed_LFP = np.sum(summed_LFP, axis=0)
    summed_cdm = np.sum(summed_cdm, axis=0)
    normalize = np.max(np.abs(summed_LFP))
    plt.subplot(121)
    for idx in range(len(elec_x)):
        plt.plot(cell.tvec,
                 summed_LFP[idx] / normalize * h + elec_z[idx],
                 c=elec_clr(idx))

    plt.subplot(122)
    plt.plot(cell.tvec, summed_cdm[:, 2])

    plt.savefig(
        join("figures",
             'summed_LFP_CDM_{}_num:{}.png'.format(cell_name, num_sims)))
Exemple #5
0
def simulate_spike_current_dipole_moment():

    dt = 2**-5
    cell_name = 'hay'

    jitter_std = 10
    num_trials = 1000

    # Create a grid of measurement locations, in (mum)
    grid_x, grid_z = np.mgrid[-750:751:25, -750:1301:25]

    grid_y = np.zeros(grid_x.shape)

    # Define electrode parameters
    grid_elec_params = {
        'sigma': 0.3,  # extracellular conductivity
        'x': grid_x.flatten(),  # electrode requires 1d vector of positions
        'y': grid_y.flatten(),
        'z': grid_z.flatten()
    }

    elec_x = np.array([
        30,
    ])
    elec_y = np.array([
        0,
    ])
    elec_z = np.array([
        0,
    ])

    elec_params = {
        'sigma': 0.3,  # extracellular conductivity
        'x': elec_x,  # x,y,z-coordinates of contact points
        'y': elec_y,
        'z': elec_z
    }

    elec_clr = 'r'

    cell_woca_data, idx_clr, plot_idxs = run_cell_simulation(
        make_ca_spike=False,
        dt=dt,
        cell_name=cell_name,
        grid_elec_params=grid_elec_params,
        elec_params=elec_params)
    cell_wca_data, idx_clr, plot_idxs = run_cell_simulation(
        make_ca_spike=True,
        dt=dt,
        cell_name=cell_name,
        grid_elec_params=grid_elec_params,
        elec_params=elec_params)

    fig = plt.figure(figsize=[8, 7])
    fig.subplots_adjust(hspace=0.5,
                        left=0.0,
                        wspace=0.4,
                        right=0.96,
                        top=0.97,
                        bottom=0.1)

    ax_m = fig.add_axes([-0.01, 0.05, 0.25, 0.97],
                        aspect=1,
                        frameon=False,
                        xticks=[],
                        yticks=[])
    ax_m.plot(cell_wca_data["cell_x"].T, cell_wca_data["cell_z"].T, c='k')

    [
        ax_m.plot(cell_wca_data["cell_x"][idx].mean(),
                  cell_wca_data["cell_z"][idx].mean(),
                  'o',
                  c=idx_clr[idx],
                  ms=13) for idx in plot_idxs
    ]
    ax_m.plot(elec_x, elec_z, elec_clr, marker='D')

    ax_top = 0.98
    ax_h = 0.15
    h_space = 0.1
    ax_w = 0.17
    ax_left = 0.37

    num = 11
    levels = np.logspace(-2.5, 0, num=num)
    scale_max = 100.

    levels_norm = scale_max * np.concatenate((-levels[::-1], levels))
    bwr_cmap = plt.cm.get_cmap('bwr_r')

    colors_from_map = [
        bwr_cmap(i * np.int(255 / (len(levels_norm) - 2)))
        for i in range(len(levels_norm) - 1)
    ]
    colors_from_map[num - 1] = (1.0, 1.0, 1.0, 1.0)

    spike_plot_time_idxs = [1030, 1151]
    summed_cdm_max = np.zeros(2)
    for plot_row, cell_data in enumerate([cell_woca_data, cell_wca_data]):
        ax_left += plot_row * 0.25

        grid_LFP = cell_data["grid_LFP"]
        cdm = cell_data["cdm"]
        tvec = cell_data["tvec"]
        vmem = cell_data["vmem"]
        elec_LFP = cell_data["elec_LFP"]

        time_idx = spike_plot_time_idxs[plot_row]
        print(tvec[time_idx])

        grid_LFP_ = grid_LFP[:, time_idx].reshape(grid_x.shape)
        ax_ = fig.add_axes([0.75, 0.55 - plot_row * 0.46, 0.3, 0.45],
                           xticks=[],
                           yticks=[],
                           aspect=1,
                           frameon=False)
        mark_subplots(ax_, [["D"], ["E"]][plot_row], ypos=0.95, xpos=0.35)

        ax_.plot(cell_data["cell_x"].T, cell_data["cell_z"].T, c='k')

        ep_intervals = ax_.contourf(grid_x,
                                    grid_z,
                                    grid_LFP_,
                                    zorder=-2,
                                    colors=colors_from_map,
                                    levels=levels_norm,
                                    extend='both')

        ax_.contour(grid_x,
                    grid_z,
                    grid_LFP_,
                    colors='k',
                    linewidths=(1),
                    zorder=-2,
                    levels=levels_norm)

        if plot_row == 1:
            cax = fig.add_axes([0.82, 0.12, 0.16, 0.01])
            cbar = fig.colorbar(ep_intervals,
                                cax=cax,
                                orientation='horizontal',
                                format='%.0E')

            cbar.set_ticks(
                np.array([-1, -0.1, -0.01, 0, 0.01, 0.1, 1]) * scale_max)
            cax.set_xticklabels(np.array(
                np.array([-1, -0.1, -0.01, 0, 0.01, 0.1, 1]) * scale_max,
                dtype=int),
                                fontsize=11,
                                rotation=45)
            cbar.set_label('$\phi$ (µV)', labelpad=-5)

        sum_tvec, summed_cdm = sum_jittered_cdm(cdm[2, :], dt, jitter_std,
                                                num_trials)

        summed_cdm_max[plot_row] = np.max(np.abs(summed_cdm))

        ax_vm = fig.add_axes([ax_left, ax_top - ax_h, ax_w, ax_h],
                             ylim=[-80, 50],
                             xlim=[0, 100])

        ax_eap = fig.add_axes(
            [ax_left, ax_top - 2 * ax_h - h_space, ax_w, ax_h],
            ylim=[-120, 40],
            xlim=[0, 100])
        ax_cdm = fig.add_axes(
            [ax_left, ax_top - 3 * ax_h - 2 * h_space, ax_w, ax_h],
            ylim=[-0.5, 1],
            xlim=[0, 100],
        )

        ax_cdm_sum = fig.add_axes(
            [ax_left, ax_top - 4 * ax_h - 3 * h_space, ax_w, ax_h],
            ylim=[-250, 100],
            xlabel="Time (ms)",
            xlim=[0, 140])
        if plot_row == 0:
            ax_vm.set_ylabel("Membrane\npotential\n(mV)", labelpad=-3)
            ax_eap.set_ylabel("Extracellular\npotential\n(µV)", labelpad=-3)
            ax_cdm.set_ylabel("Curent dipole\nmoment\n(µA$\cdot$µm)",
                              labelpad=-3)
            ax_cdm_sum.set_ylabel("Jittered sum\n(µA$\cdot$µm)", labelpad=-3)
        elif plot_row == 1:
            ax_vm.text(65, -5, "Ca$^{2+}$ spike", fontsize=11, color='orange')
            ax_vm.arrow(80, -10, -10, -18, color='orange', head_width=4)

        mark_subplots(ax_vm, [["B1"], ["C1"]][plot_row], xpos=-0.3, ypos=0.93)
        mark_subplots(ax_eap, [["B2"], ["C2"]][plot_row], xpos=-0.3, ypos=0.93)
        mark_subplots(ax_cdm, [["B3"], ["C3"]][plot_row],
                      xpos=-0.35,
                      ypos=0.97)
        mark_subplots(ax_cdm_sum, [["B4"], ["C4"]][plot_row],
                      xpos=-0.3,
                      ypos=0.93)
        [ax_vm.plot(tvec, vmem[idx], c=idx_clr[idx]) for idx in plot_idxs]
        ax_vm.axvline(tvec[time_idx], ls='--', c='gray')

        [
            ax_eap.plot(tvec, elec_LFP[idx], c=elec_clr)
            for idx in range(len(elec_x))
        ]

        ax_cdm.plot(tvec, 1e-3 * cdm[2, :], c='k')
        ax_cdm_sum.plot(sum_tvec, 1e-3 * summed_cdm, c='k')

    print(
        "Summed CDM max (abs), ratio",
        summed_cdm_max * 1e-3,
    )

    mark_subplots([ax_m], xpos=0.1, ypos=0.95)
    simplify_axes(fig.axes)

    plt.savefig(join("figures", 'Figure5.pdf'))
Exemple #6
0
def animate_ca_spike():
    # Deprecated, and not updated for LFPy2.2 and newer
    dt = 2**-5
    cell_name = 'hay'

    jitter_std = 10
    num_trials = 1000

    # Create a grid of measurement locations, in (mum)
    grid_x, grid_z = np.mgrid[-750:751:25, -750:1301:25]
    grid_y = np.zeros(grid_x.shape)

    # Define electrode parameters
    grid_elec_params = {
        'sigma': 0.3,  # extracellular conductivity
        'x': grid_x.flatten(),  # electrode requires 1d vector of positions
        'y': grid_y.flatten(),
        'z': grid_z.flatten()
    }

    elec_x = np.array([
        30,
    ])
    elec_y = np.array([
        0,
    ])
    elec_z = np.array([
        0,
    ])

    num = 11
    levels = np.logspace(-2.5, 0, num=num)
    scale_max = 100.

    levels_norm = scale_max * np.concatenate((-levels[::-1], levels))
    bwr_cmap = plt.cm.get_cmap('bwr_r')  # rainbow, spectral, RdYlBu

    colors_from_map = [
        bwr_cmap(i * np.int(255 / (len(levels_norm) - 2)))
        for i in range(len(levels_norm) - 1)
    ]
    colors_from_map[num - 1] = (1.0, 1.0, 1.0, 1.0)

    cell, idx_clr, plot_idxs = run_cell_simulation(make_ca_spike=False,
                                                   dt=dt,
                                                   cell_name=cell_name)
    grid_electrode = LFPy.RecExtElectrode(cell, **grid_elec_params)
    grid_electrode.calc_lfp()
    grid_LFP = 1e3 * grid_electrode.LFP

    grid_LFP -= grid_LFP[:, 0, None]
    for time_idx in range(len(cell.tvec))[2000:]:
        plt.close("all")
        fig = plt.figure(figsize=[5, 4])
        fig.text(0.5,
                 0.95,
                 "T={:1.1f} ms".format(cell.tvec[time_idx]),
                 ha="center")
        ax_ = fig.add_axes([0.4, 0.14, 0.6, 0.88],
                           xticks=[],
                           yticks=[],
                           aspect=1,
                           frameon=False)
        cax = fig.add_axes([0.5, 0.15, 0.4, 0.01])
        ax_vm = fig.add_axes([0.25, 0.65, 0.2, 0.3],
                             ylabel="membrane\npotential\n(mV)",
                             xlabel="time (ms)")
        ax_cdm = fig.add_axes([0.25, 0.2, 0.2, 0.3],
                              ylabel="current dipole\nmoment\n(µA$\cdot$µm)",
                              xlabel="time (ms)")
        grid_LFP_ = grid_LFP[:, time_idx].reshape(grid_x.shape)
        [
            ax_.plot([cell.xstart[idx], cell.xend[idx]],
                     [cell.zstart[idx], cell.zend[idx]],
                     c='gray') for idx in range(cell.totnsegs)
        ]
        ax_.plot([350, 350], [-540, -640], c='k', lw=2)
        ax_.text(360, -590, "100 $\mu$m", va='center')
        ep_intervals = ax_.contourf(grid_x,
                                    grid_z,
                                    grid_LFP_,
                                    zorder=-2,
                                    colors=colors_from_map,
                                    levels=levels_norm,
                                    extend='both')

        ax_.contour(grid_x,
                    grid_z,
                    grid_LFP_,
                    colors='k',
                    linewidths=(1),
                    zorder=-2,
                    levels=levels_norm)

        cbar = fig.colorbar(ep_intervals,
                            cax=cax,
                            orientation='horizontal',
                            format='%.0E',
                            extend='max')

        cbar.set_ticks(
            np.array([-1, -0.1, -0.01, 0, 0.01, 0.1, 1]) * scale_max)
        cax.set_xticklabels(np.array(
            np.array([-1, -0.1, -0.01, 0, 0.01, 0.1, 1]) * scale_max,
            dtype=int),
                            fontsize=7,
                            rotation=45)
        cbar.set_label('$\phi$ (µV)', labelpad=-5)
        [
            ax_vm.plot(cell.tvec, cell.vmem[idx], c=idx_clr[idx])
            for idx in plot_idxs
        ]
        ax_cdm.plot(cell.tvec, cell.current_dipole_moment[:, 2], c='k')
        ax_vm.axvline(cell.tvec[time_idx], c='gray', lw=1, ls='--')
        ax_cdm.axvline(cell.tvec[time_idx], c='gray', lw=1, ls='--')
        simplify_axes([ax_vm, ax_cdm])
        anim_fig_folder = "anim_no_ca"
        os.makedirs(anim_fig_folder, exist_ok=True)
        plt.savefig(
            join(anim_fig_folder, "cell_ca_cont_{:04d}.png".format(time_idx)))
            if idx < 100:
                ax1.plot(tvec, cdm[:, 0], lw=0.5, c="0.7")
                ax2.plot(tvec, cdm[:, 1], lw=0.5, c="0.7")
                ax3.plot(tvec, cdm[:, 2], lw=0.5, c="0.7")

    ax1.plot(tvec, summed_cdm[:, 0] / len(files), lw=2, c="k")
    ax2.plot(tvec, summed_cdm[:, 1] / len(files), lw=2, c="k")
    ax3.plot(tvec, summed_cdm[:, 2] / len(files), lw=2, c="k")
    return summed_cdm

for pop_name, subpops in sub_pop_groups_dict.items():

    plt.close("all")
    fig = plt.figure(figsize=[18, 9], )
    fig.suptitle("{}".format(pop_name))
    fig.subplots_adjust(hspace=0.5, left=0.3)

    ax1 = fig.add_subplot(311, title="$P_x$", **ax_dict)
    ax2 = fig.add_subplot(312, title="$P_y$", **ax_dict)
    ax3 = fig.add_subplot(313, title="$P_z$", xlabel="Time (ms)", **ax_dict)
    # ax1.axvline(900, color="gray", zorder=0, lw=0.5)
    # ax2.axvline(900, color="gray", zorder=0, lw=0.5)
    # ax3.axvline(900, color="gray", zorder=0, lw=0.5)
    summed_cdm = plot_and_return_subgroup_cdms(subpops)

    simplify_axes([ax1, ax2, ax3])
    plt.savefig(join(sim_folder, "cdm_{}.png".format(pop_name)))
    plt.savefig(join(sim_folder, "cdm_{}.pdf".format(pop_name)))

    np.save(join(sim_folder, "cdm", "summed_cdm_{}.npy".format(pop_name)), summed_cdm)