コード例 #1
0
def make_param_dict(param_dir):
    dict_list = os.listdir(param_dir)

    #Store keys that change over files
    all_array_keys = []

    #Initialize with first file in list, iterate over all files
    _, dict_array = paramrw.read(param_dir + '/' + dict_list[0])

    for f in dict_list[1:]:
        _, p = paramrw.read(param_dir + '/' + f)
        for key in p.keys():
            #Look for parameters that change across files, turn into list
            if dict_array[key] != p[key]:
                all_array_keys.append(key)

    #Reduce to unique elements
    array_keys = list(np.unique(all_array_keys))

    #Append array_key values for every file
    for f in dict_list[1:]:
        _, p = paramrw.read(param_dir + '/' + f)
        for key in array_keys:
            #Look for parameters that change across files, turn into list
            if type(dict_array[key]) == list:
                dict_array[key].append(p[key])
            else:
                dict_array[key] = [dict_array[key], p[key]]

    return dict_array, array_keys
コード例 #2
0
def example_analysis_for_simulation():
    # from these two directories
    droot = fio.return_data_dir()
    dsim = os.path.join(droot, '2016-02-03/beta-sweep-000')

    # create the SimulationPaths() object ddata and read the simulation
    ddata = fio.SimulationPaths()
    ddata.read_sim(droot, dsim)

    # print dir(ddata)
    # print type(np.zeros(5))

    # print ddata.expmt_groups
    # print ddata.fparam
    # for key, val in ddata.dfig['testing'].items():
    #     print key, val
    # print dir({})

    # p_exp = paramrw.ExpParams(ddata.fparam)
    # print p_exp.p_all['dt']
    # # print p_exp.p_all

    # iterate through experimental groups and do the analysis on individual files, etc.
    for expmt_group in ddata.expmt_groups:
        print "experiment group is: %s" % (expmt_group)
        # print ddata.dfig[expmt_group]
        flist_param = ddata.file_match(expmt_group, 'param')
        flist_dpl = ddata.file_match(expmt_group, 'rawdpl')
        # flist_spk = ddata.file_match(expmt_group, 'rawspk')
        # fio.prettyprint(flist_spk)

        # iterate through files in the lists
        for fparam, fdpl in zip(flist_param, flist_dpl):
            # print fparam, fdpl
            gid_dict, p_tr = paramrw.read(fparam)

            # for key, val in p_tr.items():
            #     print key, val
            # fio.prettyprint(p_tr.keys())

            # create and load dipole data structure
            d = dipolefn.Dipole(fdpl)

            # more or less analysis goes here.
            # generate a filename for a dipole plot
            fname_png = ddata.return_filename_example('figdpl',
                                                      expmt_group,
                                                      p_tr['Sim_No'],
                                                      tr=p_tr['Trial'])
            # print p_tr['Trial'], p_tr['Sim_No'], fname_png

            # example figure for this pair of files
            fig = PT_example.FigExample()

            # plot dipole
            fig.ax['dipole'].plot(d.t, d.dpl['agg'])
            fig.ax['dipole'].plot(d.t, d.dpl['L2'])
            fig.ax['dipole'].plot(d.t, d.dpl['L5'])
            fig.savepng(fname_png)
            fig.close()
コード例 #3
0
    def __init__(self,
                 tvec,
                 tsvec,
                 fparam,
                 f_max=None,
                 p_dict=None,
                 tmin=50.0,
                 f_min=1.):
        # Save variable portion of fdata_spec as identifying attribute
        # self.name = fdata_spec

        # Import dipole data and remove extra dimensions from signal array.
        self.tvec = tvec
        self.tsvec = tsvec

        # function is called this way because paramrw.read() returns 2 outputs
        if p_dict is None:
            self.p_dict = paramrw.read(fparam)[1]
        else:
            self.p_dict = p_dict

        self.f_min = f_min

        # maximum frequency of analysis
        # Add 1 to ensure analysis is inclusive of maximum frequency
        if not f_max:
            self.f_max = self.p_dict['f_max_spec'] + 1
        else:
            self.f_max = f_max + 1

        # cutoff time in ms
        self.tmin = tmin

        # truncate these vectors appropriately based on tmin
        if self.p_dict['tstop'] > self.tmin:
            # must be done in this order! timeseries first!
            self.tsvec = self.tsvec[self.tvec >= self.tmin]
            self.tvec = self.tvec[self.tvec >= self.tmin]

        # Check that tstop is greater than tmin
        if self.p_dict['tstop'] > self.tmin:
            # Array of frequencies over which to sort
            self.f = np.arange(self.f_min, self.f_max)

            # Number of cycles in wavelet (>5 advisable)
            self.width = 7.

            # Calculate sampling frequency
            self.fs = 1000. / self.p_dict['dt']

            # Generate Spec data
            self.TFR = self.__traces2TFR()

            # Add time vector as first row of TFR data
            # self.TFR = np.vstack([self.timevec, self.TFR])

        else:
            print(
                "tstop not greater than %4.2f ms. Skipping wavelet analysis." %
                self.tmin)
コード例 #4
0
def freqpwr_with_hist(ddata, dsim):
    fspec_list = fio.file_match(ddata.dsim, '-spec.npz')
    spk_list = fio.file_match(ddata.dsim, '-spk.txt')
    fparam_list = fio.file_match(ddata.dsim, '-param.txt')

    p_exp = paramrw.ExpParams(ddata.fparam)
    key_types = p_exp.get_key_types()

    # If no save spec reslts exist, redo spec analysis
    if not fspec_list:
        print "No saved spec data found. Performing spec analysis...",
        exec_spec_regenerate(ddata)
        fspec_list = fio.file_match(ddata.dsim, '-spec.npz')
        # spec_results = exec_spec_regenerate(ddata)

        print "now doing spec freq-pwr analysis"

    # perform freqpwr analysis
    freqpwr_results_list = [specfn.freqpwr_analysis(fspec) for fspec in fspec_list]

    # Plot
    for freqpwr_result, f_spk, fparam in zip(freqpwr_results_list, spk_list, fparam_list):
        gid_dict, p_dict = paramrw.read(fparam)
        file_name = 'freqpwr.png'

        specfn.pfreqpwr_with_hist(file_name, freqpwr_result, f_spk, gid_dict, p_dict, key_types)
コード例 #5
0
    def add_row_labels(self, param_list, key):
        gap = (0.9 - self.top_margin) / self.N_rows + self.gap_height

        for i in range(0, self.N_rows):
            ind = self.N_cols * i
            p_dict = paramrw.read(param_list[ind])[1]

            # place text in middle of each row of axes
            x = self.left_margin / 2.
            y = 1. - self.top_margin - self.gap_height - gap / 2. - gap * i

            # self.f.text(x, y, key+': %s' %p_dict[key], fontsize=36, rotation='vertical', horizontalalignment='left', verticalalignment='center')

            # try using key as a key in param dict
            try:
                self.f.text(x, y, key+': %s' %p_dict[key], fontsize=36, rotation='vertical', horizontalalignment='left', verticalalignment='center')

            # if this doesn't work, use individual parts of key as labels
            except:
                # check to see if there are enough args in key
                if len(key) == self.N_rows:
                    self.f.text(x, y, key[i], fontsize=36, rotation='vertical', horizontalalignment='left', verticalalignment='center')

                # if not, do nothing
                else:
                    print("Dude, the number of labels don't match the number of rows. I can't do nothing now.")
                    return 0
コード例 #6
0
ファイル: plotfn.py プロジェクト: SMPugliese/hnn_calcium
def pkernel(dfig,
            f_param,
            f_spk,
            f_dpl,
            f_spec,
            key_types,
            xlim=None,
            ylim=None):
    gid_dict, p_dict = paramrw.read(f_param)
    tstop = p_dict['tstop']
    # fig dirs
    dfig_dpl = dfig['figdpl']
    dfig_spec = dfig['figspec']
    dfig_spk = dfig['figspk']
    pdipole_dict = {
        'xlim': xlim,
        'ylim': ylim,
        # 'xmin': xlim[0],
        # 'xmax': xlim[1],
        # 'ymin': None,
        # 'ymax': None,
    }
    # plot kernels
    praster(f_param, tstop, f_spk, dfig_spk)
    dipolefn.pdipole(f_dpl, dfig_dpl, pdipole_dict, f_param, key_types)
    # dipolefn.pdipole(f_dpl, f_param, dfig_dpl, key_types, pdipole_dict)
    # usage of xlim to pspec is temporarily disabled. pspec_dpl() will use internal states for plotting
    pspec.pspec_dpl(f_spec, f_dpl, dfig_spec, p_dict, key_types, xlim, ylim,
                    f_param)
    # pspec.pspec_dpl(f_spec, f_dpl, dfig_spec, p_dict, key_types)
    # pspec.pspec_dpl(data_spec, f_dpl, dfig_spec, p_dict, key_types, xlim)
    return 0
コード例 #7
0
    def set_title(self, fparam, key_types):
        # get param dict
        p_dict = paramrw.read(fparam)[1]

        # create_title() is external fn
        title = create_title(p_dict, key_types)
        self.f.suptitle(title)
コード例 #8
0
def pmaxpwr(file_name, results_list, fparam_list):
    f = ac.FigStd()
    f.ax0.hold(True)

    # instantiate lists for storing x and y data
    x_data = []
    y_data = []

    # plot points
    for result, fparam in zip(results_list, fparam_list):
        p = paramrw.read(fparam)[1]

        x_data.append(p['f_input_prox'])
        y_data.extend(result['freq_at_max'])

        f.ax0.plot(x_data[-1], y_data[-1], 'kx')

    # add trendline
    fit = np.polyfit(x_data, y_data, 1)
    fit_fn = np.poly1d(fit)

    f.ax0.plot(x_data, fit_fn(x_data), 'k-')

    # Axis stuff
    f.ax0.set_xlabel('Proximal/Distal Input Freq (Hz)')
    f.ax0.set_ylabel('Freq at which max avg power occurs (Hz)')

    f.save(file_name)
コード例 #9
0
def spec_dpl_kernel(fparam, fts, fspec, f_max):
    dpl = dipolefn.Dipole(fts)
    dpl.units = 'nAm'

    # Do the conversion prior to generating these spec
    # dpl.convert_fAm_to_nAm()

    # Generate various spec results
    spec_agg = MorletSpec(dpl.t, dpl.dpl['agg'], fparam, f_max)
    spec_L2 = MorletSpec(dpl.t, dpl.dpl['L2'], fparam, f_max)
    spec_L5 = MorletSpec(dpl.t, dpl.dpl['L5'], fparam, f_max)

    # Get max spectral power data
    # for now, only doing this for agg
    max_agg = spec_agg.max()

    # Generate periodogram resutls
    p_dict = paramrw.read(fparam)[1]
    pgram = Welch(dpl.t, dpl.dpl['agg'], p_dict['dt'])

    # Save spec results
    np.savez_compressed(fspec,
                        time=spec_agg.t,
                        freq=spec_agg.f,
                        TFR=spec_agg.TFR,
                        max_agg=max_agg,
                        t_L2=spec_L2.t,
                        f_L2=spec_L2.f,
                        TFR_L2=spec_L2.TFR,
                        t_L5=spec_L5.t,
                        f_L5=spec_L5.f,
                        TFR_L5=spec_L5.TFR,
                        pgram_p=pgram.P,
                        pgram_f=pgram.f)
コード例 #10
0
def pdipole_with_hist(f_dpl, f_spk, dfig, f_param, key_types, plot_dict):
  """ this function has not been converted to use the Dipole() class yet
  """
  # dpl is an obj of Dipole() class
  dpl = Dipole(f_dpl)
  dpl.baseline_renormalize(f_param)
  dpl.convert_fAm_to_nAm()
  # split to find file prefix
  file_prefix = f_dpl.split('/')[-1].split('.')[0]
  # grabbing the p_dict from the f_param
  _, p_dict = paramrw.read(f_param)
  # get xmin and xmax from the plot_dict
  if plot_dict['xmin'] is None:
    xmin = 0.
  else:
    xmin = plot_dict['xmin']
  if plot_dict['xmax'] is None:
    xmax = p_dict['tstop']
  else:
    xmax = plot_dict['xmax']
  # truncate tvec and dpl data using logical indexing
  t_range = dpl.t[(dpl.t >= xmin) & (dpl.t <= xmax)]
  dpl_range = dpl.dpl['agg'][(dpl.t >= xmin) & (dpl.t <= xmax)]
  # Plotting
  f = ac.FigDplWithHist()
  # dipole
  f.ax['dipole'].plot(t_range, dpl_range)
  # set new xlim based on dipole plot
  xlim_new = f.ax['dipole'].get_xlim()
  # Get extinput data and account for delays
  extinputs = spikefn.ExtInputs(f_spk, f_param)
  extinputs.add_delay_times()
  # set number of bins (150 bins per 1000ms)
  bins = ceil(150. * (xlim_new[1] - xlim_new[0]) / 1000.) # bins needs to be an int
  # plot histograms
  hist = {}
  hist['feed_prox'] = extinputs.plot_hist(f.ax['feed_prox'], 'prox', dpl.t, bins, xlim_new, color='red')
  hist['feed_dist'] = extinputs.plot_hist(f.ax['feed_dist'], 'dist', dpl.t, bins, xlim_new, color='green')
  # Invert dist histogram
  f.ax['feed_dist'].invert_yaxis()
  # for now, set the xlim for the other one, force it!
  f.ax['dipole'].set_xlim(xlim_new)
  f.ax['feed_prox'].set_xlim(xlim_new)
  f.ax['feed_dist'].set_xlim(xlim_new)
  # set hist axis properties
  f.set_hist_props(hist)
  # Add legend to histogram
  for key in f.ax.keys():
    if 'feed' in key:
      f.ax[key].legend()
  # force xlim on histograms
  f.ax['feed_prox'].set_xlim((xmin, xmax))
  f.ax['feed_dist'].set_xlim((xmin, xmax))
  title_str = ac.create_title(p_dict, key_types)
  f.f.suptitle(title_str)
  fig_name = os.path.join(dfig, file_prefix+'.png')
  plt.savefig(fig_name)
  f.close()
コード例 #11
0
    def add_column_labels(self, param_list, key):
        # override = {'fontsize': 8*self.N_cols}

        gap = (0.85 - self.left_margin) / self.N_cols + self.gap_width

        for i in range(0, self.N_cols):
            p_dict = paramrw.read(param_list[i])[1]

            x = self.left_margin + gap / 2. + gap * i
            y = 1 - self.top_margin / 2.

            self.f.text(x, y, key+' :%2.1f' %p_dict[key], fontsize=36, horizontalalignment='center', verticalalignment='top')
コード例 #12
0
 def __init__ (self, fspk, fparam, evoked=False):
   # load gid and param dicts
   self.gid_dict, self.p_dict = paramrw.read(fparam)
   self.evoked = evoked
   # parse evoked prox and dist input gids from gid_dict
   # print('getting evokedinput gids')
   self.gid_evprox, self.gid_evdist = self.__get_evokedinput_gids()
   # print('got evokedinput gids')
   # parse ongoing prox and dist input gids from gid_dict
   self.gid_prox, self.gid_dist = self.__get_extinput_gids()
   # poisson input gids
   #print('getting pois input gids')
   self.gid_pois = self.__get_poisinput_gids()
   # self.inputs is dict of input times with keys 'prox' and 'dist'
   self.inputs = self.__get_extinput_times(fspk)
コード例 #13
0
def spikes_from_file(fparam, fspikes):
    gid_dict, _ = paramrw.read(fparam)
    # cell list - requires cell to start with L2/L5
    src_list = []
    src_extinput_list = []
    src_unique_list = []
    # fill in 2 lists from the keys
    for key in gid_dict.keys():
        if key.startswith('L2_') or key.startswith('L5_'):
            src_list.append(key)
        elif key == 'extinput':
            src_extinput_list.append(key)
        else:
            src_unique_list.append(key)
    # check to see if there are spikes in here, otherwise return an empty array
    if os.stat(fspikes).st_size:
        s = np.loadtxt(open(fspikes, 'rb'))
    else:
        s = np.array([], dtype='float64')
    # get the skeleton s_dict from the cell_list
    s_dict = dict.fromkeys(src_list)
    # iterate through just the src keys
    for key in s_dict.keys():
        # sort of a hack to separate extgauss
        s_dict[key] = Spikes(s, gid_dict[key])
        # figure out its extgauss feed
        newkey_gauss = 'extgauss_' + key
        s_dict[newkey_gauss] = split_extrand(s, gid_dict, key, 'extgauss')
        # figure out its extpois feed
        newkey_pois = 'extpois_' + key
        s_dict[newkey_pois] = split_extrand(s, gid_dict, key, 'extpois')
    # do the keys in unique list
    for key in src_unique_list:
        s_dict[key] = Spikes(s, gid_dict[key])
    # Deal with alpha feeds (extinputs)
    # order guaranteed by order of inputs in p_ext in paramrw
    # and by details of gid creation in class_net
    # A little kludgy to deal with the fact that one might not exist
    if len(gid_dict['extinput']) > 1:
        s_dict['alpha_feed_prox'] = Spikes(s, [gid_dict['extinput'][0]])
        s_dict['alpha_feed_dist'] = Spikes(s, [gid_dict['extinput'][1]])
    else:
        # not sure why this is done here
        # handle the extinput: this is a LIST!
        s_dict['extinput'] = [Spikes(s, [gid]) for gid in gid_dict['extinput']]
    return s_dict
コード例 #14
0
ファイル: plotfn.py プロジェクト: SMPugliese/hnn_calcium
def aggregate_spec_with_hist(ddir, p_exp, labels):
    untype = 'debug'
    # preallocate lists for use below
    param_list = []
    dpl_list = []
    spec_list = []
    spk_list = []
    dfig_list = []
    spec_list = []
    # Get dimensions for aggregate fig
    N_rows = len(ddir.expmt_groups)
    N_cols = len(ddir.file_match(ddir.expmt_groups[0], 'param'))
    # Create figure
    f = ac.FigAggregateSpecWithHist(N_rows, N_cols)
    # Grab all necessary data in aggregated lists
    for expmt_group in ddir.expmt_groups:
        # these should be equivalent lengths
        param_list.extend(ddir.file_match(expmt_group, 'param'))
        dpl_list.extend(ddir.file_match(expmt_group, 'rawdpl'))
        spec_list.extend(ddir.file_match(expmt_group, 'rawspec'))
        spk_list.extend(ddir.file_match(expmt_group, 'rawspk'))
    # apply async to compiled lists
    if runtype is 'parallel':
        pl = Pool()
        for f_param, f_spk, f_dpl, fspec, ax in zip(param_list, spk_list,
                                                    dpl_list, spec_list,
                                                    f.ax_list):
            _, p_dict = paramrw.read(f_param)
            pl.apply_async(specfn.aggregate_with_hist,
                           (f, ax, fspec, f_dpl, f_spk, fparam, p_dict))
        pl.close()
        pl.join()
    elif runtype is 'debug':
        for f_param, f_spk, f_dpl, fspec, ax in zip(param_list, spk_list,
                                                    dpl_list, spec_list,
                                                    f.ax_list):
            # _, p_dict = paramrw.read(f_param)
            pspec.aggregate_with_hist(f, ax, fspec, f_dpl, f_spk, f_param)
    # add row labels
    f.add_row_labels(param_list, labels[0])
    # add column labels
    f.add_column_labels(param_list, labels[1])
    fig_name = os.path.join(ddir.dsim, 'aggregate_hist.png')
    f.save(fig_name)
    f.close()
コード例 #15
0
def exec_spike_rates(ddata, opts):
    # opts should be:
    # opts_default = {
    #     expmt_group: 'something',
    #     celltype: 'L5_pyramidal',
    # }
    expmt_group = opts['expmt_group']
    celltype = opts['celltype']

    list_f_spk = ddata.file_match(expmt_group, 'rawspk')
    list_f_param = ddata.file_match(expmt_group, 'param')

    # note! this is NOT ignoring first 50 ms
    for fspk, fparam in zip(list_f_spk, list_f_param):
        s_all = spikefn.spikes_from_file(fparam, fspk)
        _, p_dict = paramrw.read(fparam)
        T = p_dict['tstop']

        # check if the celltype is in s_all
        if celltype in s_all.keys():
            s = s_all[celltype].spike_list
            n_cells = len(s)

            # grab all the sp_counts
            sp_counts = np.array([len(spikes_cell) for spikes_cell in s])

            # calc mean and stdev
            sp_count_mean = np.mean(sp_counts)
            sp_count_stdev = np.std(sp_counts)

            # calc rate in Hz, assume T in ms
            sp_rates = sp_counts * 1000. / T
            sp_rate_mean = np.mean(sp_rates)
            sp_rate_stdev = np.std(sp_rates)

            # direct
            sp_rate = sp_count_mean * 1000. / T

            print "Sim No. %i, Trial %i, celltype is %s:" % (p_dict['Sim_No'], p_dict['Trial'], celltype)
            print "  spike count mean is: %4.3f" % sp_count_mean
            print "  spike count stdev is: %4.3f" % sp_count_stdev
            print "  spike rate over %4.3f ms is %4.3f Hz +/- %4.3f" % (T, sp_rate_mean, sp_rate_stdev)
            print "  spike rate over %4.3f ms is %4.3f Hz" % (T, sp_rate)
コード例 #16
0
def pspecpwr_ax(ax_specpwr, specpwr_list, fparam_list, key_types):
    ax_specpwr.hold(True)

    # Preallocate legend list
    legend_list = []

    # iterate over freqpwr results and param list to plot and construct legend
    for result, fparam in zip(specpwr_list, fparam_list):
        # Plot to axis
        ax_specpwr.plot(result['freq'], result['p_avg'])

        # Build legend
        p = paramrw.read(fparam)[1]
        lgd_temp = [
            key + ': %2.1f' % p[key] for key in key_types['dynamic_keys']
        ]
        legend_list.append(reduce(lambda x, y: x + ', ' + y, lgd_temp[:]))

    # Do not need to return axis, apparently
    return legend_list
コード例 #17
0
    def __init__(self, tsarray1, tsarray2, fparam, f_max=60.):
        # Save time-series arrays as self variables
        # ohhhh. Do not use 1-indexed keys of a dict!
        self.ts = {
            1: tsarray1,
            2: tsarray2,
        }

        # Get param dict
        self.p = paramrw.read(fparam)[1]

        # Set frequecies over which to sort
        self.f = 1. + np.arange(0., f_max, 1.)

        # Set width of Morlet wavelet (>= 5 suggested)
        self.width = 7.

        # Calculate sampling frequency
        self.fs = 1000. / self.p['dt']

        self.data = self.__traces2PLS()
コード例 #18
0
import numpy as np
import os
import fileio, paramrw, params_default, param_gen_utils
import fnmatch
import datalad
from pathlib import Path

#Define template param file
base_param_path = os.path.abspath(
    'param/standard_params/jyrki_good_3trials_opt_flipped_input.param')
gid_dict, p = paramrw.read(base_param_path)

#Random number generator inputs need to be ints, HNN won't recognize as floats
p['prng_seedcore_input_prox'] = 13
p['prng_seedcore_input_dist'] = 14
p['prng_seedcore_extpois'] = 4
p['prng_seedcore_extgauss'] = 4
p['prng_seedcore_evprox_1'] = 4
p['prng_seedcore_evdist_1'] = 4
p['prng_seedcore_evprox_2'] = 4
p['prng_seedcore_evdist_2'] = 0

# Store arrays off parameters to sweep over
p_array = p.copy()

#Glob search for specific keys
sweep_params = [
    key for key in list(p.keys())
    if fnmatch.fnmatch(key, '*gbar_ev*1*Pyr*ampa')
]
#Update array dict and unpack into a list of each permutation
コード例 #19
0
def pspec_with_hist(f_spec,
                    f_dpl,
                    f_spk,
                    dfig,
                    f_param,
                    key_types,
                    xlim=None,
                    ylim=None):
    # Generate file prefix
    # print('f_spec:',f_spec)
    fprefix = f_spec.split('/')[-1].split('.')[0]
    # Create the fig name
    fig_name = os.path.join(dfig, fprefix + '.png')
    # print('fig_name:',fig_name)
    # load param dict
    _, p_dict = paramrw.read(f_param)
    f = ac.FigSpecWithHist()
    # load spec data
    spec = specfn.Spec(f_spec)
    # Plot TFR data and add colorbar
    pc = spec.plot_TFR(f.ax['spec'], 'agg', xlim, ylim)
    f.f.colorbar(pc, ax=f.ax['spec'])
    # set xlim based on TFR plot
    xlim_new = f.ax['spec'].get_xlim()
    # grab the dipole data
    dpl = dipolefn.Dipole(f_dpl)
    dpl.baseline_renormalize(f_param)
    dpl.convert_fAm_to_nAm()
    # plot routine
    dpl.plot(f.ax['dipole'], xlim_new, 'agg')
    # data_dipole = np.loadtxt(open(f_dpl, 'r'))
    # t_dpl = data_dipole[xmin_ind:xmax_ind+1, 0]
    # dp_total = data_dipole[xmin_ind:xmax_ind+1, 1]
    # f.ax['dipole'].plot(t_dpl, dp_total)
    # x = (xmin, xmax)
    # # grab alpha feed data. spikes_from_file() from spikefn.py
    # s_dict = spikefn.spikes_from_file(f_param, f_spk)
    # # check for existance of alpha feed keys in s_dict.
    # s_dict = spikefn.alpha_feed_verify(s_dict, p_dict)
    # # Account for possible delays
    # s_dict = spikefn.add_delay_times(s_dict, p_dict)
    # Get extinput data and account for delays
    extinputs = spikefn.ExtInputs(f_spk, f_param)
    extinputs.add_delay_times()
    extinputs.get_envelope(dpl.t, feed='dist')
    # set number of bins (150 bins per 1000ms)
    bins = ceil(150. * (xlim_new[1] - xlim_new[0]) /
                1000.)  # bins should be int
    # plot histograms
    hist = {}
    hist['feed_prox'] = extinputs.plot_hist(f.ax['feed_prox'],
                                            'prox',
                                            dpl.t,
                                            bins=bins,
                                            xlim=xlim_new,
                                            color='red')
    hist['feed_dist'] = extinputs.plot_hist(f.ax['feed_dist'],
                                            'dist',
                                            dpl.t,
                                            bins=bins,
                                            xlim=xlim_new,
                                            color='green')
    f.ax['feed_dist'].invert_yaxis()
    # for now, set the xlim for the other one, force it!
    f.ax['dipole'].set_xlim(xlim_new)
    f.ax['spec'].set_xlim(xlim_new)
    f.ax['feed_prox'].set_xlim(xlim_new)
    f.ax['feed_dist'].set_xlim(xlim_new)
    # set hist axis props
    f.set_hist_props(hist)
    # axis labels
    f.ax['spec'].set_xlabel('Time (ms)')
    f.ax['spec'].set_ylabel('Frequency (Hz)')
    # Add legend to histogram
    for key in f.ax.keys():
        if 'feed' in key:
            f.ax[key].legend()
    # create title
    title_str = ac.create_title(p_dict, key_types)
    f.f.suptitle(title_str)
    f.savepng(fig_name)
    f.close()
コード例 #20
0
def aggregate_with_hist(f, ax, f_spec, f_dpl, f_spk, f_param):
    # load param dict
    _, p_dict = paramrw.read(f_param)

    # load spec data from file
    spec = specfn.Spec(f_spec)
    # data_spec = np.load(f_spec)

    # timevec = data_spec['time']
    # freqvec = data_spec['freq']
    # TFR = data_spec['TFR']

    xmin = timevec[0]
    xmax = p_dict['tstop']
    x = (xmin, xmax)

    pc = spec.plot_TFR(ax['spec'], layer='agg', xlim=x)
    # pc = ax['spec'].imshow(TFR, extent=[timevec[0], timevec[-1], freqvec[-1], freqvec[0]], aspect='auto', origin='upper')
    f.f.colorbar(pc,
                 ax=ax['spec'],
                 norm=plt.colors.Normalize(vmin=0, vmax=90000),
                 cmap=plt.get_cmap('jet'))

    # grab the dipole data
    dpl = dipolefn.Dipole(f_dpl)
    dpl.plot(ax['dipole'], x, layer='agg')
    # data_dipole = np.loadtxt(open(f_dpl, 'r'))

    # t_dpl = data_dipole[xmin/p_dict['dt']:, 0]
    # dp_total = data_dipole[xmin/p_dict['dt']:, 1]

    # ax['dipole'].plot(t_dpl, dp_total)

    # grab alpha feed data. spikes_from_file() from spikefn.py
    s_dict = spikefn.spikes_from_file(f_param, f_spk)

    # check for existance of alpha feed keys in s_dict.
    s_dict = spikefn.alpha_feed_verify(s_dict, p_dict)

    # Account for possible delays
    s_dict = spikefn.add_delay_times(s_dict, p_dict)

    # set number of bins (150 bins/1000ms)
    bins = 150. * (xmax - xmin) / 1000.

    hist = {}

    # Proximal feed
    hist['feed_prox'] = ax['feed_prox'].hist(
        s_dict['alpha_feed_prox'].spike_list,
        bins,
        range=[xmin, xmax],
        color='red',
        label='Proximal feed')

    # Distal feed
    hist['feed_dist'] = ax['feed_dist'].hist(
        s_dict['alpha_feed_dist'].spike_list,
        bins,
        range=[xmin, xmax],
        color='green',
        label='Distal feed')

    # for now, set the xlim for the other one, force it!
    ax['dipole'].set_xlim(x)
    ax['spec'].set_xlim(x)
    ax['feed_prox'].set_xlim(x)
    ax['feed_dist'].set_xlim(x)

    # set hist axis props
    f.set_hist_props(ax, hist)

    # axis labels
    ax['spec'].set_xlabel('Time (ms)')
    ax['spec'].set_ylabel('Frequency (Hz)')

    # Add legend to histogram
    for key in ax.keys():
        if 'feed' in key:
            ax[key].legend()
コード例 #21
0
def exec_show(ddata, dict_opts):
    dict_opts_default = {
        'run': 0,
        'trial': 0,
        'expmt_group': '',
        'key': 'changed',
        'var_list': [],
    }

    # hack for now to get backward compatibility with this original function
    var_list = dict_opts_default['var_list']

    exclude_list = [
        'sim_prefix',
        'N_trials',
        'Run_Date',
    ]

    args_check(dict_opts_default, dict_opts)
    if dict_opts_default['expmt_group'] not in ddata.expmt_groups:
        # print "Warning: expmt_group %s not found" % dict_opts_default['expmt_group']
        dict_opts_default['expmt_group'] = ddata.expmt_groups[0]

    # output the expmt group used
    print "expmt_group: %s" % dict_opts_default['expmt_group']

    # find the params
    p_exp = paramrw.ExpParams(ddata.fparam)

    if dict_opts_default['key'] == 'changed':
        print "Showing changed ... \n"
        # create a list
        var_list = [val[0] for val in paramrw.changed_vars(ddata.fparam)]

    elif dict_opts_default['key'] in p_exp.keys():
        # create a list with just this element
        var_list = [dict_opts_default['key']]

    else:
        key_part = dict_opts_default['key']
        var_list = [key for key in p_exp.keys() if key_part in key]

    if not var_list:
        print "Keys were not found by exec_show()"
        return 0

    # files
    fprefix = ddata.trial_prefix_str % (dict_opts_default['run'], dict_opts_default['trial'])
    fparam = ddata.create_filename(dict_opts_default['expmt_group'], 'param', fprefix)

    list_param = ddata.file_match(dict_opts_default['expmt_group'], 'param')

    if fparam in list_param:
        # this version of read returns the gid dict as well ...
        _, p = paramrw.read(fparam)

        # use var_list to print values
        for key in var_list:
            if key not in exclude_list:
                try:
                    print '%s: %s' % (key, p[key])

                except KeyError:
                    print "Value %s not found in file %s!" % (key, fparam)
コード例 #22
0
def pdipole(f_dpl, dfig, plot_dict, f_param=None, key_types={}):
    """ single dipole file combination (incl. param file)
        this should be done with an axis input too
        two separate functions, a pdipole kernel function and a specific function for this simple plot
    """
    # dpl is an obj of Dipole() class
    dpl = Dipole(f_dpl)

    if f_param:
        dpl.baseline_renormalize(f_param)

    dpl.convert_fAm_to_nAm()

    # split to find file prefix
    file_prefix = f_dpl.split('/')[-1].split('.')[0]


    # parse xlim from plot_dict
    if plot_dict['xlim'] is None:
        xmin = dpl.t[0]
        xmax = dpl.t[-1]

    else:
        xmin, xmax = plot_dict['xlim']

        if xmin < 0.:
            xmin = 0.

        if xmax < 0.:
            xmax = self.f[-1]

    # # get xmin and xmax from the plot_dict
    # if plot_dict['xmin'] is None:
    #     xmin = 0.
    # else:
    #     xmin = plot_dict['xmin']

    # if plot_dict['xmax'] is None:
    #     xmax = p_dict['tstop']
    # else:
    #     xmax = plot_dict['xmax']

    # truncate them using logical indexing
    t_range = dpl.t[(dpl.t >= xmin) & (dpl.t <= xmax)]
    dpl_range = dpl.dpl['agg'][(dpl.t >= xmin) & (dpl.t <= xmax)]

    f = ac.FigStd()
    f.ax0.plot(t_range, dpl_range)

    # sorry about the parity between vars here and above with xmin/xmax
    if plot_dict['ylim'] is None:
    # if plot_dict['ymin'] is None or plot_dict['ymax'] is None:
        pass
    else:
        f.ax0.set_ylim(plot_dict['ylim'][0], plot_dict['ylim'][1])
        # f.ax0.set_ylim(plot_dict['ymin'], plot_dict['ymax'])

    # Title creation
    if f_param and key_types:
        # grabbing the p_dict from the f_param
        _, p_dict = paramrw.read(f_param)

        # useful for title strings
        title_str = ac.create_title(p_dict, key_types)
        f.f.suptitle(title_str)

    # create new fig name
    fig_name = os.path.join(dfig, file_prefix+'.png')

    # savefig
    plt.savefig(fig_name, dpi=300)
    f.close()
コード例 #23
0
def calc_avgdpl_stimevoked(ddata):
    for expmt_group in ddata.expmt_groups:
        # create the filename
        dexp = ddata.dexpmt_dict[expmt_group]
        fname_short = '%s-%s-dpl' % (ddata.sim_prefix, expmt_group)
        fname_data = os.path.join(dexp, fname_short + '.txt')

        # grab the list of raw data dipoles and assoc params in this expmt
        fdpl_list = ddata.file_match(expmt_group, 'rawdpl')
        param_list = ddata.file_match(expmt_group, 'param')
        spk_list = ddata.file_match(expmt_group, 'rawspk')

        # actual list of Dipole() objects
        dpl_list = [Dipole(fdpl) for fdpl in fdpl_list]
        t_truncated = []

        # iterate through the lists, grab the spike time, phase align the signals,
        # cut them to length, and then mean the dipoles
        for dpl, f_spk, f_param in zip(dpl_list, spk_list, param_list):
            _, p = paramrw.read(f_param)

            # grab the corresponding relevant starting spike time
            s = spikefn.spikes_from_file(f_param, f_spk)
            s = spikefn.alpha_feed_verify(s, p)
            s = spikefn.add_delay_times(s, p)

            # t_evoked is the same for all of the cells in these simulations
            t_evoked = s['evprox0'].spike_list[0][0]

            # attempt to give a 50 ms buffer
            if t_evoked > 50.:
                t0 = t_evoked - 50.
            else:
                t0 = t_evoked

            # truncate the dipole related vectors
            dpl.t = dpl.t[dpl.t > t0]
            dpl.dpl['agg'] = dpl.dpl['agg'][dpl.t > t0]
            t_truncated.append(dpl.t[0])

        # find the t0_max value to compare on other dipoles
        t_truncated -= np.max(t_truncated)

        for dpl, t_adj in zip(dpl_list, t_truncated):
            # negative numbers mean that this vector needs to be shortened by that many ms
            T_new = dpl.t[-1] + t_adj
            dpl.dpl['agg'] = dpl.dpl['agg'][dpl.t < T_new]
            dpl.t = dpl.t[dpl.t < T_new]

            if dpl is dpl_list[0]:
                dpl_total = dpl.dpl['agg']

            else:
                dpl_total += dpl.dpl['agg']

        dpl_mean = dpl_total / len(dpl_list)
        t_dpl = dpl_list[0].t

        # write this data to the file
        with open(fname_data, 'w') as f:
            for t, x in zip(t_dpl, dpl_mean):
                f.write("%03.3f\t%5.4f\n" % (t, x))
コード例 #24
0
def exec_plotaverages(ddata, ylim=[]):
    # runtype = 'parallel'
    runtype = 'debug'

    # this is a qnd check to create the fig dir if it doesn't already exist
    # backward compatibility check for sims that didn't auto-create these dirs
    for expmt_group in ddata.expmt_groups:
        dfig_avgdpl = ddata.dfig[expmt_group]['figavgdpl']
        dfig_avgspec = ddata.dfig[expmt_group]['figavgspec']

        # create them if they did not previously exist
        fio.dir_create(dfig_avgdpl)
        fio.dir_create(dfig_avgspec)

    # presumably globally true information
    p_exp = paramrw.ExpParams(ddata.fparam)
    key_types = p_exp.get_key_types()

    # empty lists to be used/appended
    dpl_list = []
    spec_list = []
    dfig_list = []
    dfig_dpl_list = []
    dfig_spec_list = []
    pdict_list = []

    # by doing all file operations sequentially by expmt_group in this iteration
    # trying to guarantee order better than before
    for expmt_group in ddata.expmt_groups:
        # print expmt_group, ddata.dfig[expmt_group]

        # avgdpl and avgspec data paths
        # fio.file_match() returns lists sorted
        # dpl_list_expmt is so i can iterate through them in a sec
        dpl_list_expmt = fio.file_match(ddata.dfig[expmt_group]['avgdpl'], '-dplavg.txt')
        dpl_list += dpl_list_expmt
        spec_list += fio.file_match(ddata.dfig[expmt_group]['avgspec'], '-specavg.npz')

        # create redundant list of avg dipole dirs and avg spec dirs
        # unique parts are expmt group names
        # create one entry for each in dpl_list
        dfig_list_expmt = [ddata.dfig[expmt_group] for path in dpl_list_expmt]
        dfig_list += dfig_list_expmt
        dfig_dpl_list += [dfig['figavgdpl'] for dfig in dfig_list_expmt]
        dfig_spec_list += [dfig['figavgspec'] for dfig in dfig_list_expmt]

        # param list to match avg data lists
        fparam_list = fio.fparam_match_minimal(ddata.dfig[expmt_group]['param'], p_exp)
        pdict_list += [paramrw.read(f_param)[1] for f_param in fparam_list]

    if dpl_list:
        # new input to dipolefn
        pdipole_dict = {
            'xlim': None,
            'ylim': None,
            # 'xmin': 0.,
            # 'xmax': None,
            # 'ymin': None,
            # 'ymax': None,
        }

        # if there is a length, assume it's 2 (it should be!)
        if len(ylim):
            pdipole_dict['ymin'] = ylim[0]
            pdipole_dict['ymax'] = ylim[1]

        if runtype == 'debug':
            for f_dpl, f_param, dfig_dpl in zip(dpl_list, fparam_list, dfig_dpl_list):
                dipolefn.pdipole(f_dpl, dfig_dpl, pdipole_dict, f_param, key_types)

        elif runtype == 'parallel':
            pl = Pool()
            for f_dpl, f_param, dfig_dpl in zip(dpl_list, fparam_list, dfig_dpl_list):
                pl.apply_async(dipolefn.pdipole, (f_dpl, f_param, dfig_dpl, key_types, pdipole_dict))

            pl.close()
            pl.join()

    else:
        print "No avg dipole data found."
        return 0

    # if avg spec data exists
    if spec_list:
        if runtype == 'debug':
            for f_spec, f_dpl, f_param, dfig_spec, pdict in zip(spec_list, dpl_list, fparam_list, dfig_spec_list, pdict_list):
                pspec.pspec_dpl(f_spec, f_dpl, dfig_spec, pdict, key_types, f_param=f_param)

        elif runtype == 'parallel':
            pl = Pool()
            for f_spec, f_dpl, dfig_spec, pdict in zip(spec_list, dpl_list, dfig_spec_list, pdict_list):
                pl.apply_async(pspec.pspec_dpl, (f_spec, f_dpl, dfig_spec, pdict, key_types))

            pl.close()
            pl.join()

    else:
        print "No averaged spec data found. Run avgtrials()."
        return 0
コード例 #25
0
def pdipole_evoked(dfig, f_dpl, f_spk, f_param, ylim=[]):
    """ for each individual simulation/trial
    """
    gid_dict, p_dict = paramrw.read(f_param)

    # get the spike dict from the files
    s_dict = spikefn.spikes_from_file(f_param, f_spk)
    s = s_dict.keys()
    s.sort()

    # create an empty dict 'spk_unique'
    spk_unique = dict.fromkeys([key for key in s_dict.keys() if key.startswith(('evprox', 'evdist'))])

    for key in spk_unique:
        spk_unique[key] = s_dict[key].unique_all(0)

    # draw vertical lines for each item in this

    # x_dipole is dipole data
    # x_dipole = np.loadtxt(open(f_dpl, 'r'))
    dpl = Dipole(f_dpl)

    # split to find file prefix
    file_prefix = f_dpl.split('/')[-1].split('.')[0]

    # # set xmin value
    # xmin = xlim[0] / p_dict['dt']

    # # set xmax value
    # if xlim[1] == 'tstop':
    #     xmax = p_dict['tstop'] / p_dict['dt']
    # else:
    #     xmax = xlim[1] / p_dict['dt']

    # these are the vectors for now, but this is going to change
    t_vec = dpl.t
    dp_total = dpl.dpl['agg']

    f = ac.FigStd()

    # hold on
    f.ax0.hold(True)

    f.ax0.plot(t_vec, dp_total)

    lines_spk = dict.fromkeys(spk_unique)

    print(spk_unique)

    # plot the lines
    for key in spk_unique:
        print(key, spk_unique[key])
        x_val = spk_unique[key][0]
        lines_spk[key] = plt.axvline(x=x_val, linewidth=0.5, color='r')

    # title_txt = [key + ': {:.2e}' % p_dict[key] for key in key_types['dynamic_keys']]
    title_txt = 'test'
    f.ax0.set_title(title_txt)

    if ylim:
        f.ax0.set_ylim(ylim)

    fig_name = os.path.join(dfig, file_prefix+'.png')

    plt.savefig(fig_name, dpi=300)
    f.close()
コード例 #26
0
ファイル: ppsth.py プロジェクト: SMPugliese/hnn_calcium
def ppsth(simpaths):
    # get filename lists in dictionaries of experiments
    dict_exp_param = simpaths.exp_files_of_type('param')
    dict_exp_spk = simpaths.exp_files_of_type('rawspk')

    # assumes a match between expnames and the keys of the previous dicts
    for expname in simpaths.expnames:
        # get the tstop
        exp_param_list = dict_exp_param[expname]
        exp_spk_list = dict_exp_spk[expname]
        gid_dict, p = paramrw.read(exp_param_list[0])
        # gid_dict, p = paramrw.read(dict_exp_param[expname][0])
        tstop = p['tstop']

        # get representative spikes
        s_dict = spikefn.spikes_from_file(gid_dict, exp_spk_list[0])

        s_dict_L2 = {}
        s_dict_L5 = {}
        s_dict_L2_extgauss = {}
        s_dict_L2_extpois = {}
        s_dict_L5_extgauss = {}
        s_dict_L5_extpois = {}

        # clean out s_dict destructively
        # borrowed from praster
        for key in s_dict.keys():
            # do this first to remove all extgauss feeds
            if 'extgauss' in key:
                if 'L2_' in key:
                    s_dict_L2_extgauss[key] = s_dict.pop(key)

                elif 'L5_' in key:
                    s_dict_L5_extgauss[key] = s_dict.pop(key)

            elif 'extpois' in key:
                # s_dict_extpois[key] = s_dict.pop(key)
                if 'L2_' in key:
                    s_dict_L2_extpois[key] = s_dict.pop(key)

                elif 'L5_' in key:
                    s_dict_L5_extpois[key] = s_dict.pop(key)

            # L2 next
            elif 'L2_' in key:
                s_dict_L2[key] = s_dict.pop(key)

            elif 'L5_' in key:
                s_dict_L5[key] = s_dict.pop(key)

        # these are total spike dicts for the experiments
        s_L2Pyr_list = []
        s_L5Pyr_list = []

        # iterate through params and spikes for a given experiment
        for fparam, fspk in zip(dict_exp_param[expname],
                                dict_exp_spk[expname]):
            # get gid dict
            gid_dict, p = paramrw.read(fparam)

            # get spike dict
            s_dict = spikefn.spikes_from_file(gid_dict, fspk)

            # add a new entry to list for each different file assoc with an experiment
            s_L2Pyr_list.append(
                np.array(
                    list(
                        it.chain.from_iterable(
                            s_dict['L2_pyramidal'].spike_list))))
            s_L5Pyr_list.append(
                np.array(
                    list(
                        it.chain.from_iterable(
                            s_dict['L5_pyramidal'].spike_list))))

        # now aggregate over all spikes
        s_L2Pyr = np.array(list(it.chain.from_iterable(s_L2Pyr_list)))
        s_L5Pyr = np.array(list(it.chain.from_iterable(s_L5Pyr_list)))

        # optimize bins, currently unused for comparison reasons!
        N_trials = len(fparam)
        # bin_L2 = 120
        # bin_L5 = 120
        bin_L2 = spikefn.hist_bin_opt(s_L2Pyr, N_trials)
        bin_L5 = spikefn.hist_bin_opt(s_L5Pyr, N_trials)

        # create standard fig and axes
        f = ac.FigPSTH(400.)
        f.ax['L2_psth'].hist(s_L2Pyr, bin_L2, facecolor='g', alpha=0.75)
        f.ax['L5_psth'].hist(s_L5Pyr, bin_L5, facecolor='g', alpha=0.75)

        # normalize these axes
        y_L2 = f.ax['L2_psth'].get_ylim()
        y_L5 = f.ax['L5_psth'].get_ylim()

        print y_L2, y_L5

        # f.ax['L2_psth'].set_ylim((0, 450.))
        # f.ax['L5_psth'].set_ylim((0, 450.))

        spikefn.spike_png(f.ax['L2'], s_dict_L2)
        spikefn.spike_png(f.ax['L5'], s_dict_L5)
        spikefn.spike_png(f.ax['L2_extpois'], s_dict_L2_extpois)
        spikefn.spike_png(f.ax['L2_extgauss'], s_dict_L2_extgauss)
        spikefn.spike_png(f.ax['L5_extpois'], s_dict_L5_extpois)
        spikefn.spike_png(f.ax['L5_extgauss'], s_dict_L5_extgauss)

        # # testfig.ax0.plot(t_vec, dp_total)
        fig_name = os.path.join(simpaths.dsim, expname + '.eps')

        plt.savefig(fig_name)
        f.close()

    # run the compression
    fio.epscompress(simpaths.dsim, '.eps', 1)
コード例 #27
0
def pdipole_evoked_aligned(ddata):
    """ over ALL trials in ALL conditions in EACH experiment
        appears to be iteration over pdipole_exp2()
    """
    # grab the original dipole from a specific dir
    dproj = fio.return_data_dir()

    runtype = 'somethingotherthandebug'
    # runtype = 'debug'

    if runtype == 'debug':
        ddate = '2013-04-08'
        dsim = 'mubaseline-04-000'
        i_ctrl = 0
    else:
        ddate = raw_input('Short date directory? ')
        dsim = raw_input('Sim name? ')
        i_ctrl = ast.literal_eval(raw_input('Sim number: '))
    dcheck = os.path.join(dproj, ddate, dsim)

    # create a blank ddata structure
    ddata_ctrl = fio.SimulationPaths()
    dsim = ddata_ctrl.read_sim(dproj, dcheck)

    # find the mu_low and mu_high in the expmtgroup names
    # this means the group names must be well formed
    for expmt_group in ddata_ctrl.expmt_groups:
        if 'mu_low' in expmt_group:
            mu_low_group = expmt_group
        elif 'mu_high' in expmt_group:
            mu_high_group = expmt_group

    # choose the first [0] from the list of the file matches for mu_low
    fdpl_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawdpl')[i_ctrl]
    fparam_mu_low = ddata_ctrl.file_match(mu_low_group, 'param')[i_ctrl]
    fspk_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawspk')[i_ctrl]
    fspec_mu_low = ddata_ctrl.file_match(mu_low_group, 'rawspec')[i_ctrl]

    # choose the first [0] from the list of the file matches for mu_high
    fdpl_mu_high = ddata_ctrl.file_match(mu_high_group, 'rawdpl')[i_ctrl]
    fparam_mu_high = ddata_ctrl.file_match(mu_high_group, 'param')[i_ctrl]

    # grab the relevant dipole and renormalize it for mu_low
    dpl_mu_low = Dipole(fdpl_mu_low)
    dpl_mu_low.baseline_renormalize(fparam_mu_low)

    # grab the relevant dipole and renormalize it for mu_high
    dpl_mu_high = Dipole(fdpl_mu_high)
    dpl_mu_high.baseline_renormalize(fparam_mu_high)

    # input feed information
    s = spikefn.spikes_from_file(fparam_mu_low, fspk_mu_low)
    _, p_ctrl = paramrw.read(fparam_mu_low)
    s = spikefn.alpha_feed_verify(s, p_ctrl)
    s = spikefn.add_delay_times(s, p_ctrl)

    # find tstop, assume same over all. grab the first param file, get the tstop
    tstop = paramrw.find_param(fparam_mu_low, 'tstop')

    # hard coded bin count for now
    n_bins = spikefn.bin_count(150., tstop)

    # sim_prefix
    fprefix = ddata.sim_prefix

    # create the figure name
    fname_exp = '%s_dpl_align' % (fprefix)
    fname_exp_fig = os.path.join(ddata.dsim, fname_exp + '.png')

    # create one figure comparing across all
    N_expmt_groups = len(ddata.expmt_groups)
    ax_handles = [
        'spec',
        'input',
        'dpl_mu',
        'spk',
    ]
    f_exp = ac.FigDipoleExp(ax_handles)

    # plot the ctrl dipoles
    f_exp.ax['dpl_mu'].plot(dpl_mu_low.t, dpl_mu_low.dpl, color='k')
    f_exp.ax['dpl_mu'].hold(True)
    f_exp.ax['dpl_mu'].plot(dpl_mu_high.t, dpl_mu_high.dpl)

    # function creates an f_exp.ax_twinx list and returns the index of the new feed
    f_exp.create_axis_twinx('input')

    # input hist information: predicated on the fact that the input histograms
    # should be identical for *all* of the inputs represented in this figure
    # places 2 histograms on two axes (meant to be one axis flipped)
    hists = spikefn.pinput_hist(f_exp.ax['input'], f_exp.ax_twinx['input'], s['alpha_feed_prox'].spike_list, s['alpha_feed_dist'].spike_list, n_bins)

    # grab the max counts for both hists
    # the [0] item of hist are the counts
    max_hist = np.max([np.max(hists[key][0]) for key in hists.keys()])
    ymax = 2 * max_hist

    # plot the spec here
    pc = specfn.pspec_ax(f_exp.ax['spec'], fspec_mu_low)

    # deal with the axes here
    f_exp.ax['input'].set_ylim((0, ymax))
    f_exp.ax_twinx['input'].set_ylim((ymax, 0))
    # f_exp.ax[1].set_ylim((0, ymax))

    # f_exp.ax[1].set_xlim((50., tstop))

    # turn hold on
    f_exp.ax[dpl_mu].hold(True)

    # empty list for the aggregate dipole data
    dpl_exp = []

    # go through each expmt
    # calculation is extremely redundant
    for expmt_group in ddata.expmt_groups:
        # a little sloppy, just find the param file
        # this param file was for the baseline renormalization and
        # assumes it's the same in all for this expmt_group
        # also for getting the gid_dict, also assumed to be the same
        fparam = ddata.file_match(expmt_group, 'param')[0]

        # general check to see if the aggregate dipole data exists
        if 'mu_low' in expmt_group or 'mu_high' in expmt_group:
            # check to see if these files exist
            flist = ddata.find_aggregate_file(expmt_group, 'dpl')

            # if no file exists, then find one
            if not len(flist):
                calc_aggregate_dipole(ddata)
                flist = ddata.find_aggregate_file(expmt_group, 'dpl')

            # testing the first file
            list_spk = ddata.file_match(expmt_group, 'rawspk')
            list_s_dict = [spikefn.spikes_from_file(fparam, fspk) for fspk in list_spk]
            list_evoked = [s_dict['evprox0'].spike_list[0][0] for s_dict in list_s_dict]
            lines_spk = [f_exp.ax['dpl_mu'].axvline(x=x_val, linewidth=0.5, color='r') for x_val in list_evoked]
            lines_spk = [f_exp.ax['spk'].axvline(x=x_val, linewidth=0.5, color='r') for x_val in list_evoked]

        # handle mu_low and mu_high separately
        if 'mu_low' in expmt_group:
            dpl_mu_low_ev = Dipole(flist[0])
            dpl_mu_low_ev.baseline_renormalize(fparam)
            f_exp.ax['spk'].plot(dpl_mu_low_ev.t, dpl_mu_low_ev.dpl, color='k')

            # get xlim stuff
            t0 = dpl_mu_low_ev.t[0]
            T = dpl_mu_low_ev.t[-1]

        elif 'mu_high' in expmt_group:
            dpl_mu_high_ev = Dipole(flist[0])
            dpl_mu_high_ev.baseline_renormalize(fparam)
            f_exp.ax['spk'].plot(dpl_mu_high_ev.t, dpl_mu_high_ev.dpl, color='b')

    f_exp.ax['input'].set_xlim(50., tstop)

    for ax_name in f_exp.ax_handles[2:]:
        ax.set_xlim((t0, T))

    f_exp.savepng(fname_exp_fig)
    f_exp.close()
コード例 #28
0
ファイル: ppsth.py プロジェクト: SMPugliese/hnn_calcium
def ppsth_grid(simpaths):
    # get filename lists in dictionaries of experiments
    dict_exp_param = simpaths.exp_files_of_type('param')
    dict_exp_spk = simpaths.exp_files_of_type('rawspk')

    # recreate the ExpParams object used in the simulation
    p_exp = paramrw.ExpParams(simpaths.fparam[0])

    # need number of lambda vals (cols) and number of sigma vals (rows)
    try:
        N_rows = len(p_exp.p_all['L2Pyr_Gauss_A_weight'])
    except TypeError:
        N_rows = 1

    try:
        N_cols = len(p_exp.p_all['L2Basket_Pois_lamtha'])
    except TypeError:
        N_cols = 1

    tstop = p_exp.p_all['tstop']

    print N_rows, N_cols, tstop

    # ugly but slightly less ugly than the index arithmetic i had planned. muahaha
    f = ac.FigGrid(N_rows, N_cols, tstop)

    # create coordinates for axes
    # this is backward-looking for a reason!
    axes_coords = [
        (j, i) for i, j in it.product(np.arange(N_cols), np.arange(N_rows))
    ]

    if len(simpaths.expnames) != len(axes_coords):
        print "um ... see ppsth.py"

    # assumes a match between expnames and the keys of the previous dicts
    for expname, axis_coord in zip(simpaths.expnames, axes_coords):
        # get the tstop
        exp_param_list = dict_exp_param[expname]
        exp_spk_list = dict_exp_spk[expname]
        gid_dict, p = paramrw.read(exp_param_list[0])
        tstop = p['tstop']
        lamtha = p['L2Basket_Pois_lamtha']
        sigma = p['L2Pyr_Gauss_A_weight']

        # these are total spike dicts for the experiments
        s_L2Pyr_list = []
        # s_L5Pyr_list = []

        # iterate through params and spikes for a given experiment
        for fparam, fspk in zip(dict_exp_param[expname],
                                dict_exp_spk[expname]):
            # get gid dict
            gid_dict, p = paramrw.read(fparam)

            # get spike dict
            s_dict = spikefn.spikes_from_file(gid_dict, fspk)

            # add a new entry to list for each different file assoc with an experiment
            s_L2Pyr_list.append(
                np.array(
                    list(
                        it.chain.from_iterable(
                            s_dict['L2_pyramidal'].spike_list))))
            # s_L5Pyr_list.append(np.array(list(it.chain.from_iterable(s_dict['L5_pyramidal'].spike_list))))

        # now aggregate over all spikes
        s_L2Pyr = np.array(list(it.chain.from_iterable(s_L2Pyr_list)))
        # s_L5Pyr = np.array(list(it.chain.from_iterable(s_L5Pyr_list)))

        # optimize bins, currently unused for comparison reasons!
        N_trials = len(fparam)
        bin_L2 = 250
        # bin_L5 = 120
        # bin_L2 = spikefn.hist_bin_opt(s_L2Pyr, N_trials)
        # bin_L5 = spikefn.hist_bin_opt(s_L5Pyr, N_trials)

        r = axis_coord[0]
        c = axis_coord[1]
        # create standard fig and axes
        f.ax[r][c].hist(s_L2Pyr, bin_L2, facecolor='g', alpha=0.75)

        if r == 0:
            f.ax[r][c].set_title(r'$\lambda_i$ = %d' % lamtha)

        if c == 0:
            f.ax[r][c].set_ylabel(r'$A_{gauss}$ = %.3e' % sigma)
            # f.ax[r][c].set_ylabel(r'$\sigma_{gauss}$ = %d' % sigma)

        # normalize these axes
        y_L2 = f.ax[r][c].get_ylim()
        # y_L2 = f.ax['L2_psth'].get_ylim()

        print expname, lamtha, sigma, r, c, y_L2[1]

        f.ax[r][c].set_ylim((0, 250.))
        # f.ax['L2_psth'].set_ylim((0, 450.))
        # f.ax['L5_psth'].set_ylim((0, 450.))

        # spikefn.spike_png(f.ax['L2'], s_dict_L2)
        # spikefn.spike_png(f.ax['L5'], s_dict_L5)
        # spikefn.spike_png(f.ax['L2_extpois'], s_dict_L2_extpois)
        # spikefn.spike_png(f.ax['L2_extgauss'], s_dict_L2_extgauss)
        # spikefn.spike_png(f.ax['L5_extpois'], s_dict_L5_extpois)
        # spikefn.spike_png(f.ax['L5_extgauss'], s_dict_L5_extgauss)

    # testfig.ax0.plot(t_vec, dp_total)
    fig_name = os.path.join(simpaths.dsim, 'aggregate.eps')

    plt.savefig(fig_name)
    f.close()

    # run the compression
    fio.epscompress(simpaths.dsim, '.eps', 1)