Пример #1
0
    def test_Riehle_et_al_97_UE(self):
        url = "http://raw.githubusercontent.com/ReScience-Archives/Rostami-" \
              "Ito-Denker-Gruen-2017/master/data"
        files_to_download = (("extracted_data.npy",
                              "c4903666ce8a8a31274d6b11238a5ac3"),
                             ("winny131_23.gdf",
                              "cc2958f7b4fb14dbab71e17bba49bd10"))
        for filename, checksum in files_to_download:
            # The files will be downloaded to ELEPHANT_TMP_DIR
            download(url=f"{url}/{filename}", checksum=checksum)

        # load spike data of figure 2 of Riehle et al 1997
        spiketrain = self.load_gdf2Neo(ELEPHANT_TMP_DIR / "winny131_23.gdf",
                                       trigger='RS_4',
                                       t_pre=1799 * pq.ms,
                                       t_post=300 * pq.ms)

        # calculating UE ...
        winsize = 100 * pq.ms
        bin_size = 5 * pq.ms
        winstep = 5 * pq.ms
        pattern_hash = [3]
        t_start = spiketrain[0][0].t_start
        t_stop = spiketrain[0][0].t_stop
        t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
        significance_level = 0.05

        UE = ue.jointJ_window_analysis(spiketrain,
                                       pattern_hash=pattern_hash,
                                       bin_size=bin_size,
                                       win_size=winsize,
                                       win_step=winstep,
                                       method='analytic_TrialAverage')
        # load extracted data from figure 2 of Riehle et al 1997
        extracted_data = np.load(ELEPHANT_TMP_DIR / 'extracted_data.npy',
                                 encoding='latin1',
                                 allow_pickle=True).item()
        Js_sig = ue.jointJ(significance_level)
        sig_idx_win = np.where(UE['Js'] >= Js_sig)[0]
        diff_UE_rep = []
        y_cnt = 0
        for trial_id in range(len(spiketrain)):
            trial_id_str = "trial{}".format(trial_id)
            indices_unique = np.unique(UE['indices'][trial_id_str])
            if len(indices_unique) > 0:
                # choose only the significant coincidences
                indices_unique_significant = []
                for j in sig_idx_win:
                    significant = indices_unique[np.where(
                        (indices_unique * bin_size >= t_winpos[j])
                        & (indices_unique * bin_size < t_winpos[j] + winsize))]
                    indices_unique_significant.extend(significant)
                x_tmp = np.unique(indices_unique_significant) * \
                    bin_size.magnitude
                if len(x_tmp) > 0:
                    ue_trial = np.sort(extracted_data['ue'][y_cnt])
                    diff_UE_rep = np.append(diff_UE_rep, x_tmp - ue_trial)
                    y_cnt += +1
        np.testing.assert_array_less(np.abs(diff_UE_rep), 0.3)
Пример #2
0
 def test__winpos(self):
     t_start = 10 * pq.ms
     t_stop = 46 * pq.ms
     winsize = 15 * pq.ms
     winstep = 3 * pq.ms
     expected = [10., 13., 16., 19., 22., 25., 28., 31.] * pq.ms
     self.assertTrue(
         np.allclose(
             ue._winpos(t_start, t_stop, winsize,
                        winstep).rescale('ms').magnitude,
             expected.rescale('ms').magnitude))
 def test__winpos(self):
     t_start = 10*pq.ms
     t_stop = 46*pq.ms
     winsize = 15*pq.ms
     winstep = 3*pq.ms
     expected = [ 10., 13., 16., 19., 22., 25., 28., 31.]*pq.ms
     self.assertTrue(
         np.allclose(
             ue._winpos(
                 t_start, t_stop, winsize,
                 winstep).rescale('ms').magnitude,
             expected.rescale('ms').magnitude))
    def test_Riehle_et_al_97_UE(self):
        url = "http://raw.githubusercontent.com/ReScience-Archives/Rostami-" \
              "Ito-Denker-Gruen-2017/master/data"
        shortname = "unitary_event_analysis_test_data"
        local_test_dir = create_local_temp_dir(shortname)
        files_to_download = ["extracted_data.npy", "winny131_23.gdf"]
        context = ssl._create_unverified_context()
        for filename in files_to_download:
            url_file = "{url}/{filename}".format(url=url, filename=filename)
            dist = urlopen(url_file, context=context)
            localfile = os.path.join(local_test_dir, filename)
            with open(localfile, 'wb') as f:
                f.write(dist.read())

        # load spike data of figure 2 of Riehle et al 1997
        spiketrain = self.load_gdf2Neo(os.path.join(local_test_dir,
                                                    "winny131_23.gdf"),
                                       trigger='RS_4',
                                       t_pre=1799 * pq.ms,
                                       t_post=300 * pq.ms)

        # calculating UE ...
        winsize = 100 * pq.ms
        binsize = 5 * pq.ms
        winstep = 5 * pq.ms
        pattern_hash = [3]
        t_start = spiketrain[0][0].t_start
        t_stop = spiketrain[0][0].t_stop
        t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
        significance_level = 0.05

        UE = ue.jointJ_window_analysis(spiketrain,
                                       binsize,
                                       winsize,
                                       winstep,
                                       pattern_hash,
                                       method='analytic_TrialAverage')
        # load extracted data from figure 2 of Riehle et al 1997
        extracted_data = np.load(os.path.join(local_test_dir,
                                              'extracted_data.npy'),
                                 encoding='latin1',
                                 allow_pickle=True).item()
        Js_sig = ue.jointJ(significance_level)
        sig_idx_win = np.where(UE['Js'] >= Js_sig)[0]
        diff_UE_rep = []
        y_cnt = 0
        for trial_id in range(len(spiketrain)):
            trial_id_str = "trial{}".format(trial_id)
            indices_unique = np.unique(UE['indices'][trial_id_str])
            if len(indices_unique) > 0:
                # choose only the significant coincidences
                indices_unique_significant = []
                for j in sig_idx_win:
                    significant = indices_unique[np.where(
                        (indices_unique * binsize >= t_winpos[j])
                        & (indices_unique * binsize < t_winpos[j] + winsize))]
                    indices_unique_significant.extend(significant)
                x_tmp = np.unique(indices_unique_significant) * \
                    binsize.magnitude
                if len(x_tmp) > 0:
                    ue_trial = np.sort(extracted_data['ue'][y_cnt])
                    diff_UE_rep = np.append(diff_UE_rep, x_tmp - ue_trial)
                    y_cnt += +1
        shutil.rmtree(local_test_dir)
        np.testing.assert_array_less(np.abs(diff_UE_rep), 0.3)
Пример #5
0
def _plot_UE(data,
             Js_dict,
             sig_level,
             binsize,
             winsize,
             winstep,
             pattern_hash,
             N,
             args,
             add_epochs=[]):
    """
    Examples:
    ---------
    dict_args = {'events':{'SO':[100*pq.ms]},
     'save_fig': True,
     'path_filename_format':'UE1.pdf',
     'showfig':True,
     'suptitle':True,
     'figsize':(12,10),
    'unit_ids':[10, 19, 20],
    'ch_ids':[1,3,4],
    'fontsize':15,
    'linewidth':2,
    'set_xticks' :False'}
    'marker_size':8,
    """
    import matplotlib.pylab as plt
    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop

    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    Js_sig = ue.jointJ(sig_level)
    num_tr = len(data)
    pat = ue.inverse_hash_from_pattern(pattern_hash, N)
    events = args['events']

    # figure format
    figsize = args['figsize']
    if 'top' in args.keys():
        top = args['top']
    else:
        top = .90
    if 'bottom' in args.keys():
        bottom = args['bottom']
    else:
        bottom = .05
    if 'right' in args.keys():
        right = args['right']
    else:
        right = .95
    if 'left' in args.keys():
        left = args['left']
    else:
        left = .1

    if 'hspace' in args.keys():
        hspace = args['hspace']
    else:
        hspace = .5
    if 'wspace' in args.keys():
        wspace = args['wspace']
    else:
        wspace = .5

    if 'fontsize' in args.keys():
        fsize = args['fontsize']
    else:
        fsize = 20
    if 'unit_ids' in args.keys():
        unit_real_ids = args['unit_ids']
        if len(unit_real_ids) != N:
            raise ValueError(
                'length of unit_ids should be equal to number of neurons!')
    else:
        unit_real_ids = numpy.arange(1, N + 1, 1)
    if 'ch_ids' in args.keys():
        ch_real_ids = args['ch_ids']
        if len(ch_real_ids) != N:
            raise ValueError(
                'length of ch_ids should be equal to number of neurons!')
    else:
        ch_real_ids = []

    if 'showfig' in args.keys():
        showfig = args['showfig']
    else:
        showfig = False
    if 'linewidth' in args.keys():
        lw = args['linewidth']
    else:
        lw = 2

    if 'S_ylim' in args.keys():
        S_ylim = args['S_ylim']
    else:
        S_ylim = [-3, 3]

    if 'marker_size' in args.keys():
        ms = args['marker_size']
    else:
        ms = 8

    if add_epochs != []:
        coincrate = add_epochs['coincrate']
        backgroundrate = add_epochs['backgroundrate']
        num_row = 6
    else:
        num_row = 5
    num_col = 1
    ls = '-'
    alpha = 0.5
    plt.figure(1, figsize=figsize)
    if args['suptitle'] == True:
        plt.suptitle("Spike Pattern:" + str((pat.T)[0]), fontsize=20)
    print 'plotting UEs ...'
    plt.subplots_adjust(top=top,
                        right=right,
                        left=left,
                        bottom=bottom,
                        hspace=hspace,
                        wspace=wspace)
    ax = plt.subplot(num_row, 1, 1)
    ax.set_title('Unitary Events', fontsize=20, color='r')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            plt.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            sig_idx_win = numpy.where(Js_dict['Js'] >= Js_sig)[0]
            if len(sig_idx_win) > 0:
                x = numpy.unique(Js_dict['indices']['trial' + str(tr)])
                if len(x) > 0:
                    xx = []
                    for j in sig_idx_win:
                        xx = numpy.append(
                            xx, x[numpy.where((x * binsize >= t_winpos[j]) & (
                                x * binsize < t_winpos[j] + winsize))])
                    plt.plot(numpy.unique(xx) * binsize,
                             numpy.ones_like(numpy.unique(xx)) * tr + n *
                             (num_tr + 1) + 1,
                             ms=ms,
                             marker='s',
                             ls='',
                             mfc='none',
                             mec='r')
        plt.axhline((tr + 2) * (n + 1), lw=2, color='k')
    y_ticks_pos = numpy.arange(num_tr / 2 + 1, N * (num_tr + 1), num_tr + 1)
    plt.yticks(y_ticks_pos)
    plt.gca().set_yticklabels(unit_real_ids, fontsize=fsize)
    for ch_cnt, ch_id in enumerate(ch_real_ids):
        print ch_id
        plt.gca().text((max(t_winpos) + winsize).rescale('ms').magnitude,
                       y_ticks_pos[ch_cnt],
                       'CH-' + str(ch_id),
                       fontsize=fsize)

    plt.ylim(0, (tr + 2) * (n + 1) + 1)
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.xticks([])
    plt.ylabel('Unit ID', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])
    print 'plotting Raw Coincidences ...'
    ax1 = plt.subplot(num_row, 1, 2, sharex=ax)
    ax1.set_title('Raw Coincidences', fontsize=20, color='c')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            plt.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            plt.plot(
                numpy.unique(Js_dict['indices']['trial' + str(tr)]) * binsize,
                numpy.ones_like(
                    numpy.unique(Js_dict['indices']['trial' + str(tr)])) * tr +
                n * (num_tr + 1) + 1,
                ls='',
                ms=ms,
                marker='s',
                markerfacecolor='none',
                markeredgecolor='c')
        plt.axhline((tr + 2) * (n + 1), lw=2, color='k')
    plt.ylim(0, (tr + 2) * (n + 1) + 1)
    plt.yticks(numpy.arange(num_tr / 2 + 1, N * (num_tr + 1), num_tr + 1))
    plt.gca().set_yticklabels(unit_real_ids, fontsize=fsize)
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.xticks([])
    plt.ylabel('Unit ID', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)

    print 'plotting PSTH ...'
    plt.subplot(num_row, 1, 3, sharex=ax)
    #max_val_psth = 0.*pq.Hz
    for n in range(N):
        #data_psth = []
        #for tr,data_tr in enumerate(data):
        #    data_psth.append(data_tr[p])
        #psth = ss.peth(data_psth, w = psth_width)
        #plt.plot(psth.times,psth.base/float(num_tr)/psth_width.rescale('s'), label = 'unit '+str(unit_real_ids[p]))
        #max_val_psth = max(max_val_psth, max((psth.base/float(num_tr)/psth_width.rescale('s')).magnitude))
        plt.plot(t_winpos + winsize / 2.,
                 Js_dict['rate_avg'][:, n].rescale('Hz'),
                 label='unit ' + str(unit_real_ids[n]),
                 lw=lw)
        #max_val_psth = max(max_val_psth, max(Js_dict['rate_avg'][:,n].rescale('Hz')))
    plt.ylabel('Rate [Hz]', fontsize=fsize)
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    max_val_psth = plt.gca().get_ylim()[1]
    plt.ylim(0, max_val_psth)
    plt.yticks([0, int(max_val_psth / 2), int(max_val_psth)], fontsize=fsize)
    plt.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)

    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])
    print 'plotting emp. and exp. coincidences rate ...'
    plt.subplot(num_row, 1, 4, sharex=ax)
    plt.plot(t_winpos + winsize / 2.,
             Js_dict['n_emp'],
             label='empirical',
             lw=lw,
             color='c')
    plt.plot(t_winpos + winsize / 2.,
             Js_dict['n_exp'],
             label='expected',
             lw=lw,
             color='m')
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.ylabel('# Coinc.', fontsize=fsize)
    plt.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    YTicks = plt.ylim(0, int(max(max(Js_dict['n_emp']),
                                 max(Js_dict['n_exp']))))
    plt.yticks([0, YTicks[1]], fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])

    print 'plotting Surprise ...'
    plt.subplot(num_row, 1, 5, sharex=ax)
    plt.plot(t_winpos + winsize / 2., Js_dict['Js'], lw=lw, color='k')
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.axhline(Js_sig, ls='-', color='gray')
    plt.axhline(-Js_sig, ls='-', color='gray')
    plt.gca().text(10,
                   Js_sig + 0.2,
                   str(int(sig_level * 100)) + '%',
                   fontsize=fsize - 2,
                   color='gray')
    plt.xticks(t_winpos.magnitude[::len(t_winpos) / 10])
    plt.yticks([-2, 0, 2], fontsize=fsize)
    plt.ylabel('S', fontsize=fsize)
    plt.xlabel('Time [ms]', fontsize=fsize)
    plt.ylim(S_ylim)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
            plt.gca().text(e_val - 10 * pq.ms,
                           2 * S_ylim[0],
                           key,
                           fontsize=fsize,
                           color='r')
    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])

    if add_epochs != []:
        plt.subplot(num_row, 1, 6, sharex=ax)
        plt.plot(coincrate, lw=lw, color='c')
        plt.plot(backgroundrate, lw=lw, color='m')
        plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
        plt.ylim(plt.gca().get_ylim()[0] - 2, plt.gca().get_ylim()[1] + 2)
    if args['save_fig'] == True:
        plt.savefig(args['path_filename_format'])
        if showfig == False:
            plt.cla()
            plt.close()
# plt.xticks(t_winpos.magnitude)

    if showfig == True:
        plt.show()
Пример #6
0
def plot_figure3(data_list, Js_dict_lst_lst, sig_level, binsize, winsize,
                 winstep, patterns, N, N_comb, plot_params_user):
    """plots Figure 3 of the manuscript"""

    import matplotlib.pylab as plt
    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    if len(unit_real_ids) != N:
        raise ValueError(
            'length of unit_ids should be equal to number of neurons!')
    plt.rcParams.update({'font.size': fsize})
    plt.rc('legend', fontsize=fsize)
    ms1 = 0.5
    plt.figure(1, figsize=figsize)
    plt.subplots_adjust(top=top,
                        right=right,
                        left=left,
                        bottom=bottom,
                        hspace=hspace,
                        wspace=wspace)
    events = [events1, events2]
    ls = '-'
    alpha = 1.
    num_row, num_col = 3, 5
    for data_cnt, data in enumerate(data_list):
        Js_dict_lst = Js_dict_lst_lst[data_cnt]
        t_start = data[0][0].t_start
        t_stop = data[0][0].t_stop
        t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
        Js_sig = ue.jointJ(sig_level)
        num_tr = len(data)
        ax0 = plt.subplot2grid((num_row, num_col), (0, data_cnt * 3),
                               rowspan=1,
                               colspan=3 - data_cnt)
        ax0.set_xticks([])
        if data_cnt == 0:
            ax0.set_title('Spike Events', fontsize=fsize)
            ax0.set_yticks([num_tr / 2 + 1, num_tr * 3 / 2., num_tr * 5 / 2])
            ax0.set_yticklabels(['5', '4', '3'], fontsize=fsize)
            ax0.spines['right'].set_visible(False)
            ax0.yaxis.set_ticks_position('left')
            # ax0.text(-110, 190, 'Neuron #', rotation=90, fontsize=fsize)
            ax0.set_ylabel('Neuron #', fontsize=fsize)
        else:
            ax0.spines['left'].set_visible(False)
            ax0.yaxis.set_ticks_position('right')
            ax0.set_yticks([])
        for n in range(N):
            for tr, data_tr in enumerate(data):
                ax0.plot(data_tr[n].rescale('ms').magnitude,
                         numpy.ones_like(data_tr[n].magnitude) * tr + n *
                         (num_tr + 1) + 1,
                         '.',
                         markersize=ms1,
                         color='k')
            if n < N - 1:
                ax0.axhline((tr + 2) * (n + 1), lw=0.5, color='k')
        ax0.set_ylim(0, (tr + 2) * (n + 1) + 1)
        for ev in events[data_cnt].keys():
            ax0.axvline(events[data_cnt][ev], color='k')

        ax1 = plt.subplot2grid((num_row, num_col), (1, data_cnt * 3),
                               rowspan=1,
                               colspan=3 - data_cnt)
        ax1.set_xticks([])
        if data_cnt == 0:
            ax1.set_title('Coincident Events', fontsize=fsize)
            ax1.text(-110, 190, 'Neuron #', rotation=90, fontsize=fsize)
            ax1.set_yticks([num_tr / 2 + 1, num_tr * 3 / 2., num_tr * 5 / 2])
            ax1.set_yticklabels(['5', '4', '3'], fontsize=fsize)
            ax1.spines['right'].set_visible(False)
            ax1.yaxis.set_ticks_position('left')
        else:
            ax1.spines['left'].set_visible(False)
            ax1.yaxis.set_ticks_position('right')
            ax1.set_yticks([])
        for N_cnt, N_comb_sel in enumerate(N_comb):
            patt_tmp = ue.inverse_hash_from_pattern([patterns[N_cnt]],
                                                    len(N_comb_sel))
            for n_cnt, n in enumerate(N_comb_sel):
                for tr, data_tr in enumerate(data):
                    if N_cnt == 0:
                        ax1.plot(data_tr[n].rescale('ms').magnitude,
                                 numpy.ones_like(data_tr[n].magnitude) * tr +
                                 n * (num_tr + 1) + 1,
                                 '.',
                                 markersize=ms1,
                                 color='k')
                    if patt_tmp[n_cnt][0] == 1:
                        ax1.plot(numpy.unique(
                            Js_dict_lst[N_cnt]['indices']['trial' + str(tr)]) *
                                 binsize,
                                 numpy.ones_like(
                                     numpy.unique(Js_dict_lst[N_cnt]['indices']
                                                  ['trial' + str(tr)])) * tr +
                                 n * (num_tr + 1) + 1,
                                 ls='',
                                 ms=ms,
                                 marker='s',
                                 markerfacecolor='none',
                                 markeredgecolor='c')
                if N_comb_sel == [0, 1, 2] and n < len(N_comb_sel) - 1:
                    ax1.axhline((tr + 2) * (n + 1), lw=0.5, color='k')
        for ev in events[data_cnt].keys():
            ax1.axvline(events[data_cnt][ev], color='k')

        ax2 = plt.subplot2grid((num_row, num_col), (2, data_cnt * 3),
                               rowspan=1,
                               colspan=3 - data_cnt)
        ax2.set_xticks([])
        if data_cnt == 0:
            ax2.set_title('Unitary Events', fontsize=fsize)
            ax2.text(-110, 190, 'Neuron #', rotation=90, fontsize=fsize)
            ax2.set_yticks([num_tr / 2 + 1, num_tr * 3 / 2., num_tr * 5 / 2])
            ax2.set_yticklabels(['5', '4', '3'], fontsize=fsize)
            ax2.spines['right'].set_visible(False)
            ax2.yaxis.set_ticks_position('left')
        else:
            ax2.spines['left'].set_visible(False)
            ax2.yaxis.set_ticks_position('right')
            ax2.set_yticks([])
        for N_cnt, N_comb_sel in enumerate(N_comb):
            patt_tmp = ue.inverse_hash_from_pattern([patterns[N_cnt]],
                                                    len(N_comb_sel))
            for n_cnt, n in enumerate(N_comb_sel):
                for tr, data_tr in enumerate(data):
                    if N_cnt == 0:
                        ax2.plot(data_tr[n].rescale('ms').magnitude,
                                 numpy.ones_like(data_tr[n].magnitude) * tr +
                                 n * (num_tr + 1) + 1,
                                 '.',
                                 markersize=ms1,
                                 color='k')
                    if patt_tmp[n_cnt][0] == 1:
                        sig_idx_win = numpy.where(
                            Js_dict_lst[N_cnt]['Js'] >= Js_sig)[0]
                        if len(sig_idx_win) > 0:
                            x = numpy.unique(
                                Js_dict_lst[N_cnt]['indices']['trial' +
                                                              str(tr)])
                            if len(x) > 0:
                                xx = []
                                for j in sig_idx_win:
                                    xx = numpy.append(
                                        xx, x[numpy.where(
                                            (x * binsize >= t_winpos[j])
                                            & (x * binsize < t_winpos[j] +
                                               winsize))])
                                ax2.plot(
                                    numpy.unique(xx) * binsize,
                                    numpy.ones_like(numpy.unique(xx)) * tr +
                                    n * (num_tr + 1) + 1,
                                    ms=ms,
                                    marker='s',
                                    ls='',
                                    mfc='none',
                                    mec='r')
                if N_comb_sel == [0, 1, 2] and n < len(N_comb_sel) - 1:
                    plt.axhline((tr + 2) * (n + 1), lw=0.5, color='k')
        for ev in events[data_cnt].keys():
            ax2.axvline(events[data_cnt][ev], color='k')
        for key in events[data_cnt].keys():
            for e_val in events[data_cnt][key]:
                ax2.text(e_val - 10 * pq.ms,
                         plt.gca().get_ylim()[0] - 80,
                         key,
                         fontsize=fsize,
                         color='k')

    if save_fig:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()
Пример #7
0
def plot_figure1_2(data, Js_dict, sig_level, binsize, winsize, winstep,
                   pattern_hash, N, plot_params_user):
    """plots Figure 1 and Figure 2 of the manuscript"""

    import matplotlib.pylab as plt
    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop

    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    Js_sig = ue.jointJ(sig_level)
    num_tr = len(data)
    pat = ue.inverse_hash_from_pattern(pattern_hash, N)

    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    if len(unit_real_ids) != N:
        raise ValueError('length of unit_ids should be' +
                         'equal to number of neurons!')
    plt.rcParams.update({'font.size': fsize})
    plt.rc('legend', fontsize=fsize)

    num_row, num_col = 6, 1
    ls = '-'
    alpha = 0.5
    plt.figure(1, figsize=figsize)
    if 'suptitle' in plot_params.keys():
        plt.suptitle("Trial aligned on " + plot_params['suptitle'],
                     fontsize=20)
    plt.subplots_adjust(top=top,
                        right=right,
                        left=left,
                        bottom=bottom,
                        hspace=hspace,
                        wspace=wspace)

    print('plotting raster plot ...')
    ax0 = plt.subplot(num_row, 1, 1)
    ax0.set_title('Spike Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax0.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
        if n < N - 1:
            ax0.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax0.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax0.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax0.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax0.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax0.set_xticks([])
    ax0.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax0.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    Xlim = ax0.get_xlim()
    ax0.text(Xlim[1] - 200, num_tr * 2 + 7, 'Neuron 2')
    ax0.text(Xlim[1] - 200, -12, 'Neuron 3')

    print('plotting Spike Rates ...')
    ax1 = plt.subplot(num_row, 1, 2, sharex=ax0)
    ax1.set_title('Spike Rates')
    for n in range(N):
        ax1.plot(t_winpos + winsize / 2.,
                 Js_dict['rate_avg'][:, n].rescale('Hz'),
                 label='Neuron ' + str(unit_real_ids[n]),
                 lw=lw)
    ax1.set_ylabel('(1/s)', fontsize=fsize)
    ax1.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    max_val_psth = 40
    ax1.set_ylim(0, max_val_psth)
    ax1.set_yticks([0, int(max_val_psth / 2), int(max_val_psth)])
    ax1.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    for key in events.keys():
        for e_val in events[key]:
            ax1.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax1.set_xticks([])

    print('plotting Raw Coincidences ...')
    ax2 = plt.subplot(num_row, 1, 3, sharex=ax0)
    ax2.set_title('Coincident Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax2.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            ax2.plot(
                numpy.unique(Js_dict['indices']['trial' + str(tr)]) * binsize,
                numpy.ones_like(
                    numpy.unique(Js_dict['indices']['trial' + str(tr)])) * tr +
                n * (num_tr + 1) + 1,
                ls='',
                ms=ms,
                marker='s',
                markerfacecolor='none',
                markeredgecolor='c')
        if n < N - 1:
            ax2.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax2.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax2.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax2.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax2.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax2.set_xticks([])
    ax2.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax2.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)

    print('plotting emp. and exp. coincidences rate ...')
    ax3 = plt.subplot(num_row, 1, 4, sharex=ax0)
    ax3.set_title('Coincidence Rates')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_emp'] / (winsize.rescale('s').magnitude * num_tr),
             label='empirical',
             lw=lw,
             color='c')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_exp'] / (winsize.rescale('s').magnitude * num_tr),
             label='expected',
             lw=lw,
             color='m')
    ax3.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax3.set_ylabel('(1/s)', fontsize=fsize)
    ax3.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    YTicks = ax3.get_ylim()
    ax3.set_yticks([0, YTicks[1] / 2, YTicks[1]])
    for key in events.keys():
        for e_val in events[key]:
            ax3.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    ax3.set_xticks([])

    print('plotting Surprise ...')
    ax4 = plt.subplot(num_row, 1, 5, sharex=ax0)
    ax4.set_title('Statistical Significance')
    ax4.plot(t_winpos + winsize / 2., Js_dict['Js'], lw=lw, color='k')
    ax4.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax4.axhline(Js_sig, ls='-', color='r')
    ax4.axhline(-Js_sig, ls='-', color='g')
    ax4.text(t_winpos[30], Js_sig + 0.3, '$\\alpha +$', color='r')
    ax4.text(t_winpos[30], -Js_sig - 0.5, '$\\alpha -$', color='g')
    ax4.set_xticks(t_winpos.magnitude[::int(len(t_winpos) / 10)])
    ax4.set_yticks([ue.jointJ(0.99), ue.jointJ(0.5), ue.jointJ(0.01)])
    ax4.set_yticklabels([0.99, 0.5, 0.01])

    ax4.set_ylim(S_ylim)
    for key in events.keys():
        for e_val in events[key]:
            ax4.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax4.set_xticks([])

    print('plotting UEs ...')
    ax5 = plt.subplot(num_row, 1, 6, sharex=ax0)
    ax5.set_title('Unitary Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax5.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            sig_idx_win = numpy.where(Js_dict['Js'] >= Js_sig)[0]
            if len(sig_idx_win) > 0:
                x = numpy.unique(Js_dict['indices']['trial' + str(tr)])
                if len(x) > 0:
                    xx = []
                    for j in sig_idx_win:
                        xx = numpy.append(
                            xx, x[numpy.where((x * binsize >= t_winpos[j]) & (
                                x * binsize < t_winpos[j] + winsize))])
                    ax5.plot(numpy.unique(xx) * binsize,
                             numpy.ones_like(numpy.unique(xx)) * tr + n *
                             (num_tr + 1) + 1,
                             ms=ms,
                             marker='s',
                             ls='',
                             mfc='none',
                             mec='r')
        if n < N - 1:
            ax5.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax5.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax5.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax5.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax5.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax5.set_xticks([])
    ax5.set_ylabel('Trial', fontsize=fsize)
    ax5.set_xlabel('Time [ms]', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax5.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
            ax5.text(e_val - 10 * pq.ms,
                     S_ylim[0] - 35,
                     key,
                     fontsize=fsize,
                     color='r')
    ax5.set_xticks([])

    for i in range(num_row):
        ax = locals()['ax' + str(i)]
        ax.text(-0.05,
                1.1,
                string.ascii_uppercase[i],
                transform=ax.transAxes,
                size=fsize + 5,
                weight='bold')
    if plot_params['save_fig']:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()
Пример #8
0
def plot_figure3(sts_lst, UE, extracted_data, significance_level,
                 binsize, winsize, winstep, plot_params_user):
    """plots Figure 3 of the manuscript"""

    import matplotlib.pylab as plt
    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    for key, val in plot_params.items():
        exec(key + '=val')
    t_start = sts_lst[0][0][0].t_start
    t_stop = sts_lst[0][0][0].t_stop
    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
    f.set_size_inches(figsize)
    plt.subplots_adjust(top=.9, right=.92, left=.12,
                        bottom=.1, hspace=.4, wspace=.4)
    fsize = 12
    # scatter plot of panel A of figure 2 of the original publication
    ax = ax1
    ax.locator_params(nbins=4, axis='x')
    ax.locator_params(nbins=4, axis='y')
    ys_cnt = 0
    diff_spikes_rep = []
    spiketrain = sts_lst[0]
    for cnt_i, i in enumerate(spiketrain.T):
        for cnt_j, j in enumerate(i):
            sp_trial = extracted_data['spikes'][ys_cnt]
            if len(j) >= len(sp_trial):
                for cnt, sp_t in enumerate(sp_trial):
                    min_positive = numpy.argmin(numpy.abs(j.magnitude - sp_t))
                    ax.plot(j.magnitude[min_positive], sp_t, ',', color='k')
                    diff_spikes_rep.append(j.magnitude[min_positive] - sp_t)
            else:
                for cnt, sp_t in enumerate(j.magnitude):
                    min_positive = numpy.argmin(numpy.abs(sp_trial - sp_t))
                    ax.plot(sp_t, sp_trial[min_positive], ',', color='k')
                    diff_spikes_rep.append(sp_t - sp_trial[min_positive])
            ys_cnt += 1

    ax.set_ylabel('spike time [ms] (reproduced)')
    ax.set_xlabel('spike time [ms] (original article)')
    ax.text(-0.1, 1.07, 'A', transform=ax.transAxes, size=fsize + 5,
            weight='bold')

    # scatter plot of UEs of panel E of figure 2
    ax = ax2
    ax.locator_params(nbins=5, axis='x')
    ax.locator_params(nbins=5, axis='y')
    ys_cnt = 0
    Js_sig = ue.jointJ(significance_level)
    sig_idx_win = numpy.where(UE['Js'] >= Js_sig)[0]
    diff_UE_rep = []
    y_cnt = 0
    for tr in range(len(spiketrain)):
        x_idx = numpy.sort(
            numpy.unique(UE['indices']['trial' + str(tr)],
                      return_index=True)[1])
        x = UE['indices']['trial' + str(tr)][x_idx]
        if len(x) > 0:
            # choose only the significance coincidences
            xx = []
            for j in sig_idx_win:
                xx = numpy.append(xx, x[numpy.where(
                    (x * binsize >= t_winpos[j]) &
                    (x * binsize < t_winpos[j] + winsize))])
            x_tmp = numpy.unique(xx) * binsize.magnitude
            if len(x_tmp) > 0:
                ue_trial = numpy.sort(extracted_data['ue'][y_cnt])
                diff_UE_rep = numpy.append(diff_UE_rep, x_tmp - ue_trial)
                y_cnt += +1
                ax.plot(ue_trial, x_tmp, 'o', color='r')
    ax.text(-0.1, 1.07, 'B', transform=ax.transAxes, size=fsize + 5,
            weight='bold')
    ax.set_ylabel('time of UEs [ms] (reporduced)')
    ax.set_xlabel('time of UEs [ms] (original article)')

    # histogram of spike times differences between
    # extracted and available data
    ax = ax3
    ax.locator_params(nbins=6, axis='x')
    ax.locator_params(nbins=6, axis='y')
    binwidth = 0.05
    ax.set_adjustable('box')
    ax.hist(diff_spikes_rep, numpy.arange(-0.5, 0.5, binwidth),
            histtype='step', color='k', label="spikes [ms]")
    ax.set_ylabel('count spikes')
    ax.set_xlim(-1.5, 1.5)
    ax.set_xlabel('difference [ms] (original - reproduced)')
    ax.text(-0.1, 1.07, 'C', transform=ax.transAxes, size=fsize + 5,
            weight='bold')

    # PLOT UNreproduced distribution
    ys_cnt = 0
    diff_spikes = []
    spiketrain = sts_lst[1]
    for cnt_i, i in enumerate(spiketrain.T):
        for cnt_j, j in enumerate(i):
            # extracted data differs by 6 ms when aligned on PS_4 in
            # comparison to aligned on RS_4. The reason is that the lenght
            # of trial between PS and RS is 1505ms not 1500 as stated
            # in the original papero. Therefore we add 6 ms to the
            # extracted values before comparing the spike times
            sp_trial = extracted_data['spikes'][ys_cnt] + 6
            if len(j) >= len(sp_trial):
                for cnt, sp_t in enumerate(sp_trial):
                    min_positive = numpy.argmin(numpy.abs(j.magnitude - sp_t))
                    diff_spikes = numpy.append(
                        diff_spikes, (j.magnitude - sp_t)[min_positive])
            else:
                for cnt, sp_t in enumerate(j.magnitude):
                    min_positive = numpy.argmin(numpy.abs(sp_trial - sp_t))
                    diff_spikes = numpy.append(
                        diff_spikes, (sp_trial - sp_t)[min_positive])
            ys_cnt += 1
    ax.hist(diff_spikes, numpy.arange(-1.5, 1.5, binwidth),
            histtype='step', color='gray', label="spikes [ms]")

    ax = ax4
    ax.locator_params(nbins=6, axis='x')
    ax.locator_params(nbins=6, axis='y')
    ax.set_adjustable('box')
    ax.hist(diff_UE_rep, numpy.arange(-1.5, 1.5, 2 * binwidth),
            histtype='step', normed=1, color='r', label="UEs [ms]")
    ax.set_ylim(0, 5)
    ax.set_ylabel('count UEs')
    ax.text(-0.1, 1.07, 'D', transform=ax.transAxes, size=fsize + 5,
            weight='bold')
    ax.set_xlabel('difference [ms] (original - reproduced)')
    ax.set_xlim(-1.5, 1.5)

    if save_fig:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()
Пример #9
0
def plot_UE(data, Js_dict, sig_level, binsize, winsize, winstep,
            pattern_hash, N, plot_params_user):
    """plots Figure 1 and Figure 2 of the manuscript"""

    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop

    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    Js_sig = ue.jointJ(sig_level)
    num_tr = len(data)
    pat = ue.inverse_hash_from_pattern(pattern_hash, N)

    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    if len(unit_real_ids) != N:
        raise ValueError('length of unit_ids should be' +
                         'equal to number of neurons!')
    plt.rcParams.update({'font.size': fsize})
    plt.rc('legend', fontsize=fsize)

    num_row, num_col = 6, 1
    ls = '-'
    alpha = 0.5
    plt.figure(1, figsize=figsize)
    if 'suptitle' in plot_params.keys():
        plt.suptitle("Trial aligned on " +
                     plot_params['suptitle'], fontsize=20)
    plt.subplots_adjust(top=top, right=right, left=left,
                        bottom=bottom, hspace=hspace, wspace=wspace)

    print('plotting raster plot ...')
    ax0 = plt.subplot(num_row, 1, 1)
    ax0.set_title('Spike Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax0.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) *
                     tr + n * (num_tr + 1) + 1,
                     '.', markersize=0.5, color='k')
        if n < N - 1:
            ax0.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax0.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax0.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax0.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax0.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax0.set_xticks([])
    ax0.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax0.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    Xlim = ax0.get_xlim()
    ax0.text(Xlim[1] - 200, num_tr * 2 + 7, 'Neuron 2')
    ax0.text(Xlim[1] - 200, -12, 'Neuron 3')

    print('plotting Spike Rates ...')
    ax1 = plt.subplot(num_row, 1, 2, sharex=ax0)
    ax1.set_title('Spike Rates')
    for n in range(N):
        ax1.plot(t_winpos + winsize / 2.,
                 Js_dict['rate_avg'][:, n].rescale('Hz'),
                 label='Neuron ' + str(unit_real_ids[n]), lw=lw)
    ax1.set_ylabel('(1/s)', fontsize=fsize)
    ax1.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    max_val_psth = 40
    ax1.set_ylim(0, max_val_psth)
    ax1.set_yticks([0, int(max_val_psth / 2), int(max_val_psth)])
    ax1.legend(
        bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    for key in events.keys():
        for e_val in events[key]:
            ax1.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax1.set_xticks([])

    print('plotting Raw Coincidences ...')
    ax2 = plt.subplot(num_row, 1, 3, sharex=ax0)
    ax2.set_title('Coincident Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax2.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) *
                     tr + n * (num_tr + 1) + 1,
                     '.', markersize=0.5, color='k')
            ax2.plot(
                numpy.unique(Js_dict['indices']['trial' + str(tr)]) *
                binsize,
                numpy.ones_like(numpy.unique(Js_dict['indices'][
                    'trial' + str(tr)])) * tr + n * (num_tr + 1) + 1,
                ls='', ms=ms, marker='s', markerfacecolor='none',
                markeredgecolor='c')
        if n < N - 1:
            ax2.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax2.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax2.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax2.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax2.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax2.set_xticks([])
    ax2.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax2.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)

    print('plotting emp. and exp. coincidences rate ...')
    ax3 = plt.subplot(num_row, 1, 4, sharex=ax0)
    ax3.set_title('Coincidence Rates')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_emp'] / (winsize.rescale('s').magnitude * num_tr),
             label='empirical', lw=lw, color='c')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_exp'] / (winsize.rescale('s').magnitude * num_tr),
             label='expected', lw=lw, color='m')
    ax3.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax3.set_ylabel('(1/s)', fontsize=fsize)
    ax3.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    YTicks = ax3.get_ylim()
    ax3.set_yticks([0, YTicks[1] / 2, YTicks[1]])
    for key in events.keys():
        for e_val in events[key]:
            ax3.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    ax3.set_xticks([])

    print('plotting Surprise ...')
    ax4 = plt.subplot(num_row, 1, 5, sharex=ax0)
    ax4.set_title('Statistical Significance')
    ax4.plot(t_winpos + winsize / 2., Js_dict['Js'], lw=lw, color='k')
    ax4.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax4.axhline(Js_sig, ls='-', color='r')
    ax4.axhline(-Js_sig, ls='-', color='g')
    ax4.text(t_winpos[30], Js_sig + 0.3, '$\\alpha +$', color='r')
    ax4.text(t_winpos[30], -Js_sig - 0.5, '$\\alpha -$', color='g')
    ax4.set_xticks(t_winpos.magnitude[::int(len(t_winpos) / 10)])
    ax4.set_yticks([ue.jointJ(0.99), ue.jointJ(0.5), ue.jointJ(0.01)])
    ax4.set_yticklabels([0.99, 0.5, 0.01])

    ax4.set_ylim(S_ylim)
    for key in events.keys():
        for e_val in events[key]:
            ax4.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax4.set_xticks([])

    print('plotting UEs ...')
    ax5 = plt.subplot(num_row, 1, 6, sharex=ax0)
    ax5.set_title('Unitary Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax5.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) *
                     tr + n * (num_tr + 1) + 1, '.',
                     markersize=0.5, color='k')
            sig_idx_win = numpy.where(Js_dict['Js'] >= Js_sig)[0]
            if len(sig_idx_win) > 0:
                x = numpy.unique(Js_dict['indices']['trial' + str(tr)])
                if len(x) > 0:
                    xx = []
                    for j in sig_idx_win:
                        xx = numpy.append(xx, x[numpy.where(
                            (x * binsize >= t_winpos[j]) &
                            (x * binsize < t_winpos[j] + winsize))])
                    ax5.plot(
                        numpy.unique(
                            xx) * binsize,
                        numpy.ones_like(numpy.unique(xx)) *
                        tr + n * (num_tr + 1) + 1,
                        ms=ms, marker='s', ls='', mfc='none', mec='r')
        if n < N - 1:
            ax5.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax5.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax5.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax5.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax5.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax5.set_xticks([])
    ax5.set_ylabel('Trial', fontsize=fsize)
    ax5.set_xlabel('Time [ms]', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax5.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
            ax5.text(e_val - 10 * pq.ms,
                     S_ylim[0] - 35, key, fontsize=fsize, color='r')
    ax5.set_xticks([])

    for i in range(num_row):
        ax = locals()['ax' + str(i)]
        ax.text(-0.05, 1.1, string.ascii_uppercase[i],
                transform=ax.transAxes, size=fsize + 5,
                weight='bold')
    if plot_params['save_fig']:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()

    return None
Пример #10
0
def plot_unitary_events(data, joint_surprise_dict, significance_level, binsize,
                        window_size, window_step, **plot_params_user):
    """
    Plots the results of unitary event analysis as a column of six subplots,
    comprised of raster plot, peri-stimulus time histogram, coincident event
    plot, coincidence rate plot, significance plot and unitary event plot,
    respectively.

    Parameters
    ----------
    data : list of list of neo.SpikeTrain
        A nested list of trails, neurons and there neo.SpikeTrain objects,
        respectively. This should be identical to the one used to generate
        joint_surprise_dict
    joint_surprise_dict : dict
        The output of elephant.unitary_event_analysis.jointJ_window_analysis
        function. The values of each key has the shape of
            different pattern hash --> 0-axis
            different window --> 1-axis
        Keys:
        -----
        Js : list of float
            JointSurprise of different given pattern within each window.
        indices : list of list of int
            A list of indices of pattern within each window.
        n_emp : list of int
        The empirical number of each observed pattern.
        n_exp : list of float
            The expected number of each pattern.
        rate_avg : list of float
            The average firing rate of each neuron.
    significance_level : float
        The significance threshold used to determine which coincident events
        are classified as unitary events within a window.
    binsize : quantities.Quantity
        The size of bins for discretizing spike trains. This value should be
        identical to the one used to generate joint_surprise_dict.
    window_size : quantities.Quantity
        The size of the analysis-window. This value should be identical to the
        one used to generate joint_surprise_dict.
    window_step : quantities.Quantity
        The size of the window step. This value should be identical to th one
        used to generate joint_surprise_dict.
    plot_params_user : dict
        A dictionary of plotting parameters used to update the default plotting
        parameter values.
        Keys:
        -----
        events : dictionary (default: {})
            Epochs to be marked on the time axis
            key: epochs name as string
            value: list of quantities.Quantity
        figsize : tuple of int (default: (10, 12))
            The dimensions for the figure size.
        hspace : float (default: 1)
            The amount of height reserved for white space between subplots.
        wspace : float (default: 0.5)
            The amount of width reserved for white space between subplots.
        top : float (default: 0.9)
        bottom : float (default: 0.1)
        right : float (default: 0.9)
        left : float (default: 0.1)
            The sizes of the respective margin of the subplot in the figure.
        fsize : integer (default: 12)
            The size of the font
        unit_real_ids : list of integers (default: [1, 2])
            The unit ids form the experimental recording.
        lw: float (default: 2)
            The default line width.
        S_ylim : tuple of ints or floats (default: (-3, 3))
            The y-axis limits for the joint surprise plot.
        marker_size : integers (default: 5)
            The marker size for the coincidence and unitary events.
        time_unit : string (default: 'ms')
            The time unit used to rescale the spiketrains.
        frequency_unit : string (default: 'Hz')
            The frequency unit used to rescale the spikerates.
    Returns
    -------
    result : instance of namedtuple()
        The container for Axis objects generated by this function. Individual
        axes can be accessed using the respective identifiers:
        result.identifier
        Identifiers: spike_events_axes, spike_rates_axes,
                     coincidence_events_axes, coincidence_rates_axes,
                     statistical_significance_axes, unitary_events_axes
    """
    # update params_dict_default with user input
    params_dict = params_dict_default.copy()
    params_dict.update(plot_params_user)

    # rescale all spiketrains to the uniform time unit from params_dict
    for m in range(len(data)):
        for n in range(len(data[0])):
            data[m][n] = data[m][n].rescale(params_dict['time_unit'])

    # set common variables
    n_neurons = len(data[0])
    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop
    t_winpos = ue._winpos(t_start, t_stop, window_size, window_step)
    center_of_analysis_window = t_winpos + window_size / 2.
    n_trials = len(data)
    joint_surprise_significance = ue.jointJ(significance_level)
    xlim_left = (min(t_winpos)).magnitude
    xlim_right = (max(t_winpos) + window_size).magnitude

    if len(params_dict['unit_real_ids']) != n_neurons:
        raise ValueError(
            'length of unit_ids should be equal to number of neurons! \n'
            'Unit_Ids: ' + params_dict['unit_real_ids']
            + 'not equal number of neurons: ' + n_neurons)

    plt.figure(num=1, figsize=params_dict['figsize'])
    plt.subplots_adjust(hspace=params_dict['hspace'],
                        wspace=params_dict['wspace'],
                        top=params_dict['top'],
                        bottom=params_dict['bottom'],
                        left=params_dict['left'],
                        right=params_dict['right'])

    # set y-axis for raster plots with ticks and labels
    y_ticks_list = [n_trials, n_neurons * n_trials + 1]
    y_ticks_labels_list = [n_trials, n_trials]

    def mark_epochs(axes_name):
        """
        Marks epochs on the respective axis by creating a vertical line and
        shows the epoch's name under the last subplot. Epochs need to be
        defined in the plot_params_user dictionary.
        Parameters
        ----------
        axes_name : matplotlib.axes._subplots.AxesSubplot
            The axes in which the epochs will be marked.
        """
        for key in params_dict['events'].keys():
            for event_timepoint in params_dict['events'][key]:
                # check if epochs are between time-axis limits
                if ((xlim_left <= event_timepoint) and
                        (event_timepoint <= xlim_right)):
                    axes_name.axvline(event_timepoint, ls='-',
                                      lw=params_dict['lw'], color='r')
                    if axes_name.get_geometry()[2] == 6:
                        axes_name.text(x=event_timepoint, y=-54, s=key,
                                       fontsize=12, color='r',
                                       horizontalalignment='center')

    print('plotting Unitary Event Analysis ...')

    print('plotting Spike Events ...')
    axes1 = plt.subplot(6, 1, 1)
    axes1.set_title('Spike Events')
    for n in range(n_neurons):
        for trial, data_trial in enumerate(data):
            spike_events_on_timescale = data_trial[n].magnitude
            spike_events_on_trialscale = \
                np.full_like(data_trial[n].magnitude, trial) + \
                n * (n_trials + 1) + 1
            axes1.plot(spike_events_on_timescale, spike_events_on_trialscale,
                       ls='none', marker='.', color='k', markersize=0.5)
    axes1.axhline(n_trials + 1, lw=params_dict['lw'], color='k')
    axes1.set_xlim(xlim_left, xlim_right)
    axes1.set_ylim(0, (n_trials + 1) * n_neurons + 1)
    axes1.xaxis.set_major_locator(MaxNLocator(integer=True))
    axes1.set_yticks(y_ticks_list)
    axes1.set_yticklabels(y_ticks_labels_list)
    axes1.text(1.0, 1.0, f"Unit {params_dict['unit_real_ids'][1]}",
               fontsize=params_dict['fsize']//2,
               horizontalalignment='right',
               verticalalignment='bottom',
               transform=axes1.transAxes)
    axes1.text(1.0, 0, f"Unit {params_dict['unit_real_ids'][0]}",
               fontsize=params_dict['fsize']//2,
               horizontalalignment='right',
               verticalalignment='top',
               transform=axes1.transAxes)
    axes1.set_ylabel('Trial', fontsize=params_dict['fsize'])

    print('plotting Spike Rates ...')
    axes2 = plt.subplot(6, 1, 2, sharex=axes1)
    axes2.set_title('Spike Rates')
    # psth = peristimulus time histogram
    max_val_psth = 0
    for n in range(n_neurons):
        respective_rate_average = joint_surprise_dict['rate_avg'][:, n].\
            rescale(params_dict['frequency_unit'])
        axes2.plot(center_of_analysis_window, respective_rate_average,
                   label=f"Unit {params_dict['unit_real_ids'][n]}",
                   lw=params_dict['lw'])
        if max(joint_surprise_dict['rate_avg'][:, n]) > max_val_psth:
            max_val_psth = max(joint_surprise_dict['rate_avg'][:, n])
    axes2.set_xlim(xlim_left, xlim_right)
    max_val_psth = max_val_psth.rescale(
        params_dict['frequency_unit']).magnitude
    axes2.set_ylim(0, max_val_psth + max_val_psth/10)
    axes2.xaxis.set_major_locator(MaxNLocator(integer=True))
    axes2.set_yticks([0, int(max_val_psth / 2), int(max_val_psth)])
    axes2.legend(fontsize=params_dict['fsize']//2)
    axes2.set_ylabel(f"({params_dict['frequency_unit']})",
                     fontsize=params_dict['fsize'])

    print('plotting Coincident Events ...')
    axes3 = plt.subplot(6, 1, 3, sharex=axes1)
    axes3.set_title('Coincident Events')
    for n in range(n_neurons):
        for trial, data_trial in enumerate(data):
            spike_events_on_timescale = data_trial[n].magnitude
            spike_events_on_trialscale = \
                np.full_like(data_trial[n].magnitude, trial) + \
                n * (n_trials + 1) + 1
            axes3.plot(spike_events_on_timescale, spike_events_on_trialscale,
                       ls='none', marker='.', color='k', markersize=0.5)
            indices_of_coincidence_events = \
                np.unique(joint_surprise_dict['indices']['trial' + str(trial)])
            coincidence_events_on_timescale = \
                indices_of_coincidence_events * binsize
            coincidence_events_on_trialscale = np.full_like(
                indices_of_coincidence_events, trial) + n * (n_trials + 1) + 1
            axes3.plot(coincidence_events_on_timescale,
                       coincidence_events_on_trialscale, ls='',
                       markersize=params_dict['marker_size'], marker='s',
                       markerfacecolor='none', markeredgecolor='c')
    axes3.axhline(n_trials + 1, lw=params_dict['lw'], color='k')
    axes3.set_xlim(xlim_left, xlim_right)
    axes3.set_ylim(0, (n_trials + 1) * n_neurons + 1)
    axes3.xaxis.set_major_locator(MaxNLocator(integer=True))
    axes3.set_yticks(y_ticks_list)
    axes3.set_yticklabels(y_ticks_labels_list)
    axes3.set_ylabel('Trial', fontsize=params_dict['fsize'])

    print('plotting Coincidence Rates ..')
    axes4 = plt.subplot(6, 1, 4, sharex=axes1)
    axes4.set_title('Coincidence Rates')
    empirical_coincidence_rate = joint_surprise_dict['n_emp'] / \
        (window_size.rescale('s').magnitude * n_trials)
    expected_coincidence_rate = joint_surprise_dict['n_exp'] / \
        (window_size.rescale('s').magnitude * n_trials)
    axes4.plot(center_of_analysis_window, empirical_coincidence_rate,
               label='Empirical', lw=params_dict['lw'], color='c')
    axes4.plot(center_of_analysis_window, expected_coincidence_rate,
               label='Expected', lw=params_dict['lw'], color='m')
    axes4.set_xlim(xlim_left, xlim_right)
    axes4.xaxis.set_major_locator(MaxNLocator(integer=True))
    y_ticks = axes4.get_ylim()
    axes4.set_yticks([0, y_ticks[1] / 2, y_ticks[1]])
    axes4.legend(fontsize=params_dict['fsize']//2)
    axes4.set_ylabel(f"({params_dict['frequency_unit']})",
                     fontsize=params_dict['fsize'])

    print('plotting Statistical Significance ...')
    axes5 = plt.subplot(6, 1, 5, sharex=axes1)
    axes5.set_title('Statistical Significance')
    joint_surprise_values = joint_surprise_dict['Js']
    axes5.plot(center_of_analysis_window, joint_surprise_values,
               lw=params_dict['lw'], color='k')
    axes5.set_xlim(xlim_left, xlim_right)
    axes5.set_ylim(params_dict['S_ylim'])
    axes5.axhline(joint_surprise_significance, ls='-', color='r')
    axes5.axhline(-joint_surprise_significance, ls='-', color='g')
    axes5.text(t_winpos[30], joint_surprise_significance + 0.3, '$\\alpha +$',
               color='r')
    axes5.text(t_winpos[30], -joint_surprise_significance - 0.9, '$\\alpha -$',
               color='g')
    axes5.xaxis.set_major_locator(MaxNLocator(integer=True))
    axes5.set_yticks([ue.jointJ(1-significance_level), ue.jointJ(0.5),
                      ue.jointJ(significance_level)])
    axes5.set_yticklabels([1-significance_level, 0.5, significance_level])

    print('plotting Unitary Events ...')
    axes6 = plt.subplot(6, 1, 6, sharex=axes1)
    axes6.set_title('Unitary Events')
    for n in range(n_neurons):
        for trial, data_trial in enumerate(data):
            spike_events_on_timescale = data_trial[n].magnitude
            spike_events_on_trialscale = \
                np.full_like(data_trial[n].magnitude, trial) + \
                n * (n_trials + 1) + 1
            axes6.plot(spike_events_on_timescale, spike_events_on_trialscale,
                       ls='None', marker='.', markersize=0.5, color='k')
            indices_of_significant_joint_surprises = np.where(
                joint_surprise_dict['Js'] >= joint_surprise_significance)[0]
            if len(indices_of_significant_joint_surprises) > 0:
                indices_of_coincidence_events = np.unique(
                    joint_surprise_dict['indices']['trial' + str(trial)])
                if len(indices_of_coincidence_events) > 0:
                    indices_of_unitary_events = []
                    for j in indices_of_significant_joint_surprises:
                        coincidence_indices_greater_left_window_margin = \
                            indices_of_coincidence_events * binsize >= \
                            t_winpos[j]
                        coincidence_indices_smaller_right_window_margin = \
                            indices_of_coincidence_events * binsize < \
                            t_winpos[j] + window_size
                        coincidence_indices_in_actual_analysis_window = \
                            coincidence_indices_greater_left_window_margin & \
                            coincidence_indices_smaller_right_window_margin
                        indices_of_unitary_events = \
                            np.append(
                                indices_of_unitary_events,
                                indices_of_coincidence_events
                                [coincidence_indices_in_actual_analysis_window]
                                )
                    unitary_events_on_timescale = \
                        np.unique(indices_of_unitary_events) * binsize
                    unitary_events_on_trialscale = \
                        np.ones_like(np.unique(indices_of_unitary_events)) * \
                        trial + n * (n_trials + 1) + 1
                    axes6.plot(unitary_events_on_timescale,
                               unitary_events_on_trialscale,
                               markersize=params_dict['marker_size'],
                               marker='s', ls='', markerfacecolor='none',
                               markeredgecolor='r')
    axes6.axhline(n_trials + 1, lw=params_dict['lw'], color='k')
    axes6.set_xlim(xlim_left, xlim_right)
    axes6.set_ylim(0, (n_trials + 1) * n_neurons + 1)
    axes6.xaxis.set_major_locator(MaxNLocator(integer=True))
    axes6.set_yticks(y_ticks_list)
    axes6.set_yticklabels(y_ticks_labels_list)
    axes6.set_xlabel(f'Time ({params_dict["time_unit"]})',
                     fontsize=params_dict['fsize'])
    axes6.set_ylabel('Trial', fontsize=params_dict['fsize'])

    # mark all epochs on all subplots and annotate all axes-subplots
    for n in range(6):
        axes_list = [axes1, axes2, axes3, axes4, axes5, axes6]
        letter_list = ['A', 'B', 'C', 'D', 'E', 'F']
        mark_epochs(eval(f"axes{n+1}"))
        axes = axes_list[n]
        letter = letter_list[n]
        axes.text(-0.05, 1.1, letter, transform=axes.transAxes,
                  size=params_dict['fsize'] + 5, weight='bold')

    result = FigureUE(axes1, axes2, axes3, axes4, axes5, axes6)
    return result
    def test_Riehle_et_al_97_UE(self):      
        from neo.rawio.tests.tools import (download_test_file,
                                           create_local_temp_dir,
                                           make_all_directories)
        from neo.test.iotest.tools import (cleanup_test_file)
        url = [
            "https://raw.githubusercontent.com/ReScience-Archives/" +
            "Rostami-Ito-Denker-Gruen-2017/master/data",
            "https://raw.githubusercontent.com/ReScience-Archives/" +
            "Rostami-Ito-Denker-Gruen-2017/master/data"]
        shortname = "unitary_event_analysis_test_data"
        local_test_dir = create_local_temp_dir(
            shortname, os.environ.get("ELEPHANT_TEST_FILE_DIR"))
        files_to_download = ["extracted_data.npy", "winny131_23.gdf"]
        make_all_directories(files_to_download,
                             local_test_dir)
        for f_cnt, f in enumerate(files_to_download):
            download_test_file(f, local_test_dir, url[f_cnt])

        # load spike data of figure 2 of Riehle et al 1997
        sys.path.append(local_test_dir)
        file_name = '/winny131_23.gdf'
        trigger = 'RS_4'
        t_pre = 1799 * pq.ms
        t_post = 300 * pq.ms
        spiketrain = self.load_gdf2Neo(local_test_dir + file_name,
                                       trigger, t_pre, t_post)

        # calculating UE ...
        winsize = 100 * pq.ms
        binsize = 5 * pq.ms
        winstep = 5 * pq.ms
        pattern_hash = [3]
        method = 'analytic_TrialAverage'
        t_start = spiketrain[0][0].t_start
        t_stop = spiketrain[0][0].t_stop
        t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
        significance_level = 0.05

        UE = ue.jointJ_window_analysis(
            spiketrain, binsize, winsize, winstep,
            pattern_hash, method=method)
        # load extracted data from figure 2 of Riehle et al 1997
        try:
            extracted_data = np.load(
                local_test_dir + '/extracted_data.npy').item()
        except UnicodeError:
            extracted_data = np.load(
                local_test_dir + '/extracted_data.npy', encoding='latin1').item()
        Js_sig = ue.jointJ(significance_level)
        sig_idx_win = np.where(UE['Js'] >= Js_sig)[0]
        diff_UE_rep = []
        y_cnt = 0
        for tr in range(len(spiketrain)):
            x_idx = np.sort(
                np.unique(UE['indices']['trial' + str(tr)],
                          return_index=True)[1])
            x = UE['indices']['trial' + str(tr)][x_idx]
            if len(x) > 0:
                # choose only the significant coincidences
                xx = []
                for j in sig_idx_win:
                    xx = np.append(xx, x[np.where(
                        (x * binsize >= t_winpos[j]) &
                        (x * binsize < t_winpos[j] + winsize))])
                x_tmp = np.unique(xx) * binsize.magnitude
                if len(x_tmp) > 0:
                    ue_trial = np.sort(extracted_data['ue'][y_cnt])
                    diff_UE_rep = np.append(
                        diff_UE_rep, x_tmp - ue_trial)
                    y_cnt += +1
        np.testing.assert_array_less(np.abs(diff_UE_rep), 0.3)
        cleanup_test_file('dir', local_test_dir)