Ejemplo n.º 1
0
def compare_synapses(filename, param_list=None, errors=0):
    """Run simulations with different number of (E,I) synapses"""
    if type(param_list) is tuple:
        skip, syn_list = param_list
    else:
        syn_list = param_list
        skip = SKIP_IF_EXISTS
    if not syn_list:
        syn_list = [(120, 10), (140, 50), (180, 130)]
    print("syn_list:{}".format(syn_list))
    load_file(filename)
    load_file(filename)
    vm, cli = get_base_vm_cli(filename)
    h.v_init = vm
    for i, (exc, inh) in enumerate(syn_list):
        h.newSynapses(exc, inh)
        try:
            save_run("{}_{},{}".format(filename, exc, inh), set_cli=cli, skip_if_exists=skip)
        except RuntimeError as re:
            print("Error {} occured on E={} & I={}".format(re, exc, inh))
            print("leftover params = {}".format(syn_list[i:]))
            if errors == 2:
                print("second time error occured, bailing")
                import sys
                sys.exit(-1)
            else:
                compare_synapses(filename, param_list=syn_list[i:], errors=errors + 1)
Ejemplo n.º 2
0
def compare_dynamic_K(filename, param_list=None):
    """Run simulations with dynamic potassium (True) or not (False)
    The base filename should be used (e.g. 'distal') such that
    - distal is the baseline
    - distal_KCC2_True is KCC2 WITH dynamic potassium and
    - distal_KCC2_False is KCC2 WITHOUT dynamic potassium
    """
    if 'KCC2' in filename:
        print("ignoring {} in 'compare_dynamic_K' because param_list should cover this case".format(filename))
        return
    K_list = param_list
    if not K_list:
        K_list = ['control', True, False]
    print("K_list:{}".format(K_list))

    for dyn_K in K_list:
        fname = filename if dyn_K == 'control' else filename + "_KCC2"
        assert (dyn_K != 'control' and 'KCC2' in fname) or (dyn_K == 'control' and 'KCC2' not in fname)
        load_file(fname)
        # get steady state here as it isn't done in run.hoc anymore with the set_cli arg
        vm, cli = get_base_vm_cli(fname, compartment=h.ldend)
        h.v_init = vm
        if dyn_K == 'control':
            print("control (no KCC2 nor KCC2_pot")
            cmd = "forall{" + """
                              cli={cli}
                              cli0_cl_ion={cli}
                           """.format(cli=cli) +\
                  "}"
            h(cmd)
            save_run("{}_{}".format(fname, 0), set_cli=cli)
            pot = ""
        elif dyn_K:
            # then add KCC2 potassium
            cmd = """forall{
                uninsert pasghk
                uninsert KCC2
                insert pas
                insert KCC2_pot
                g_pas = 0.00021
                }"""
            h(cmd)
            pot = "_pot"
        else:
            # already have KCC2 in the neuron
            pot = ""
        cmd = "forall{" + """
                  cli={cli}
                  cli0_cl_ion={cli}
                  cli0_KCC2{pot}={cli}
               """.format(cli=cli, pot=pot) +\
              "}"
        h(cmd)
        save_run("{}_{}".format(fname, int(dyn_K)), set_cli=cli)
Ejemplo n.º 3
0
def compare_pas(filename, param_list=None):
    """Change input resistance by adjusting the passive leak channel.
    The leak channel has K+, Na+, and Cl- conductances with a set ratio (see config.hoc).
    Calling h.changePas(<new passive K conductance>) respects the ratio.
    Saved filenames include the relative change in PAS, the Input Resistance, the steady-state cli (at 0 input),
    and membrane potential.
    """
    rel_gkpbar = param_list
    if not rel_gkpbar:
        rel_gkpbar = [0.8, 1, 1.2, 1.5]
    print("rel_gkpbar:{}".format(rel_gkpbar))
    load_file(filename)
    # list of (exc,inh) synapses numbers
    imp = h.Impedance()
    imp.loc(0.5, sec=h.soma)
    for p in rel_gkpbar:
        h.changePas(h.g_pas_k_config*p)
        vm, cli = get_base_vm_cli(filename, compartment=h.soma)
        imp.compute(0, 1)  # calculate impedance at 0 Hz (i.e. resistance)
        ir = imp.input(0.5, sec=h.soma)  # input resistance at 0.5
        print("when p={}, Rm={}, [Cl]i={}, and Vm={}".format(p, ir, cli, vm))
        save_run("{}_[{},{:.2f},{:.2f},{:.2f}]".format(filename, p, ir, cli, vm))
Ejemplo n.º 4
0
def pyrun(file_name,
          synapse_type=1,
          synapse_numbers=(100, 100),
          syn_input=None,
          diam=None,
          pa_kcc2=None,
          location='axon',
          trials=1,
          save=True,
          tstop=TSTOP,
          **kwargs):
    """
    Run a NEURON simulation for a neuron specified in ``file_name`` with input specified by other parameters provided.

    :param file_name: Neuron definition (excluding '.hoc').
    :type file_name: str
    :param synapse_type: Type of synapses to use (0 for frequency-based 'f-in', 1 for persistent conductance, 'gclamp').
    :type synapse_type: int
    :param synapse_numbers: Number of (E, I) on the neuron.
    :type synapse_numbers: (int, int)
    :param syn_input: Mapping of excitatory/inhibitory type to input strength. {'ex', E, 'in: I}
    :type syn_input: dict[str, int]
    :param diam: Re-specify diam for specific regions of the neuron. Valid: 'ldend', 'bdend', 'soma', 'axon'.
    :type diam: dict[str: float]
    :param pa_kcc2: Strength of KCC2.
    :type pa_kcc2: float
    :param location: Location to recording firing rate.
    :type location: str
    :param trials: Number of repeated simulations to run.
    :type trials: int
    :param save: Whether to load/save the results from/to file.
    :type save: bool
    :param tstop: Length of simulation (ms).
    :type tstop: float
    :param kwargs: Other keywords are ignored.

    :return: Pair of DataFrame with results and name of save file (even if not saved, the name is generated).
    :rtype: (pd.DataFrame, str)
    """
    if syn_input is None:
        syn_input = {'in': 5, 'ex': 5}
    save_name = "{}_{}_{}_{}".format(file_name, synapse_type, synapse_numbers,
                                     syn_input)
    save_name += "_{}".format(diam) if diam is not None else ''
    save_name += "_{}".format(pa_kcc2) if pa_kcc2 is not None else ''
    save_name += "_{}_{}".format(location, trials)
    save_name = save_name.replace(" ", "").replace("'", "").replace(
        ":", "=").replace("{", "(").replace("}", ")")
    logger.info(save_name)
    if save:
        loaded = load_from_file(save_name)
        if loaded is not None:
            return loaded, save_name

    load_file(file_name)
    load_file(file_name)
    if diam is not None:
        for seg in diam.keys():
            nrn_seg = get_compartment(seg)
            nrn_seg.diam = diam[seg]
    if pa_kcc2 is not None:
        hoc_cmd = "forall {" + """
            Pa_KCC2 = {pa_kcc2}e-5""".format(pa_kcc2=pa_kcc2) +\
                  " }"
        h(hoc_cmd)
    compartment = get_compartment(location)
    if file_name.find('distal') > -1:
        cli_rec_loc = get_compartment('ldend')
    elif file_name.find('proximal') > -1 or file_name.find('proximal') > -1:
        cli_rec_loc = get_compartment('bdend')
    else:
        cli_rec_loc = get_compartment('soma')

    logger.info("recording cli from {}".format(cli_rec_loc.hname()))

    h("access {}".format(location))
    h.changeSynapseType(synapse_type)
    h.newSynapses(synapse_numbers[0], synapse_numbers[1])
    h.inPy(0)
    h.ex(0)
    vm_init, cli = get_base_vm_cli(file_name, compartment=compartment)
    h.v_init = vm_init
    h_str = """
        forall{""" + """
            cli = {0}
            cli0_cl_ion = {0}
            """.format(cli)
    if file_name.find("KCC2") > 0:
        h_str += """cli0_KCC2 = {0}
        """.format(cli)
    h_str += "}"
    h(h_str)

    h.tstop = tstop

    if synapse_type == 0:
        # Hz, duration (s), start (ms), noise, weight/channels
        h.inPy(syn_input['in'], h.tstop / 1000, 0, 1, 1)
        h.ex(syn_input['ex'], h.tstop / 1000, 0, 1, 1)
    else:
        h.inPy(syn_input['in'])
        h.ex(syn_input['ex'])

    # create recording vector objects
    t_rec = h.Vector()
    v_rec = h.Vector()
    cli_rec = h.Vector()
    spike_rec = h.Vector()

    t_rec.record(h._ref_t)
    v_rec.record(compartment(0.5)._ref_v)
    cli_rec.record(cli_rec_loc(0.5)._ref_cli)
    spike_rec.record(h.apc._ref_n)

    time_past = current_time('ms')
    logger.info("using {}...".format(file_name))
    assert trials > 0
    logger.info(save_name)
    logger.info("trial # | # spikes")
    df = pd.DataFrame()
    for i in range(trials):
        trial_num = i + 1
        h.run()
        logger.info("{:7} | {:8}".format(trial_num, h.apc.n))

        temp_dict = {
            (trial_num, 'v'): v_rec.as_numpy(),
            (trial_num, 'cli'): cli_rec.as_numpy(),
            (trial_num, 'spikes'): spike_rec.as_numpy()
        }
        recording = pd.DataFrame(temp_dict, index=t_rec.to_python())

        df = pd.concat([df, recording], axis=1)

    logger.info("time taken: {}ms".format(current_time('ms') - time_past))

    if save:
        save_to_file(save_name, df)
    return df, save_name
def figure_v_traces(inh_region="distal",
                    KCC2=True,
                    show_cli=False,
                    base=True,
                    synapse_type=0,
                    synapse_numbers=(100, 100),
                    hz=None,
                    mean=False,
                    savefig=False,
                    **kwargs):
    """
    Display voltage traces for an inhibitory input region (e.g. "distal") and (E, I) synapse pairs.

    Optionally include an axis for [Cl-]i.

    :param inh_region: Region for inhibitory synapses. Only those in hoc_files/cells/ are supported.
    :type inh_region: str
    :param KCC2: Open the '_KCC2' version of the hoc file instead if True.
    :type KCC2: bool
    :param show_cli: Include a plot for [Cl-]i. Only applicable when KCC2 is True.
    :type show_cli: bool
    :param base: Find steady-state values for vm and cli (`True`),
        provide the steady-state values (`(<vm value>, <cli value>)`),
        use defaults of -71 mV and 4.25 mM for vm and cli, respectively(`False`).
    :type base: bool or tuple of float
    :param synapse_type: Use frequency-based synapses (`0`) or persistant synapses (`1`).
    :type synapse_type: int
    :param synapse_numbers: Synapses numbers for the neuron in (E, I) format. A list of (E,I) pairs can be provided
        for multiple traces.
    :type synapse_numbers: list of (tuple of (int)) or tuple of int
    :param hz: Frequency of synapses (if `synapse_type` is `0`) or the relative conductance (if `synapse_type` is `1`).
    :type hz: dict
    :param mean: Additionally plot the mean vm and cli values using `plot_v_trace`
    :type mean: bool
    :param savefig: Save the figure to results_plots
    :type savefig: bool
    :param kwargs: Keyword arguments to pass to `get_trace`
    :type kwargs: dict
    :return: Steady-state voltage and [Cl-]i used for this simulation
    :rtype:
    """
    if hz is None:
        hz = {'in': 5, 'ex': 5}
    logger.info(
        f"figure_v_traces(inh_region={inh_region} KCC2={KCC2}, synapse_numbers={synapse_numbers}, hz={hz})"
    )
    cmap_name = "Blues" if inh_region == "distal" else "Greens"
    cmap = sns.color_palette(cmap_name, n_colors=len(synapse_numbers))
    if mean: cmap = None
    filename = inh_region
    if KCC2:
        filename += "_KCC2"
    else:
        show_cli = False
    logger.info("getting base vm cli")
    if base:
        if type(base) is tuple:
            vm, cli = base
        else:
            vm, cli = get_base_vm_cli(f"{inh_region}_KCC2", load=True)
    else:
        vm, cli = -71., 4.25

    if type(synapse_numbers) is tuple:
        synapse_numbers = [synapse_numbers]

    dynamic_data = []
    spikes = []
    for syn_num in synapse_numbers:
        trace_kwargs = dict(synapse_type=synapse_type,
                            synapse_numbers=syn_num,
                            hz=hz,
                            space_plot=False,
                            **kwargs)
        dynamic_data.append(get_trace(filename, vm=vm, cli=cli,
                                      **trace_kwargs))
        spikes.append(h.apc.n)
    logger.info("plotting voltage")
    fig, ax = plot_v_trace(*dynamic_data,
                           show_cli=show_cli,
                           cmap=cmap,
                           mean=mean)
    if not mean:
        for i, (spike_num, syn_num) in enumerate(zip(spikes, synapse_numbers)):
            ax[i, 0].text(ax[i, 0].get_xlim()[1],
                          0,
                          f"{spike_num:.0f}",
                          ha="left")
            ax[i, 0].set_title(str(syn_num).replace("(", "").replace(")", ""),
                               fontsize='small',
                               va='top')
    title = filename + str(hz)
    title = title.replace("{",
                          "\n(").replace("}",
                                         ")").replace("'",
                                                      "").replace(": ", "=")
    if synapse_type == 0:
        title = title.replace(",", " Hz,").replace(")", " Hz)")
    fig.suptitle(title, fontsize='medium')
    fig.align_ylabels(ax[:, 0])
    sns.despine(fig, bottom=True, left=True)
    for _ax in ax[:-1, 0]:
        _ax.xaxis.set_visible(False)
        _ax.yaxis.set_visible(False)
    sns.despine(ax=ax[-1, 0])
    fig.subplots_adjust(left=0.15, bottom=0.15)
    if savefig:
        create_dir("results_plots")
        fig.savefig(f"results_plots/figure_trace_{title}.png",
                    bbox_inches='tight',
                    facecolor='None')
        fig.savefig(f"results_plots/figure_trace_{title}.pdf",
                    bbox_inches='tight',
                    facecolor='None')
    return vm, cli
def figure_cli_heatmaps(distal,
                        proximal,
                        freqs=(5, 10, 25, 50),
                        n_trials=1,
                        show_fr=False,
                        vmin=0,
                        vmax=None,
                        savefig=False):
    """
    Heatmaps of [Cl-]i for different frequencies and (E,I) synapse pairs and for both proximal and distal.

    :param distal: List (E,I) synapse number pairs for distal input.
    :type distal: list of (tuple of (float))
    :param proximal: List (E,I) synapse number pairs for proximal input.
    :type proximal: list of (tuple of (float))
    :param freqs:
    :type freqs: list of Number or tuple of Number
    :param n_trials: Nomber of repeated runs.
    :type n_trials: int
    :param show_fr: Show the firing rate to the right of the heatmaps when plotting.
        True shows the firing rate as an arrow and text.
        Any value greater than 1 shows the firing rate as a heatmap cell (different colormap)
    :type show_fr: bool or int
    :param vmin: Minimum [Cl-]i for color range.
    :type vmin: float
    :param vmax: Maximum [Cl-]i for color range.
    :type vmax: float
    :param savefig: Whether to save the figure to results_plots
    :type savefig: bool
    """
    logger.info(f"figure_cli_heatmaps(distal={distal}, proximal={proximal}, "
                f"n_trials={n_trials}, show_fr={show_fr})")

    fig, ax2d = plt.subplots(nrows=2,
                             ncols=len(distal) + 1 + int(show_fr > 1),
                             figsize=(8, 5),
                             gridspec_kw={
                                 'width_ratios': [15] * len(distal) + [2] *
                                 (1 + int(show_fr > 1))
                             })
    cmap = sns.cubehelix_palette(16,
                                 start=2.7,
                                 rot=-.2,
                                 light=0.98,
                                 dark=0.40,
                                 as_cmap=True)
    cmap_fr = sns.color_palette("Reds" if show_fr > 1 else "magma",
                                n_colors=200,
                                desat=1)
    vmax_fr = 5
    # h.hoc_stdout("hoc_output_traces.txt")
    for fdx, (filename, synapse_numbers) in enumerate(
            zip(["distal_KCC2", "proximal_KCC2"], [distal, proximal])):
        ax = ax2d[fdx, :]
        cbar_ax = ax[-1 - int(show_fr > 1)]
        vm, cli = get_base_vm_cli(filename, load=True)

        d_cli = {}
        fr = {}
        stddev = {}
        for i, syn_n in enumerate(synapse_numbers):
            kwargs = dict(synapse_type=0,
                          synapse_numbers=syn_n,
                          space_plot=True)
            for hz in freqs:
                n_spikes = []
                sec_means = [0, 0, 0,
                             0]  # 4 sections [ldend, bdend, soma, axon]
                for t in range(n_trials):
                    logger.info(
                        f"filename={filename} syn_n={syn_n} hz={hz} t={t}")
                    _, _, _, data = get_trace(filename,
                                              vm=vm,
                                              cli=cli,
                                              hz={
                                                  'in': hz,
                                                  'ex': hz
                                              },
                                              **kwargs)
                    n_spikes.append(h.apc.n)
                    x, y = space_from_dict(data)
                    # separate x into regions based on known lengths
                    ldend = x <= -50
                    bdend = np.logical_and(-50 <= x, x <= 0)
                    soma = np.logical_and(0 <= x, x <= 15)
                    axon = x >= 15
                    for s, sec in enumerate([ldend, bdend, soma, axon]):
                        sec_means[s] += y[sec].mean()
                sec_means = [m / n_trials for m in sec_means]
                d_cli[(f"{syn_n[0]:>3.0f}:{syn_n[1]:>3.0f}", hz)] = sec_means
                fr[(f"{syn_n[0]:>3.0f}:{syn_n[1]:>3.0f}",
                    hz)] = sum(n_spikes) / n_trials
                stddev[(f"{syn_n[0]:>3.0f}:{syn_n[1]:>3.0f}",
                        hz)] = np.std(n_spikes)

        df = pd.DataFrame.from_dict(
            d_cli,
            orient='index',
            columns=["Distal\nDendrite", "Proximal\nDendrite", "Soma", "Axon"])
        df_fr = pd.DataFrame.from_dict(fr, orient='index', columns=["Output"])
        df = df.reindex(pd.MultiIndex.from_tuples(df.index))
        df_fr = df_fr.reindex(pd.MultiIndex.from_tuples(df_fr.index))
        vmin = vmin or df.values.min()
        vmax = vmax or df.values.max()
        vmin_fr = 0
        vmax_fr = max(
            vmax_fr,
            df_fr.values.max()) * 1.1  # give a 10% buffer for the cmap

        logger.info("plotting cli_heatmaps")
        for i, syn_n in enumerate(df.index.levels[0]):
            df_syn = df.loc[syn_n]
            df_fr_syn = df_fr.loc[syn_n]
            sns.heatmap(
                df_syn,
                ax=ax[i],
                annot=False,
                fmt=".1f",
                square=True,
                annot_kws=dict(fontsize='xx-small'),
                vmin=vmin,
                vmax=vmax,
                cmap=cmap,
                cbar=(i == 0),
                cbar_ax=None if i > 0 else cbar_ax,
            )

            if show_fr > 1:
                new_df = pd.concat([df_syn, df_fr_syn], axis='columns')
                mask = np.ones(df_syn.shape)
                mask = np.append(mask, np.zeros(df_fr_syn.shape), axis=1)
                cbar_fr_ax = ax[-1]
                sns.heatmap(
                    new_df,
                    ax=ax[i],
                    annot=True,
                    fmt=".1f",
                    square=False,
                    annot_kws=dict(fontsize='xx-small'),
                    vmin=vmin_fr,
                    vmax=vmax_fr,
                    cmap=cmap_fr,
                    mask=mask,
                    cbar=(i == 0),
                    cbar_ax=cbar_fr_ax,
                    cbar_kws={'label': "Firing rate (Hz)"},
                )
            elif show_fr:
                for j, _fr in enumerate(df_fr_syn['Output']):
                    idx = (len(cmap_fr) - 1) * (_fr - vmin_fr) // (vmax_fr -
                                                                   vmin_fr)
                    c = cmap_fr[int(idx)]
                    text = ax[i].annotate(f"{_fr:>2.1f}",
                                          xy=(4, j + 0.5),
                                          xytext=(4.6, j + 0.5),
                                          color=c,
                                          alpha=1,
                                          arrowprops={'arrowstyle': '<-'},
                                          fontsize='x-small',
                                          va='center')
                    # path effects can sometimes make text clearer, but it can also make things worse...
                    # text.set_path_effects([path_effects.Stroke(linewidth=0.5, foreground='black', alpha=0.5),
                    #                        path_effects.Normal()])
            ax[i].set_xticklabels(ax[i].get_xticklabels(),
                                  fontsize="small",
                                  rotation=45)
            ax[i].set_title(syn_n, fontsize='small')
            if i == 0:
                ax[i].set(ylabel='Balanced Input (Hz)')
                ax[i].set_yticklabels(df_syn.index, rotation=0)
            else:
                ax[i].set_yticklabels([])
            if i == 0:
                cbar_ax.set_xlabel(f'{settings.CLI} (mM)', ha='center')
                cbar_ax.xaxis.set_label_position('top')
    fig.tight_layout()
    if savefig:
        create_dir("results_plots", timestamp=False)
        fig.savefig(f"results_plots/figure_cli_heatmaps_{int(show_fr)}.png",
                    bbox_inches='tight',
                    facecolor='None')
        fig.savefig(f"results_plots/figure_cli_heatmaps_{int(show_fr)}.pdf",
                    bbox_inches='tight',
                    facecolor='None')
def figure_cli_distribution(
        syn_n_dist=(250, 300), syn_n_prox=(330, 90), savefig=False):
    """
    Plot the distribution of [Cl-]i along a neuron for distal and proximal inhibitory input.

    :param syn_n_dist: Number of (E, I) synapses for the distal input simulation.
    :type syn_n_dist: tuple
    :param syn_n_prox: Number of (E, I) synapses for the proximal input simulation.
    :type syn_n_prox: tuple
    :param savefig: Whether to save the figure to results_plots.
    :type savefig: bool
    """
    txt = f"figure_cli_distribution(syn_n_dist={syn_n_dist}, syn_n_prox={syn_n_prox})"
    logger.info(txt)
    fig, ax = plt.subplots(nrows=2,
                           ncols=1,
                           figsize=(8, 3.5),
                           sharey=True,
                           sharex=True)
    fig.subplots_adjust(hspace=0.5)

    for filename, syn_n, _ax in zip(['distal_KCC2', 'proximal_KCC2'],
                                    [syn_n_dist, syn_n_prox], ax):
        # steady state vm and [Cl-]i
        vm, cli = get_base_vm_cli(filename, load=True)
        trace_kwargs = dict(
            synapse_type=0,  # frequency-based
            synapse_numbers=syn_n,  # (E,I)
            space_plot=True)  # compute spatial component for cli

        # run simulations for different frequencies
        _, _, _, inh = get_trace(filename,
                                 vm=vm,
                                 cli=cli,
                                 hz={
                                     'in': 5,
                                     'ex': 0
                                 },
                                 **trace_kwargs)
        _, _, _, exc = get_trace(filename,
                                 vm=vm,
                                 cli=cli,
                                 hz={
                                     'in': 0,
                                     'ex': 5
                                 },
                                 **trace_kwargs)
        _, _, _, equal = get_trace(filename,
                                   vm=vm,
                                   cli=cli,
                                   hz={
                                       'in': 5,
                                       'ex': 5
                                   },
                                   **trace_kwargs)

        logger.info(f"space plot {filename}")
        space_plot(exc, ax=_ax, color=COLOR.E, label='5 : 0')
        space_plot(inh, ax=_ax, color=COLOR.I, label='0 : 5')
        space_plot(equal, ax=_ax, color=COLOR.E_I, label='5 : 5')
        # display some info on figure
        name = filename.split("_")[0].capitalize()
        _ax.set_title(
            f"{name} input (E: I) (# of synapses)\n"
            f"{syn_n[0]}:{syn_n[1]:>3.0f}",
            va='top',
            fontsize='medium')
    # light adjustments
    ax[0].set_xlabel("")
    # ax[0].set_xticklabels([])
    ax[0].set_xlim(-550, 500)  # 0 is soma(0)
    ax[-1].legend(title='Input (E : I) (Hz)', loc='upper right', frameon=False)
    if savefig:
        create_dir("results_plots", timestamp=False)
        fig.savefig(f"results_plots/{txt}.png",
                    bbox_inches='tight',
                    frameon=False)
        fig.savefig(f"results_plots/{txt}.pdf",
                    bbox_inches='tight',
                    frameon=False)