Example #1
0
def plotcheckersurround(exp_name, stim_nr, filename=None, spikecutoff=1000,
                        ratingcutoff=4, staqualcutoff=0, inner_b=2,
                        outer_b=4):

    """
    Divides into center and surround by fitting 2D Gaussian, and plot
    temporal components.

    spikecutoff:
        Minimum number of spikes to include.

    ratingcutoff:
        Minimum spike sorting rating to include.

    staqualcutoff:
        Minimum STA quality (as measured by z-score) to include.

    inner_b:
        Defined limit between receptive field center and surround
        in units of sigma.

    outer_b:
        Defined limit of the end of receptive field surround.
    """

    exp_dir = iof.exp_dir_fixer(exp_name)
    stim_nr = str(stim_nr)
    if filename:
        filename = str(filename)

    if not filename:
        savefolder = 'surroundplots'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'surroundplots_' + label

    _, metadata = asc.read_spikesheet(exp_name)
    px_size = metadata['pixel_size(um)']

    data = iof.load(exp_name, stim_nr, fname=filename)

    clusters = data['clusters']
    stas = data['stas']
    stx_h = data['stx_h']
    exp_name = data['exp_name']
    stimname = data['stimname']
    max_inds = data['max_inds']
    frame_duration = data['frame_duration']
    filter_length = data['filter_length']
    quals = data['quals'][-1, :]

    spikenrs = data['spikenrs']

    c1 = np.where(spikenrs > spikecutoff)[0]
    c2 = np.where(clusters[:, 2] <= ratingcutoff)[0]
    c3 = np.where(quals > staqualcutoff)[0]

    choose = [i for i in range(clusters.shape[0]) if ((i in c1) and
                                                      (i in c2) and
                                                      (i in c3))]
    clusters = clusters[choose]
    stas = list(np.array(stas)[choose])
    max_inds = list(np.array(max_inds)[choose])

    clusterids = plf.clusters_to_ids(clusters)

    t = np.arange(filter_length)*frame_duration*1000

    # Determine frame size so that the total frame covers
    # an area large enough i.e. 2*700um
    f_size = int(700/(stx_h*px_size))

    del data

    for i in range(clusters.shape[0]):

        sta_original = stas[i]
        max_i_original = max_inds[i]

        try:
            sta, max_i = mf.cut_around_center(sta_original,
                                              max_i_original, f_size)
        except ValueError:
            continue

        fit_frame = sta[:, :, max_i[2]]

        if np.max(fit_frame) != np.max(np.abs(fit_frame)):
            onoroff = -1
        else:
            onoroff = 1



        Y, X = np.meshgrid(np.arange(fit_frame.shape[1]),
                           np.arange(fit_frame.shape[0]))

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',
                                    '.*divide by zero*.', RuntimeWarning)
            pars = gfit.gaussfit(fit_frame*onoroff)
            f = gfit.twodgaussian(pars)
            Z = f(X, Y)

        # Correcting for Mahalonobis dist.
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',
                                    '.*divide by zero*.', RuntimeWarning)
            Zm = np.log((Z-pars[0])/pars[1])
        Zm[np.isinf(Zm)] = np.nan
        Zm = np.sqrt(Zm*-2)

        ax = plt.subplot(1, 2, 1)

        plf.stashow(fit_frame, ax)
        ax.set_aspect('equal')

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', category=UserWarning)
            warnings.filterwarnings('ignore', '.*invalid value encountered*.')
            ax.contour(Y, X, Zm, [inner_b, outer_b],
                       cmap=plf.RFcolormap(('C0', 'C1')))

        barsize = 100/(stx_h*px_size)
        scalebar = AnchoredSizeBar(ax.transData,
                                   barsize, '100 µm',
                                   'lower left',
                                   pad=1,
                                   color='k',
                                   frameon=False,
                                   size_vertical=.2)
        ax.add_artist(scalebar)

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',
                                    '.*invalid value encountered in*.',
                                    RuntimeWarning)
            center_mask = np.logical_not(Zm < inner_b)
            center_mask_3d = np.broadcast_arrays(sta,
                                                 center_mask[..., None])[1]
            surround_mask = np.logical_not(np.logical_and(Zm > inner_b,
                                                          Zm < outer_b))
            surround_mask_3d = np.broadcast_arrays(sta,
                                                   surround_mask[..., None])[1]

        sta_center = np.ma.array(sta, mask=center_mask_3d)
        sta_surround = np.ma.array(sta, mask=surround_mask_3d)

        sta_center_temporal = np.mean(sta_center, axis=(0, 1))
        sta_surround_temporal = np.mean(sta_surround, axis=(0, 1))

        ax1 = plt.subplot(1, 2, 2)
        l1 = ax1.plot(t, sta_center_temporal,
                      label='Center\n(<{}σ)'.format(inner_b),
                      color='C0')
        sct_max = np.max(np.abs(sta_center_temporal))
        ax1.set_ylim(-sct_max, sct_max)
        ax2 = ax1.twinx()
        l2 = ax2.plot(t, sta_surround_temporal,
                      label='Surround\n({}σ<x<{}σ)'.format(inner_b, outer_b),
                      color='C1')
        sst_max = np.max(np.abs(sta_surround_temporal))
        ax2.set_ylim(-sst_max, sst_max)
        plf.spineless(ax1)
        plf.spineless(ax2)
        ax1.tick_params('y', colors='C0')
        ax2.tick_params('y', colors='C1')
        plt.xlabel('Time[ms]')
        plt.axhline(0, linestyle='dashed', linewidth=1)

        lines = l1+l2
        labels = [line.get_label() for line in lines]
        plt.legend(lines, labels, fontsize=7)
        plt.title('Temporal components')
        plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')

        plt.subplots_adjust(wspace=.5, top=.85)

        plotpath = os.path.join(exp_dir, 'data_analysis',
                                stimname, savefolder)
        if not os.path.isdir(plotpath):
            os.makedirs(plotpath, exist_ok=True)

        plt.savefig(os.path.join(plotpath, clusterids[i])+'.svg',
                    format='svg', dpi=300)
        plt.close()
    print(f'Plotted checkerflicker surround for {stimname}')
Example #2
0
def stripesurround_SVD(exp_name, stimnrs, nrcomponents=5):
    """
    nrcomponents:
        first N components of singular value decomposition (SVD)
        will be used to reduce noise.
    """
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stimnrs, int):
        stimnrs = [stimnrs]

    for stimnr in stimnrs:
        data = iof.load(exp_name, stimnr)

        _, metadata = asc.read_spikesheet(exp_dir)
        px_size = metadata['pixel_size(um)']

        clusters = data['clusters']
        stas = data['stas']
        max_inds = data['max_inds']
        filter_length = data['filter_length']
        stx_w = data['stx_w']
        exp_name = data['exp_name']
        stimname = data['stimname']
        frame_duration = data['frame_duration']
        quals = data['quals']

        # Record which clusters are ignored during analysis
        try:
            included = data['included']
        except KeyError:
            included = [True] * clusters.shape[0]

        # Average STA values 100 ms around the brightest frame to
        # minimize noise
        cut_time = int(100 / (frame_duration * 1000) / 2)

        # Tolerance for distance between center and surround
        # distributions 60 μm
        dtol = int((60 / px_size) / 2)

        clusterids = plf.clusters_to_ids(clusters)

        fsize = int(700 / (stx_w * px_size))
        t = np.arange(filter_length) * frame_duration * 1000
        vscale = fsize * stx_w * px_size

        cs_inds = np.empty(clusters.shape[0])
        polarities = np.empty(clusters.shape[0])

        savepath = os.path.join(exp_dir, 'data_analysis', stimname)

        for i in range(clusters.shape[0]):
            sta = stas[i]
            max_i = max_inds[i]

            # From this point on, use the low-rank approximation
            # version
            sta_reduced = sumcomponent(nrcomponents, sta)

            try:
                sta_reduced, max_i = msc.cutstripe(sta_reduced, max_i,
                                                   fsize * 2)
            except ValueError as e:
                if str(e) == 'Cutting outside the STA range.':
                    included[i] = False
                    continue
                else:
                    print(f'Error while analyzing {stimname}\n' +
                          f'Index:{i}    Cluster:{clusterids[i]}')
                    raise

            # Isolate the time point from which the fit will
            # be obtained
            if max_i[1] < cut_time:
                max_i[1] = cut_time + 1
            fitv = np.mean(sta_reduced[:, max_i[1] - cut_time:max_i[1] +
                                       cut_time + 1],
                           axis=1)

            # Make a space vector
            s = np.arange(fitv.shape[0])

            if np.max(fitv) != np.max(np.abs(fitv)):
                onoroff = -1
            else:
                onoroff = 1
            polarities[i] = onoroff
            # Determine the peak values for center and surround
            # to give as initial parameters for curve fitting
            centerpeak = onoroff * np.max(fitv * onoroff)
            surroundpeak = onoroff * np.max(fitv * -onoroff)

            # Define initial guesses for the center and surround gaussians
            # First set of values are for center, second for surround.
            p_initial = [centerpeak, max_i[0], 2, surroundpeak, max_i[0], 8]
            if onoroff == 1:
                bounds = ([0, -np.inf, -np.inf, 0, max_i[0] - dtol, 4], [
                    np.inf, np.inf, np.inf, np.inf, max_i[0] + dtol, 20
                ])
            elif onoroff == -1:
                bounds = ([
                    -np.inf, -np.inf, -np.inf, -np.inf, max_i[0] - dtol, 4
                ], [0, np.inf, np.inf, 0, max_i[0] + dtol, 20])

            try:
                popt, _ = curve_fit(centersurround_onedim,
                                    s,
                                    fitv,
                                    p0=p_initial,
                                    bounds=bounds)
            except (ValueError, RuntimeError) as e:
                er = str(e)
                if (er == "`x0` is infeasible."
                        or er.startswith("Optimal parameters not found")):
                    popt, _ = curve_fit(onedgauss, s, fitv, p0=p_initial[:3])
                    popt = np.append(popt, [0, popt[1], popt[2]])
                elif er == "array must not contain infs or NaNs":
                    included[i] = False
                    continue
                else:
                    print(f'Error while analyzing {stimname}\n' +
                          f'Index:{i}    Cluster:{clusterids[i]}')
                    import pdb
                    pdb.set_trace()
                    raise

            fit = centersurround_onedim(s, *popt)
            popt[0] = popt[0] * onoroff
            popt[3] = popt[3] * onoroff

            csi = popt[3] / popt[0]
            cs_inds[i] = csi

            plt.figure(figsize=(10, 9))
            ax = plt.subplot(121)
            plf.stashow(sta_reduced, ax, extent=[0, t[-1], -vscale, vscale])
            ax.set_xlabel('Time [ms]')
            ax.set_ylabel('Distance [µm]')
            ax.set_title(f'Using first {nrcomponents} components of SVD',
                         fontsize='small')

            ax = plt.subplot(122)
            plf.spineless(ax)
            ax.set_yticks([])
            # We need to flip the vertical axis to match
            # with the STA next to it
            plt.plot(onoroff * fitv, -s, label='Data')
            plt.plot(onoroff * fit, -s, label='Fit')
            plt.axvline(0, linestyle='dashed', alpha=.5)
            plt.title(f'Center: a: {popt[0]:4.2f}, μ: {popt[1]:4.2f},' +
                      f' σ: {popt[2]:4.2f}\n' +
                      f'Surround: a: {popt[3]:4.2f}, μ: {popt[4]:4.2f},' +
                      f' σ: {popt[5]:4.2f}' + f'\n CS index: {csi:4.2f}')
            plt.subplots_adjust(top=.85)
            plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]} ' +
                         f'Q: {quals[i]:4.2f}')
            os.makedirs(os.path.join(savepath, 'stripesurrounds_SVD'),
                        exist_ok=True)
            plt.savefig(os.path.join(savepath, 'stripesurrounds_SVD',
                                     clusterids[i] + '.svg'),
                        bbox_inches='tight')
            plt.close()

        data.update({
            'cs_inds': cs_inds,
            'polarities': polarities,
            'included': included
        })
        np.savez(os.path.join(savepath, f'{stimnr}_data_SVD.npz'), **data)
        print(f'Surround plotted and saved for {stimname}.')
Example #3
0
        Z = f(X, Y)

    # Correcting for Mahalonobis dist.
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore',
                                '.*divide by zero*.', RuntimeWarning)
        Zm = np.log((Z-pars[0])/pars[1])
    Zm[np.isinf(Zm)] = np.nan
    Zm = np.sqrt(Zm*-2)

    ax1 = plt.subplot(rows, columns, 1)
    plf.subplottext('A', ax1)

    vmax = np.abs(fit_frame).max()
    vmin = -vmax
    im = plf.stashow(fit_frame, ax1)
    ax1.set_aspect('equal')
    plf.spineless(ax1)
    ax1.set_xticks([])
    ax1.set_yticks([])

    checkercolors = ['black', 'orange']

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=UserWarning)
        warnings.filterwarnings('ignore', '.*invalid value encountered*.')
        ax1.contour(Y, X, Zm, [inner_b, outer_b], linewidths=.5,
                   cmap=plf.RFcolormap(checkercolors))

    barsize_set_checker = 100 # micrometers
    checker_scalebarsize = barsize_set_checker/(stx_h*px_size)
Example #4
0
        f = gfit.twodgaussian(pars)
        Z = f(X, Y)

    # Correcting for Mahalonobis dist.
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', '.*divide by zero*.', RuntimeWarning)
        Zm = np.log((Z - pars[0]) / pars[1])
    Zm[np.isinf(Zm)] = np.nan
    Zm = np.sqrt(Zm * -2)

    ax = plt.subplot(1, 2, 1)
    plf.subplottext('A', ax, x=0, y=1.3)

    vmax = np.abs(fit_frame).max()
    vmin = -vmax
    im = plf.stashow(fit_frame, ax)
    ax.set_aspect('equal')
    plf.spineless(ax)
    ax.set_axis_off()

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=UserWarning)
        warnings.filterwarnings('ignore', '.*invalid value encountered*.')
        ax.contour(Y,
                   X,
                   Zm, [inner_b, outer_b],
                   cmap=plf.RFcolormap(('C0', 'C1')))

    barsize = 100 / (stx_h * px_size)
    scalebar = AnchoredSizeBar(ax.transData,
                               barsize,
Example #5
0
def sumcomponent(nr, u, s, v):
    cumulative = np.zeros((u.shape[0], v.shape[-1]))
    for comp in range(nr + 1):
        cumulative += component(comp, u, s, v)
    return cumulative


for i, clusterid in enumerate(clusterids):
    sta = stas[i]
    rows = 2
    cols = 5

    plt.figure(figsize=(12, 12))
    ax1 = plt.subplot(rows, cols, 1)
    plf.stashow(sta, ax1)

    u, s, v = np.linalg.svd(sta)

    #componentnr = 1
    comp_range = 9
    sta_dn = np.zeros(sta.shape)
    for componentnr in range(comp_range):
        #u = u[:, :componentnr]
        #v = v[:componentnr, :]

        #
        #for i in range(componentnr):
        #    sta_dn += np.dot(u[:, i], v[i, :])

        sta_dn += component(componentnr, u, s, v)
Example #6
0
@author: ycan
"""

import plotfuncs as plf
import matplotlib.pyplot as plt
import miscfuncs as msc
import iofuncs as iof

data = iof.load('20180124', 12)
index = 5
sta = data['stas'][index]
max_i = data['max_inds'][index]
sta, max_i = msc.cutstripe(sta, max_i, 30)

a = 'Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Vega10, Vega10_r, Vega20, Vega20_r, Vega20b, Vega20b_r, Vega20c, Vega20c_r, Wistia, Wistia_r, YlGn, YlGnBu, YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn, autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r, gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, inferno, inferno_r, jet, jet_r, magma, magma_r, nipy_spectral, nipy_spectral_r, ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, seismic, seismic_r, spectral, spectral_r, spring, spring_r, summer, summer_r, tab10, tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, viridis, viridis_r, winter, winter_r'
b = a.split(',')
c = [i.strip(' ') for i in b if not i.endswith('_r')]

c = ['bwr_r', 'RdBu', 'seismic_r', 'bwr', 'RdBu_r', 'seismic']
dims = plf.numsubplots(len(c))
plt.figure(figsize=(20, 20))
for i, cm in enumerate(c):
    ax = plt.subplot(dims[0], dims[1], i + 1)
    im = plf.stashow(sta, ax, cmap=cm, ticks=[])
    plt.axis('off')
    im.axes.get_xaxis().set_visible(False)
    im.axes.get_yaxis().set_visible(False)
    ax.set_title(cm, size='x-small')
plt.savefig('cmaps.svg', bbox_inches='tight')
plt.close()
randnrs = [1 if i > .5 else -1 for i in randnrs]
stim = np.reshape(randnrs, (sy, filter_length), order='F')

for frame in range(total_frames):
    randnrs, seed = randpy.ran1(seed, sy)
    randnrs = [1 if i > .5 else -1 for i in randnrs]
    stim = np.hstack((stim[:, 1:], np.array(randnrs)[..., None]))
    spike = np.random.poisson()
    spike = 1
    if spike != 0:
        sta += stim * spike
        spikect += spike
sta /= spikect
# %%
plt.figure(figsize=(6, 14))
plf.stashow(sta, plt.gca())
plt.show()

#%%
bar = sta[:, -1]
clusters = data['clusters']
stas = data['stas']
clusterids = plf.clusters_to_ids(clusters)

rows = 1
columns = 2

for i in range(clusters.shape[0]):
    orig_sta = stas[i]

    plt.figure(figsize=(14, 14))
Example #8
0
import matplotlib.pyplot as plt
import plotfuncs as plf

multiplier = 100

sy = 160
total_frames = 80000 * multiplier
filter_length = 40
seed = -1000
length = sy * total_frames

randnrs, seed = randpy.ran1(seed, length)
randnrs = [1 if i > .5 else -1 for i in randnrs]

stimulus = np.reshape(randnrs, (sy, total_frames), order='F')
del randnrs

sta = np.zeros((sy, filter_length))
spikecounter = 0

for k in range(filter_length, total_frames - filter_length + 1):
    stim_small = stimulus[:, k - filter_length + 1:k + 1][:, ::-1]
    spike = np.random.poisson()
    if spike != 0:
        sta += spike * stim_small
        spikecounter += spike
del stimulus
sta = sta / spikecounter
ax = plt.subplot(111)
plf.stashow(sta, ax)
plt.show()
        # Changed width from 700 micrometer to 400 to zoom in on the
        # region of interest. This shifts where the fit is drawn,
        # it's fixed when plotting.
        fsize_original = int(700/(stx_w*px_size))
        fsize = int(400/(stx_w*px_size))
        fsize_diff = fsize_original - fsize
        t = np.arange(filter_length)*frame_duration*1000
        vscale = fsize * stx_w*px_size

        sta, max_i = msc.cutstripe(sta, max_i, fsize*2)

        ax1 = axes[2*j]
        plf.subplottext(['A', 'C'][j], ax1, x=-.4)
        plf.subplottext(['Mesopic', 'Photopic'][j],
                        ax1, x=-.5, y=.5, rotation=90, va='center')
        plf.stashow(sta, ax1, extent=[0, t[-1], -vscale, vscale])
        ax1.set_xlabel('Time [ms]')
#        ax1.set_ylabel(r'Distance [$\upmu$m]')
        ax1.set_ylabel(r'Distance [μm]')

        fitv = np.mean(sta[:, max_i[1]-cut_time:max_i[1]+cut_time+1],
                       axis=1)

        s = np.arange(fitv.shape[0])

        ax2 = axes[2*j+1]
        plf.subplottext(['B', 'D'][j], ax2, x=-.1)
        plf.subplottext(f'Center-Surround Index: {csi:4.2f}',
                        ax2, x=.95, y=.15, fontsize=8, fontweight='normal')
        plf.spineless(ax2)
        ax2.set_yticks([])
Example #10
0
        # Changed width from 700 micrometer to 400 to zoom in on the
        # region of interest. This shifts where the fit is drawn,
        # it's fixed when plotting.
        fsize_original = int(700/(stx_w*px_size))
        fsize = int(400/(stx_w*px_size))
        fsize_diff = fsize_original - fsize
        t = np.arange(filter_length)*frame_duration*1000
        vscale = fsize * stx_w*px_size

        sta, max_i = msc.cutstripe(sta, max_i, fsize*2)

        ax1 = axes[2*j]
        plf.subplottext(['A', 'C'][j], ax1, x=-.4)
        plf.subplottext(['Mesopic', 'Photopic'][j],
                        ax1, x=-.5, y=.5, rotation=90, va='center')
        plf.stashow(sta, ax1, extent=[0, t[-1], -vscale, vscale],
                    cmap=texplot.cmap)
        ax1.set_xlabel('Time [ms]')
        ax1.set_ylabel(r'Distance [$\upmu$m]')

        fitv = np.mean(sta[:, max_i[1]-cut_time:max_i[1]+cut_time+1],
                       axis=1)

        s = np.arange(fitv.shape[0])

        ax2 = axes[2*j+1]
        plf.subplottext(['B', 'D'][j], ax2, x=-.1)
        plf.subplottext(f'Center-Surround Index: {csi:4.2f}',
                        ax2, x=.95, y=.15, fontsize=8, fontweight='normal')
        plf.spineless(ax2)
        ax2.set_yticks([])
        ax2.set_xticks([])
Example #11
0
#%%
sp = sp1c

rows = 2
columns = 2
plt.figure(figsize=(12, 10))

f, pars0, pol0 = getfit(sp)

X, Y = np.meshgrid(np.arange(sp.shape[0]), np.arange(sp.shape[1]))
Z0 = f(Y, X)
Z0m = mahalonobis_convert(Z0, pars0)

ax0 = plt.subplot(rows, columns, 1)
plf.stashow(sp, ax0)
ax0.contour(X, Y, Z0m, [2])

d1 = sp - Z0 * pol0
f1, pars1, pol1 = getfit(d1)
Z1 = f1(Y, X)
Z1m = mahalonobis_convert(Z1, pars1)

ax1 = plt.subplot(rows, columns, 2)
plf.stashow(d1, ax1)
ax1.contour(X, Y, Z1m, [1.4])

d2 = sp - Z1 * pol1
f2, pars2, pol2 = getfit(d2)
Z2 = f2(Y, X)
Z2m = mahalonobis_convert(Z2, pars2)
def stripesurround(exp_name, stimnrs):
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stimnrs, int):
        stimnrs = [stimnrs]

    for stimnr in stimnrs:
        data = iof.load(exp_name, stimnr)

        _, metadata = asc.read_spikesheet(exp_dir)
        px_size = metadata['pixel_size(um)']

        clusters = data['clusters']
        stas = data['stas']
        max_inds = data['max_inds']
        filter_length = data['filter_length']
        stx_w = data['stx_w']
        exp_name = data['exp_name']
        stimname = data['stimname']
        frame_duration = data['frame_duration']
        quals = data['quals']

        clusterids = plf.clusters_to_ids(clusters)

        fsize = int(700 / (stx_w * px_size))
        t = np.arange(filter_length) * frame_duration * 1000
        vscale = fsize * stx_w * px_size

        #%%
        cs_inds = np.empty(clusters.shape[0])
        polarities = np.empty(clusters.shape[0])

        savepath = os.path.join(exp_dir, 'data_analysis', stimname)

        for i in range(clusters.shape[0]):
            sta = stas[i]
            max_i = max_inds[i]

            sta, max_i = msc.cutstripe(sta, max_i, fsize * 2)
            plt.figure(figsize=(12, 10))
            ax = plt.subplot(121)
            plf.stashow(sta, ax)

            # Isolate the time point from which the fit will
            # be obtained
            fitv = sta[:, max_i[1]]
            # Make a space vector
            s = np.arange(fitv.shape[0])

            if np.max(fitv) != np.max(np.abs(fitv)):
                onoroff = -1
            else:
                onoroff = 1
            polarities[i] = onoroff
            # Determine the peak values for center and surround
            # to give as initial parameters for curve fitting
            centerpeak = -onoroff * np.max(fitv * onoroff)
            surroundpeak = -onoroff * np.max(fitv * -onoroff)

            # Define initial guesses for the center and surround gaussians
            # First set of values are for center, second for surround.
            p_initial = [centerpeak, max_i[0], 2, surroundpeak, max_i[0], 4]
            bounds = ([0, -np.inf, -np.inf, 0, -np.inf, -np.inf], np.inf)

            try:
                popt, _ = curve_fit(centersurround_onedim,
                                    s,
                                    fitv,
                                    p0=p_initial,
                                    bounds=bounds)
            except ValueError as e:
                if str(e) == "`x0` is infeasible.":
                    print(e)
                    popt, _ = curve_fit(onedgauss,
                                        s,
                                        onoroff * fitv,
                                        p0=p_initial[:3])
                    popt = np.append(popt, [0, popt[1], popt[2]])
                else:
                    raise
            fit = centersurround_onedim(s, *popt)

            # Avoid dividing by zero when calculating center-surround index
            if popt[3] > 0:
                csi = popt[0] / popt[3]
            else:
                csi = 0
            cs_inds[i] = csi
            ax = plt.subplot(122)
            plf.spineless(ax)
            ax.set_yticks([])

            # We need to flip the vertical axis to match
            # with the STA next to it
            plt.plot(onoroff * fitv, -s, label='Data')
            plt.plot(onoroff * fit, -s, label='Fit')
            plt.axvline(0, linestyle='dashed', alpha=.5)
            plt.title(f'Center: a: {popt[0]:4.2f}, μ: {popt[1]:4.2f},' +
                      f' σ: {popt[2]:4.2f}\n' +
                      f'Surround: a: {popt[3]:4.2f}, μ: {popt[4]:4.2f},' +
                      f' σ: {popt[5]:4.2f}' + f'\n CS index: {csi:4.2f}')
            plt.subplots_adjust(top=.82)
            plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')
            os.makedirs(os.path.join(savepath, 'stripesurrounds'),
                        exist_ok=True)
            plt.savefig(
                os.path.join(savepath, 'stripesurrounds',
                             clusterids[i] + '.svg'))
            plt.close()

        data.update({'cs_inds': cs_inds, 'polarities': polarities})
        np.savez(os.path.join(savepath, f'{stimnr}_data.npz'), **data)
        f = gfit.twodgaussian(pars)
        Z = f(X, Y)

    # Correcting for Mahalonobis dist.
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', '.*divide by zero*.', RuntimeWarning)
        Zm = np.log((Z - pars[0]) / pars[1])
    Zm[np.isinf(Zm)] = np.nan
    Zm = np.sqrt(Zm * -2)

    ax = plt.subplot(1, 2, 1)
    plf.subplottext('A', ax, x=0, y=1.3)

    vmax = np.abs(fit_frame).max()
    vmin = -vmax
    im = plf.stashow(fit_frame, ax, cmap=texplot.cmap)
    ax.set_aspect('equal')
    plf.spineless(ax)
    ax.set_axis_off()

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=UserWarning)
        warnings.filterwarnings('ignore', '.*invalid value encountered*.')
        ax.contour(Y,
                   X,
                   Zm, [inner_b, outer_b],
                   cmap=plf.RFcolormap(('C0', 'C1')))

    barsize = 100 / (stx_h * px_size)
    scalebar = AnchoredSizeBar(ax.transData,
                               barsize,