def corrector(sy, total_frames, filter_length, seed):
    print(f'corrector parameters: {sy}, {total_frames}, '
          f'{filter_length}, {seed}')
    sta = np.zeros((sy, filter_length))
    spikect = 0

    # Initialize the stimulus
    randnrs, seed = randpy.ran1(seed, sy*filter_length)
    randnrs = [1 if i > .5 else -1 for i in randnrs]
    stim = np.reshape(randnrs, (sy, filter_length), order='F')

    for frame in range(total_frames):
        randnrs, seed = randpy.ran1(seed, sy)
        randnrs = [1 if i > .5 else -1 for i in randnrs]
        stim = np.hstack((stim[:, 1:], np.array(randnrs)[..., None]))
        spike = np.random.poisson()
        spike = 1
        if spike != 0:
            sta += stim*spike
            spikect += spike
    sta /= spikect
    bar = sta[:, -1]
    return bar
def randomizestripes(label, exp_name='20180124', stim_nrs=6):
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stim_nrs, int):
        stim_nrs = [stim_nrs]

    for stim_nr in stim_nrs:
        stimname = iof.getstimname(exp_name, stim_nr)

        clusters, metadata = asc.read_spikesheet(exp_dir)

        parameters = asc.read_parameters(exp_dir, stim_nr)

        scr_width = metadata['screen_width']
        px_size = metadata['pixel_size(um)']

        stx_w = parameters['stixelwidth']
        stx_h = parameters['stixelheight']

        if (stx_h/stx_w) < 2:
            raise ValueError('Make sure the stimulus is stripeflicker.')

        sy = scr_width/stx_w
#        sy = sy*4
        sy = int(sy)

        nblinks = parameters['Nblinks']
        try:
            bw = parameters['blackwhite']
        except KeyError:
            bw = False

        try:
            seed = parameters['seed']
            initialseed = parameters['seed']
        except KeyError:
            seed = -10000
            initialseed = -10000

        if nblinks == 1:
            ft_on, ft_off = asc.readframetimes(exp_dir, stim_nr,
                                               returnoffsets=True)
            # Initialize empty array twice the size of one of them, assign
            # value from on or off to every other element.
            frametimings = np.empty(ft_on.shape[0]*2, dtype=float)
            frametimings[::2] = ft_on
            frametimings[1::2] = ft_off
            # Set filter length so that temporal filter is ~600 ms.
            # The unit here is number of frames.
            filter_length = 40
        elif nblinks == 2:
            frametimings = asc.readframetimes(exp_dir, stim_nr)
            filter_length = 20
        else:
            raise ValueError('Unexpected value for nblinks.')

        # Omit everything that happens before the first 10 seconds
        cut_time = 10

        frame_duration = np.average(np.ediff1d(frametimings))
        total_frames = int(frametimings.shape[0]/4)

        all_spiketimes = []
        # Store spike triggered averages in a list containing correct
        # shaped arrays
        stas = []

        for i in range(len(clusters[:, 0])):
            spikes_orig = asc.read_raster(exp_dir, stim_nr,
                                         clusters[i, 0], clusters[i, 1])
            spikesneeded = spikes_orig.shape[0]*1000

            spiketimes = np.random.random_sample(spikesneeded)*spikes_orig.max()
            spiketimes = np.sort(spiketimes)
            spikes = asc.binspikes(spiketimes, frametimings)
            all_spiketimes.append(spikes)
            stas.append(np.zeros((sy, filter_length)))

        if bw:
            randnrs, seed = randpy.ran1(seed, sy*total_frames)
#            randnrs = mersennetw(sy*total_frames, seed1=seed)
            randnrs = [1 if i > .5 else -1 for i in randnrs]
        else:
            randnrs, seed = randpy.gasdev(seed, sy*total_frames)

        stimulus = np.reshape(randnrs, (sy, total_frames), order='F')
        del randnrs

        for k in range(filter_length, total_frames-filter_length+1):
            stim_small = stimulus[:, k-filter_length+1:k+1][:, ::-1]
            for j in range(clusters.shape[0]):
                spikes = all_spiketimes[j]
                if spikes[k] != 0 and frametimings[k]>cut_time:
                    stas[j] += spikes[k]*stim_small

        max_inds = []

        spikenrs = np.array([spikearr.sum() for spikearr in all_spiketimes])

        quals = np.array([])

        for i in range(clusters.shape[0]):
            stas[i] = stas[i]/spikenrs[i]
            # Find the pixel with largest absolute value
            max_i = np.squeeze(np.where(np.abs(stas[i])
                                        == np.max(np.abs(stas[i]))))
            # If there are multiple pixels with largest value,
            # take the first one.
            if max_i.shape != (2,):
                try:
                    max_i = max_i[:, 0]
                # If max_i cannot be found just set it to zeros.
                except IndexError:
                    max_i = np.array([0, 0])

            max_inds.append(max_i)

            quals = np.append(quals, asc.staquality(stas[i]))

#        savefname = str(stim_nr)+'_data'
#        savepath = pjoin(exp_dir, 'data_analysis', stimname)
#
#        exp_name = os.path.split(exp_dir)[-1]
#
#        if not os.path.isdir(savepath):
#            os.makedirs(savepath, exist_ok=True)
#        savepath = os.path.join(savepath, savefname)
#
#        keystosave = ['stas', 'max_inds', 'clusters', 'sy',
#                      'frame_duration', 'all_spiketimes', 'stimname',
#                      'total_frames', 'stx_w', 'spikenrs', 'bw',
#                      'quals', 'nblinks', 'filter_length', 'exp_name']
#        data_in_dict = {}
#        for key in keystosave:
#            data_in_dict[key] = locals()[key]
#
#        np.savez(savepath, **data_in_dict)
#        print(f'Analysis of {stimname} completed.')


        clusterids = plf.clusters_to_ids(clusters)

#        assert(initialseed.ty)
        correction = corrector(sy, total_frames, filter_length, initialseed)
        correction = np.outer(correction, np.ones(filter_length))

        t = np.arange(filter_length)*frame_duration*1000
        vscale = int(stas[0].shape[0] * stx_w*px_size/1000)
        for i in range(clusters.shape[0]):
            sta = stas[i]-correction

            vmax = 0.03
            vmin = -vmax
            plt.figure(figsize=(6, 15))
            ax = plt.subplot(111)
            im = ax.imshow(sta, cmap='RdBu', vmin=vmin, vmax=vmax,
                           extent=[0, t[-1], -vscale, vscale], aspect='auto')
            plt.xlabel('Time [ms]')
            plt.ylabel('Distance [mm]')

            plf.spineless(ax)
            plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f', size='2%')
            plt.suptitle('{}\n{}\n'
                         '{} Rating: {}\n'
                         'nrofspikes {:5.0f}'.format(exp_name,
                                                       stimname,
                                                       clusterids[i],
                                                       clusters[i][2],
                                                       spikenrs[i]))
            plt.subplots_adjust(top=.90)
            savepath = os.path.join(exp_dir, 'data_analysis',
                                    stimname, 'STAs_randomized')
            svgpath = pjoin(savepath, label)
            if not os.path.isdir(svgpath):
                os.makedirs(svgpath, exist_ok=True)
            plt.savefig(os.path.join(svgpath, clusterids[i]+'.svg'),
                        bbox_inches='tight')
            plt.close()

    os.system(f"convert -delay 25 {svgpath}/*svg {savepath}/animated_{label}.gif")
def stripeflickeranalysis(exp_name, stim_nrs):
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stim_nrs, int):
        stim_nrs = [stim_nrs]

    for stim_nr in stim_nrs:
        stimname = iof.getstimname(exp_name, stim_nr)

        clusters, metadata = asc.read_spikesheet(exp_dir)

        parameters = asc.read_parameters(exp_dir, stim_nr)

        scr_width = metadata['screen_width']
        px_size = metadata['pixel_size(um)']

        stx_w = parameters['stixelwidth']
        stx_h = parameters['stixelheight']

        if (stx_h / stx_w) < 2:
            raise ValueError('Make sure the stimulus is stripeflicker.')

        sy = scr_width / stx_w
        if sy % 1 == 0:
            sy = int(sy)
        else:
            raise ValueError('sy is not an integer')

        nblinks = parameters['Nblinks']
        try:
            bw = parameters['blackwhite']
        except KeyError:
            bw = False

        try:
            seed = parameters['seed']
        except KeyError:
            seed = -10000

        if nblinks == 1:
            ft_on, ft_off = asc.readframetimes(exp_dir,
                                               stim_nr,
                                               returnoffsets=True)
            # Initialize empty array twice the size of one of them, assign
            # value from on or off to every other element.
            frametimings = np.empty(ft_on.shape[0] * 2, dtype=float)
            frametimings[::2] = ft_on
            frametimings[1::2] = ft_off
            # Set filter length so that temporal filter is ~600 ms.
            # The unit here is number of frames.
            filter_length = 40
        elif nblinks == 2:
            frametimings = asc.readframetimes(exp_dir, stim_nr)
            filter_length = 20
        else:
            raise ValueError('Unexpected value for nblinks.')

        # Omit everything that happens before the first 10 seconds
        cut_time = 10

        frame_duration = np.average(np.ediff1d(frametimings))
        total_frames = frametimings.shape[0]

        all_spiketimes = []
        # Store spike triggered averages in a list containing correct
        # shaped arrays
        stas = []

        for i in range(len(clusters[:, 0])):
            spiketimes = asc.read_raster(exp_dir, stim_nr, clusters[i, 0],
                                         clusters[i, 1])
            spikes = asc.binspikes(spiketimes, frametimings)
            all_spiketimes.append(spikes)
            stas.append(np.zeros((sy, filter_length)))

        if bw:
            randnrs, seed = randpy.ran1(seed, sy * total_frames)
            randnrs = [1 if i > .5 else -1 for i in randnrs]
        else:
            randnrs, seed = randpy.gasdev(seed, sy * total_frames)

        stimulus = np.reshape(randnrs, (sy, total_frames), order='F')
        del randnrs

        for k in range(filter_length, total_frames - filter_length + 1):
            stim_small = stimulus[:, k - filter_length + 1:k + 1][:, ::-1]
            for j in range(clusters.shape[0]):
                spikes = all_spiketimes[j]
                if spikes[k] != 0 and frametimings[k] > cut_time:
                    stas[j] += spikes[k] * stim_small

        max_inds = []
        spikenrs = np.array([spikearr.sum() for spikearr in all_spiketimes])

        quals = np.array([])

        for i in range(clusters.shape[0]):
            stas[i] = stas[i] / spikenrs[i]
            # Find the pixel with largest absolute value
            max_i = np.squeeze(
                np.where(np.abs(stas[i]) == np.max(np.abs(stas[i]))))
            # If there are multiple pixels with largest value,
            # take the first one.
            if max_i.shape != (2, ):
                try:
                    max_i = max_i[:, 0]
                # If max_i cannot be found just set it to zeros.
                except IndexError:
                    max_i = np.array([0, 0])

            max_inds.append(max_i)

            quals = np.append(quals, asc.staquality(stas[i]))

        savefname = str(stim_nr) + '_data'
        savepath = pjoin(exp_dir, 'data_analysis', stimname)

        exp_name = os.path.split(exp_dir)[-1]

        if not os.path.isdir(savepath):
            os.makedirs(savepath, exist_ok=True)
        savepath = os.path.join(savepath, savefname)

        keystosave = [
            'stas', 'max_inds', 'clusters', 'sy', 'frame_duration',
            'all_spiketimes', 'stimname', 'total_frames', 'stx_w', 'spikenrs',
            'bw', 'quals', 'nblinks', 'filter_length', 'exp_name'
        ]
        data_in_dict = {}
        for key in keystosave:
            data_in_dict[key] = locals()[key]

        np.savez(savepath, **data_in_dict)
        print(f'Analysis of {stimname} completed.')
Beispiel #4
0
@author: ycan
"""
from randpy import randpy
import numpy as np
import matplotlib.pyplot as plt
import plotfuncs as plf

multiplier = 100

sy = 160
total_frames = 80000 * multiplier
filter_length = 40
seed = -1000
length = sy * total_frames

randnrs, seed = randpy.ran1(seed, length)
randnrs = [1 if i > .5 else -1 for i in randnrs]

stimulus = np.reshape(randnrs, (sy, total_frames), order='F')
del randnrs

sta = np.zeros((sy, filter_length))
spikecounter = 0

for k in range(filter_length, total_frames - filter_length + 1):
    stim_small = stimulus[:, k - filter_length + 1:k + 1][:, ::-1]
    spike = np.random.poisson()
    if spike != 0:
        sta += spike * stim_small
        spikecounter += spike
del stimulus
"""
from randpy import randpy
import numpy as np
import matplotlib.pyplot as plt
import plotfuncs as plf

sy = 160
total_frames = 77510
filter_length = 40
seed = -1000

sta = np.zeros((sy, filter_length))
spikect = 0

# Initialize the stimulus
randnrs, seed = randpy.ran1(seed, sy * filter_length)
randnrs = [1 if i > .5 else -1 for i in randnrs]
stim = np.reshape(randnrs, (sy, filter_length), order='F')

for frame in range(total_frames):
    randnrs, seed = randpy.ran1(seed, sy)
    randnrs = [1 if i > .5 else -1 for i in randnrs]
    stim = np.hstack((stim[:, 1:], np.array(randnrs)[..., None]))
    spike = np.random.poisson()
    spike = 1
    if spike != 0:
        sta += stim * spike
        spikect += spike
sta /= spikect
# %%
plt.figure(figsize=(6, 14))
Beispiel #6
0
    spikenrs = np.zeros(clusters.shape[0]).astype('int')
    for i in range(len(clusters[:, 0])):
        spiketimes = asc.read_raster(exp_dir, stimulusnr,
                                     clusters[i, 0], clusters[i, 1])

        spikes = asc.binspikes(spiketimes, frametimings)
        all_spiketimes.append(spikes)
        stas.append(np.zeros((sx, sy, filter_length)))

    # Length of the chunks (specified in number of frames)
    chunklength = 5000
    chunksize = chunklength*sx*sy
    nrofchunks = int(np.ceil(total_frames/chunklength))
    time = startime = datetime.datetime.now()
    for i in range(nrofchunks):
        randnrs, seed = randpy.ran1(seed, chunksize)
        randnrs = [1 if i > .5 else -1 for i in randnrs]
        stimulus = np.reshape(randnrs, (sx, sy, chunklength), order='F')
        del randnrs
        # Range of indices we are interested in for the current chunk
        if (i+1)*chunklength < total_frames:
            chunkind = slice(i*chunklength, (i+1)*chunklength)
            chunkend = chunklength
        else:
            chunkind = slice(i*chunklength, None)
            chunkend = total_frames - i*chunklength

        for k in range(filter_length, chunkend-filter_length+1):
            stim_small = stimulus[:, :, k-filter_length+1:k+1][:, :, ::-1]
            for j in range(clusters.shape[0]):
                spikes = all_spiketimes[j][chunkind]
def saccadegratingsanalyzer(exp_name, stim_nr):
    """
    Analyze and save responses to saccadegratings stimulus.
    """

    exp_dir = iof.exp_dir_fixer(exp_name)
    exp_name = os.path.split(exp_dir)[-1]
    stimname = iof.getstimname(exp_dir, stim_nr)
    clusters, metadata = asc.read_spikesheet(exp_dir)
    clusterids = plf.clusters_to_ids(clusters)

    refresh_rate = metadata['refresh_rate']

    parameters = asc.read_parameters(exp_name, stim_nr)
    if parameters['stimulus_type'] != 'saccadegrating':
        raise ValueError('Unexpected stimulus type: '
                         f'{parameters["stimulus_type"]}')
    fixfr = parameters.get('fixationframes', 80)
    sacfr = parameters.get('saccadeframes', 10)
    barwidth = parameters.get('barwidth', 40)
    averageshift = parameters.get('averageshift', 2)
    # The seed is hard-coded in the Stimulator
    seed = -10000

    ftimes = asc.readframetimes(exp_dir, stim_nr)
    ftimes.resize(int(ftimes.shape[0] / 2), 2)
    nfr = ftimes.size
    # Re-generate the stimulus
    # Amplitude of the shift and the transition type (saccade or grey is
    # determined based on the output of ran1
    randnrs = np.array(randpy.ran1(seed, nfr)[0])

    # Separate the amplitude and transitions into two arrays
    stimpos = (4 * randnrs[::2]).astype(int)

    # Transition variable, determines whether grating is moving during
    # the transion or only a grey screen is presented.
    trans = np.array(randnrs[1::2] > 0.5)

    # Record before and after positions in a single array and remove
    # The first element b/c there is no before value
    stimposx = np.append(0, stimpos)[:-1]
    stimtr = np.stack((stimposx, stimpos), axis=1)[1:]
    trans = trans[:-1]

    saccadetr = stimtr[trans, :]
    greytr = stimtr[~trans, :]

    # Create a time vector with defined temporal bin size
    tstep = 0.01  # Bin size is defined here, unit is seconds
    trialduration = (fixfr + sacfr) / refresh_rate
    nrsteps = int(trialduration / tstep) + 1
    t = np.linspace(0, trialduration, num=nrsteps)

    # Collect saccade beginning time for each trial
    trials = ftimes[1:, 0]
    sacftimes = trials[trans]
    greyftimes = trials[~trans]

    sacspikes = np.empty((clusters.shape[0], sacftimes.shape[0], t.shape[0]))
    greyspikes = np.empty((clusters.shape[0], greyftimes.shape[0], t.shape[0]))
    # Collect all the psth in one array. The order is
    # transision type, cluster index, start pos, target pos, time
    psth = np.zeros((2, clusters.shape[0], 4, 4, t.size))

    for i, (chid, clid, _) in enumerate(clusters):
        spiketimes = asc.read_raster(exp_dir, stim_nr, chid, clid)
        for j, _ in enumerate(sacftimes):
            sacspikes[i, j, :] = asc.binspikes(spiketimes, sacftimes[j] + t)
        for k, _ in enumerate(greyftimes):
            greyspikes[i, k, :] = asc.binspikes(spiketimes, greyftimes[k] + t)

    # Sort trials according to the transition type
    # nton[i][j] contains the indexes of trials where saccade was i to j
    nton_sac = [[[] for _ in range(4)] for _ in range(4)]
    for i, trial in enumerate(saccadetr):
        nton_sac[trial[0]][trial[1]].append(i)
    nton_grey = [[[] for _ in range(4)] for _ in range(4)]
    for i, trial in enumerate(greytr):
        nton_grey[trial[0]][trial[1]].append(i)

    savedir = os.path.join(exp_dir, 'data_analysis', stimname)
    os.makedirs(savedir, exist_ok=True)
    for i in range(clusters.shape[0]):
        fig, axes = plt.subplots(4,
                                 4,
                                 sharex=True,
                                 sharey=True,
                                 figsize=(8, 8))
        for j in range(4):
            for k in range(4):
                # Start from bottom left corner
                ax = axes[3 - j][k]
                # Average all transitions of one type
                psth_sac = sacspikes[i, nton_sac[j][k], :].mean(axis=0)
                psth_grey = greyspikes[i, nton_grey[j][k], :].mean(axis=0)
                # Convert to spikes per second
                psth_sac = psth_sac / tstep
                psth_grey = psth_grey / tstep
                psth[0, i, j, k, :] = psth_sac
                psth[1, i, j, k, :] = psth_grey
                ax.axvline(sacfr / refresh_rate * 1000,
                           color='red',
                           linestyle='dashed',
                           linewidth=.5)
                ax.plot(t * 1000, psth_sac, label='Saccadic trans.')
                ax.plot(t * 1000, psth_grey, label='Grey trans.')
                ax.set_yticks([])
                ax.set_xticks([])
                # Cosmetics
                plf.spineless(ax)
                if j == k:
                    ax.set_facecolor((1, 1, 0, 0.15))
                if j == 0:
                    ax.set_xlabel(f'{k}')
                    if k == 3:
                        ax.legend(fontsize='xx-small', loc=0)
                if k == 0:
                    ax.set_ylabel(f'{j}')

        # Add an encompassing label for starting and target positions
        ax0 = fig.add_axes([0.08, 0.08, .86, .86])
        plf.spineless(ax0)
        ax0.patch.set_alpha(0)
        ax0.set_xticks([])
        ax0.set_yticks([])
        ax0.set_ylabel('Start position')
        ax0.set_xlabel('Target position')
        plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')
        plt.savefig(os.path.join(savedir, f'{clusterids[i]}.svg'))
        plt.close()
    # Save results
    keystosave = [
        'fixfr', 'sacfr', 't', 'averageshift', 'barwidth', 'seed', 'trans',
        'saccadetr', 'greytr', 'nton_sac', 'nton_grey', 'stimname',
        'sacspikes', 'greyspikes', 'psth', 'nfr', 'parameters'
    ]
    data_in_dict = {}
    for key in keystosave:
        data_in_dict[key] = locals()[key]

    np.savez(os.path.join(savedir, str(stim_nr) + '_data'), **data_in_dict)
    print(f'Analysis of {stimname} completed.')