def make_spk(rate, duration_ms, n=1, dt=1 * ms, uniform_freq=True):
    # Set uniform_freq to False to obtain trains with mean Rate instead of exact Rate
    if type(rate) == b2.units.fundamentalunits.Quantity:
        comp_rate = rate / Hz
    else:
        comp_rate = rate
    trains = brian_poisson(rate, duration_ms, n=n, dt=dt)
    if n > 1:
        if uniform_freq:
            trains = np.array(trains)
            actual_rates = [tr.size for tr in trains]
            rate_match = np.equal(actual_rates, comp_rate)
            trains = trains[rate_match]
            real_N = trains.size
            while real_N < n:
                new_trains = np.array(
                    brian_poisson(rate, duration_ms, n=n, dt=dt))
                new_rates = [tr.size for tr in new_trains]
                rate_match = np.equal(new_rates, comp_rate)
                new_trains = new_trains[rate_match]
                trains = np.append(trains, new_trains)
                real_N = trains.size
        tr_spk = [spk.SpikeTrain(train, duration_ms) for train in trains[0:n]]
    else:
        rate_match = trains.size == comp_rate
        while not (rate_match):
            trains = brian_poisson(rate, duration_ms, n=n, dt=dt)
            rate_match = trains.size == comp_rate
        tr_spk = spk.SpikeTrain(trains, duration_ms)
    return tr_spk
예제 #2
0
def calc_stimuli_distance(stimulus_a: np.array, stimulus_b: np.array,
                          stimulus_duration: float) -> object:
    """
    This function computes the average distance between neurons in two stimuli
    using the spike-distance metric  (see: http://www.scholarpedia.org/article/SPIKE-distance)

    :param stimulus_a: numpy array where each element is a single neurons spike times, specified in milliseconds
    :param stimulus_b: numpy array where each element is a single neurons spike times, specified in milliseconds
    :param stimulus_duration: Maximal stimulus_duration of the stimulus, units: Sec
    """
    # Verify stimuli are comparable
    if stimulus_a.size != stimulus_b.size:
        raise Exception('Stimuli must consist of same number of neurons')

    distances = []  # Placeholder for distances between each pair of neurons
    for neuron_a, neuron_b in zip(stimulus_a, stimulus_b):
        # Converting to pyspike SpikeTrain object for calculation
        neuron_a = spk.SpikeTrain(neuron_a,
                                  edges=[0, stimulus_duration * 1000])
        neuron_b = spk.SpikeTrain(neuron_b,
                                  edges=[0, stimulus_duration * 1000])
        # Compute distance
        distance = spk.spike_distance(neuron_a, neuron_b)
        distances.append(distance)
    mean_distance = np.mean(distance)
    return mean_distance
예제 #3
0
def sync_test():
    import pyspike

    spkt_train = []

    empty_train = [[] for i in range(801)]

    for i in range(len(empty_train)):
        if i < 200:
            empty_train[i] = random.sample(xrange(0, 1500), 0)
            spkt_train.append(pyspike.SpikeTrain(empty_train[i], timeRange))
        elif (i >= 200 and i < 400):
            empty_train[i] = random.sample(xrange(0, 1500), 1)
            spkt_train.append(pyspike.SpikeTrain(empty_train[i], timeRange))
        elif (i >= 400 and i < 600):
            empty_train[i] = random.sample(xrange(0, 1500), 5)
            spkt_train.append(pyspike.SpikeTrain(empty_train[i], timeRange))
        elif i >= 600:
            empty_train[i] = random.sample(xrange(0, 1500), 50)
            spkt_train.append(pyspike.SpikeTrain(empty_train[i], timeRange))

    spike_sync = pyspike.spike_sync_matrix(spkt_train)

    for i in range(len(spkt_train)):
        if i < 400:
            for v in range(400):
                empty_array[i][v] = 1.0

    getmat['2'] = spike_sync - empty_array
예제 #4
0
def distance_spike(spike_train_a, spike_train_b, interval):
    """
    SPIKE-distance (Kreutz) using pyspike
    """
    spike_train_1 = pyspike.SpikeTrain(spike_train_a, interval)
    spike_train_2 = pyspike.SpikeTrain(spike_train_b, interval)
    return pyspike.spike_distance(spike_train_1, spike_train_2, interval)
예제 #5
0
def distance_isi(spike_train_a, spike_train_b, interval):
    """
    ISI-distance (Kreutz) using pyspike
    """
    spike_train_1 = pyspike.SpikeTrain(spike_train_a, interval)
    spike_train_2 = pyspike.SpikeTrain(spike_train_b, interval)
    return pyspike.isi_distance(spike_train_1, spike_train_2, interval)
예제 #6
0
def Synchro(spikes1, spikes2):
    synchs = []
    for i in range(len(spikes1)):
        sp1 = pyspike.SpikeTrain(np.where(spikes1[i] > 0)[0], len(spikes1[i]))
        sp2 = pyspike.SpikeTrain(np.where(spikes2[i] > 0)[0], len(spikes2[i]))
        synchs.append(pyspike.spike_sync(sp1, sp2))
    return np.array(synchs)
예제 #7
0
def make_spk(rate, duration_ms, n=1, dt=1 * ms, exact_freq=True):
    """
    Creates a stimulus with n neurons of a desired frequency.
    This function returns each neuron as a PySpike object which can be used
    for evaluating distance between neurons, and uses brian internally for sample generation.

    :param rate: Desired rate of fire for neurons in the sample
    :param duration_ms: Length of a "trial" in milliseconds
    :param n: Number of neurons in the sample
    :param dt: The shortest time interval between spikes in the sample
    :param exact_freq: Whether frequency should be Exactly "rate" for each neuron, or "rate" on average.
    :return:
    """

    # Making sure the variable used for rate comparison is of the correct type
    if type(rate) == b2.units.fundamentalunits.Quantity:
        comp_rate = rate / Hz
    else:
        comp_rate = rate

    # Generating the spike-trains using brian
    trains = brian_poisson(rate, duration_ms, n=n, dt=dt)

    # Handling single or multiple train generation as well as exact or average frequency
    if n > 1:
        if exact_freq:
            trains = np.array(trains)
            actual_rates = [tr.size for tr in trains]
            # ToDO: Handle the actual rates considering the stimulus_duration of the stimulus
            rate_match = np.equal(actual_rates, comp_rate)
            trains = trains[rate_match]
            real_N = trains.size
            while real_N < n:
                new_trains = np.array(
                    brian_poisson(rate, duration_ms, n=n, dt=dt))
                new_rates = [tr.size for tr in new_trains]
                rate_match = np.equal(new_rates, comp_rate)
                new_trains = new_trains[rate_match]
                trains = np.append(trains, new_trains)
                real_N = trains.size
        tr_spk = [spk.SpikeTrain(train, duration_ms) for train in trains[0:n]]
    else:
        rate_match = trains.size == comp_rate
        while not (rate_match):
            trains = brian_poisson(rate, duration_ms, n=n, dt=dt)
            rate_match = trains.size == comp_rate
        tr_spk = spk.SpikeTrain(trains, duration_ms)
    return tr_spk
예제 #8
0
def check_single_spike_train_set(index):
    """ Debuging function """
    np.set_printoptions(precision=15)
    spike_file = "regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    spike_train_data = spike_train_sets[index]

    spike_trains = []
    N = 0
    for spikes in spike_train_data[0]:
        N += len(spikes.flatten())
        print("Spikes:", len(spikes.flatten()))
        spikes_array = spikes.flatten()
        if len(spikes_array > 0) and (spikes_array[-1] > 100.0):
            spikes_array[-1] = 100.0
        spike_trains.append(spk.SpikeTrain(spikes_array, 100.0))
        print(spike_trains[-1].spikes)

    print(N)

    print(spk.spike_sync_multi(spike_trains))

    print(spk.spike_sync_profile_multi(spike_trains).integral())
예제 #9
0
def cluster_spike_trains(spike_trains, tstart, tend, interval=None, eps=0.01, measure='SPIKE_distance'):
    """Cluster a list of spike trains by SPIKE distance measure.

    tstart and tend are the start and end time of recording (used for creating SpikeTrain object from PySpike.

    interval is (t0, t1) the time interval during which to compute
    SPIKE_distance. If None, (tstart, tend) is used.

    eps: epsilon parameter for DBSCAN algorithm. This is the maximum
    distance between two samples for them to be considered in the same
    neighborhood.

    All spike trains should be nonempty.

    Return (cluster_info, spike_distance_matrix)
    where
    cluster_info: the result of applying DBSCAN. This gives 
    spike_distance_matrix: an NxN distance matrix for N spike trains.

    """
    if interval is None:
        interval = (tstart, tend)
    st_list = [spk.SpikeTrain(st, (tstart, tend)) for st in spike_trains]
    print('Interval:', interval)
    print('tstart: {} tend: {}'.format(tstart, tend))
    print('Number of spike trains:', len(st_list))
    
    dist = spk.spike_distance_matrix(st_list, interval=interval)
    clus = skc.DBSCAN(eps=eps, metric='precomputed').fit(dist)
    return clus, dist
예제 #10
0
def test_merge_empty_spike_trains():
    # first load the data
    spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
    # take two non-empty trains, and one empty one
    empty = spk.SpikeTrain([],[spike_trains[0].t_start,spike_trains[0].t_end])
    merged_spikes = spk.merge_spike_trains([spike_trains[0], empty, spike_trains[1]])
    # test if result is sorted
    assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
    def shift_main(src):

        spiketimes = src.spikes[:, np.newaxis]

        num_shifts = len(shifts)
        num_fracshift = len(frac_shifts)
        num_spikes = spiketimes.size

        # Finding edges
        left_edge = np.where(spiketimes < shifts)
        right_edge = np.where((spiketimes + shifts) > duration)

        # Create the shift matrices
        shifts_bottom = np.tile(shifts, reps=(num_spikes, n, num_fracshift, 1))
        shifts_top = shifts_bottom.copy()

        shifts_bottom[left_edge[0], :, :, left_edge[1]] = \
                spiketimes[left_edge[0]][:, np.newaxis]
        shifts_top[right_edge[0], :, :, right_edge[1]] = \
                ((duration) - spiketimes[right_edge[0]][:, np.newaxis])

        shifts_range = shifts_top + shifts_bottom

        # boolean matrix for fractional shifting - MAY REQUIRE CHANGING AFTER FRAC_FIRE IS INTEGRATED
        bool_frac_mat = np.random.rand(num_spikes, n, num_fracshift,
                                       num_shifts)
        bool_frac_mat = bool_frac_mat <= frac_shifts[:, np.newaxis]

        # Draw!
        shiftvals = np.random.rand(num_spikes, n, num_fracshift, num_shifts)

        # Fix edges!
        shiftvals = (
            (shiftvals * shifts_range) - shifts_bottom) * bool_frac_mat
        shiftvals.round(out=shiftvals)

        # Shift!
        shifted = shiftvals + spiketimes[:, np.newaxis, np.newaxis]
        shifted.sort()

        # Find uniques!
        uniques = [
            np.unique(shifted[:, i, j, k]) for i in range(n)
            for j in range(num_fracshift) for k in range(num_shifts)
        ]

        # pyspike!
        shifted_spk = [
            spk.SpikeTrain(shifted_times, duration)
            for shifted_times in uniques
        ]
        return shifted_spk
예제 #12
0
    def _similarity(self, v1, v2, method='isi', **kwargs):
        assert v1.shape == v2.shape
        assert v1.ndim == 1
        if method == 'isi':
            thresh = kwargs.pop('thresh', 10)

            spikes1 = np.diff((v1 > thresh).astype('int'))
            spikes2 = np.diff((v2 > thresh).astype('int'))

            spike_times_1 = np.where(spikes1 > 0)[0]
            spike_times_2 = np.where(spikes2 > 0)[0]

            # spike_train_1 = pyspike.SpikeTrain(spike_times_1, 5500, 14500)
            # spike_train_2 = pyspike.SpikeTrain(spike_times_2, 5500, 14500)

            spike_train_1 = pyspike.SpikeTrain(spike_times_1, 9000 *
                                               .02)  # TODO: need # timebins
            spike_train_2 = pyspike.SpikeTrain(spike_times_2, 9000 * .02)

            return np.abs(pyspike.isi_distance(spike_train_1, spike_train_2))
        elif method == 'efel':
            trace1 = self._make_efel_trace(v1)
            trace2 = self._make_efel_trace(v2)
            efel1, efel2 = efel.getFeatureValues([trace1, trace2],
                                                 ['mean_frequency'])

            return np.abs(efel1['mean_frequency'] - efel2['mean_frequency'])[0]

            try:
                y = np.abs(
                    [efel1[feat] - efel2[feat] for feat in EFEL_FEATURES])
            except:
                import matplotlib.pyplot as plt
                plt.plot(v1)
                plt.plot(v2)
                plt.show()
            return y
        else:
            raise ValueError("unknown similarity metric")
예제 #13
0
def make_spike_trains(spike_times, time_range):

    spike_trains = [[] for i in range(len(spike_times))]

    for index, times in enumerate(spike_times):

        spike_times[index] = [
            time for time in times
            if (time > time_range[0] and time < time_range[1])
        ]
        spike_trains[index] = spk.SpikeTrain(np.array(spike_times[index]),
                                             edges=time_range)

    return spike_trains
예제 #14
0
def calculate_spike_synchrony_pyspike(ts, gids, interval=None):

    if interval is None:
        tmin = np.min(ts)
        tmax = np.max(ts)
        interval = [tmin, tmax]

    s = spike_trains_list_of_list(ts, gids)
    spike_trains = []
    for i in range(len(s)):
        spike_trains.append(spk.SpikeTrain(s[i], interval, is_sorted=False))

    spk_sync = spk.spike_sync(spike_trains, interval=interval, max_tau=0.1)

    return spk_sync
def test_regression_random():

    spike_file = "test/numeric/regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "test/numeric/regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]
    results_cSPIKY = loadmat(result_file)[result_name]

    for i, spike_train_data in enumerate(spike_train_sets):
        spike_trains = []
        for spikes in spike_train_data[0]:
            spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

        isi = spk.isi_distance_multi(spike_trains)
        isi_prof = spk.isi_profile_multi(spike_trains).avrg()

        spike = spk.spike_distance_multi(spike_trains)
        spike_prof = spk.spike_profile_multi(spike_trains).avrg()

        spike_sync = spk.spike_sync_multi(spike_trains)
        spike_sync_prof = spk.spike_sync_profile_multi(spike_trains).avrg()

        assert_almost_equal(isi,
                            results_cSPIKY[i][0],
                            decimal=14,
                            err_msg="Index: %d, ISI" % i)
        assert_almost_equal(isi_prof,
                            results_cSPIKY[i][0],
                            decimal=14,
                            err_msg="Index: %d, ISI" % i)

        assert_almost_equal(spike,
                            results_cSPIKY[i][1],
                            decimal=14,
                            err_msg="Index: %d, SPIKE" % i)
        assert_almost_equal(spike_prof,
                            results_cSPIKY[i][1],
                            decimal=14,
                            err_msg="Index: %d, SPIKE" % i)

        assert_almost_equal(spike_sync,
                            spike_sync_prof,
                            decimal=14,
                            err_msg="Index: %d, SPIKE-Sync" % i)
예제 #16
0
def check_regression_dataset(spike_file="benchmark.mat",
                             spikes_name="spikes",
                             result_file="results_cSPIKY.mat",
                             result_name="Distances"):
    """ Debuging function """
    np.set_printoptions(precision=15)

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    err_max = 0.0
    err_max_ind = -1
    err_count = 0

    for i, spike_train_data in enumerate(spike_train_sets):
        spike_trains = []
        for spikes in spike_train_data[0]:
            spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

        isi = spk.isi_distance_multi(spike_trains)
        spike = spk.spike_distance_multi(spike_trains)
        # spike_sync = spk.spike_sync_multi(spike_trains)

        if abs(isi - results_cSPIKY[i][0]) > 1E-14:
            print("Error in ISI:", i, isi, results_cSPIKY[i][0])
            print("Spike trains:")
            for st in spike_trains:
                print(st.spikes)

        err = abs(spike - results_cSPIKY[i][1])
        if err > 1E-14:
            err_count += 1
        if err > err_max:
            err_max = err
            err_max_ind = i

    print("Total Errors:", err_count)

    if err_max_ind > -1:
        print("Max SPIKE distance error:", err_max, "at index:", err_max_ind)
        spike_train_data = spike_train_sets[err_max_ind]
        for spikes in spike_train_data[0]:
            print(spikes.flatten())
예제 #17
0
def get_data():

    print 'Starting analysis of spike times per cell: ' + str(label_network)

    sync_data = [[[[] for i in range(2)] for x in range(len(gids))]
                 for p in range(len(stats))]

    spktsRange = [
        spkt for spkt in spkts if timeRange[0] <= spkt <= timeRange[1]
    ]

    for i, stat in enumerate(stats):
        for ii, subset in enumerate(gids):
            spkmat = [
                pyspike.SpikeTrain([
                    spkt for spkind, spkt in zip(spkinds, spkts)
                    if (spkind == gid and spkt in spktsRange)
                ], timeRange) for gid in set(subset)
            ]

            if stat == 'spike_sync_profile':
                print str(stat) + ", subset: " + str(
                    ii) + ", number of trains: " + str(len(spkmat))
                syncMat1 = pyspike.spike_sync_profile(spkmat)
                x, y = syncMat1.get_plottable_data()
                sync_data[i][ii][0] = x
                sync_data[i][ii][1] = y

            elif stat == 'spike_profile':
                print str(stat) + ", subset: " + str(
                    ii) + ", number of trains: " + str(len(spkmat))
                syncMat2 = pyspike.spike_profile(spkmat)
                x, y = syncMat2.get_plottable_data()
                sync_data[i][ii][0] = x
                sync_data[i][ii][1] = y

            elif stat == 'isi_profile':
                print str(stat) + ", subset: " + str(
                    ii) + ", number of trains: " + str(len(spkmat))
                syncMat3 = pyspike.isi_profile(spkmat)
                x, y = syncMat3.get_plottable_data()
                sync_data[i][ii][0] = x
                sync_data[i][ii][1] = y
예제 #18
0
def check_single_spike_train_set(index):
    """ Debuging function """
    np.set_printoptions(precision=15)
    spike_file = "regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    spike_train_data = spike_train_sets[index]

    spike_trains = []
    for spikes in spike_train_data[0]:
        print("Spikes:", spikes.flatten())
        spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

    print(spk.spike_distance_multi(spike_trains))

    print(results_cSPIKY[index][1])

    print(spike_trains[1].spikes)
예제 #19
0
        g = arg
    if opt in ('-v'):
        v = arg

for s in range(1):
    try:
        d = np.load('save/brunel_inp={}_g={}_seed_{}.npy'.format(inp, g,
                                                                 s)).item()

        # synchronicity
        sp = d['sp']

        spike_list = []
        for train in sp:
            spike_list.append(
                spk.SpikeTrain(list(sp[train]), (0, 50), is_sorted=False))

        sync_dist = spk.spike_sync_matrix(spike_list,
                                          indices=None,
                                          interval=(1, 20))
        spike_dist = spk.spike_distance_matrix(spike_list,
                                               indices=None,
                                               interval=(1, 20))
        for i in range(sync_dist.shape[0]):
            sync_dist[i, i] = 1
        utils.Weight2txt(
            1 - sync_dist,
            'txt/brunel_inp={}_g={}_seed_{}_sync.txt'.format(inp, g, s))
        utils.Weight2txt(
            spike_dist,
            'txt/brunel_inp={}_g={}_seed_{}_dist.txt'.format(inp, g, s))
예제 #20
0
def rate_select(path):
    with open(path, 'r') as fp:
        px = json.load(fp)

    #syll = pd.read_csv('../restoration_syllables.csv')

    song = []
    condition = []
    rate = []
    gapon = {}
    #gapoff = {}
    spikes = []

    for t in range(len(px['pprox'])):
        #For new recordings:
        if px['pprox'][t]['condition'] == 'continuous':
            song.append(px['pprox'][t]['stimulus'] + '-1')
            song.append(px['pprox'][t]['stimulus'] + '-2')
            condition.append(px['pprox'][t]['condition'] + '1')
            condition.append(px['pprox'][t]['condition'] + '2')
            spikes.append(px['pprox'][t]['event'])
            spikes.append(px['pprox'][t]['event'])
        else:
            songid = px['pprox'][t]['stimulus'] + '-' + px['pprox'][t][
                'condition'][-1]
            song.append(px['pprox'][t]['stimulus'] + '-' +
                        px['pprox'][t]['condition'][-1])
            condition.append(px['pprox'][t]['condition'])
            spikes.append(px['pprox'][t]['event'])

        if 'gap_on' in px['pprox'][t].keys():
            gapon[songid] = px['pprox'][t]['gap_on']
            #gapoff[songid] = px['pprox'][t]['gap_off']

    x = []
    #y = []
    for s in song:
        x.append(gapon[s])
        #y.append(gapoff[s])

    gapon = x
    #gapoff = y
    lag = []
    train = []
    gaps = difference_psth(song, condition, spikes, gapon)
    songset = np.unique(song)
    for t in range(len(spikes)):
        windowstart = gaps[np.where(songset == song[t])[0][0]][0]
        windowstop = gaps[np.where(songset == song[t])[0][0]][1]
        interval = (windowstop / 1000) - (windowstart / 1000)
        if windowstart != 0:
            numspikes = len([
                spike for spike in spikes[t]
                if spike >= windowstart and spike <= windowstop
            ])
            rate.append(numspikes / interval)
            lag.append(windowstart - gapon[t])
            train.append(
                spk.SpikeTrain([
                    spike for spike in spikes[t]
                    if spike >= windowstart and spike <= windowstop
                ], [windowstart, windowstop]))
        else:
            rate.append(np.nan)
            lag.append(np.nan)
            train.append(np.nan)

    ziplist = list(zip(song, condition, rate, lag))

    df = pd.DataFrame(ziplist, columns=["Song", "Condition", "Rate", "Lag"])
    avgrate = df.groupby(['Song', 'Condition']).mean()
    avgrate = [(avgrate.iloc[x].name + (avgrate.iloc[x].values[0], ))
               for x in range(len(avgrate))]
    sdrate = df.groupby(['Song', 'Condition']).std().values
    lag = df.groupby(['Song', 'Condition']).mean().values
    ri = df[(df['Condition'] == 'continuous1') |
            (df['Condition'] == 'continuous2')].groupby(['Song',
                                                         'Condition']).mean()
    count = 16
    selectivity = (1 - ((np.sum(ri / count))**2) / (np.sum(
        (ri**2) / count))) / (1 - (1 / count))
    selectivity = [selectivity] * len(avgrate)

    bird = [px['bird']] * len(avgrate)
    recording = [px['recording']] * len(avgrate)
    channel = [px['channel']] * len(avgrate)
    cluster = [px['cluster']] * len(avgrate)
    k = 5
    window = [gap for gap in gaps for i in range(k)]

    confusion = spikey(train, song, condition)
    confusion = [dist for dist in confusion for i in range(k)]

    records = list(
        zip(bird, recording, channel, cluster, avgrate, sdrate, selectivity,
            lag, window, confusion))

    return [(a, b, c, d, e, f, g, h[0], i.values[0], j[1], k, l, m, n, o, p, q,
             r, s, t, u, v, w, x, y) for a, b, c, d, (e, f, g), h, i, j, (
                 k, l), [m, n, o, p, q, r, s, t, u, v, w, x, y] in records]
예제 #21
0
import pyspike as spk
import matplotlib.pyplot as plt

a = np.random.rand(10) * 10
b = spk.SpikeTrain(a, [0, 10], is_sorted=False)
print b.spikes

spike_trains = spk.load_spike_trains_from_txt("PySpike_testdata.txt",
                                              edges=(0, 4000))

# compute the two spike trains and multivariate ISI profile
f = spk.isi_profile(spike_trains[0], spike_trains[1])
f = spk.isi_profile(spike_trains)

# t = [900, 1100, 2000, 3100]
# print("ISI value at t =", t, ":", f(t))
# print("Average ISI distance:", f.avrg())

# compute the two spike trains and multivariate SPIKE profile
f = spk.spike_profile(spike_trains[0], spike_trains[1])
f = spk.spike_profile(spike_trains)

# t = [900, 1100, 2000, 3100]
# print("Multivariate SPIKE value at t =", t, ":", f(t))
# print("Average multivariate SPIKE distance:", f.avrg())

# plot the spike times
for (i, spike_train) in enumerate(spike_trains):
    plt.scatter(spike_train, i * np.ones_like(spike_train), marker='|')
    # print np.asarray(spike_train)
def Synchronization_analysis(sim_duration,specify_targets,no_of_groups,exp_specify,spike_plot_parameters,general_plot_parameters):
    
    cell_no_array=[]
    for exp_id in range(0,len(exp_specify[0]) ):
        if exp_specify[1][1]==True:
           for cell_group in range(0,no_of_groups):
               cell_group_positions=np.loadtxt('simulations/%s/Golgi_pop%d.txt'%(exp_specify[0][exp_id],cell_group))
               dim_array=np.shape(cell_group_positions)
               cell_no_array.append(dim_array[0])
           
        else:
           for cell_group in range(0,no_of_groups):
               cell_group_positions=np.loadtxt('simulations/%s/sim0/Golgi_pop%d.txt'%(exp_specify[0][exp_id],cell_group))
               dim_array=np.shape(cell_group_positions)
               cell_no_array.append(dim_array[0])



    n_trials=exp_specify[2]

   

    lines = []
    lines_sep=[]
    experiment_seed=random.sample(range(0,15000),1)[0]
    for exp_id in range(0,len(exp_specify[0])):
        #get target ids
        experiment_parameters=[]
        experiment_parameters.append(exp_specify[0][exp_id])
        experiment_parameters.append(exp_specify[1])
        experiment_parameters.append(exp_specify[2])
       
        target_cell_array=get_cell_ids_for_sync_analysis(specify_targets,no_of_groups,experiment_parameters,experiment_seed)
        test_array=target_cell_array
        
        if exp_id==0:
           
           if "save sync plot to a separate file" in spike_plot_parameters:
              fig_sync, ax_sync=plt.subplots(figsize=(2,1.5),ncols=1,nrows=1)
              
           if spike_plot_parameters[0]=="2D raster plots":
              trial_indicator_target=False
              if n_trials >1:
                 target_cell_array_target_trial=target_cell_array[spike_plot_parameters[1]]
                 no_of_rasters=0
                 if target_cell_array_target_trial != []:
                    test_non_empty_target_array=True
                    trial_indicator_target=True
                    no_of_rasters=len(target_cell_array_target_trial)
                    rows=1+no_of_rasters
                    columns=1
                    fig_stack, ax_stack = plt.subplots(figsize=(4,rows+1),ncols=columns,nrows=rows, sharex=True)
                    ax_stack=ax_stack.ravel()
                 
                    
              else:
                 no_of_rasters=0
                 if target_cell_array != []:
                    trial_indicator=True
                    no_of_rasters=len(target_cell_array)
                    rows=1+no_of_rasters
                    columns=1
                    fig_stack, ax_stack = plt.subplots(figsize=(4,rows),ncols=columns,nrows=rows, sharex=True)
                    ax_stack=ax_stack.ravel()
              
              if "save all trials to separate files" in spike_plot_parameters:
                 raster_fig_array=[]
                 raster_ax_array=[]
                 pop_no_array=[]
                 trial_indicator=False
                 if n_trials >1:
                    non_empty_trial_indices=[]
                    for trial in range(0,n_trials):
                        if target_cell_array[trial] !=[]:
                           non_empty_trial_indices.append(trial)
                    test_indices=non_empty_trial_indices
                    for trial in range(0,len(non_empty_trial_indices)):
                        target_trial=target_cell_array[non_empty_trial_indices[trial]]
                        columns=1
                        rows=len(target_trial)
                        pop_no_array.append(rows)
                        plot_rows=rows+1
                        fig, ax = plt.subplots(figsize=(4,plot_rows),ncols=columns,nrows=rows, sharex=True)
                        if len(target_trial) >1 :
                           ax=ax.ravel()
                        raster_fig_array.append(fig)
                        raster_ax_array.append(ax)
                 else:
                    no_of_rasters=0
                    if target_cell_array != []:
                       trial_indicator=True
                       no_of_rasters=len(target_cell_array)
                       rows=no_of_rasters
                       pop_no_array.append(rows)
                       columns=1
                       plot_rows=rows+1
                       fig_stack_one_trial, ax_stack_one_trial= plt.subplots(figsize=(4,plot_rows),ncols=columns,nrows=rows, sharex=True)
                       if no_of_rasters  >  1:
                          ax_stack_one_trial=ax_stack.ravel()
              if "save all trials to one separate file" in spike_plot_parameters:
                 pop_no_array=[]
                 if n_trials >1:
                    non_empty_trial_indices=[]
                    for trial in range(0,n_trials):
                        if target_cell_array[trial] !=[]:
                           non_empty_trial_indices.append(trial)
                    total_no_of_rows=0
                    for trial in range(0,len(non_empty_trial_indices)):
                        target_trial=target_cell_array[non_empty_trial_indices[trial]]
                        
                        rows=len(target_trial)
                        pop_no_array.append(rows)
                        total_no_of_rows=total_no_of_rows + rows
                    
                    fig_all_trials, ax_all_trials = plt.subplots(figsize=(4,total_no_of_rows),ncols=1,nrows=total_no_of_rows, sharex=True)
                    if total_no_of_rows >1 :
                        ax_all_trials=ax_all_trials.ravel()
                    
                        
                 

                 

                    
        non_empty_non_unitary_trial_counter=0
        distances = []
        
        color = sns.color_palette()[exp_id+1]
        
        row_counter=0
        target_pop_index_array=[]
        if spike_plot_parameters[0]=="2D raster plots":
           if "save all trials to one separate file" in spike_plot_parameters:
              row_array=range(0,total_no_of_rows)
        if spike_plot_parameters[0]=="2D raster plots":
           if "save all trials to separate files" in spike_plot_parameters:
              raster_ax_row_counter=0
            
        for trial in range(0,n_trials):
            sim_dir = 'simulations/' + exp_specify[0][exp_id]+'/sim%d'%trial+'/txt'
            ######   
            if n_trials > 1:
               target_cell_array_per_trial=target_cell_array[trial]
            else:
               target_cell_array_per_trial=target_cell_array
            
            if target_cell_array_per_trial !=[]:
               spike_trains = []
               target_pop_index_array_per_trial=[]
               print(" Trial %d is not empty"%(trial))
               for target_pop in range(0,len(target_cell_array_per_trial)):
                   for pop in range(0,no_of_groups):
                       if ('pop%d'%(pop)) in target_cell_array_per_trial[target_pop]:
                          target_pop_index_array_per_trial.append(pop)
                          target_cells = [x for x in target_cell_array_per_trial[target_pop] if isinstance(x,int)]
                          print target_cells
                          for cell in range(0,len(target_cells)):
                              #create target txt file containing spike times
                              if not os.path.isfile('%s/Golgi_pop%d_cell%d.txt'%(sim_dir,pop,target_cells[cell])):
                                 get_spike_times('Golgi_pop%d_cell%d'%(pop,target_cells[cell]),exp_specify[0][exp_id],trial)
                              spikes = np.loadtxt('%s/Golgi_pop%d_cell%d.txt'%(sim_dir,pop,target_cells[cell]))
                              spike_train=pyspike.SpikeTrain(spikes,[0,sim_duration])
                              spike_trains.append(spike_train)
                              print spike_trains
                              if spike_plot_parameters[0]=="2D raster plots":
                                 if spike_plot_parameters[1]==trial:
                                    ax_stack[target_pop].scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                                 if "save all trials to separate files" in spike_plot_parameters:
                                    if n_trials >1:
                                       if len(target_cell_array_per_trial) > 1:
                                          raster_ax_array[raster_ax_row_counter][target_pop].scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                                          
                                       else:
                                          raster_ax_array[raster_ax_row_counter].scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                                          
                                    else:
                                       if len(target_cell_array_per_trial) >1:
                                          ax_stack_one_trial[target_pop].scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                                       else:
                                          ax_stack_one_trial.scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                                 if "save all trials to one separate file" in spike_plot_parameters:
                                    if n_trials >1:
                                       if total_no_of_rows >1 :
                                          ax_all_trials[row_array[row_counter]].scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                                          
                                       else:
                                          ax_all_trials.scatter(spikes,np.zeros_like(spikes)+target_cells[cell]+exp_id*(cell_no_array[pop]+1) ,marker='|',s=2,c=color)
                          if spike_plot_parameters[0]=="2D raster plots":
                             if "save all trials to one separate file" in spike_plot_parameters:
                                row_counter=row_counter+1
               if spike_plot_parameters[0]=="2D raster plots":
                  if "save all trials to separate files" in spike_plot_parameters:

                     raster_ax_row_counter+=1
                  

               

               if spike_plot_parameters[1]==trial:
                  target_trial_index=target_pop_index_array.index(target_pop_index_array_per_trial) 
               

               


               target_pop_index_array.append(target_pop_index_array_per_trial)
               ########   
               print("Length of spike_trains is %d"%len(spike_trains))
               if len(spike_trains) >1:
                  print("Trial %d contains more than one cell; Synchrony metric will be computed for this trial"%(trial))
                  non_empty_non_unitary_trial_counter+=1
                  print non_empty_non_unitary_trial_counter
                  distances.append(pyspike.spike_profile_multi(spike_trains))
               else:
                  print("Trial %d contains one cell; Synchrony metric will not be computed for this trial"%(trial))
               ######## 
               
        # average synchrony index across "non empty" trials
        average_distance = distances[0]
        for distance in distances[1:]:
            average_distance.add(distance)
        average_distance.mul_scalar(1./non_empty_non_unitary_trial_counter)

        # below blocks for saving synchrony and spike raster plots
        
        mark_steps=sim_duration/50
        marks=[]
        for mark in range(0,mark_steps+1):
            Mark=50*mark
            marks.append(Mark)

        xmin = 50
        right_shaded_region=50
        if sim_duration >=1000:
            right_shaded_region=100
        xmax = sim_duration-right_shaded_region
        x, y = average_distance.get_plottable_data()
        ximin = np.searchsorted(x, xmin)
        ximax = np.searchsorted(x, xmax)
        if spike_plot_parameters[0]=="2D raster plots":
           if target_cell_array_target_trial != []:
              target_cell_array_target_trial_indicator=True
              lines.append(ax_stack[no_of_rasters].plot(x[ximin:ximax+1], 1-y[ximin:ximax+1], lw=2, c=color)[0])
              ax_stack[no_of_rasters].plot(x[:ximin+1], 1-y[:ximin+1], lw=2, c=color, alpha=0.4)
              ax_stack[no_of_rasters].plot(x[ximax:], 1-y[ximax:], lw=2, c=color, alpha=0.4)
        if "save sync plot to a separate file" in spike_plot_parameters:
           lines_sep.append(ax_sync.plot(x[ximin:ximax+1], 1-y[ximin:ximax+1], lw=2, c=color)[0])
           ax_sync.plot(x[:ximin+1], 1-y[:ximin+1], lw=2, c=color, alpha=0.4)
           ax_sync.plot(x[ximax:], 1-y[ximax:], lw=2, c=color, alpha=0.4)
        
             
              
    
    if "save sync plot to a separate file" in spike_plot_parameters:
       print("save sync plot to a separate file is specified")
       for tick in ax_sync.xaxis.get_major_ticks():
           tick.label.set_fontsize(general_plot_parameters[4]) 
       for tick in ax_sync.yaxis.get_major_ticks():
           tick.label.set_fontsize(general_plot_parameters[5])
       ax_sync.locator_params(axis='y', tight=True, nbins=10)
       ax_sync.set_xlabel('Time (ms)',fontsize=4)
       ax_sync.set_ylabel('Synchrony index',size=4)
       ax_sync.set_ylim(0,1)
       ax_sync.set_xticks(marks)
       ax_sync.set_title('Synchronization between %s'%general_plot_parameters[1],size=4)
       fig_sync.tight_layout()

       fig_sync.subplots_adjust(top=0.90)
       fig_sync.subplots_adjust(bottom=0.55)
       l=fig_sync.legend(lines_sep,general_plot_parameters[3],title=general_plot_parameters[2],loc='center',ncol=len(general_plot_parameters[3]),bbox_to_anchor=(0.52, 0.25),prop={'size':4})
       plt.setp(l.get_title(),fontsize=4)
          
       fig_sync.savefig('simulations/sync_only_%s.%s'%(general_plot_parameters[0],spike_plot_parameters[-1]))
       print("saved %s in simulations"%'sync_only_%s.%s'%(general_plot_parameters[0],spike_plot_parameters[-1]))
       
    if spike_plot_parameters[0]=="2D raster plots":    
       print("Intend to plot a main figure with representative 2D raster plots")
       #create label array
       if trial_indicator_target:
          print("2D raster plot procedures started")
          for pop in range(0,no_of_rasters):
              label_array=[]
              ytick_array=[]
              for exp in range(0,len(general_plot_parameters[3])):
                  label_array.append("%d"%0)
                  label_array.append("%d"%(cell_no_array[target_pop_index_array[target_trial_index][pop]]-1))

                  if exp==0:
                     ytick_array.append(exp)
                     ytick_array.append(cell_no_array[target_pop_index_array[target_trial_index][pop]]-1)
                     left_value=cell_no_array[target_pop_index_array[target_trial_index][pop]]-1
                  else:
                     ytick_array.append(left_value+2)
                     ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[target_trial_index][pop]])
                     left_value=left_value+2+cell_no_array[target_pop_index_array[target_trial_index][pop]]

              print label_array
              print ytick_array
            
              ax_stack[pop].set_yticks(ytick_array)
              fig_stack.canvas.draw()
              ax_stack[pop].set_ylim(0,(cell_no_array[target_pop_index_array[target_trial_index][pop]]+1)*len(general_plot_parameters[3]) )
              ax_stack[pop].set_ylabel('Cell ids, population %d'%pop,size=4)
              ax_stack[pop].set_yticks([cell_no_array[target_pop_index_array[target_trial_index][pop]]+(cell_no_array[target_pop_index_array[target_trial_index][pop]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
              ax_stack[pop].yaxis.grid(False, which='major')
              ax_stack[pop].yaxis.grid(True, which='minor')
           
              labels = [x.get_text() for x in ax_stack[pop].get_yticklabels()]
           
              for label in range(0,len(labels)):
                   labels[label] =label_array[label]

              ax_stack[pop].set_yticklabels(labels)
              if pop==0:
                 ax_stack[pop].set_title('Raster plots for target Golgi cell populations (trial id=%d)'%spike_plot_parameters[1],size=6)
              for pop in range(0,no_of_rasters+1):
                  for tick in ax_stack[pop].xaxis.get_major_ticks():
                      tick.label.set_fontsize(general_plot_parameters[4]) 
                  for tick in ax_stack[pop].yaxis.get_major_ticks():
                      tick.label.set_fontsize(general_plot_parameters[5])
              ax_stack[no_of_rasters].locator_params(axis='y', tight=True, nbins=10)
              ax_stack[no_of_rasters].set_xlabel('Time (ms)',fontsize=6)
              ax_stack[no_of_rasters].set_ylabel('Synchrony index',size=4)
              ax_stack[no_of_rasters].set_ylim(0,1)
              ax_stack[no_of_rasters].set_xticks(marks)
              ax_stack[no_of_rasters].set_title('Synchronization between %s'%general_plot_parameters[1],size=6)
              fig_stack.tight_layout()
       
              fig_stack.subplots_adjust(top=0.80)
              fig_stack.subplots_adjust(bottom=0.15)
              fig_stack.subplots_adjust(hspace=.4)
              l=fig_stack.legend(lines,general_plot_parameters[3],title=general_plot_parameters[2], loc='upper center',ncol=len(general_plot_parameters[3]),bbox_to_anchor=(0.55, 1.0))
              plt.setp(l.get_title(),fontsize=6)
              plt.setp(l.get_texts(), fontsize=6)
              fig_stack.savefig('simulations/%s.%s'%(general_plot_parameters[0],spike_plot_parameters[-1]))
       else:
          print("Intended to plot raster plots for trial %d, but specified region-specific selection of cells produces an empty target array.\nThus a main figure with a representative raster will not be produced.\nHowever, synchrony plot can be saved in a separate file.\nAlternatively, plot a non-empty trial"%(spike_plot_parameters[1]))
       if "save all trials to one separate file" in spike_plot_parameters:
          print target_pop_index_array
          if total_no_of_rows >1:
             row_counter=0
             for trial in range(0,len(non_empty_trial_indices)):
                 if len(target_pop_index_array[trial]) >1:
                    for pop in range(0,len(target_pop_index_array[trial])):
                        label_array=[]
                        ytick_array=[]
                        for exp in range(0,len(general_plot_parameters[3])):
                            label_array.append("%d"%0)
                            label_array.append("%d"%(cell_no_array[target_pop_index_array[trial][pop] ]-1))
                            if exp==0:
                               ytick_array.append(exp)
                               ytick_array.append(cell_no_array[target_pop_index_array[trial][pop]]-1)
                               left_value=cell_no_array[target_pop_index_array[trial][pop]]-1
                            else:
                               ytick_array.append(left_value+2)
                               ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[trial][pop]]   )
                               left_value=left_value+2+cell_no_array[target_pop_index_array[trial][pop]]
                        ax_all_trials[row_counter].set_yticks(ytick_array)
                        #ax_all_trials[row_counter].canvas.draw()
                        ax_all_trials[row_counter].set_ylim(0,(cell_no_array[target_pop_index_array[trial][pop]]+1)*len(general_plot_parameters[3]) )
                        ax_all_trials[row_counter].set_ylabel('Cell ids for pop %d\ntrial %d'%(target_pop_index_array[trial][pop],non_empty_trial_indices[trial]),size=4)
                        ax_all_trials[row_counter].set_yticks([cell_no_array[target_pop_index_array[trial][pop]]+( cell_no_array[target_pop_index_array[trial][pop]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
                        ax_all_trials[row_counter].yaxis.grid(False, which='major')
                        ax_all_trials[row_counter].yaxis.grid(True, which='minor')
                        if row_counter==total_no_of_rows-1:
                           ax_all_trials[row_counter].set_xlabel('Time (ms)',fontsize=4)
           
                        labels = [x.get_text() for x in  ax_all_trials[row_counter].get_yticklabels()]
           
                        for label in range(0,len(labels)):
                             labels[label] =label_array[label]

                        ax_all_trials[row_counter].set_yticklabels(labels)
                         
                        for tick in ax_all_trials[row_counter].xaxis.get_major_ticks():
                            tick.label.set_fontsize(general_plot_parameters[4]) 
                        for tick in ax_all_trials[row_counter].yaxis.get_major_ticks():
                            tick.label.set_fontsize(general_plot_parameters[5])
                        row_counter+=1
                 else:
                    label_array=[]
                    ytick_array=[]
                    for exp in range(0,len(general_plot_parameters[3])):
                        label_array.append("%d"%0)
                        label_array.append("%d"%(cell_no_array[target_pop_index_array[trial][0]]-1))
                        if exp==0:
                           ytick_array.append(exp)
                           ytick_array.append(cell_no_array[target_pop_index_array[trial][0]]-1)
                           left_value=cell_no_array[target_pop_index_array[trial][0]]-1
                        else:
                           ytick_array.append(left_value+2)
                           ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[trial][0]]   )
                           left_value=left_value+2+cell_no_array[target_pop_index_array[trial][0]]
                    ax_all_trials[row_counter].set_yticks(ytick_array)
                    ax_all_trials[row_counter].set_ylim(0,(cell_no_array[target_pop_index_array[trial][0]]+1)*len(general_plot_parameters[3]) )
                    ax_all_trials[row_counter].set_ylabel('Cell ids for pop %d\ntrial %d'%(target_pop_index_array[trial][0],non_empty_trial_indices[trial]),size=4)
                    ax_all_trials[row_counter].set_yticks([cell_no_array[target_pop_index_array[trial][0]]+( cell_no_array[target_pop_index_array[trial][0]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
                    ax_all_trials[row_counter].yaxis.grid(False, which='major')
                    ax_all_trials[row_counter].yaxis.grid(True, which='minor')
                    if row_counter==total_no_of_rows-1:
                       ax_all_trials[row_counter].set_xlabel('Time (ms)',fontsize=4)
           
                    labels = [x.get_text() for x in  ax_all_trials[row_counter].get_yticklabels()]
           
                    for label in range(0,len(labels)):
                        labels[label] =label_array[label]

                    ax_all_trials[row_counter].set_yticklabels(labels)
                         
                    for tick in ax_all_trials[row_counter].xaxis.get_major_ticks():
                        tick.label.set_fontsize(general_plot_parameters[4]) 
                    for tick in ax_all_trials[row_counter].yaxis.get_major_ticks():
                        tick.label.set_fontsize(general_plot_parameters[5])
                    row_counter+=1

          else:
             label_array=[]
             ytick_array=[]
             for exp in range(0,len(general_plot_parameters[3])):
                 label_array.append("%d"%0)
                 label_array.append("%d"%(cell_no_array[target_pop_index_array[0][0]]-1))
                 if exp==0:
                    ytick_array.append(exp)
                    ytick_array.append(cell_no_array[target_pop_index_array[0][pop]]-1)
                    left_value=cell_no_array[target_pop_index_array[0][0]]-1
                 else:
                    ytick_array.append(left_value+2)
                    ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[0][0]]   )
                    left_value=left_value+2+cell_no_array[target_pop_index_array[0][0]]
             ax_all_trials.set_yticks(ytick_array)
             ax_all_trials.canvas.draw()
             ax_all_trials.set_ylim(0,(cell_no_array[target_pop_index_array[0][0]]+1)*len(general_plot_parameters[3]) )
             ax_all_trials.set_ylabel('Cell ids for pop %d\ntrial %d'%(target_pop_index_array[trial][pop],non_empty_trial_indices[trial]),size=4)
             ax_all_trials.set_yticks([cell_no_array[target_pop_index_array[0][0]]+( cell_no_array[target_pop_index_array[0][0]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
             ax_all_trials.yaxis.grid(False, which='major')
             ax_all_trials.yaxis.grid(True, which='minor')
             ax_all_trials.set_xlabel('Time (ms)',fontsize=4)
           
             labels = [x.get_text() for x in  ax_all_trials.get_yticklabels()]
           
             for label in range(0,len(labels)):
                 labels[label] =label_array[label]

             ax_all_trials.set_yticklabels(labels)
                         
             for tick in ax_all_trials.xaxis.get_major_ticks():
                 tick.label.set_fontsize(general_plot_parameters[4]) 
             for tick in ax_all_trials.yaxis.get_major_ticks():
                 tick.label.set_fontsize(general_plot_parameters[5])
          
          fig_all_trials.subplots_adjust(hspace=0.1)
          fig_all_trials.subplots_adjust(top=0.95)
          fig_all_trials.subplots_adjust(bottom=0.1)
          l=fig_all_trials.legend(lines,general_plot_parameters[3],title=general_plot_parameters[2],loc='center',ncol=len(general_plot_parameters[3]),bbox_to_anchor=(0.52, 0.98),prop={'size':4})
          plt.setp(l.get_title(),fontsize=4)

          fig_all_trials.savefig('simulations/all_trials_%s.%s'%(general_plot_parameters[0],spike_plot_parameters[-1]))
          plt.clf() 
             
       if "save all trials to separate files" in spike_plot_parameters:
          if n_trials >1:
             for trial in range(0,len(non_empty_trial_indices)):
                 print("started ploting non-empty trials")
                 if pop_no_array[trial] >1:
                    for pop in range(0,pop_no_array[trial]):
                        label_array=[]
                        ytick_array=[]
                        for exp in range(0,len(general_plot_parameters[3])):
                            label_array.append("%d"%0)
                            label_array.append("%d"%(cell_no_array[target_pop_index_array[trial][pop]]-1))
                            if exp==0:
                               ytick_array.append(exp)
                               ytick_array.append(cell_no_array[target_pop_index_array[trial][pop]]-1)
                               left_value=cell_no_array[target_pop_index_array[trial][pop]]-1
                            else:
                               ytick_array.append(left_value+2)
                               ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[trial][pop]]   )
                               left_value=left_value+2+cell_no_array[target_pop_index_array[trial][pop]]
                        raster_ax_array[trial][pop].set_yticks(ytick_array)
                        raster_fig_array[trial].canvas.draw()
                        raster_ax_array[trial][pop].set_ylim(0,(cell_no_array[target_pop_index_array[trial][pop]]+1)*len(general_plot_parameters[3]) )
                        raster_ax_array[trial][pop].set_ylabel('Cell ids, population %d'%target_pop_index_array[trial][pop],size=4)
                        raster_ax_array[trial][pop].set_yticks([cell_no_array[target_pop_index_array[trial][pop]]+( cell_no_array[target_pop_index_array[trial][pop]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
                        raster_ax_array[trial][pop].yaxis.grid(False, which='major')
                        raster_ax_array[trial][pop].yaxis.grid(True, which='minor')
           
                        labels = [x.get_text() for x in raster_ax_array[trial][pop].get_yticklabels()]
           
                        for label in range(0,len(labels)):
                            labels[label] =label_array[label]

                        raster_ax_array[trial][pop].set_yticklabels(labels)
                        if pop==0:
                           raster_ax_array[trial][pop].set_title('Raster plot for Golgi cell populations (trial id=%d)'%non_empty_trial_indices[trial],size=6)
                        if pop==pop_no_array[trial]-1:
                           raster_ax_array[trial][pop].set_xlabel('Time (ms)',fontsize=6)
                        for tick in raster_ax_array[trial][pop].xaxis.get_major_ticks():
                            tick.label.set_fontsize(general_plot_parameters[4]) 
                        for tick in raster_ax_array[trial][pop].yaxis.get_major_ticks():
                            tick.label.set_fontsize(general_plot_parameters[5])
                 else:
                    label_array=[]
                    ytick_array=[]
                    for exp in range(0,len(general_plot_parameters[3])):
                        label_array.append("%d"%0)
                        label_array.append("%d"%(cell_no_array[target_pop_index_array[trial][0]]-1))
                        if exp==0:
                           ytick_array.append(exp)
                           ytick_array.append(cell_no_array[target_pop_index_array[trial][0]]-1)
                           left_value=cell_no_array[target_pop_index_array[trial][0]]-1
                        else:
                           ytick_array.append(left_value+2)
                           ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[trial][0]]   )
                           left_value=left_value+2+cell_no_array[target_pop_index_array[trial][0]]
                        raster_ax_array[trial].set_yticks(ytick_array)
                        raster_fig_array[trial].canvas.draw()
                        raster_ax_array[trial].set_ylim(0,(cell_no_array[target_pop_index_array[trial][0]]+1)*len(general_plot_parameters[3]) )
                        raster_ax_array[trial].set_ylabel('Cell ids, population %d'% target_pop_index_array[trial][0],size=4)
                        raster_ax_array[trial].set_yticks([cell_no_array[target_pop_index_array[trial][0]]+( cell_no_array[target_pop_index_array[trial][0]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
                        raster_ax_array[trial].yaxis.grid(False, which='major')
                        raster_ax_array[trial].yaxis.grid(True, which='minor')
           
                        labels = [x.get_text() for x in raster_ax_array[trial].get_yticklabels()]
           
                        for label in range(0,len(labels)):
                            labels[label] =label_array[label]

                        raster_ax_array[trial].set_yticklabels(labels)
                        raster_ax_array[trial].set_title('Raster plot for Golgi cell populations (trial id=%d)'%non_empty_trial_indices[trial],size=6)
                        raster_ax_array[trial].set_xlabel('Time (ms)',fontsize=6)
                        for tick in raster_ax_array[trial].xaxis.get_major_ticks():
                            tick.label.set_fontsize(general_plot_parameters[4]) 
                        for tick in raster_ax_array[trial].yaxis.get_major_ticks():
                            tick.label.set_fontsize(general_plot_parameters[5])
                            
                 raster_fig_array[trial].subplots_adjust(top=0.90)
                 raster_fig_array[trial].subplots_adjust(bottom=0.3)
                 l=raster_fig_array[trial].legend(lines,general_plot_parameters[3],title=general_plot_parameters[2],loc='center',ncol=len(general_plot_parameters[3]),bbox_to_anchor=(0.52, 0.1),prop={'size':4})
                 plt.setp(l.get_title(),fontsize=4)

                 raster_fig_array[trial].savefig('simulations/sim%d_rasters_%s.%s'%(non_empty_trial_indices[trial],general_plot_parameters[0],spike_plot_parameters[-1]))
                   
          else:
             if trial_indicator:
                if pop_no_array[0] >1:
                   for pop in range(0,pop_no_array[0]):
                       label_array=[]
                       ytick_array=[]
                       for exp in range(0,len(general_plot_parameters[3])):
                           label_array.append("%d"%0)
                           label_array.append("%d"%(cell_no_array[target_pop_index_array[0][pop]]-1))
                           if exp==0:
                              ytick_array.append(exp)
                              ytick_array.append(cell_no_array[target_pop_index_array[0][pop]]-1)
                              left_value=cell_no_array[target_pop_index_array[0][pop]]-1
                           else:
                              ytick_array.append(left_value+2)
                              ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[0][pop]]   )
                              left_value=left_value+2+cell_no_array[target_pop_index_array[0][pop]]
                       ax_stack_one_trial[pop].set_yticks(ytick_array)
                       fig_stack_one_trial.canvas.draw()
                       ax_stack_one_trial[pop].set_ylim(0,(cell_no_array[target_pop_index_array[0][pop]]+1)*len(general_plot_parameters[3]) )
                       ax_stack_one_trial[pop].set_ylabel('Cell ids, population %d'%target_pop_index_array[0][pop],size=4)
                       ax_stack_one_trial[pop].set_yticks([cell_no_array[target_pop_index_array[0][pop]]+( cell_no_array[target_pop_index_array[0][pop]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
                       ax_stack_one_trial[pop].yaxis.grid(False, which='major')
                       ax_stack_one_trial[pop].yaxis.grid(True, which='minor')
           
                       labels = [x.get_text() for x in ax_stack_one_trial[pop].get_yticklabels()]
           
                       for label in range(0,len(labels)):
                           labels[label] =label_array[label]

                       ax_stack_one_trial[pop].set_yticklabels(labels)
                       if pop==0:
                          ax_stack_one_trial[pop].set_title('Raster plot for Golgi cell populations (trial id=%d)'%spike_plot_parameters[1],size=6)
                       if pop==pop_no_array[0]-1:
                          ax_stack_one_trial[pop].set_xlabel('Time (ms)',fontsize=6)
                       for tick in ax_stack_one_trial[pop].xaxis.get_major_ticks():
                           tick.label.set_fontsize(general_plot_parameters[4]) 
                       for tick in ax_stack_one_trial[pop].yaxis.get_major_ticks():
                           tick.label.set_fontsize(general_plot_parameters[5])
                else:
                   label_array=[]
                   ytick_array=[]
                   for exp in range(0,len(general_plot_parameters[3])):
                       label_array.append("%d"%0)
                       label_array.append("%d"%(cell_no_array[target_pop_index_array[0][0]]-1))
                       if exp==0:
                          ytick_array.append(exp)
                          ytick_array.append(cell_no_array[target_pop_index_array[0][0]]-1)
                          left_value=cell_no_array[target_pop_index_array[0][0]]-1
                       else:
                          ytick_array.append(left_value+2)
                          ytick_array.append(left_value+1+cell_no_array[target_pop_index_array[0][0]]   )
                          left_value=left_value+2+cell_no_array[target_pop_index_array[0][0]]
                   ax_stack_one_trial.set_yticks(ytick_array)
                   fig_stack_one_trial.canvas.draw()
                   ax_stack_one_trial.set_ylim(0,(cell_no_array[target_pop_index_array[0][0]]+1)*len(general_plot_parameters[3]) )
                   ax_stack_one_trial.set_ylabel('Cell ids, population %d'%target_pop_index_array[0][0],size=4)
                   ax_stack_one_trial.set_yticks([cell_no_array[target_pop_index_array[0][0]]+( cell_no_array[target_pop_index_array[0][0]]+2)*k for k in range(0,len(general_plot_parameters[3]))],minor=True)
                   ax_stack_one_trial.yaxis.grid(False, which='major')
                   ax_stack_one_trial.yaxis.grid(True, which='minor')
           
                   labels = [x.get_text() for x in ax_stack_one_trial.get_yticklabels()]
           
                   for label in range(0,len(labels)):
                       labels[label] =label_array[label]

                   ax_stack_one_trial.set_yticklabels(labels)
                   ax_stack_one_trial.set_title('Raster plot for Golgi cell populations (trial id=%d)'%spike_plot_parameters[1],size=6)
                   ax_stack_one_trial.set_xlabel('Time (ms)',fontsize=6)
                   for tick in ax_stack_one_trial.xaxis.get_major_ticks():
                       tick.label.set_fontsize(general_plot_parameters[4]) 
                   for tick in ax_stack_one_trial.yaxis.get_major_ticks():
                       tick.label.set_fontsize(general_plot_parameters[5])
                fig_stack_one_trial.subplots_adjust(top=0.90)
                fig_stack_one_trial.subplots_adjust(bottom=0.3)
                l=fig_stack_one_trial.legend(lines,general_plot_parameters[3],title=general_plot_parameters[2],loc='center',ncol=len(general_plot_parameters[3]),bbox_to_anchor=(0.52, 0.1),prop={'size':4})
                plt.setp(l.get_title(),fontsize=4)

                fig_stack_one_trial.savefig('simulations/sim%d_rasters_%s.%s'%(spike_plot_parameters[1],general_plot_parameters[0],spike_plot_parameters[-1]))
                plt.clf() 
예제 #23
0
    with open('pickles/membrane_dynamics_balanced_file.p', 'rb') as f:
       mdf0 = pickle.load(f)
       print(mdf0)

except:
   pass



# first load the data, interval ending time = 4000, start=0 (default)
#spike_trains_txt = spk.load_spike_trains_from_txt("PySpike_testdata.txt", 4000)

wrangled_trains = []
for spiketrain in mdf1.spiketrains:
    y = np.ones_like(spiketrain) * spiketrain.annotations['source_id']
    pspikes = pyspike.SpikeTrain(spiketrain,edges=(0,4000))
    wrangled_trains.append(pspikes)
    print(pspikes)

    """ Class representing spike trains for the PySpike Module.
    def __init__(self, spike_times, edges, is_sorted=True):
    Constructs the SpikeTrain.

    :param spike_times: ordered array of spike times.
    :param edges: The edges of the spike train. Given as a pair of floats
                  (T0, T1) or a single float T1, where then T0=0 is
                  assumed.
    :param is_sorted: If `False`, the spike times will sorted by `np.sort`.

    """
def multi_shift(source,
                shifts,
                n,
                frac_shifts=[1],
                increase=False,
                jitter=False):

    ### FRAC_FIRE IS CURRENTLY IGNORED!!!!, implemented in shift_fwd ###
    if not (isinstance(frac_shifts, np.ndarray)):
        frac_shifts = np.array(frac_shifts)

    if not (isinstance(shifts, np.ndarray)):
        shifts = np.array(shifts)

    duration = source.t_end

    def shift_main(src):

        spiketimes = src.spikes[:, np.newaxis]

        num_shifts = len(shifts)
        num_fracshift = len(frac_shifts)
        num_spikes = spiketimes.size

        # Finding edges
        left_edge = np.where(spiketimes < shifts)
        right_edge = np.where((spiketimes + shifts) > duration)

        # Create the shift matrices
        shifts_bottom = np.tile(shifts, reps=(num_spikes, n, num_fracshift, 1))
        shifts_top = shifts_bottom.copy()

        shifts_bottom[left_edge[0], :, :, left_edge[1]] = \
                spiketimes[left_edge[0]][:, np.newaxis]
        shifts_top[right_edge[0], :, :, right_edge[1]] = \
                ((duration) - spiketimes[right_edge[0]][:, np.newaxis])

        shifts_range = shifts_top + shifts_bottom

        # boolean matrix for fractional shifting - MAY REQUIRE CHANGING AFTER FRAC_FIRE IS INTEGRATED
        bool_frac_mat = np.random.rand(num_spikes, n, num_fracshift,
                                       num_shifts)
        bool_frac_mat = bool_frac_mat <= frac_shifts[:, np.newaxis]

        # Draw!
        shiftvals = np.random.rand(num_spikes, n, num_fracshift, num_shifts)

        # Fix edges!
        shiftvals = (
            (shiftvals * shifts_range) - shifts_bottom) * bool_frac_mat
        shiftvals.round(out=shiftvals)

        # Shift!
        shifted = shiftvals + spiketimes[:, np.newaxis, np.newaxis]
        shifted.sort()

        # Find uniques!
        uniques = [
            np.unique(shifted[:, i, j, k]) for i in range(n)
            for j in range(num_fracshift) for k in range(num_shifts)
        ]

        # pyspike!
        shifted_spk = [
            spk.SpikeTrain(shifted_times, duration)
            for shifted_times in uniques
        ]
        return shifted_spk

    if increase:
        slots = np.setdiff1d(np.arange(source.t_start, source.t_end),
                             source.spikes,
                             assume_unique=True)
        sources = {0: source}
        for inc in increase:
            new_spikes = np.random.choice(slots, inc)
            sources[inc] = spk.SpikeTrain(
                np.sort(np.append(source.spikes, new_spikes)), duration)

        shifted = []
        for src in sources.values():
            shifted.extend(shift_main(src))

    else:
        shifted = shift_main(source)

    return shifted
예제 #25
0
def get_matrix(select='subset',
               min_spike_number=0,
               save=None,
               analysis=['SPIKE-Sync'],
               network=[0]):
    import pyspike

    load_data(network)

    getmat = {}

    empty_dict_array = {}
    no_empty_dict_array = {}

    spkts = {}
    spkinds = {}
    spktsRange = {}
    spkt_train = {}
    spike_sync = {}

    for f, p in enumerate(data_files):
        if f in network:
            spkts[f] = d[p]['simData']['spkt']  #list
            spkinds[f] = d[p]['simData']['spkid']  #list

            print 'Starting analysis of spike times per ' + str(
                select) + ': ' + str(p)

            for t, y in enumerate(timeRange):

                spktsRange = [
                    spkt for spkt in spkts[f]
                    if timeRange[t][0] <= spkt <= timeRange[t][1]
                ]

                spkt_train[str(f) + str(t)] = []

                if select == 'subset':
                    print 'Time Range: ' + str(y)

                    empty_array = np.zeros(
                        ((len(net_labels) * 2), (len(net_labels) * 2)))
                    no_empty_array = np.zeros(
                        ((len(net_labels) * 2), (len(net_labels) * 2)))
                    array_ii = np.zeros(
                        ((len(net_labels) * 2), (len(net_labels) * 2)))

                    empty_gids = []
                    gids_included = []

                    for k, v in enumerate(gids):
                        train = []
                        for i, gid in enumerate(v):
                            for spkind, spkt in zip(spkinds[f], spkts[f]):
                                if (spkind == gid and spkt in spktsRange):
                                    train.append(spkt)

                        spkt_train[str(f) + str(t)].append(
                            pyspike.SpikeTrain(train, timeRange[t]))

                        if len(train) < min_spike_number:
                            empty_gids.append(k)
                        else:
                            gids_included.append(k)

                    for i in range(len(spkt_train[str(f) + str(t)])):
                        if i in gids_included:
                            for k, v in enumerate(gids_included):
                                no_empty_array[i][v] = 1.0

                    for l in range(len(array_ii)):
                        array_ii[l][l] = 1.0

                    no_empty_dict_array[str(f) + str(t)] = no_empty_array

                elif select == 'cell':

                    print 'Time Range: ' + str(y)

                    empty_array = np.zeros(
                        ((len(net_labels) * 80), (len(net_labels) * 80)))
                    no_empty_array = np.zeros(
                        ((len(net_labels) * 80), (len(net_labels) * 80)))

                    empty_gids = []
                    spkmat2 = []
                    gids_included = []
                    #sync = np.zeros(((len(net_labels)*80),(len(net_labels)*80)))

                    for ii, subset in enumerate(gids):
                        spkmat = [
                            pyspike.SpikeTrain([
                                spkt
                                for spkind, spkt in zip(spkinds[f], spkts[f])
                                if (spkind == gid and spkt in spktsRange)
                            ], timeRange[t]) for gid in set(subset)
                        ]
                        spkt_train[str(f) + str(t)].extend(spkmat)

                        for gid in set(subset):
                            list_spkt = [
                                spkt
                                for spkind, spkt in zip(spkinds[f], spkts[f])
                                if (spkind == gid and spkt in spktsRange)
                            ]

                            if len(list_spkt) < min_spike_number:
                                empty_gids.append(gid)
                            else:
                                spkmat2.append(
                                    pyspike.SpikeTrain(list_spkt,
                                                       timeRange[t]))
                                gids_included.append(gid)
                        pos_labels.append(len(gids_included))

                    #print gids_included
                    empty_gids[:] = [x - 200 for x in empty_gids]
                    gids_included[:] = [x - 200 for x in gids_included]
                    #print empty_gids
                    for i in range(len(spkt_train[str(f) + str(t)])):
                        if i in empty_gids:
                            for k, v in enumerate(empty_gids):
                                empty_array[i][v] = 1.0

                    for i in range(len(spkt_train[str(f) + str(t)])):
                        if i in gids_included:
                            for k, v in enumerate(gids_included):
                                no_empty_array[i][v] = 1.0

                    #print empty_array
                    empty_dict_array[str(f) + str(t)] = empty_array
                    no_empty_dict_array[str(f) + str(t)] = no_empty_array
                #print spkt_train
                for l, mat in enumerate(mats):
                    #spike_sync
                    if (mat == 'ISI-distance' and mat in analysis):
                        print str(mat) + ", number of trains: " + str(
                            len(spkt_train[str(f) + str(t)]))
                        isi_distance = pyspike.isi_distance_matrix(
                            spkt_train[str(f) + str(t)])
                        getmat[str(f) + str(t) + str(l)] = isi_distance

                    elif (mat in analysis and mat == 'SPIKE-distance'):
                        print str(mat) + ", number of trains: " + str(
                            len(spkt_train[str(f) + str(t)]))
                        spike_distance = pyspike.spike_distance_matrix(
                            spkt_train[str(f) + str(t)])
                        getmat[str(f) + str(t) + str(l)] = spike_distance

                    elif (mat in analysis and mat == 'SPIKE-Sync'):
                        print str(mat) + ", number of trains: " + str(
                            len(spkt_train[str(f) + str(t)]))
                        spike_sync[str(f) +
                                   str(t)] = pyspike.spike_sync_matrix(
                                       spkt_train[str(f) + str(t)])
                        #if select == 'subset':
                        getmat[str(f) + str(t) + str(l)] = (
                            spike_sync[str(f) + str(t)] *
                            no_empty_dict_array[str(f) + str(t)]) + array_ii
                        #elif select == 'cell':
                        #getmat[str(f)+str(t)+str(l)] = spike_sync[str(f)+str(t)] * no_empty_dict_array[str(f)+str(t)]

                empty_array = np.zeros(
                    ((len(net_labels) * 80), (len(net_labels) * 80)))
        else:
            pass

    if save == True:
        with open(str(path) + 'data1.pkl', 'wb') as output:
            pickle.dump(getmat, output)

    return getmat
    print 'finished getting data for matrix plotting'
예제 #26
0
pl.ylim(0, N + 1)
pl.xlim(0, duration * runs)
pl.show(block=False)

########## measure synchrony ##################################################
slices = []
for run in range(runs):
    section = []
    for n in range(N):
        section.append([])
        subint = [
            x for x in spikes[n]
            if x >= ((run - 1) * duration) and x <= (run * duration)
        ]
        section[n] = spk.SpikeTrain(subint, (0, duration))

    slices.append(section)

pl.figure(3)

sync = []
for c in range(len(slices)):
    # sync.append(np.var(spk.spike_sync_matrix(slices[c])))
    sync.append(np.linalg.norm(spk.spike_sync_matrix(slices[c])))
    # sync.append(np.sum(spk.spike_sync_matrix(slices[c])))

pl.plot(sync, linestyle="-", marker="o", markersize="7")
# pl.hlines(15, 0, len(homXS), linewidth=0.3)
pl.grid(which='both', axis='y')
# pl.xlim(xmin=-0.5,xmax=len(homXS)+1.5)
예제 #27
0
def df_to_spike_train(df, t_start, edges):
    return spk.SpikeTrain(
        np.unique([(i - t_start).total_seconds() for i in df.index]), edges)
예제 #28
0
def spikey(fp):
    events = glob.glob(os.path.join(fp,'*.pprox'))
    events.sort()
    #For new recordings:
    syll = pd.read_csv('../restoration_syllables.csv')
    
    #For old recordings:
    #syll = pd.read_csv('../syllables.csv')
    
    for eventfile in events:
        
        with open(eventfile) as fl:
            data = json.load(fl)
        song = []
        condition = []
        train = []
        gapon = {}
        gapoff = {}
        spikes = []
        for t in range(len(data['pprox'])):
            #For new recordings:
            if data['pprox'][t]['condition'] == 'continuous':
                song.append(data['pprox'][t]['stimulus']+'-1')
                song.append(data['pprox'][t]['stimulus']+'-2')
                condition.append(data['pprox'][t]['condition']+'1')
                condition.append(data['pprox'][t]['condition']+'2')
                spikes.append(data['pprox'][t]['event'])
                spikes.append(data['pprox'][t]['event'])
            else:
                songid = data['pprox'][t]['stimulus']+'-'+data['pprox'][t]['condition'][-1]
                song.append(data['pprox'][t]['stimulus']+'-'+data['pprox'][t]['condition'][-1])
                condition.append(data['pprox'][t]['condition'])
                spikes.append(data['pprox'][t]['event'])
                
            if 'gap_on' in data['pprox'][t].keys():
                gapon[songid] = data['pprox'][t]['gap_on']
                gapoff[songid] = data['pprox'][t]['gap_off']
                
            #For old recordings:
            #song.append(data['pprox'][t]['stimulus'])
            #condition.append(data['pprox'][t]['condition'])
            #spikes.append(data['pprox'][t]['event'])
            #if 'gap_on' in data['pprox'][t].keys():
                #gapon[song[t]] = data['pprox'][t]['gap_on'][0]/40
                #gapoff[song[t]] = data['pprox'][t]['gap_off'][0]/40
                
        songset = np.unique(song)
        x = []
        y = []
        for s in song:
            x.append(gapon[s])
            y.append(gapoff[s])
            
        gapon = x
        gapoff = y
        
        for t in range(len(spikes)):
            #For new recordings:
            syllstart = syll['start'][syll['songid'] == song[t][:-2]][syll['start'] <= gapon[t]/1000+0.001][syll['end'] >= gapoff[t]/1000-0.001].values[0] * 1000
            index = syll[syll['songid'] == song[t][:-2]][syll['start'] <= gapon[t]/1000+0.001][syll['end'] >= gapoff[t]/1000-0.001].index.values[0] + 1
            
            #For old recordings:
            #syllstart = syll['start'][syll['songid'] == song[t]][syll['start'] <= gapon[t]/1000+0.001][syll['end'] >= gapoff[t]/1000-0.001].values[0] * 1000
            #index = syll[syll['songid'] == song[t]][syll['start'] <= gapon[t]/1000+0.001][syll['end'] >= gapoff[t]/1000-0.001].index.values[0] + 1
            
            nextsyllend = syll['end'].at[index] * 1000
            spikes[t] = [spike for spike in spikes[t] if spike >= syllstart and spike <= nextsyllend]
            train.append(spk.SpikeTrain(spikes[t],[syllstart,nextsyllend]))
            
        for s,stim in enumerate(songset):
            pairs = np.zeros((len(train)//len(songset),len(train)//len(songset)))
            subset = np.where(np.asarray(song) == stim)
            trainsub = [train[x] for x in subset[0]]
            for i in range(len(trainsub)):
                for j in range(len(trainsub)):
                    pairs[i,j] = spk.spike_distance(trainsub[i], trainsub[j])
            labels = [condition[x] for x in range(len(condition)) if x in subset[0]]
            df = pd.DataFrame(pairs, columns = labels, index = labels)
            df.to_csv(os.path.splitext(eventfile)[0]+'_'+stim+'.csv')
예제 #29
0
def iter_plot0(md):
    import seaborn as sns
    import pickle
    with open('cell_indexs.p', 'rb') as f:
        returned_list = pickle.load(f)
    index_exc = returned_list[0]
    index_inh = returned_list[1]
    index, mdf1 = md
    #wgf = {0.025:None,0.05:None,0.125:None,0.25:None,0.3:None,0.4:None,0.5:None,1.0:None,1.5:None,2.0:None,2.5:None,3.0:None}
    wgf = {
        0.0025: None,
        0.0125: None,
        0.025: None,
        0.05: None,
        0.125: None,
        0.25: None,
        0.3: None,
        0.4: None,
        0.5: None,
        1.0: None,
        1.5: None,
        2.0: None,
        2.5: None,
        3.0: None
    }

    weight_gain_factors = {k: v for k, v in enumerate(wgf.keys())}
    print(len(weight_gain_factors))
    print(weight_gain_factors.keys())
    #weight_gain_factors = {0:0.5,1:1.0,2:1.5,3:2.0,4:2.5,5:3}
    #weight_gain_factors = {:None,1.0:None,1.5:None,2.0:None,2.5:None}

    k = weight_gain_factors[index]
    #print(len(mdf1.segments),'length of block')

    ass = mdf1.analogsignals[0]

    time_points = ass.times
    avg = np.mean(ass, axis=0)  # Average over signals of Segment
    #maxx = np.max(ass, axis=0)  # Average over signals of Segment
    std = np.std(ass, axis=0)  # Average over signals of Segment
    #avg_minus =
    plt.figure()
    plt.plot([i for i in range(0, len(avg))], avg)
    plt.plot([i for i in range(0, len(std))], std)

    plt.title("Mean and Standard Dev of $V_{m}$ amplitude per neuron ")
    plt.xlabel('time $(ms)$')
    plt.xlabel('Voltage $(mV)$')

    plt.savefig(str(index) + 'prs.png')
    vm_spiking = []
    vm_not_spiking = []
    spike_trains = []
    binary_trains = []
    max_spikes = 0

    vms = np.array(mdf1.analogsignals[0].as_array().T)
    #print(data)
    #for i,vm in enumerate(data):

    cnt = 0
    for spiketrain in mdf1.spiketrains:
        #spiketrain = mdf1.spiketrains[index]
        y = np.ones_like(spiketrain) * spiketrain.annotations['source_id']
        #import sklearn
        #sklearn.decomposition.NMF(y)
        # argument edges is the time interval you want to be considered.
        pspikes = pyspike.SpikeTrain(spiketrain, edges=(0, len(ass)))
        spike_trains.append(pspikes)
        if len(spiketrain) > max_spikes:
            max_spikes = len(spiketrain)

        if np.max(ass[spiketrain.annotations['source_id']]) > 0.0:
            vm_spiking.append(vms[spiketrain.annotations['source_id']])
        else:
            vm_not_spiking.append(vms[spiketrain.annotations['source_id']])
        cnt += 1

    for spiketrain in mdf1.spiketrains:
        x = conv.BinnedSpikeTrain(spiketrain,
                                  binsize=1 * pq.ms,
                                  t_start=0 * pq.s)
        binary_trains.append(x)
    end_floor = np.floor(float(mdf1.t_stop))
    dt = float(mdf1.t_stop) % end_floor
    mdf1.t_start
    #v = mdf1.take_slice_of_analogsignalarray_by_unit()
    t_axis = np.arange(float(mdf1.t_start), float(mdf1.t_stop), dt)
    plt.figure()
    plt.clf()

    plt.figure()
    plt.clf()
    cleaned = []
    data = np.array(mdf1.analogsignals[0].as_array().T)
    #print(data)
    for i, vm in enumerate(data):
        if np.max(vm) > 900.0 or np.min(vm) < -900.0:
            pass
        else:
            plt.plot(ass.times, vm)  #,label='neuron identifier '+str(i)))
            cleaned.append(vm)
            #vm = s#.as_array()[:,i]

    assert len(cleaned) < len(ass)

    print(len(cleaned))
    plt.title('neuron $V_{m}$')
    #plt.legend(loc="upper left")
    plt.savefig(str('weight_') + str(k) + 'analogsignals' + '.png')
    plt.xlabel('Time $(ms)$')
    plt.ylabel('Voltage $(mV)$')

    plt.close()

    #pass

    plt.figure()
    plt.clf()
    plt.title('Single Neuron $V_{m}$ trace')
    plt.plot(ass.times[0:int(len(ass.times) / 10)],
             vm_not_spiking[index_exc[0]][0:int(len(ass.times) / 10)])
    plt.xlabel('$ms$')
    plt.ylabel('$mV$')
    plt.xlabel('Time $(ms)$')
    plt.ylabel('Voltage $(mV)$')
    plt.savefig(str('weight_') + str(k) + 'eespecific_analogsignals' + '.png')
    plt.close()

    plt.figure()
    plt.clf()
    plt.title('Single Neuron $V_{m}$ trace')
    plt.plot(ass.times[0:int(len(ass.times) / 10)],
             vm_not_spiking[index_inh[0]][0:int(len(ass.times) / 10)])
    plt.xlabel('$ms$')
    plt.ylabel('$mV$')

    plt.savefig(str('weight_') + str(k) + 'inhibitory_analogsignals' + '.png')
    plt.close()

    cvs = [0 for i in range(0, len(spike_trains))]
    cvsd = {}
    cvs = []
    cvsi = []
    rates = []  # firing rates per cell. in spikes a second.
    for i, j in enumerate(spike_trains):
        rates.append(float(len(j) / 2.0))
        cva = cv(j)
        if np.isnan(cva) or cva == 0:
            pass
            #cvs[i] = 0
            #cvsd[i] = 0
        else:
            pass
            #cvs[i] = cva
            #cvsd[i] = cva
        cvs.append(cva)
    #import pickle
    #with open(str('weight_')+str(k)+'coefficients_of_variation.p','wb') as f:
    #   pickle.dump([cvs,cvsd],f)
    import numpy
    a = numpy.asarray(cvs)
    numpy.savetxt('pickles/' + str('weight_') + str(k) +
                  'coefficients_of_variation.csv',
                  a,
                  delimiter=",")

    import numpy
    a = numpy.asarray(rates)
    numpy.savetxt('pickles/' + str('weight_') + str(k) + 'firing_of_rate.csv',
                  a,
                  delimiter=",")

    cvs = [i for i in cvs if i != 0]
    cells = [i for i in range(0, len(cvs))]

    plt.clf()
    fig, axes = plt.subplots()
    axes.set_title('Coefficient of Variation Versus Neuron')
    axes.set_xlabel('Neuron number')
    axes.set_ylabel('CV estimate')
    mcv = np.mean(cvs)
    #plt.scatter(cells,cvs)
    cvs = np.array(cvs)
    plt.scatter(index_inh, cvs[index_inh], label="inhibitory cells")
    plt.scatter(index_exc, cvs[index_exc], label="excitatory cells")
    plt.legend(loc="upper left")

    fig.tight_layout()
    plt.savefig(str('weight_') + str(k) + 'cvs_mean_' + str(mcv) + '.png')
    plt.close()

    plt.clf()
    #frequencies, power = elephant.spectral.welch_psd(ass)
    #mfreq = frequencies[np.where(power==np.max(power))[0][0]]
    #fig, axes = plt.subplots()
    axes.set_title('Firing Rate Versus Neuron Number at mean f=' +
                   str(np.mean(rates)) + str('(Spike Per Second)'))
    axes.set_xlabel('Neuron number')
    axes.set_ylabel('Spikes per second')
    rates = np.array(rates)
    plt.scatter(index_inh, rates[index_inh], label="inhibitory cells")
    plt.scatter(index_exc, rates[index_exc], label="excitatory cells")
    plt.legend(loc="upper left")
    fig.tight_layout()
    plt.savefig(str('firing_rates_per_cell_') + str(k) + str(mcv) + '.png')
    plt.close()
    '''
    import pandas as pd
    d = {'coefficent_of_variation': cvs, 'cells': cells}
    df = pd.DataFrame(data=d)

    ax = sns.regplot(x='cells', y='coefficent_of_variation', data=df)#, fit_reg=False)
    plt.savefig(str('weight_')+str(k)+'cvs_regexp_'+str(mcv)+'.png');
    plt.close()
    '''

    spike_trains = []
    ass = mdf1.analogsignals[0]
    tstop = mdf1.t_stop
    np.max(ass.times) == mdf1.t_stop
    #assert tstop == 2000
    tstop = 2000
    vm_spiking = []

    for spiketrain in mdf1.spiketrains:
        vm_spiking.append(
            mdf1.analogsignals[0][spiketrain.annotations['source_id']])
        y = np.ones_like(spiketrain) * spiketrain.annotations['source_id']

        # argument edges is the time interval you want to be considered.
        pspikes = pyspike.SpikeTrain(spiketrain, edges=(0, tstop))
        spike_trains.append(pspikes)

    # plot the spike times

    plt.clf()
    for (i, spike_train) in enumerate(spike_trains):
        plt.scatter(spike_train, i * np.ones_like(spike_train), marker='.')
    plt.xlabel('Time (ms)')
    plt.ylabel('Cell identifier')
    plt.title('Raster Plot for weight strength:' + str(k))

    plt.savefig(str('weight_') + str(k) + 'raster_plot' + '.png')
    plt.close()

    f = spk.isi_profile(spike_trains, indices=[0, 1])
    x, y = f.get_plottable_data()

    #text_file.close()
    text_file = open(str('weight_') + str(index) + 'net_out.txt', 'w')

    plt.figure()
    plt.plot(x, np.abs(y), '--k', label="ISI-profile")
    print("ISI-distance: %.8f" % f.avrg())
    f = spk.spike_profile(spike_trains, indices=[0, 1])
    x, y = f.get_plottable_data()
    plt.plot(x, y, '-b', label="SPIKE-profile")
    #print("SPIKE-distance: %.8f" % f.avrg())
    string_to_write = str("ISI-distance:") + str(f.avrg()) + str("\n\n")
    plt.title(string_to_write)
    plt.xlabel('Time $(ms)$')
    plt.ylabel('ISI distance')
    plt.legend(loc="upper left")
    plt.savefig(str('weight_') + str(k) + 'ISI_distance_bivariate' + '.png')
    plt.close()
    text_file.write(string_to_write)

    #text_file.write("SPIKE-distance: %.8f" % f.avrg())
    #text_file.write("\n\n")

    plt.figure()
    f = spk.spike_sync_profile(spike_trains[0], spike_trains[1])
    x, y = f.get_plottable_data()
    plt.plot(x, y, '--ok', label="SPIKE-SYNC profile")
    print(f, f.avrg())
    print("Average:" + str(f.avrg()))
    #print(len(f.avrg()),f.avrg())
    string_to_write = str("instantaneous synchrony:") + str(
        f.avrg()) + 'weight: ' + str(index)

    plt.title(string_to_write)
    plt.xlabel('Time $(ms)$')
    plt.ylabel('instantaneous synchrony')

    text_file.write(string_to_write)

    #text_file.write(list())

    f = spk.spike_profile(spike_trains[0], spike_trains[1])
    x, y = f.get_plottable_data()

    plt.plot(x, y, '-b', label="SPIKE-profile")
    plt.axis([0, 4000, -0.1, 1.1])
    plt.legend(loc="center right")
    plt.clf()
    plt.figure()
    plt.subplot(211)

    f = spk.spike_sync_profile(spike_trains)
    x, y = f.get_plottable_data()
    plt.plot(x, y, '-b', alpha=0.7, label="SPIKE-Sync profile")
    x1, y1 = f.get_plottable_data(averaging_window_size=50)
    plt.plot(x1, y1, '-k', lw=2.5, label="averaged SPIKE-Sync profile")
    plt.subplot(212)

    f_psth = spk.psth(spike_trains, bin_size=50.0)
    x, y = f_psth.get_plottable_data()
    plt.plot(x, y, '-k', alpha=1.0, label="PSTH")

    plt.savefig(str('weight_') + str(k) + 'multivariate_PSTH' + '.png')
    plt.close()
    plt.xlabel('Time $(ms)$')
    plt.ylabel('Spikes per bin')

    plt.clf()
    plt.figure()

    f_psth = spk.psth(spike_trains, bin_size=50.0)
    x, y = f_psth.get_plottable_data()
    plt.plot(x, y, '-k', alpha=1.0, label="PSTH")

    plt.savefig(str('weight_') + str(k) + 'exclusively_PSTH' + '.png')
    plt.close()

    plt.figure()
    isi_distance = spk.isi_distance_matrix(spike_trains)
    plt.imshow(isi_distance, interpolation='none')
    plt.title('Pairwise ISI distance, T=0-2000')
    plt.xlabel('post-synaptic neuron number')
    plt.ylabel('pre-synaptic neuron number')

    plt.title("ISI-distance")
    plt.savefig(str('weight_') + str(k) + 'ISI_distance' + '.png')
    plt.close()

    #plt.show()

    plt.figure()
    plt.clf()
    import seaborn as sns

    sns.set()
    sns.clustermap(isi_distance)  #,vmin=-,vmax=1);

    plt.savefig(str('weight_') + str(k) + 'cluster_isi_distance' + '.png')
    plt.close()

    plt.figure()
    spike_distance = spk.spike_distance_matrix(spike_trains,
                                               interval=(0, float(tstop)))

    import pickle
    with open('spike_distance_matrix.p', 'wb') as f:
        pickle.dump(spike_distance, f)

    plt.imshow(spike_distance, interpolation='none')
    plt.title("Pairwise SPIKE-distance, T=0-2000")
    plt.xlabel('post-synaptic neuron number')
    plt.ylabel('pre-synaptic neuron number')

    plt.savefig(str('weight_') + str(k) + 'spike_distance_matrix' + '.png')
    plt.close()
    plt.figure()
    plt.clf()
    sns.set()
    sns.clustermap(spike_distance)

    plt.savefig(str('weight_') + str(k) + 'cluster_spike_distance' + '.png')
    plt.close()

    plt.figure()
    spike_sync = spk.spike_sync_matrix(spike_trains,
                                       interval=(0, float(tstop)))
    plt.imshow(spike_sync, interpolation='none')
    plt.title('Pairwise Spike Synchony, T=0-2000')
    plt.xlabel('post-synaptic neuron number')
    plt.ylabel('pre-synaptic neuron number')

    import numpy
    a = numpy.asarray(spike_sync)
    numpy.savetxt("spike_sync_matrix.csv", a, delimiter=",")

    plt.figure()
    plt.clf()
    sns.clustermap(spike_sync)
    plt.savefig(
        str('weight_') + str(k) + 'cluster_spike_sync_distance' + '.png')
    plt.close()
예제 #30
0
def iter_plot1(md):
    index, mdf1 = md
    wgf = {
        0.025: None,
        0.05: None,
        0.125: None,
        0.25: None,
        0.3: None,
        0.4: None,
        0.5: None,
        1.0: None,
        1.5: None,
        2.0: None,
        2.5: None,
        3.0: None
    }
    weight_gain_factors = {k: v for k, v in enumerate(wgf.keys())}
    k = weight_gain_factors[index]
    ass = mdf1.analogsignals[0]
    vm_spiking = []
    vm_not_spiking = []
    spike_trains = []
    binary_trains = []
    max_spikes = 0
    cnt = 0
    for spiketrain in mdf1.spiketrains:
        #spiketrain = mdf1.spiketrains[index]
        y = np.ones_like(spiketrain) * spiketrain.annotations['source_id']
        # argument edges is the time interval you want to be considered.
        pspikes = pyspike.SpikeTrain(spiketrain, edges=(0, len(ass)))
        spike_trains.append(pspikes)
        if len(spiketrain) > max_spikes:
            max_spikes = len(spiketrain)

        if np.max(ass[spiketrain.annotations['source_id']]) > 0.0:
            vm_spiking.append(ass[spiketrain.annotations['source_id']])
        else:
            vm_not_spiking.append(ass[spiketrain.annotations['source_id']])
        cnt += 1

    import elephant
    from scipy.signal import periodogram
    #dt = 0.0025
    #frequencies, power = periodogram(ass,fs=1/dt)
    frequencies, power = elephant.spectral.welch_psd(ass)

    mfreq = frequencies[np.where(power == np.max(power))[0][0]]
    import pickle
    with open(str(k) + '_' + str(mfreq) + '_' + 'mfreq.p', 'wb') as f:
        pickle.dump(mfreq, f)

    def plot_periodogram(frequencies, power):
        plt.figure(figsize=(10, 4))
        sns.heatmap(power)
        plt.xlabel('Frequency ($Hz$)')
        plt.ylabel('Power pre neuron ($V^2/Hz$)')  # Note that power is now
        # a normalized density
        plt.savefig(
            str('weight_') + str(k) + 'multi_variate_periodogram' + '.png')
        plt.close()

        plt.plot(frequencies, power[0])
        plt.savefig(
            str('weight_') + str(k) + '_single_neuron_periodogram' + '.png')

        return

    plot_periodogram(frequencies, power)

    lens = np.shape(ass.as_array())[1]
    coherance_matrix = np.zeros(shape=(lens, lens), dtype=float)
    for i in range(0, lens):
        for j in range(0, lens):
            if i != j:
                x = ass.as_array()[i]
                y = ass.as_array()[j]
                coh = welch_cohere(x, y)
                if np.mean(coh) != 0:
                    coherance_matrix[i, j] = np.mean(coh)
    plt.figure()
    plt.clf()
    from matplotlib.colors import LogNorm
    #plt.imshow(coherance_matrix, interpolation='none',norm=cbar_kws)
    sns.heatmap(coherance_matrix)  #,cbar_kws=cbar_kws)
    plt.title("Coherance Matrix")
    plt.xlabel('pre-synaptic cell')
    plt.ylabel('post-synaptic cell')

    plt.savefig(
        str('Coherance_matrix_weight_') + str(k) + str('freq_') + str(mfreq) +
        '.png')
    plt.close()

    import numpy
    a = numpy.asarray(coherance_matrix)
    numpy.savetxt("coherance_matrix.csv", a, delimiter=",")
    mdf1 = None
    coh = None