Beispiel #1
0
 def test_invhash_base_not_two(self):
     N = 3
     h = np.array([1, 4, 13])
     base = 3
     expected = np.array([[0, 0, 1], [0, 1, 1], [1, 1, 1]])
     m = ue.inverse_hash_from_pattern(h, N, base)
     self.assertTrue(np.all(expected == m))
Beispiel #2
0
 def test_invhash_shape_mat(self):
     N = 8
     h = np.array([178, 212, 232])
     expected = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
                          [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]])
     m = ue.inverse_hash_from_pattern(h, N)
     self.assertTrue(np.shape(m)[0] == N)
 def test_invhash_base_not_two(self):
     N = 3
     h = np.array([1,4,13])
     base = 3
     expected = np.array([[0,0,1],[0,1,1],[1,1,1]])
     m = ue.inverse_hash_from_pattern(h, N, base)
     self.assertTrue(np.all(expected == m))
Beispiel #4
0
 def test_invhash_default_base(self):
     N = 3
     h = np.array([0, 4, 2, 1, 6, 5, 3, 7])
     expected = np.array([[0, 1, 0, 0, 1, 1, 0,
                           1], [0, 0, 1, 0, 1, 0, 1, 1],
                          [0, 0, 0, 1, 0, 1, 1, 1]])
     m = ue.inverse_hash_from_pattern(h, N)
     self.assertTrue(np.all(expected == m))
Beispiel #5
0
 def test_hash_inverse_longpattern(self):
     n_patterns = 100
     m = np.random.randint(low=0, high=2, size=(n_patterns, 2))
     h = ue.hash_from_pattern(m)
     m_inv = ue.inverse_hash_from_pattern(h, N=n_patterns)
     assert_array_equal(m, m_inv)
Beispiel #6
0
 def test_hash_invhash_consistency(self):
     m = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0],
                   [1, 0, 1], [0, 1, 1], [1, 1, 1]])
     inv_h = ue.hash_from_pattern(m)
     m1 = ue.inverse_hash_from_pattern(inv_h, N=8)
     self.assertTrue(np.all(m == m1))
Beispiel #7
0
def _plot_UE(data,
             Js_dict,
             sig_level,
             binsize,
             winsize,
             winstep,
             pattern_hash,
             N,
             args,
             add_epochs=[]):
    """
    Examples:
    ---------
    dict_args = {'events':{'SO':[100*pq.ms]},
     'save_fig': True,
     'path_filename_format':'UE1.pdf',
     'showfig':True,
     'suptitle':True,
     'figsize':(12,10),
    'unit_ids':[10, 19, 20],
    'ch_ids':[1,3,4],
    'fontsize':15,
    'linewidth':2,
    'set_xticks' :False'}
    'marker_size':8,
    """
    import matplotlib.pylab as plt
    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop

    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    Js_sig = ue.jointJ(sig_level)
    num_tr = len(data)
    pat = ue.inverse_hash_from_pattern(pattern_hash, N)
    events = args['events']

    # figure format
    figsize = args['figsize']
    if 'top' in args.keys():
        top = args['top']
    else:
        top = .90
    if 'bottom' in args.keys():
        bottom = args['bottom']
    else:
        bottom = .05
    if 'right' in args.keys():
        right = args['right']
    else:
        right = .95
    if 'left' in args.keys():
        left = args['left']
    else:
        left = .1

    if 'hspace' in args.keys():
        hspace = args['hspace']
    else:
        hspace = .5
    if 'wspace' in args.keys():
        wspace = args['wspace']
    else:
        wspace = .5

    if 'fontsize' in args.keys():
        fsize = args['fontsize']
    else:
        fsize = 20
    if 'unit_ids' in args.keys():
        unit_real_ids = args['unit_ids']
        if len(unit_real_ids) != N:
            raise ValueError(
                'length of unit_ids should be equal to number of neurons!')
    else:
        unit_real_ids = numpy.arange(1, N + 1, 1)
    if 'ch_ids' in args.keys():
        ch_real_ids = args['ch_ids']
        if len(ch_real_ids) != N:
            raise ValueError(
                'length of ch_ids should be equal to number of neurons!')
    else:
        ch_real_ids = []

    if 'showfig' in args.keys():
        showfig = args['showfig']
    else:
        showfig = False
    if 'linewidth' in args.keys():
        lw = args['linewidth']
    else:
        lw = 2

    if 'S_ylim' in args.keys():
        S_ylim = args['S_ylim']
    else:
        S_ylim = [-3, 3]

    if 'marker_size' in args.keys():
        ms = args['marker_size']
    else:
        ms = 8

    if add_epochs != []:
        coincrate = add_epochs['coincrate']
        backgroundrate = add_epochs['backgroundrate']
        num_row = 6
    else:
        num_row = 5
    num_col = 1
    ls = '-'
    alpha = 0.5
    plt.figure(1, figsize=figsize)
    if args['suptitle'] == True:
        plt.suptitle("Spike Pattern:" + str((pat.T)[0]), fontsize=20)
    print 'plotting UEs ...'
    plt.subplots_adjust(top=top,
                        right=right,
                        left=left,
                        bottom=bottom,
                        hspace=hspace,
                        wspace=wspace)
    ax = plt.subplot(num_row, 1, 1)
    ax.set_title('Unitary Events', fontsize=20, color='r')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            plt.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            sig_idx_win = numpy.where(Js_dict['Js'] >= Js_sig)[0]
            if len(sig_idx_win) > 0:
                x = numpy.unique(Js_dict['indices']['trial' + str(tr)])
                if len(x) > 0:
                    xx = []
                    for j in sig_idx_win:
                        xx = numpy.append(
                            xx, x[numpy.where((x * binsize >= t_winpos[j]) & (
                                x * binsize < t_winpos[j] + winsize))])
                    plt.plot(numpy.unique(xx) * binsize,
                             numpy.ones_like(numpy.unique(xx)) * tr + n *
                             (num_tr + 1) + 1,
                             ms=ms,
                             marker='s',
                             ls='',
                             mfc='none',
                             mec='r')
        plt.axhline((tr + 2) * (n + 1), lw=2, color='k')
    y_ticks_pos = numpy.arange(num_tr / 2 + 1, N * (num_tr + 1), num_tr + 1)
    plt.yticks(y_ticks_pos)
    plt.gca().set_yticklabels(unit_real_ids, fontsize=fsize)
    for ch_cnt, ch_id in enumerate(ch_real_ids):
        print ch_id
        plt.gca().text((max(t_winpos) + winsize).rescale('ms').magnitude,
                       y_ticks_pos[ch_cnt],
                       'CH-' + str(ch_id),
                       fontsize=fsize)

    plt.ylim(0, (tr + 2) * (n + 1) + 1)
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.xticks([])
    plt.ylabel('Unit ID', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])
    print 'plotting Raw Coincidences ...'
    ax1 = plt.subplot(num_row, 1, 2, sharex=ax)
    ax1.set_title('Raw Coincidences', fontsize=20, color='c')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            plt.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            plt.plot(
                numpy.unique(Js_dict['indices']['trial' + str(tr)]) * binsize,
                numpy.ones_like(
                    numpy.unique(Js_dict['indices']['trial' + str(tr)])) * tr +
                n * (num_tr + 1) + 1,
                ls='',
                ms=ms,
                marker='s',
                markerfacecolor='none',
                markeredgecolor='c')
        plt.axhline((tr + 2) * (n + 1), lw=2, color='k')
    plt.ylim(0, (tr + 2) * (n + 1) + 1)
    plt.yticks(numpy.arange(num_tr / 2 + 1, N * (num_tr + 1), num_tr + 1))
    plt.gca().set_yticklabels(unit_real_ids, fontsize=fsize)
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.xticks([])
    plt.ylabel('Unit ID', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)

    print 'plotting PSTH ...'
    plt.subplot(num_row, 1, 3, sharex=ax)
    #max_val_psth = 0.*pq.Hz
    for n in range(N):
        #data_psth = []
        #for tr,data_tr in enumerate(data):
        #    data_psth.append(data_tr[p])
        #psth = ss.peth(data_psth, w = psth_width)
        #plt.plot(psth.times,psth.base/float(num_tr)/psth_width.rescale('s'), label = 'unit '+str(unit_real_ids[p]))
        #max_val_psth = max(max_val_psth, max((psth.base/float(num_tr)/psth_width.rescale('s')).magnitude))
        plt.plot(t_winpos + winsize / 2.,
                 Js_dict['rate_avg'][:, n].rescale('Hz'),
                 label='unit ' + str(unit_real_ids[n]),
                 lw=lw)
        #max_val_psth = max(max_val_psth, max(Js_dict['rate_avg'][:,n].rescale('Hz')))
    plt.ylabel('Rate [Hz]', fontsize=fsize)
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    max_val_psth = plt.gca().get_ylim()[1]
    plt.ylim(0, max_val_psth)
    plt.yticks([0, int(max_val_psth / 2), int(max_val_psth)], fontsize=fsize)
    plt.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)

    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])
    print 'plotting emp. and exp. coincidences rate ...'
    plt.subplot(num_row, 1, 4, sharex=ax)
    plt.plot(t_winpos + winsize / 2.,
             Js_dict['n_emp'],
             label='empirical',
             lw=lw,
             color='c')
    plt.plot(t_winpos + winsize / 2.,
             Js_dict['n_exp'],
             label='expected',
             lw=lw,
             color='m')
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.ylabel('# Coinc.', fontsize=fsize)
    plt.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    YTicks = plt.ylim(0, int(max(max(Js_dict['n_emp']),
                                 max(Js_dict['n_exp']))))
    plt.yticks([0, YTicks[1]], fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])

    print 'plotting Surprise ...'
    plt.subplot(num_row, 1, 5, sharex=ax)
    plt.plot(t_winpos + winsize / 2., Js_dict['Js'], lw=lw, color='k')
    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    plt.axhline(Js_sig, ls='-', color='gray')
    plt.axhline(-Js_sig, ls='-', color='gray')
    plt.gca().text(10,
                   Js_sig + 0.2,
                   str(int(sig_level * 100)) + '%',
                   fontsize=fsize - 2,
                   color='gray')
    plt.xticks(t_winpos.magnitude[::len(t_winpos) / 10])
    plt.yticks([-2, 0, 2], fontsize=fsize)
    plt.ylabel('S', fontsize=fsize)
    plt.xlabel('Time [ms]', fontsize=fsize)
    plt.ylim(S_ylim)
    for key in events.keys():
        for e_val in events[key]:
            plt.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
            plt.gca().text(e_val - 10 * pq.ms,
                           2 * S_ylim[0],
                           key,
                           fontsize=fsize,
                           color='r')
    if 'set_xticks' in args.keys() and args['set_xticks'] == False:
        plt.xticks([])

    if add_epochs != []:
        plt.subplot(num_row, 1, 6, sharex=ax)
        plt.plot(coincrate, lw=lw, color='c')
        plt.plot(backgroundrate, lw=lw, color='m')
        plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
        plt.ylim(plt.gca().get_ylim()[0] - 2, plt.gca().get_ylim()[1] + 2)
    if args['save_fig'] == True:
        plt.savefig(args['path_filename_format'])
        if showfig == False:
            plt.cla()
            plt.close()
# plt.xticks(t_winpos.magnitude)

    if showfig == True:
        plt.show()
def plot_figure3(data_list, Js_dict_lst_lst, sig_level, binsize, winsize,
                 winstep, patterns, N, N_comb, plot_params_user):
    """plots Figure 3 of the manuscript"""

    import matplotlib.pylab as plt
    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    if len(unit_real_ids) != N:
        raise ValueError(
            'length of unit_ids should be equal to number of neurons!')
    plt.rcParams.update({'font.size': fsize})
    plt.rc('legend', fontsize=fsize)
    ms1 = 0.5
    plt.figure(1, figsize=figsize)
    plt.subplots_adjust(top=top,
                        right=right,
                        left=left,
                        bottom=bottom,
                        hspace=hspace,
                        wspace=wspace)
    events = [events1, events2]
    ls = '-'
    alpha = 1.
    num_row, num_col = 3, 5
    for data_cnt, data in enumerate(data_list):
        Js_dict_lst = Js_dict_lst_lst[data_cnt]
        t_start = data[0][0].t_start
        t_stop = data[0][0].t_stop
        t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
        Js_sig = ue.jointJ(sig_level)
        num_tr = len(data)
        ax0 = plt.subplot2grid((num_row, num_col), (0, data_cnt * 3),
                               rowspan=1,
                               colspan=3 - data_cnt)
        ax0.set_xticks([])
        if data_cnt == 0:
            ax0.set_title('Spike Events', fontsize=fsize)
            ax0.set_yticks([num_tr / 2 + 1, num_tr * 3 / 2., num_tr * 5 / 2])
            ax0.set_yticklabels(['5', '4', '3'], fontsize=fsize)
            ax0.spines['right'].set_visible(False)
            ax0.yaxis.set_ticks_position('left')
            # ax0.text(-110, 190, 'Neuron #', rotation=90, fontsize=fsize)
            ax0.set_ylabel('Neuron #', fontsize=fsize)
        else:
            ax0.spines['left'].set_visible(False)
            ax0.yaxis.set_ticks_position('right')
            ax0.set_yticks([])
        for n in range(N):
            for tr, data_tr in enumerate(data):
                ax0.plot(data_tr[n].rescale('ms').magnitude,
                         numpy.ones_like(data_tr[n].magnitude) * tr + n *
                         (num_tr + 1) + 1,
                         '.',
                         markersize=ms1,
                         color='k')
            if n < N - 1:
                ax0.axhline((tr + 2) * (n + 1), lw=0.5, color='k')
        ax0.set_ylim(0, (tr + 2) * (n + 1) + 1)
        for ev in events[data_cnt].keys():
            ax0.axvline(events[data_cnt][ev], color='k')

        ax1 = plt.subplot2grid((num_row, num_col), (1, data_cnt * 3),
                               rowspan=1,
                               colspan=3 - data_cnt)
        ax1.set_xticks([])
        if data_cnt == 0:
            ax1.set_title('Coincident Events', fontsize=fsize)
            ax1.text(-110, 190, 'Neuron #', rotation=90, fontsize=fsize)
            ax1.set_yticks([num_tr / 2 + 1, num_tr * 3 / 2., num_tr * 5 / 2])
            ax1.set_yticklabels(['5', '4', '3'], fontsize=fsize)
            ax1.spines['right'].set_visible(False)
            ax1.yaxis.set_ticks_position('left')
        else:
            ax1.spines['left'].set_visible(False)
            ax1.yaxis.set_ticks_position('right')
            ax1.set_yticks([])
        for N_cnt, N_comb_sel in enumerate(N_comb):
            patt_tmp = ue.inverse_hash_from_pattern([patterns[N_cnt]],
                                                    len(N_comb_sel))
            for n_cnt, n in enumerate(N_comb_sel):
                for tr, data_tr in enumerate(data):
                    if N_cnt == 0:
                        ax1.plot(data_tr[n].rescale('ms').magnitude,
                                 numpy.ones_like(data_tr[n].magnitude) * tr +
                                 n * (num_tr + 1) + 1,
                                 '.',
                                 markersize=ms1,
                                 color='k')
                    if patt_tmp[n_cnt][0] == 1:
                        ax1.plot(numpy.unique(
                            Js_dict_lst[N_cnt]['indices']['trial' + str(tr)]) *
                                 binsize,
                                 numpy.ones_like(
                                     numpy.unique(Js_dict_lst[N_cnt]['indices']
                                                  ['trial' + str(tr)])) * tr +
                                 n * (num_tr + 1) + 1,
                                 ls='',
                                 ms=ms,
                                 marker='s',
                                 markerfacecolor='none',
                                 markeredgecolor='c')
                if N_comb_sel == [0, 1, 2] and n < len(N_comb_sel) - 1:
                    ax1.axhline((tr + 2) * (n + 1), lw=0.5, color='k')
        for ev in events[data_cnt].keys():
            ax1.axvline(events[data_cnt][ev], color='k')

        ax2 = plt.subplot2grid((num_row, num_col), (2, data_cnt * 3),
                               rowspan=1,
                               colspan=3 - data_cnt)
        ax2.set_xticks([])
        if data_cnt == 0:
            ax2.set_title('Unitary Events', fontsize=fsize)
            ax2.text(-110, 190, 'Neuron #', rotation=90, fontsize=fsize)
            ax2.set_yticks([num_tr / 2 + 1, num_tr * 3 / 2., num_tr * 5 / 2])
            ax2.set_yticklabels(['5', '4', '3'], fontsize=fsize)
            ax2.spines['right'].set_visible(False)
            ax2.yaxis.set_ticks_position('left')
        else:
            ax2.spines['left'].set_visible(False)
            ax2.yaxis.set_ticks_position('right')
            ax2.set_yticks([])
        for N_cnt, N_comb_sel in enumerate(N_comb):
            patt_tmp = ue.inverse_hash_from_pattern([patterns[N_cnt]],
                                                    len(N_comb_sel))
            for n_cnt, n in enumerate(N_comb_sel):
                for tr, data_tr in enumerate(data):
                    if N_cnt == 0:
                        ax2.plot(data_tr[n].rescale('ms').magnitude,
                                 numpy.ones_like(data_tr[n].magnitude) * tr +
                                 n * (num_tr + 1) + 1,
                                 '.',
                                 markersize=ms1,
                                 color='k')
                    if patt_tmp[n_cnt][0] == 1:
                        sig_idx_win = numpy.where(
                            Js_dict_lst[N_cnt]['Js'] >= Js_sig)[0]
                        if len(sig_idx_win) > 0:
                            x = numpy.unique(
                                Js_dict_lst[N_cnt]['indices']['trial' +
                                                              str(tr)])
                            if len(x) > 0:
                                xx = []
                                for j in sig_idx_win:
                                    xx = numpy.append(
                                        xx, x[numpy.where(
                                            (x * binsize >= t_winpos[j])
                                            & (x * binsize < t_winpos[j] +
                                               winsize))])
                                ax2.plot(
                                    numpy.unique(xx) * binsize,
                                    numpy.ones_like(numpy.unique(xx)) * tr +
                                    n * (num_tr + 1) + 1,
                                    ms=ms,
                                    marker='s',
                                    ls='',
                                    mfc='none',
                                    mec='r')
                if N_comb_sel == [0, 1, 2] and n < len(N_comb_sel) - 1:
                    plt.axhline((tr + 2) * (n + 1), lw=0.5, color='k')
        for ev in events[data_cnt].keys():
            ax2.axvline(events[data_cnt][ev], color='k')
        for key in events[data_cnt].keys():
            for e_val in events[data_cnt][key]:
                ax2.text(e_val - 10 * pq.ms,
                         plt.gca().get_ylim()[0] - 80,
                         key,
                         fontsize=fsize,
                         color='k')

    if save_fig:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()
def plot_figure1_2(data, Js_dict, sig_level, binsize, winsize, winstep,
                   pattern_hash, N, plot_params_user):
    """plots Figure 1 and Figure 2 of the manuscript"""

    import matplotlib.pylab as plt
    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop

    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    Js_sig = ue.jointJ(sig_level)
    num_tr = len(data)
    pat = ue.inverse_hash_from_pattern(pattern_hash, N)

    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    if len(unit_real_ids) != N:
        raise ValueError('length of unit_ids should be' +
                         'equal to number of neurons!')
    plt.rcParams.update({'font.size': fsize})
    plt.rc('legend', fontsize=fsize)

    num_row, num_col = 6, 1
    ls = '-'
    alpha = 0.5
    plt.figure(1, figsize=figsize)
    if 'suptitle' in plot_params.keys():
        plt.suptitle("Trial aligned on " + plot_params['suptitle'],
                     fontsize=20)
    plt.subplots_adjust(top=top,
                        right=right,
                        left=left,
                        bottom=bottom,
                        hspace=hspace,
                        wspace=wspace)

    print('plotting raster plot ...')
    ax0 = plt.subplot(num_row, 1, 1)
    ax0.set_title('Spike Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax0.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
        if n < N - 1:
            ax0.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax0.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax0.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax0.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax0.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax0.set_xticks([])
    ax0.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax0.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    Xlim = ax0.get_xlim()
    ax0.text(Xlim[1] - 200, num_tr * 2 + 7, 'Neuron 2')
    ax0.text(Xlim[1] - 200, -12, 'Neuron 3')

    print('plotting Spike Rates ...')
    ax1 = plt.subplot(num_row, 1, 2, sharex=ax0)
    ax1.set_title('Spike Rates')
    for n in range(N):
        ax1.plot(t_winpos + winsize / 2.,
                 Js_dict['rate_avg'][:, n].rescale('Hz'),
                 label='Neuron ' + str(unit_real_ids[n]),
                 lw=lw)
    ax1.set_ylabel('(1/s)', fontsize=fsize)
    ax1.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    max_val_psth = 40
    ax1.set_ylim(0, max_val_psth)
    ax1.set_yticks([0, int(max_val_psth / 2), int(max_val_psth)])
    ax1.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    for key in events.keys():
        for e_val in events[key]:
            ax1.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax1.set_xticks([])

    print('plotting Raw Coincidences ...')
    ax2 = plt.subplot(num_row, 1, 3, sharex=ax0)
    ax2.set_title('Coincident Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax2.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            ax2.plot(
                numpy.unique(Js_dict['indices']['trial' + str(tr)]) * binsize,
                numpy.ones_like(
                    numpy.unique(Js_dict['indices']['trial' + str(tr)])) * tr +
                n * (num_tr + 1) + 1,
                ls='',
                ms=ms,
                marker='s',
                markerfacecolor='none',
                markeredgecolor='c')
        if n < N - 1:
            ax2.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax2.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax2.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax2.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax2.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax2.set_xticks([])
    ax2.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax2.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)

    print('plotting emp. and exp. coincidences rate ...')
    ax3 = plt.subplot(num_row, 1, 4, sharex=ax0)
    ax3.set_title('Coincidence Rates')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_emp'] / (winsize.rescale('s').magnitude * num_tr),
             label='empirical',
             lw=lw,
             color='c')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_exp'] / (winsize.rescale('s').magnitude * num_tr),
             label='expected',
             lw=lw,
             color='m')
    ax3.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax3.set_ylabel('(1/s)', fontsize=fsize)
    ax3.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    YTicks = ax3.get_ylim()
    ax3.set_yticks([0, YTicks[1] / 2, YTicks[1]])
    for key in events.keys():
        for e_val in events[key]:
            ax3.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    ax3.set_xticks([])

    print('plotting Surprise ...')
    ax4 = plt.subplot(num_row, 1, 5, sharex=ax0)
    ax4.set_title('Statistical Significance')
    ax4.plot(t_winpos + winsize / 2., Js_dict['Js'], lw=lw, color='k')
    ax4.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax4.axhline(Js_sig, ls='-', color='r')
    ax4.axhline(-Js_sig, ls='-', color='g')
    ax4.text(t_winpos[30], Js_sig + 0.3, '$\\alpha +$', color='r')
    ax4.text(t_winpos[30], -Js_sig - 0.5, '$\\alpha -$', color='g')
    ax4.set_xticks(t_winpos.magnitude[::int(len(t_winpos) / 10)])
    ax4.set_yticks([ue.jointJ(0.99), ue.jointJ(0.5), ue.jointJ(0.01)])
    ax4.set_yticklabels([0.99, 0.5, 0.01])

    ax4.set_ylim(S_ylim)
    for key in events.keys():
        for e_val in events[key]:
            ax4.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax4.set_xticks([])

    print('plotting UEs ...')
    ax5 = plt.subplot(num_row, 1, 6, sharex=ax0)
    ax5.set_title('Unitary Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax5.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) * tr + n *
                     (num_tr + 1) + 1,
                     '.',
                     markersize=0.5,
                     color='k')
            sig_idx_win = numpy.where(Js_dict['Js'] >= Js_sig)[0]
            if len(sig_idx_win) > 0:
                x = numpy.unique(Js_dict['indices']['trial' + str(tr)])
                if len(x) > 0:
                    xx = []
                    for j in sig_idx_win:
                        xx = numpy.append(
                            xx, x[numpy.where((x * binsize >= t_winpos[j]) & (
                                x * binsize < t_winpos[j] + winsize))])
                    ax5.plot(numpy.unique(xx) * binsize,
                             numpy.ones_like(numpy.unique(xx)) * tr + n *
                             (num_tr + 1) + 1,
                             ms=ms,
                             marker='s',
                             ls='',
                             mfc='none',
                             mec='r')
        if n < N - 1:
            ax5.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax5.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax5.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax5.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax5.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax5.set_xticks([])
    ax5.set_ylabel('Trial', fontsize=fsize)
    ax5.set_xlabel('Time [ms]', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax5.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
            ax5.text(e_val - 10 * pq.ms,
                     S_ylim[0] - 35,
                     key,
                     fontsize=fsize,
                     color='r')
    ax5.set_xticks([])

    for i in range(num_row):
        ax = locals()['ax' + str(i)]
        ax.text(-0.05,
                1.1,
                string.ascii_uppercase[i],
                transform=ax.transAxes,
                size=fsize + 5,
                weight='bold')
    if plot_params['save_fig']:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()
 def test_hash_invhash_consistency(self):
     m = np.array([[0, 0, 0],[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 1, 0],[1, 0, 1],[0, 1, 1],[1, 1, 1]])
     inv_h = ue.hash_from_pattern(m, N=8)
     m1 = ue.inverse_hash_from_pattern(inv_h, N = 8)
     self.assertTrue(np.all(m == m1))
 def test_invhash_shape_mat(self):
     N = 8
     h = np.array([178, 212, 232])
     expected = np.array([[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0],[1,0,1],[0,1,1],[1,1,1]])
     m = ue.inverse_hash_from_pattern(h, N)
     self.assertTrue(np.shape(m)[0] == N)
 def test_invhash_default_base(self):
     N = 3
     h = np.array([0, 4, 2, 1, 6, 5, 3, 7])
     expected = np.array([[0, 1, 0, 0, 1, 1, 0, 1],[0, 0, 1, 0, 1, 0, 1, 1],[0, 0, 0, 1, 0, 1, 1, 1]])
     m = ue.inverse_hash_from_pattern(h, N)
     self.assertTrue(np.all(expected == m))
def plot_UE(data, Js_dict, sig_level, binsize, winsize, winstep,
            pattern_hash, N, plot_params_user):
    """plots Figure 1 and Figure 2 of the manuscript"""

    t_start = data[0][0].t_start
    t_stop = data[0][0].t_stop

    t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
    Js_sig = ue.jointJ(sig_level)
    num_tr = len(data)
    pat = ue.inverse_hash_from_pattern(pattern_hash, N)

    # figure format
    plot_params = plot_params_default
    plot_params.update(plot_params_user)
    globals().update(plot_params)
    if len(unit_real_ids) != N:
        raise ValueError('length of unit_ids should be' +
                         'equal to number of neurons!')
    plt.rcParams.update({'font.size': fsize})
    plt.rc('legend', fontsize=fsize)

    num_row, num_col = 6, 1
    ls = '-'
    alpha = 0.5
    plt.figure(1, figsize=figsize)
    if 'suptitle' in plot_params.keys():
        plt.suptitle("Trial aligned on " +
                     plot_params['suptitle'], fontsize=20)
    plt.subplots_adjust(top=top, right=right, left=left,
                        bottom=bottom, hspace=hspace, wspace=wspace)

    print('plotting raster plot ...')
    ax0 = plt.subplot(num_row, 1, 1)
    ax0.set_title('Spike Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax0.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) *
                     tr + n * (num_tr + 1) + 1,
                     '.', markersize=0.5, color='k')
        if n < N - 1:
            ax0.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax0.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax0.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax0.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax0.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax0.set_xticks([])
    ax0.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax0.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    Xlim = ax0.get_xlim()
    ax0.text(Xlim[1] - 200, num_tr * 2 + 7, 'Neuron 2')
    ax0.text(Xlim[1] - 200, -12, 'Neuron 3')

    print('plotting Spike Rates ...')
    ax1 = plt.subplot(num_row, 1, 2, sharex=ax0)
    ax1.set_title('Spike Rates')
    for n in range(N):
        ax1.plot(t_winpos + winsize / 2.,
                 Js_dict['rate_avg'][:, n].rescale('Hz'),
                 label='Neuron ' + str(unit_real_ids[n]), lw=lw)
    ax1.set_ylabel('(1/s)', fontsize=fsize)
    ax1.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    max_val_psth = 40
    ax1.set_ylim(0, max_val_psth)
    ax1.set_yticks([0, int(max_val_psth / 2), int(max_val_psth)])
    ax1.legend(
        bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    for key in events.keys():
        for e_val in events[key]:
            ax1.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax1.set_xticks([])

    print('plotting Raw Coincidences ...')
    ax2 = plt.subplot(num_row, 1, 3, sharex=ax0)
    ax2.set_title('Coincident Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax2.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) *
                     tr + n * (num_tr + 1) + 1,
                     '.', markersize=0.5, color='k')
            ax2.plot(
                numpy.unique(Js_dict['indices']['trial' + str(tr)]) *
                binsize,
                numpy.ones_like(numpy.unique(Js_dict['indices'][
                    'trial' + str(tr)])) * tr + n * (num_tr + 1) + 1,
                ls='', ms=ms, marker='s', markerfacecolor='none',
                markeredgecolor='c')
        if n < N - 1:
            ax2.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax2.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax2.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax2.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax2.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax2.set_xticks([])
    ax2.set_ylabel('Trial', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax2.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)

    print('plotting emp. and exp. coincidences rate ...')
    ax3 = plt.subplot(num_row, 1, 4, sharex=ax0)
    ax3.set_title('Coincidence Rates')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_emp'] / (winsize.rescale('s').magnitude * num_tr),
             label='empirical', lw=lw, color='c')
    ax3.plot(t_winpos + winsize / 2.,
             Js_dict['n_exp'] / (winsize.rescale('s').magnitude * num_tr),
             label='expected', lw=lw, color='m')
    ax3.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax3.set_ylabel('(1/s)', fontsize=fsize)
    ax3.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)
    YTicks = ax3.get_ylim()
    ax3.set_yticks([0, YTicks[1] / 2, YTicks[1]])
    for key in events.keys():
        for e_val in events[key]:
            ax3.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
    ax3.set_xticks([])

    print('plotting Surprise ...')
    ax4 = plt.subplot(num_row, 1, 5, sharex=ax0)
    ax4.set_title('Statistical Significance')
    ax4.plot(t_winpos + winsize / 2., Js_dict['Js'], lw=lw, color='k')
    ax4.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax4.axhline(Js_sig, ls='-', color='r')
    ax4.axhline(-Js_sig, ls='-', color='g')
    ax4.text(t_winpos[30], Js_sig + 0.3, '$\\alpha +$', color='r')
    ax4.text(t_winpos[30], -Js_sig - 0.5, '$\\alpha -$', color='g')
    ax4.set_xticks(t_winpos.magnitude[::int(len(t_winpos) / 10)])
    ax4.set_yticks([ue.jointJ(0.99), ue.jointJ(0.5), ue.jointJ(0.01)])
    ax4.set_yticklabels([0.99, 0.5, 0.01])

    ax4.set_ylim(S_ylim)
    for key in events.keys():
        for e_val in events[key]:
            ax4.axvline(e_val, ls=ls, color='r', lw=lw, alpha=alpha)
    ax4.set_xticks([])

    print('plotting UEs ...')
    ax5 = plt.subplot(num_row, 1, 6, sharex=ax0)
    ax5.set_title('Unitary Events')
    for n in range(N):
        for tr, data_tr in enumerate(data):
            ax5.plot(data_tr[n].rescale('ms').magnitude,
                     numpy.ones_like(data_tr[n].magnitude) *
                     tr + n * (num_tr + 1) + 1, '.',
                     markersize=0.5, color='k')
            sig_idx_win = numpy.where(Js_dict['Js'] >= Js_sig)[0]
            if len(sig_idx_win) > 0:
                x = numpy.unique(Js_dict['indices']['trial' + str(tr)])
                if len(x) > 0:
                    xx = []
                    for j in sig_idx_win:
                        xx = numpy.append(xx, x[numpy.where(
                            (x * binsize >= t_winpos[j]) &
                            (x * binsize < t_winpos[j] + winsize))])
                    ax5.plot(
                        numpy.unique(
                            xx) * binsize,
                        numpy.ones_like(numpy.unique(xx)) *
                        tr + n * (num_tr + 1) + 1,
                        ms=ms, marker='s', ls='', mfc='none', mec='r')
        if n < N - 1:
            ax5.axhline((tr + 2) * (n + 1), lw=2, color='k')
    ax5.set_yticks([num_tr + 1, num_tr + 16, num_tr + 31])
    ax5.set_yticklabels([1, 15, 30], fontsize=fsize)
    ax5.set_ylim(0, (tr + 2) * (n + 1) + 1)
    ax5.set_xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)
    ax5.set_xticks([])
    ax5.set_ylabel('Trial', fontsize=fsize)
    ax5.set_xlabel('Time [ms]', fontsize=fsize)
    for key in events.keys():
        for e_val in events[key]:
            ax5.axvline(e_val, ls=ls, color='r', lw=2, alpha=alpha)
            ax5.text(e_val - 10 * pq.ms,
                     S_ylim[0] - 35, key, fontsize=fsize, color='r')
    ax5.set_xticks([])

    for i in range(num_row):
        ax = locals()['ax' + str(i)]
        ax.text(-0.05, 1.1, string.ascii_uppercase[i],
                transform=ax.transAxes, size=fsize + 5,
                weight='bold')
    if plot_params['save_fig']:
        plt.savefig(path_filename_format)
        if not showfig:
            plt.cla()
            plt.close()

    if showfig:
        plt.show()

    return None