示例#1
0
def make_descr_fits(cellNum,
                    data_path=basePath + data_suff,
                    fit_rvc=1,
                    fit_sf=1,
                    rvcMod=1,
                    sfMod=0,
                    loss_type=2,
                    vecF1=1,
                    onsetCurr=None,
                    rvcName=rvcName,
                    sfName=sfName,
                    toSave=1,
                    fracSig=1,
                    nBoots=0,
                    n_repeats=25,
                    jointSf=0,
                    resp_thresh=(-1e5, 0),
                    veThresh=-1e5,
                    sfModRef=1,
                    phAdvCorr=True):
    ''' Separate fits for DC, F1 
      -- for DC: [maskOnly, mask+base]
      -- for F1: [maskOnly, mask+base {@mask TF}] 
      For asMulti fits (i.e. when done in parallel) we do the following to reduce multiple loading of files/race conditions
      --- we'll pass in the previous fits as fit_rvc and/or fit_sf
      --- we'll pass in [cellNum, cellName] as cellNum
  '''
    rvcFits_curr_toSave = None
    sfFits_curr_toSave = None
    # default to None, in case we don't actually do those fits
    expName = 'sfBB_core'

    if not isinstance(cellNum, int):
        cellNum, unitNm = cellNum
        print('cell %d {%s}' % (cellNum, unitNm))
    else:
        dlName = hf.get_datalist('V1_BB/')
        dataList = hf.np_smart_load(data_path + dlName)
        unitNm = dataList['unitName'][cellNum - 1]
    print('loading cell')
    cell = hf.np_smart_load('%s%s_sfBB.npy' % (data_path, unitNm))
    expInfo = cell[expName]
    byTrial = expInfo['trial']

    if fit_rvc == 1 or fit_rvc is not None:  # load existing rvcFits, if there
        rvcNameFinal = hf.rvc_fit_name(rvcName, rvcMod, None, vecF1)
        if fit_rvc == 1:
            if os.path.isfile(data_path + rvcNameFinal):
                rvcFits = hf.np_smart_load(data_path + rvcNameFinal)
            else:
                rvcFits = dict()
        else:  # otherwise, we have passed it in as fit_sf to avoid race condition during multiprocessing (i.e. multiple threads trying to load the same file)
            rvcFits = fit_rvc
        try:
            rvcFits_curr_toSave = rvcFits[cellNum - 1]
        except:
            rvcFits_curr_toSave = dict()

    if fit_sf == 1 or fit_sf is not None:
        modStr = hf.descrMod_name(sfMod)
        sfNameFinal = hf.descrFit_name(loss_type,
                                       descrBase=sfName,
                                       modelName=modStr,
                                       joint=jointSf)
        # descrLoss order is lsq/sqrt/poiss/sach
        if fit_sf == 1:
            if os.path.isfile(data_path + sfNameFinal):
                sfFits = hf.np_smart_load(data_path + sfNameFinal)
            else:
                sfFits = dict()
        else:  # otherwise, we have passed it in as fit_sf to avoid race condition during multiprocessing (i.e. multiple threads trying to load the same file)
            sfFits = fit_sf
        try:
            sfFits_curr_toSave = sfFits[cellNum - 1]
            print('---prev sf!')
        except:
            sfFits_curr_toSave = dict()
            print('---NO PREV sf!')

    # Set up whether we will bootstrap straight away
    resample = False if nBoots <= 0 else True
    if cross_val is not None:
        # we also set to True if cross_val is not None
        resample = True
    nBoots = 1 if nBoots <= 0 else nBoots
    if nBoots > 1:
        ftol = 5e-9
        #ftol = 1e-6;
    else:
        ftol = 2.220446049250313e-09
        # the default value, per scipy guide (scipy.optimize, for L-BFGS-B)

    #########
    ### Get the responses - base only, mask+base [base F1], mask only (mask F1)
    ### _____ MAKE THIS A FUNCTION???
    #########
    # 1. Get the mask only response (f1 at mask TF)
    _, _, gt_respMatrixDC_onlyMask, gt_respMatrixF1_onlyMask = hf_sf.get_mask_resp(
        expInfo,
        withBase=0,
        maskF1=1,
        vecCorrectedF1=vecF1,
        onsetTransient=onsetCurr,
        returnByTr=1,
        phAdvCorr=phAdvCorr)
    # i.e. get the maskONLY response
    # 2f1. get the mask+base response (but f1 at mask TF)
    _, _, _, gt_respMatrixF1_maskTf = hf_sf.get_mask_resp(
        expInfo,
        withBase=1,
        maskF1=1,
        vecCorrectedF1=vecF1,
        onsetTransient=onsetCurr,
        returnByTr=1,
        phAdvCorr=phAdvCorr)
    # i.e. get the maskONLY response
    # 2dc. get the mask+base response (f1 at base TF)
    _, _, gt_respMatrixDC, _ = hf_sf.get_mask_resp(expInfo,
                                                   withBase=1,
                                                   maskF1=0,
                                                   vecCorrectedF1=vecF1,
                                                   onsetTransient=onsetCurr,
                                                   returnByTr=1,
                                                   phAdvCorr=phAdvCorr)
    # i.e. get the base response for F1

    if cross_val == 2.0:
        nBoots = np.multiply(*gt_respMatrixDC_onlyMask.shape[0:2], )
        print('# boots is %03d' % nBoots)
        resample = False
        # then we DO NOT want to resample

    for boot_i in range(nBoots):
        ######
        # 3a. Ensure we have the right responses (incl. resampling, taking mean)
        ######
        # --- Note that all hf_sf.resample_all_cond(resample, arr, axis=X) calls are X=2, because all arrays are [nSf x nCon x respPerTrial x ...]
        # - o. DC responses are easy - simply resample, then take the mean/s.e.m. across all trials of a given condition
        respMatrixDC_onlyMask_resample = hf_sf.resample_all_cond(
            resample, np.copy(gt_respMatrixDC_onlyMask), axis=2)
        respMatrixDC_onlyMask = np.stack(
            (np.nanmean(respMatrixDC_onlyMask_resample, axis=-1),
             sem(respMatrixDC_onlyMask_resample, axis=-1, nan_policy='omit')),
            axis=-1)
        respMatrixDC_resample = hf_sf.resample_all_cond(
            resample, np.copy(gt_respMatrixDC), axis=2)
        respMatrixDC = np.stack(
            (np.nanmean(respMatrixDC_resample, axis=-1),
             sem(respMatrixDC_resample, axis=-1, nan_policy='omit')),
            axis=-1)
        # - F1 responses are different - vector math (i.e. hf.polar_vec_mean call)
        # --- first, resample the data, then do the vector math
        # - i. F1, only mask
        respMatrixF1_onlyMask_resample = hf_sf.resample_all_cond(
            resample, np.copy(gt_respMatrixF1_onlyMask), axis=2)
        # --- however, polar_vec_mean must be computed by condition, to handle NaN (which might be unequal across conditions):
        respMatrixF1_onlyMask = np.empty(
            respMatrixF1_onlyMask_resample.shape[0:2] +
            respMatrixF1_onlyMask_resample.shape[3:])
        respMatrixF1_onlyMask_phi = np.copy(respMatrixF1_onlyMask)
        for conds in itertools.product(
                *[range(x)
                  for x in respMatrixF1_onlyMask_resample.shape[0:2]]):
            r_mean, phi_mean, r_sem, phi_sem = hf.polar_vec_mean(
                [
                    hf.nan_rm(respMatrixF1_onlyMask_resample[conds +
                                                             (slice(None), 0)])
                ], [
                    hf.nan_rm(respMatrixF1_onlyMask_resample[conds +
                                                             (slice(None), 1)])
                ],
                sem=1)  # return s.e.m. rather than std (default)
            # - and we only care about the R value (after vec. avg.)
            respMatrixF1_onlyMask[conds] = [r_mean[0], r_sem[0]]
            # r...[0] to unpack (it's nested inside array, since f is vectorized
            respMatrixF1_onlyMask_phi[conds] = [phi_mean[0], phi_sem[0]]
        if phAdvCorr and vecF1:
            maskCon, maskSf = expInfo['maskCon'], expInfo['maskSF']
            opt_params, phAdv_model = hf_sf.phase_advance_fit_core(
                respMatrixF1_onlyMask, respMatrixF1_onlyMask_phi, maskCon,
                maskSf)
            for msI, mS in enumerate(maskSf):
                curr_params = opt_params[msI]
                # the phAdv model applies per-SF
                for mcI, mC in enumerate(maskCon):
                    curr_r, curr_phi = respMatrixF1_onlyMask[
                        mcI, msI, 0], respMatrixF1_onlyMask_phi[mcI, msI, 0]
                    refPhi = phAdv_model(*curr_params, curr_r)
                    new_r = np.multiply(
                        curr_r,
                        np.cos(np.deg2rad(refPhi) - np.deg2rad(curr_phi)))
                    respMatrixF1_onlyMask[mcI, msI, 0] = new_r

        # - ii. F1, both (@ maskTF)
        respMatrixF1_maskTf_resample = hf_sf.resample_all_cond(
            resample, np.copy(gt_respMatrixF1_maskTf), axis=2)
        # --- however, polar_vec_mean must be computed by condition, to handle NaN (which might be unequal across conditions):
        respMatrixF1_maskTf = np.empty(
            respMatrixF1_maskTf_resample.shape[0:2] +
            respMatrixF1_maskTf_resample.shape[3:])
        respMatrixF1_maskTf_phi = np.copy(respMatrixF1_maskTf)
        for conds in itertools.product(
                *[range(x) for x in respMatrixF1_maskTf_resample.shape[0:2]]):
            r_mean, phi_mean, r_sem, phi_var = hf.polar_vec_mean(
                [
                    hf.nan_rm(respMatrixF1_maskTf_resample[conds +
                                                           (slice(None), 0)])
                ], [
                    hf.nan_rm(respMatrixF1_maskTf_resample[conds +
                                                           (slice(None), 1)])
                ],
                sem=1)  # return s.e.m. rather than std (default)
            # - and we only care about the R value (after vec. avg.)
            respMatrixF1_maskTf[conds] = [r_mean[0], r_sem[0]]
            # r...[0] is to unpack (it's nested inside array, since func is vectorized
            respMatrixF1_maskTf_phi[conds] = [phi_mean[0], phi_sem[0]]
        if phAdvCorr and vecF1:
            maskCon, maskSf = expInfo['maskCon'], expInfo['maskSF']
            opt_params, phAdv_model = hf_sf.phase_advance_fit_core(
                respMatrixF1_maskTf, respMatrixF1_maskTf_phi, maskCon, maskSf)
            for msI, mS in enumerate(maskSf):
                curr_params = opt_params[msI]
                # the phAdv model applies per-SF
                for mcI, mC in enumerate(maskCon):
                    curr_r, curr_phi = respMatrixF1_maskTf[
                        mcI, msI, 0], respMatrixF1_maskTf_phi[mcI, msI, 0]
                    refPhi = phAdv_model(*curr_params, curr_r)
                    new_r = np.multiply(
                        curr_r,
                        np.cos(np.deg2rad(refPhi) - np.deg2rad(curr_phi)))
                    respMatrixF1_maskTf[mcI, msI, 0] = new_r

        if cross_val == 2.0:
            n_sfs, n_cons = gt_respMatrixF1_maskTf.shape[
                0], gt_respMatrixF1_maskTf.shape[1]
            con_ind, sf_ind = np.floor(np.divide(boot_i,
                                                 n_sfs)).astype('int'), np.mod(
                                                     boot_i,
                                                     n_sfs).astype('int')
            print('holding out con/sf indices %02d/%02d' % (con_ind, sf_ind))

        for measure in [0, 1]:
            if measure == 0:
                baseline = np.nanmean(
                    hf.resample_array(resample, expInfo['blank']['resps']))
                if cross_val == 2.0:
                    # --- so, "nan" out that condition in all of the responses
                    mask_only_ref = np.copy(respMatrixDC_onlyMask)
                    mask_base_ref = np.copy(respMatrixDC)
                    # the above establish the reference values; below, we nan out the current condition
                    mask_only = np.copy(respMatrixDC_onlyMask)
                    mask_base = np.copy(respMatrixDC)
                    mask_only[sf_ind, con_ind] = np.nan
                    mask_base[sf_ind, con_ind] = np.nan
                else:
                    mask_only = respMatrixDC_onlyMask
                    mask_base = respMatrixDC
                mask_only_all = None
                # as of 22.06.15
                fix_baseline = False
            elif measure == 1:
                baseline = 0
                if cross_val == 2.0:
                    # --- so, "nan" out that condition in all of the responses
                    mask_only_ref = np.copy(respMatrixF1_onlyMask)
                    mask_base_ref = np.copy(respMatrixF1_maskTf)
                    # the above establish the reference values; below, we nan out the current condition
                    mask_only = np.copy(respMatrixF1_onlyMask)
                    mask_base = np.copy(respMatrixF1_maskTf)
                    mask_only[sf_ind, con_ind] = np.nan
                    mask_base[sf_ind, con_ind] = np.nan
                else:
                    mask_only = respMatrixF1_onlyMask
                    mask_base = respMatrixF1_maskTf
                mask_only_all = None
                # as of 22.06.15; ignore the above line
                fix_baseline = True
            resp_str = hf_sf.get_resp_str(respMeasure=measure)

            if cross_val is not None:  # set up what is the test data!
                nan_val = -1e3
                # make sure we nan out any existing NaN values in the reference data
                mask_only_ref[np.isnan(mask_only_ref)] = nan_val
                # make sure we nan out any NaN values in the training data
                mask_only_tr = np.copy(mask_only)
                mask_only_tr[np.isnan(mask_only_tr)] = nan_val
                heldout = np.abs(
                    mask_only_tr - mask_only_ref
                ) > 1e-6  # if the idff. is g.t. this, it means they are different values
                test_mask_only = np.nan * np.zeros_like(mask_only_tr)
                test_mask_base = np.nan * np.zeros_like(mask_only_tr)
                test_mask_only[heldout] = mask_only_ref[heldout]
                test_mask_base[heldout] = mask_base_ref[heldout]
                heldouts = [test_mask_only]
                # update to include ...base IF below lines (which*) include both
            else:
                heldouts = [None]

            whichAll = [mask_only_all]
            whichResp = [mask_only]  #, mask_base];
            whichKey = ['mask']  #, 'both'];

            if fit_rvc == 1 or fit_rvc is not None:
                ''' Fit RVCs responses (see helper_fcns.rvc_fit for details) for:
            --- F0: mask alone (7 sfs)
                    mask + base together (7 sfs)
            --- F1: mask alone (7 sfs; at maskTf)
                    mask+ + base together (7 sfs; again, at maskTf)
            NOTE: Assumes only sfBB_core
        '''
                if resp_str not in rvcFits_curr_toSave:
                    rvcFits_curr_toSave[resp_str] = dict()

                cons = expInfo['maskCon']
                # first, mask only; then mask+base
                for wR, wK, wA in zip(whichResp, whichKey, whichAll):
                    # create room for an empty dict, if not already present
                    if wK not in rvcFits_curr_toSave[resp_str]:
                        rvcFits_curr_toSave[resp_str][wK] = dict()

                    adjMeans = np.transpose(wR[:, :, 0])
                    # just the means
                    # --- in new version of code [to allow boot], we can get masked array; do the following to save memory
                    adjMeans = adjMeans.data if isinstance(
                        adjMeans, np.ma.MaskedArray) else adjMeans
                    consRepeat = [cons] * len(adjMeans)

                    # get a previous fit, if present
                    try:
                        rvcFit_curr = rvcFits_curr_toSave[resp_str][
                            wK] if not resample else None
                    except:
                        rvcFit_curr = None
                    # do the fitting!
                    _, all_opts, all_conGains, all_loss = hf.rvc_fit(
                        adjMeans,
                        consRepeat,
                        var=None,
                        mod=rvcMod,
                        fix_baseline=fix_baseline,
                        prevFits=rvcFit_curr,
                        n_repeats=n_repeats)

                    # compute variance explained!
                    varExpl = [
                        hf.var_explained(
                            hf.nan_rm(dat),
                            hf.nan_rm(hf.get_rvcResp(prms, cons, rvcMod)),
                            None) for dat, prms in zip(adjMeans, all_opts)
                    ]
                    # now, package things
                    if resample:
                        if boot_i == 0:  # i.e. first time around
                            # - then we create empty lists to which we append the result of each success iteration
                            # --- note that we do not include adjMeans here (don't want nBoots iterations of response means saved!)
                            rvcFits_curr_toSave[resp_str][wK][
                                'boot_loss'] = []
                            rvcFits_curr_toSave[resp_str][wK][
                                'boot_params'] = []
                            rvcFits_curr_toSave[resp_str][wK][
                                'boot_conGain'] = []
                            rvcFits_curr_toSave[resp_str][wK][
                                'boot_varExpl'] = []
                        # then -- append!
                        rvcFits_curr_toSave[resp_str][wK]['boot_loss'].append(
                            all_loss)
                        rvcFits_curr_toSave[resp_str][wK][
                            'boot_params'].append(all_opts)
                        rvcFits_curr_toSave[resp_str][wK][
                            'boot_conGain'].append(all_conGains)
                        rvcFits_curr_toSave[resp_str][wK][
                            'boot_varExpl'].append(varExpl)
                    else:  # we will never be here more than once, since if not resample, then nBoots = 1
                        rvcFits_curr_toSave[resp_str][wK]['loss'] = all_loss
                        rvcFits_curr_toSave[resp_str][wK]['params'] = all_opts
                        rvcFits_curr_toSave[resp_str][wK][
                            'conGain'] = all_conGains
                        rvcFits_curr_toSave[resp_str][wK][
                            'adjMeans'] = adjMeans
                        rvcFits_curr_toSave[resp_str][wK]['varExpl'] = varExpl
                ########
                # END of rvc fit (for this measure, boot iteration)
                ########

            if fit_sf == 1 or fit_sf is not None:
                ''' Fit SF tuning responses (see helper_fcns.dog_fit for details) for:
            --- F0: mask alone (7 cons)
                    mask + base together (7 cons)
            --- F1: mask alone (7 cons; at maskTf)
                    mask+ + base together (7 cons; again, at maskTf)
            NOTE: Assumes only sfBB_core
        '''
                if resp_str not in sfFits_curr_toSave:
                    sfFits_curr_toSave[resp_str] = dict()

                cons, sfs = expInfo['maskCon'], expInfo['maskSF']
                stimVals = [[0], cons, sfs]
                valConByDisp = [np.arange(0, len(cons))]
                # all cons are valid in sfBB experiment

                for wR, wK, wA, heldout in zip(whichResp, whichKey, whichAll,
                                               heldouts):
                    if wK not in sfFits_curr_toSave[resp_str]:
                        sfFits_curr_toSave[resp_str][wK] = dict()

                    # get a previous fit, if present
                    try:
                        sfFit_curr = sfFits_curr_toSave[resp_str][
                            wK] if not resample else None
                    except:
                        sfFit_curr = None

                    # try to load isolated fits...
                    if jointSf > 0:
                        try:  # load non_joint fits as a reference (see hf.dog_fit or S. Sokol thesis for details)
                            modStr = hf.descrMod_name(sfMod)
                            ref_fits = hf.np_smart_load(
                                data_path + hf.descrFit_name(loss_type,
                                                             descrBase=sfName,
                                                             modelName=modStr,
                                                             joint=0))
                            isolFits = ref_fits[cellNum -
                                                1][resp_str][wK]['params']
                            if sfMod == sfModRef:
                                ref_varExpl = ref_fits[
                                    cellNum - 1][resp_str][wK]['varExpl']
                            else:
                                # otherwise, load the DoG
                                vExp_ref_fits = hf.np_smart_load(
                                    data_path + hf.descrFit_name(
                                        loss_type,
                                        descrBase=sfName,
                                        modelName=hf.descrMod_name(sfModRef),
                                        joint=0))
                                ref_varExpl = vExp_ref_fits[
                                    cellNum - 1][resp_str][wK]['varExpl']
                        except:
                            isolFits = None
                            ref_varExpl = None
                    else:
                        isolFits = None
                        ref_varExpl = None

                    allCurr = np.expand_dims(
                        np.transpose(
                            wA, [1, 0, 2]), axis=0) if wA is not None else None
                    # -- by default, loss_type=2 (meaning sqrt loss); why expand dims and transpose? dog fits assumes the data is in [disp,sf,con] and we just have [con,sf]
                    nll, prms, vExp, pSf, cFreq, totNLL, totPrm, success = hf.dog_fit(
                        [
                            np.expand_dims(np.transpose(wR[:, :, 0]),
                                           axis=0), allCurr,
                            np.expand_dims(np.transpose(wR[:, :, 1]), axis=0),
                            baseline
                        ],
                        sfMod,
                        loss_type=2,
                        disp=0,
                        expInd=None,
                        stimVals=stimVals,
                        validByStimVal=None,
                        valConByDisp=valConByDisp,
                        prevFits=sfFit_curr,
                        noDisp=1,
                        fracSig=fracSig,
                        n_repeats=n_repeats,
                        isolFits=isolFits,
                        ref_varExpl=ref_varExpl,
                        veThresh=veThresh,
                        joint=jointSf,
                        ftol=ftol,
                        resp_thresh=resp_thresh
                    )  # noDisp=1 means that we don't index dispersion when accessins prevFits

                    if resample or cross_val is not None:
                        if boot_i == 0:  # i.e. first time around
                            if cross_val is not None:
                                # first, pre-define empty lists for all of the needed results, if they are not yet defined
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_NLL_cv_test'] = np.empty(
                                        (nBoots, ) + nll.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_vExp_cv_test'] = np.empty(
                                        (nBoots, ) + vExp.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_NLL_cv_train'] = np.empty(
                                        (nBoots, ) + nll.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_vExp_cv_train'] = np.empty(
                                        (nBoots, ) + vExp.shape,
                                        dtype=np.float32)
                                # --- these are all implicitly based on training data
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_cv_params'] = np.empty(
                                        (nBoots, ) + prms.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_cv_prefSf'] = np.empty(
                                        (nBoots, ) + pSf.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_cv_charFreq'] = np.empty(
                                        (nBoots, ) + cFreq.shape,
                                        dtype=np.float32)
                            else:  # otherwise, the things we put only if we didn't have cross-validation
                                # - pre-allocate empty array of length nBoots (save time over appending each time around)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_loss'] = np.empty(
                                        (nBoots, ) + nll.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_params'] = np.empty(
                                        (nBoots, ) + prms.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_varExpl'] = np.empty(
                                        (nBoots, ) + vExp.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_prefSf'] = np.empty(
                                        (nBoots, ) + pSf.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_charFreq'] = np.empty(
                                        (nBoots, ) + cFreq.shape,
                                        dtype=np.float32)
                            # and the below apply whether or not we did cross-validation!
                            if jointSf > 0:
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_totalNLL'] = np.empty(
                                        (nBoots, ) + totNLL.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_paramList'] = np.empty(
                                        (nBoots, ) + totPrm.shape,
                                        dtype=np.float32)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_success'] = np.empty((nBoots, ),
                                                               dtype=np.bool_)
                            else:  # only if joint=0 will success be an array (and not just one value)
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_success'] = np.empty(
                                        (nBoots, ) + success.shape,
                                        dtype=np.bool_)

                        # then -- put in place (we reach here for all boot_i)
                        if cross_val is not None:
                            ### then, we need to compute test loss/vExp!
                            test_nlls = np.nan * np.zeros_like(nll)
                            test_vExps = np.nan * np.zeros_like(vExp)
                            # set up ref_params, ref_rc_val; will only be used IF applicable
                            ref_params = None
                            ref_rc_val = None
                            if sfMod == 3:
                                try:
                                    all_xc = prms[:, 1]
                                    # xc
                                    ref_params = [np.nanmin(all_xc), 1]
                                except:
                                    pass
                            else:
                                try:
                                    ref_params = prms[-1]
                                    # high contrast condition
                                    ref_rc_val = totPrm[
                                        2] if jointSf > 0 else None
                                    # even then, only used for jointSf==5
                                except:
                                    pass

                            for ii, prms_curr in enumerate(prms):
                                # we'll iterate over the parameters, which are fit for each contrast (the final dimension of test_mn)
                                if np.any(np.isnan(prms_curr)):
                                    continue
                                non_nans = np.where(~np.isnan(heldout[:, ii,
                                                                      0]))[0]
                                # check -- from the heldout data at contrast ii, which sfs are not NaN
                                curr_sfs = stimVals[2][non_nans]
                                # get those sf values
                                resps_curr = heldout[non_nans, ii, 0]
                                # and get those responses, to pass into DoG_loss
                                test_nlls[ii] = hf.DoG_loss(
                                    prms_curr,
                                    resps_curr,
                                    curr_sfs,
                                    resps_std=None,
                                    loss_type=loss_type,
                                    DoGmodel=sfMod,
                                    dir=dir,
                                    joint=0,
                                    baseline=baseline,
                                    ref_params=ref_params,
                                    ref_rc_val=ref_rc_val
                                )  # why not enforce max? b/c fewer resps means more varied range of max, don't want to wrongfully penalize
                                test_vExps[ii] = hf.var_explained(
                                    resps_curr,
                                    prms_curr,
                                    curr_sfs,
                                    sfMod,
                                    baseline=baseline,
                                    ref_params=ref_params,
                                    ref_rc_val=ref_rc_val)

                            ### done with computing test loss, so save everything
                            sfFits_curr_toSave[resp_str][wK][
                                'boot_NLL_cv_test'][boot_i] = test_nlls
                            sfFits_curr_toSave[resp_str][wK][
                                'boot_vExp_cv_test'][boot_i] = test_vExps
                            sfFits_curr_toSave[resp_str][wK][
                                'boot_NLL_cv_train'][boot_i] = nll
                            sfFits_curr_toSave[resp_str][wK][
                                'boot_vExp_cv_train'][boot_i] = vExp
                            # --- these are all implicitly based on training data
                            sfFits_curr_toSave[resp_str][wK]['boot_cv_params'][
                                boot_i] = prms
                            sfFits_curr_toSave[resp_str][wK]['boot_cv_prefSf'][
                                boot_i] = pSf
                            sfFits_curr_toSave[resp_str][wK][
                                'boot_cv_charFreq'][boot_i] = cFreq
                        else:
                            sfFits_curr_toSave[resp_str][wK]['boot_loss'][
                                boot_i] = nll
                            sfFits_curr_toSave[resp_str][wK]['boot_params'][
                                boot_i] = prms
                            sfFits_curr_toSave[resp_str][wK]['boot_charFreq'][
                                boot_i] = cFreq
                            sfFits_curr_toSave[resp_str][wK]['boot_varExpl'][
                                boot_i] = vExp
                            sfFits_curr_toSave[resp_str][wK]['boot_prefSf'][
                                boot_i] = pSf
                        # these apply regardless of c-v or not
                        sfFits_curr_toSave[resp_str][wK]['boot_success'][
                            boot_i] = success
                        if jointSf > 0:
                            try:
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_totalNLL'][boot_i] = totNLL
                                sfFits_curr_toSave[resp_str][wK][
                                    'boot_paramList'][boot_i] = totPrm
                            except:
                                pass
                    else:  # otherwise, we'll only be here once
                        sfFits_curr_toSave[resp_str][wK]['NLL'] = nll.astype(
                            np.float32)
                        sfFits_curr_toSave[resp_str][wK][
                            'params'] = prms.astype(np.float32)
                        sfFits_curr_toSave[resp_str][wK][
                            'varExpl'] = vExp.astype(np.float32)
                        sfFits_curr_toSave[resp_str][wK][
                            'prefSf'] = pSf.astype(np.float32)
                        sfFits_curr_toSave[resp_str][wK][
                            'charFreq'] = cFreq.astype(np.float32)
                        sfFits_curr_toSave[resp_str][wK][
                            'success'] = success  #.astype(np.bool_);
                        if jointSf > 0:
                            try:
                                sfFits_curr_toSave[resp_str][wK][
                                    'totalNLL'] = totNLL.astype(np.float32)
                            except:
                                sfFits_curr_toSave[resp_str][wK][
                                    'totalNLL'] = totNLL
                                # if it's None, or [], or...
                            try:
                                sfFits_curr_toSave[resp_str][wK][
                                    'paramList'] = totPrm.astype(np.float32)
                            except:
                                sfFits_curr_toSave[resp_str][wK][
                                    'paramList'] = totPrm
                                # if it's None, or [], or...
                ########
                # END of sf fit (for this measure, boot iteration)
                ########

        ########
        # END of measure (i.e. handled both measures, go back for more boot_iter, if specified)
        ########
    ########
    # END of all boot iters (i.e. handled both measures, go back for more boot_iter, if specified)
    ########

    ###########
    # NOW, save (if saving); otherwise, we return the values
    ###########
    # if we are saving, save; otherwise, return the curr_rvc, curr_sf fits
    if toSave:
        if fit_rvc:
            # load fits again in case some other run has saved/made changes
            if os.path.isfile(data_path + rvcNameFinal):
                print('reloading rvcFits...')
                rvcFits = hf.np_smart_load(data_path + rvcNameFinal)
            if cellNum - 1 not in rvcFits:
                rvcFits[cellNum - 1] = dict()

            # now save
            rvcFits[cellNum - 1] = rvcFits_curr_toSave
            np.save(data_path + rvcNameFinal, rvcFits)
            print('Saving %s, %s @ %s' % (resp_str, wK, rvcNameFinal))

        if fit_sf:

            pass_check = False
            while not pass_check:  # keep saving/reloading until the fit has properly saved everything...

                # load fits again in case some other run has saved/made changes
                if os.path.isfile(data_path + sfNameFinal):
                    print('reloading sfFits...')
                    sfFits = hf.np_smart_load(data_path + sfNameFinal)
                if cellNum - 1 not in sfFits:
                    sfFits[cellNum - 1] = dict()

                sfFits[cellNum - 1] = sfFits_curr_toSave

                # now save
                np.save(data_path + sfNameFinal, sfFits)
                print('Saving %s, %s @ %s' % (resp_str, wK, sfNameFinal))

                # now check...
                check = hf.np_smart_load(data_path + sfNameFinal)
                if resample:  # check that the boot stuff is there
                    if 'boot_params' in check[cellNum -
                                              1]['dc']['mask'].keys():
                        pass_check = True
                else:
                    if 'NLL' in check[cellNum - 1]['dc']['mask'].keys(
                    ):  # just check that any relevant key is there
                        pass_check = True
                # --- and if neither pass_check was triggered, then we go back and reload, etc

        ### End of saving (both RVC and SF)
    else:
        return rvcFits_curr_toSave, sfFits_curr_toSave
示例#2
0
### RVCFITS
#rvcBase = 'rvcFits_200507'; # direc flag & '.npy' are added
#rvcBase = 'rvcFits_191023'; # direc flag & '.npy' are adde
#rvcBase = 'rvcFits_200714'; # direc flag & '.npy' are adde
#rvcBase = 'rvcFits_200507';
rvcBase = 'rvcFits_210517'
# -- rvcAdj = -1 means, yes, load the rvcAdj fits, but with vecF1 correction rather than ph fit; so, we'll
rvcAdjSigned = rvcAdj
rvcAdj = np.abs(rvcAdj)

##################
### Spatial frequency
##################

modStr = hf.descrMod_name(descrMod)
fLname = hf.descrFit_name(descrLoss, descrBase=descrBase, modelName=modStr)
descrFits = hf.np_smart_load(data_loc + fLname)
pause_tm = 2.5 * np.random.rand()
time.sleep(pause_tm)
# set the save directory to save_loc, then create the save directory if needed
subDir = fLname.replace('Fits', '').replace('.npy', '')
save_loc = str(save_loc + subDir + '/')

if not os.path.exists(save_loc):
    os.makedirs(save_loc)

dataList = hf.np_smart_load(data_loc + expName)

cellName = dataList['unitName'][cellNum - 1]
try:
    cellType = dataList['unitType'][cellNum - 1]
示例#3
0
    print('rvc? %d -- sf? %d -- sfMod,joint %d,%d' %
          (fit_rvc, fit_sf, sf_mod, jointSf))
    fracSig = 1
    # why fracSig =1? For V1 fits, we want to contrasin the upper-half sigma of the two-half gaussian as a fraction of the lower half

    if asMulti:
        from functools import partial
        import multiprocessing as mp
        nCpu = mp.cpu_count()

        # to avoid race conditions, load the previous fits beforehand; and the datalist
        rvcNameFinal = hf.rvc_fit_name(rvcName, rvc_mod, None, vecF1=1)
        # DEFAULT is vecF1 adjustment
        modStr = hf.descrMod_name(sf_mod)
        sfNameFinal = hf.descrFit_name(loss_type,
                                       descrBase=sfName,
                                       modelName=modStr,
                                       joint=jointSf)
        # descrLoss order is lsq/sqrt/poiss/sach

        pass_rvc = hf.np_smart_load(
            '%s%s%s' %
            (basePath, data_suff, rvcNameFinal)) if fit_rvc else None
        pass_sf = hf.np_smart_load(
            '%s%s%s' % (basePath, data_suff, sfNameFinal)) if fit_sf else None

        dataList = hf.np_smart_load(
            '%s%s%s' % (basePath, data_suff, hf.get_datalist('V1_BB/')))

        if nBoots > 1:
            n_repeats = 3 if jointSf > 0 else 5
            # fewer if repeat
示例#4
0
### DESCRLIST
hpc_str = 'HPC' if isHPC else ''
descrBase = 'descrFits%s_220720vEs' % hpc_str
#descrBase = 'descrFits%s_220410' % hpc_str;
### RVCFITS
rvcBase = 'rvcFits%s_220718' % hpc_str
#rvcBase = 'rvcFits%s_220609' % hpc_str;

##################
### Spatial frequency
##################

modStr = hf.descrMod_name(descrMod)
fLname = hf.descrFit_name(descrLoss,
                          descrBase=descrBase,
                          modelName=modStr,
                          joint=joint,
                          phAdj=1 if phAdjSigned == 1 else None)
descrFits = hf.np_smart_load(data_loc + fLname)
pause_tm = 2.0 * np.random.rand()
time.sleep(pause_tm)
# set the save directory to save_loc, then create the save directory if needed
subDir = fLname.replace('Fits', '').replace('.npy', '')
save_loc = str(save_loc + subDir + '/')

if not os.path.exists(save_loc):
    os.makedirs(save_loc)

dataList = hf.np_smart_load(data_loc + expName)

cellName = dataList['unitName'][cellNum - 1]
示例#5
0
if rpt_fit:
    is_rpt = '_rpt'
else:
    is_rpt = ''

conDig = 3
# round contrast to the 3rd digit

dataList = np.load(str(dataPath + expName), encoding='latin1').item()

cellStruct = np.load(str(dataPath + dataList['unitName'][which_cell - 1] +
                         '_sfm.npy'),
                     encoding='latin1').item()

# #### Load descriptive model fits, comp. model fits
descrFitName = hf.descrFit_name(descr_fit_type)

modParams = np.load(str(dataPath + fitListName), encoding='latin1').item()
modParamsCurr = modParams[which_cell - 1]['params']

# ### Organize data
# #### determine contrasts, center spatial frequency, dispersions

data = cellStruct['sfm']['exp']['trial']

ignore, modRespAll = mod_resp.SFMGiveBof(modParamsCurr,
                                         cellStruct,
                                         normType=norm_type,
                                         lossType=lossType,
                                         expInd=expInd)
print('norm type %02d' % (norm_type))
示例#6
0
    lossSuf = '_poiss.npy'
elif lossType == 3:
    lossSuf = '_modPoiss.npy'
elif lossType == 4:
    lossSuf = '_chiSq.npy'

fitName = str(fitBase + fitSuf + lossSuf)

# set the save directory to save_loc, then create the save directory if needed
subDir = fitName.replace('fitList', 'fits').replace('.npy', '')
save_loc = str(save_loc + subDir + '/')
if not os.path.exists(save_loc):
    os.makedirs(save_loc)

# load descrFits
descrExpName = descrFit_name(descrLossType)
descrModName = descrFit_name(descrLossType, fitName)

nFam = 5
nCon = 2
plotSteps = 100
# how many steps for plotting descriptive functions?
sfPlot = np.logspace(-1, 1, plotSteps)

# for bandwidth/prefSf descriptive stuff
muLoc = 2
# mu is in location '2' of parameter arrays
height = 1 / 2.
# measure BW at half-height
sf_range = [0.01, 10]
# allowed values of 'mu' for fits - see descr_fit.py for details
示例#7
0
    vecF1 = -1
else:
    rvcMod = 1
    # i.e. Naka-rushton (1)
    dMod_num = 3
    # d-dog-s
    rvcDir = None
    # None if we're doing vec-corrected
    if expDir == 'altExp/':
        vecF1 = 0
    else:
        vecF1 = 1

dFits_mod = hf.descrMod_name(dMod_num)
descrFits_name = hf.descrFit_name(lossType=dLoss_num,
                                  descrBase=dFits_base,
                                  modelName=dFits_mod,
                                  phAdj=1 if vecF1 == -1 else None)

## now, let it run
dataPath = basePath + expDir + 'structures/'
save_loc = basePath + expDir + 'figures/'
save_locSuper = save_loc + 'superposition_220713/'
if use_mod_resp == 1:
    save_locSuper = save_locSuper + '%s/' % fitBase

dataList = hf.np_smart_load(dataPath + dataListNm)
print('Trying to load descrFits at: %s' % (dataPath + descrFits_name))
descrFits = hf.np_smart_load(dataPath + descrFits_name)
if use_mod_resp == 1 or use_mod_resp == 2:
    fitList = hf.np_smart_load(dataPath + fitList_nm)
else:
def plot_save_superposition(which_cell, expDir, use_mod_resp=0, fitType=2, excType=1, useHPCfit=1, conType=None, lgnFrontEnd=None, force_full=1, f1_expCutoff=2, to_save=1):

  if use_mod_resp == 2:
    rvcAdj   = -1; # this means vec corrected F1, not phase adjustment F1...
    _applyLGNtoNorm = 0; # don't apply the LGN front-end to the gain control weights
    recenter_norm = 1;
    newMethod = 1; # yes, use the "new" method for mrpt (not that new anymore, as of 21.03)
    lossType = 1; # sqrt
    _sigmoidSigma = 5;

  basePath = os.getcwd() + '/'
  if 'pl1465' in basePath or useHPCfit:
    loc_str = 'HPC';
  else:
    loc_str = '';

  rvcName = 'rvcFits%s_220531' % loc_str if expDir=='LGN/' else 'rvcFits%s_220609' % loc_str
  rvcFits = None; # pre-define this as None; will be overwritten if available/needed
  if expDir == 'altExp/': # we don't adjust responses there...
    rvcName = None;
  dFits_base = 'descrFits%s_220609' % loc_str if expDir=='LGN/' else 'descrFits%s_220631' % loc_str
  if use_mod_resp == 1:
    rvcName = None; # Use NONE if getting model responses, only
    if excType == 1:
      fitBase = 'fitList_200417';
    elif excType == 2:
      fitBase = 'fitList_200507';
    lossType = 1; # sqrt
    fitList_nm = hf.fitList_name(fitBase, fitType, lossType=lossType);
  elif use_mod_resp == 2:
    rvcName = None; # Use NONE if getting model responses, only
    if excType == 1:
      fitBase = 'fitList%s_210308_dG' % loc_str
      if recenter_norm:
        #fitBase = 'fitList%s_pyt_210312_dG' % loc_str
        fitBase = 'fitList%s_pyt_210331_dG' % loc_str
    elif excType == 2:
      fitBase = 'fitList%s_pyt_210310' % loc_str
      if recenter_norm:
        #fitBase = 'fitList%s_pyt_210312' % loc_str
        fitBase = 'fitList%s_pyt_210331' % loc_str
    fitList_nm = hf.fitList_name(fitBase, fitType, lossType=lossType, lgnType=lgnFrontEnd, lgnConType=conType, vecCorrected=-rvcAdj);

  # ^^^ EDIT rvc/descrFits/fitList names here; 

  ############
  # Before any plotting, fix plotting paramaters
  ############
  plt.style.use('https://raw.githubusercontent.com/paul-levy/SF_diversity/master/paul_plt_style.mplstyle');
  from matplotlib import rcParams
  rcParams['font.size'] = 20;
  rcParams['pdf.fonttype'] = 42 # should be 42, but there are kerning issues
  rcParams['ps.fonttype'] = 42 # should be 42, but there are kerning issues
  rcParams['lines.linewidth'] = 2.5;
  rcParams['axes.linewidth'] = 1.5;
  rcParams['lines.markersize'] = 8; # this is in style sheet, just being explicit
  rcParams['lines.markeredgewidth'] = 0; # no edge, since weird tings happen then

  rcParams['xtick.major.size'] = 15
  rcParams['xtick.minor.size'] = 5; # no minor ticks
  rcParams['ytick.major.size'] = 15
  rcParams['ytick.minor.size'] = 0; # no minor ticks

  rcParams['xtick.major.width'] = 2
  rcParams['xtick.minor.width'] = 2;
  rcParams['ytick.major.width'] = 2
  rcParams['ytick.minor.width'] = 0

  rcParams['font.style'] = 'oblique';
  rcParams['font.size'] = 20;

  ############
  # load everything
  ############
  dataListNm = hf.get_datalist(expDir, force_full=force_full);
  descrFits_f0 = None;
  dLoss_num = 2; # see hf.descrFit_name/descrMod_name/etc for details
  if expDir == 'LGN/':
    rvcMod = 0; 
    dMod_num = 1;
    rvcDir = 1;
    vecF1 = -1;
  else:
    rvcMod = 1; # i.e. Naka-rushton (1)
    dMod_num = 3; # d-dog-s
    rvcDir = None; # None if we're doing vec-corrected
    if expDir == 'altExp/':
      vecF1 = 0;
    else:
      vecF1 = 1;

  dFits_mod = hf.descrMod_name(dMod_num)
  descrFits_name = hf.descrFit_name(lossType=dLoss_num, descrBase=dFits_base, modelName=dFits_mod, phAdj=1 if vecF1==-1 else None);

  ## now, let it run
  dataPath = basePath + expDir + 'structures/'
  save_loc = basePath + expDir + 'figures/'
  save_locSuper = save_loc + 'superposition_220713/'
  if use_mod_resp == 1:
    save_locSuper = save_locSuper + '%s/' % fitBase

  dataList = hf.np_smart_load(dataPath + dataListNm);
  print('Trying to load descrFits at: %s' % (dataPath + descrFits_name));
  descrFits = hf.np_smart_load(dataPath + descrFits_name);
  if use_mod_resp == 1 or use_mod_resp == 2:
    fitList = hf.np_smart_load(dataPath + fitList_nm);
  else:
    fitList = None;

  if not os.path.exists(save_locSuper):
    os.makedirs(save_locSuper)

  cells = np.arange(1, 1+len(dataList['unitName']))

  zr_rm = lambda x: x[x>0];
  # more flexible - only get values where x AND z are greater than some value "gt" (e.g. 0, 1, 0.4, ...)
  zr_rm_pair = lambda x, z, gt: [x[np.logical_and(x>gt, z>gt)], z[np.logical_and(x>gt, z>gt)]];
  # zr_rm_pair = lambda x, z: [x[np.logical_and(x>0, z>0)], z[np.logical_and(x>0, z>0)]] if np.logical_and(x!=[], z!=[])==True else [], [];

  # here, we'll save measures we are going use for analysis purpose - e.g. supperssion index, c50
  curr_suppr = dict();

  ############
  ### Establish the plot, load cell-specific measures
  ############
  nRows, nCols = 6, 2;
  cellName = dataList['unitName'][which_cell-1];
  expInd = hf.get_exp_ind(dataPath, cellName)[0]
  S = hf.np_smart_load(dataPath + cellName + '_sfm.npy')
  expData = S['sfm']['exp']['trial'];

  # 0th, let's load the basic tuning characterizations AND the descriptive fit
  try:
    dfit_curr = descrFits[which_cell-1]['params'][0,-1,:]; # single grating, highest contrast
  except:
    dfit_curr = None;
  # - then the basics
  try:
    basic_names, basic_order = dataList['basicProgName'][which_cell-1], dataList['basicProgOrder']
    basics = hf.get_basic_tunings(basic_names, basic_order);
  except:
    try:
      # we've already put the basics in the data structure... (i.e. post-sorting 2021 data)
      basic_names = ['','','','',''];
      basic_order = ['rf', 'sf', 'tf', 'rvc', 'ori']; # order doesn't matter if they are already loaded
      basics = hf.get_basic_tunings(basic_names, basic_order, preProc=S, reducedSave=True)
    except:
      basics = None;

  ### TEMPORARY: save the "basics" in curr_suppr; should live on its own, though; TODO
  curr_suppr['basics'] = basics;

  try:
    oriBW, oriCV = basics['ori']['bw'], basics['ori']['cv'];
  except:
    oriBW, oriCV = np.nan, np.nan;
  try:
    tfBW = basics['tf']['tfBW_oct'];
  except:
    tfBW = np.nan;
  try:
    suprMod = basics['rfsize']['suprInd_model'];
  except:
    suprMod = np.nan;
  try:
    suprDat = basics['rfsize']['suprInd_data'];
  except:
    suprDat = np.nan;

  try:
    cellType = dataList['unitType'][which_cell-1];
  except:
    # TODO: note, this is dangerous; thus far, only V1 cells don't have 'unitType' field in dataList, so we can safely do this
    cellType = 'V1';


  ############
  ### compute f1f0 ratio, and load the corresponding F0 or F1 responses
  ############
  f1f0_rat = hf.compute_f1f0(expData, which_cell, expInd, dataPath, descrFitName_f0=descrFits_f0)[0];
  curr_suppr['f1f0'] = f1f0_rat;
  respMeasure = 1 if f1f0_rat > 1 else 0;

  if vecF1 == 1:
    # get the correct, adjusted F1 response
    if expInd > f1_expCutoff and respMeasure == 1:
      respOverwrite = hf.adjust_f1_byTrial(expData, expInd);
    else:
      respOverwrite = None;

  if (respMeasure == 1 or expDir == 'LGN/') and expDir != 'altExp/' : # i.e. if we're looking at a simple cell, then let's get F1
    if vecF1 == 1:
      spikes_byComp = respOverwrite
      # then, sum up the valid components per stimulus component
      allCons = np.vstack(expData['con']).transpose();
      blanks = np.where(allCons==0);
      spikes_byComp[blanks] = 0; # just set it to 0 if that component was blank during the trial
    else:
      if rvcName is not None:
        try:
          rvcFits = hf.get_rvc_fits(dataPath, expInd, which_cell, rvcName=rvcName, rvcMod=rvcMod, direc=rvcDir, vecF1=vecF1);
        except:
          rvcFits = None;
      else:
        rvcFits = None
      spikes_byComp = hf.get_spikes(expData, get_f0=0, rvcFits=rvcFits, expInd=expInd);
    spikes = np.array([np.sum(x) for x in spikes_byComp]);
    rates = True if vecF1 == 0 else False; # when we get the spikes from rvcFits, they've already been converted into rates (in hf.get_all_fft)
    baseline = None; # f1 has no "DC", yadig?
  else: # otherwise, if it's complex, just get F0
    respMeasure = 0;
    spikes = hf.get_spikes(expData, get_f0=1, rvcFits=None, expInd=expInd);
    rates = False; # get_spikes without rvcFits is directly from spikeCount, which is counts, not rates!
    baseline = hf.blankResp(expData, expInd)[0]; # we'll plot the spontaneous rate
    # why mult by stimDur? well, spikes are not rates but baseline is, so we convert baseline to count (i.e. not rate, too)
    spikes = spikes - baseline*hf.get_exp_params(expInd).stimDur; 

  #print('###\nGetting spikes (data): rates? %d\n###' % rates);
  _, _, _, respAll = hf.organize_resp(spikes, expData, expInd, respsAsRate=rates); # only using respAll to get variance measures
  resps_data, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, overwriteSpikes=spikes, respsAsRates=rates, modsAsRate=rates);

  if fitList is None:
    resps = resps_data; # otherwise, we'll still keep resps_data for reference
  elif fitList is not None: # OVERWRITE the data with the model spikes!
    if use_mod_resp == 1:
      curr_fit = fitList[which_cell-1]['params'];
      modResp = mod_resp.SFMGiveBof(curr_fit, S, normType=fitType, lossType=lossType, expInd=expInd, cellNum=which_cell, excType=excType)[1];
      if f1f0_rat < 1: # then subtract baseline..
        modResp = modResp - baseline*hf.get_exp_params(expInd).stimDur; 
      # now organize the responses
      resps, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, overwriteSpikes=modResp, respsAsRates=False, modsAsRate=False);
    elif use_mod_resp == 2: # then pytorch model!
      resp_str = hf_sf.get_resp_str(respMeasure)
      curr_fit = fitList[which_cell-1][resp_str]['params'];
      model = mrpt.sfNormMod(curr_fit, expInd=expInd, excType=excType, normType=fitType, lossType=lossType, lgnFrontEnd=lgnFrontEnd, newMethod=newMethod, lgnConType=conType, applyLGNtoNorm=_applyLGNtoNorm)
      ### get the vec-corrected responses, if applicable
      if expInd > f1_expCutoff and respMeasure == 1:
        respOverwrite = hf.adjust_f1_byTrial(expData, expInd);
      else:
        respOverwrite = None;

      dw = mrpt.dataWrapper(expData, respMeasure=respMeasure, expInd=expInd, respOverwrite=respOverwrite); # respOverwrite defined above (None if DC or if expInd=-1)
      modResp = model.forward(dw.trInf, respMeasure=respMeasure, sigmoidSigma=_sigmoidSigma, recenter_norm=recenter_norm).detach().numpy();

      if respMeasure == 1: # make sure the blank components have a zero response (we'll do the same with the measured responses)
        blanks = np.where(dw.trInf['con']==0);
        modResp[blanks] = 0;
        # next, sum up across components
        modResp = np.sum(modResp, axis=1);
      # finally, make sure this fills out a vector of all responses (just have nan for non-modelled trials)
      nTrialsFull = len(expData['num']);
      modResp_full = np.nan * np.zeros((nTrialsFull, ));
      modResp_full[dw.trInf['num']] = modResp;

      if respMeasure == 0: # if DC, then subtract baseline..., as determined from data (why not model? we aren't yet calc. response to no stim, though it can be done)
        modResp_full = modResp_full - baseline*hf.get_exp_params(expInd).stimDur;

      # TODO: This is a work around for which measures are in rates vs. counts (DC vs F1, model vs data...)
      stimDur = hf.get_exp_params(expInd).stimDur;
      asRates = False;
      #divFactor = stimDur if asRates == 0 else 1;
      #modResp_full = np.divide(modResp_full, divFactor);
      # now organize the responses
      resps, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, overwriteSpikes=modResp_full, respsAsRates=asRates, modsAsRate=asRates);

  predResps = resps[2];

  respMean = resps[0]; # equivalent to resps[0];
  respStd = np.nanstd(respAll, -1); # take std of all responses for a given condition
  # compute SEM, too
  findNaN = np.isnan(respAll);
  nonNaN  = np.sum(findNaN == False, axis=-1);
  respSem = np.nanstd(respAll, -1) / np.sqrt(nonNaN);

  ############
  ### first, fit a smooth function to the overall pred V measured responses
  ### --- from this, we can measure how each example superposition deviates from a central tendency
  ### --- i.e. the residual relative to the "standard" input:output relationship
  ############
  all_resps = respMean[1:, :, :].flatten() # all disp>0
  all_preds = predResps[1:, :, :].flatten() # all disp>0
  # a model which allows negative fits
  #         myFit = lambda x, t0, t1, t2: t0 + t1*x + t2*x*x;
  #         non_nan = np.where(~np.isnan(all_preds)); # cannot fit negative values with naka-rushton...
  #         fitz, _ = opt.curve_fit(myFit, all_preds[non_nan], all_resps[non_nan], p0=[-5, 10, 5], maxfev=5000)
  # naka rushton
  myFit = lambda x, g, expon, c50: hf.naka_rushton(x, [0, g, expon, c50]) 
  non_neg = np.where(all_preds>0) # cannot fit negative values with naka-rushton...
  try:
    if use_mod_resp == 1: # the reference will ALWAYS be the data -- redo the above analysis for data
      predResps_data = resps_data[2];
      respMean_data = resps_data[0];
      all_resps_data = respMean_data[1:, :, :].flatten() # all disp>0
      all_preds_data = predResps_data[1:, :, :].flatten() # all disp>0
      non_neg_data = np.where(all_preds_data>0) # cannot fit negative values with naka-rushton...
      fitz, _ = opt.curve_fit(myFit, all_preds_data[non_neg_data], all_resps_data[non_neg_data], p0=[100, 2, 25], maxfev=5000)
    else:
      fitz, _ = opt.curve_fit(myFit, all_preds[non_neg], all_resps[non_neg], p0=[100, 2, 25], maxfev=5000)
    rel_c50 = np.divide(fitz[-1], np.max(all_preds[non_neg]));
  except:
    fitz = None;
    rel_c50 = -99;

  ############
  ### organize stimulus information
  ############
  all_disps = stimVals[0];
  all_cons = stimVals[1];
  all_sfs = stimVals[2];

  nCons = len(all_cons);
  nSfs = len(all_sfs);
  nDisps = len(all_disps);

  maxResp = np.maximum(np.nanmax(respMean), np.nanmax(predResps));
  # by disp
  clrs_d = cm.viridis(np.linspace(0,0.75,nDisps-1));
  lbls_d = ['disp: %s' % str(x) for x in range(nDisps)];
  # by sf
  val_sfs = hf.get_valid_sfs(S, disp=1, con=val_con_by_disp[1][0], expInd=expInd) # pick 
  clrs_sf = cm.viridis(np.linspace(0,.75,len(val_sfs)));
  lbls_sf = ['sf: %.2f' % all_sfs[x] for x in val_sfs];
  # by con
  val_con = all_cons;
  clrs_con = cm.viridis(np.linspace(0,.75,len(val_con)));
  lbls_con = ['con: %.2f' % x for x in val_con];

  ############
  ### create the figure
  ############
  fSuper, ax = plt.subplots(nRows, nCols, figsize=(10*nCols, 8*nRows))
  sns.despine(fig=fSuper, offset=10)

  allMix = [];
  allSum = [];

  ### plot reference tuning [row 1 (i.e. 2nd row)]
  ## on the right, SF tuning (high contrast)
  sfRef = hf.nan_rm(respMean[0, :, -1]); # high contrast tuning
  ax[1, 1].plot(all_sfs, sfRef, 'k-', marker='o', label='ref. tuning (d0, high con)', clip_on=False)
  ax[1, 1].set_xscale('log')
  ax[1, 1].set_xlim((0.1, 10));
  ax[1, 1].set_xlabel('sf (c/deg)')
  ax[1, 1].set_ylabel('response (spikes/s)')
  ax[1, 1].set_ylim((-5, 1.1*np.nanmax(sfRef)));
  ax[1, 1].legend(fontsize='x-small');

  #####
  ## then on the left, RVC (peak SF)
  #####
  sfPeak = np.argmax(sfRef); # stupid/simple, but just get the rvc for the max response
  v_cons_single = val_con_by_disp[0]
  rvcRef = hf.nan_rm(respMean[0, sfPeak, v_cons_single]);
  # now, if possible, let's also plot the RVC fit
  if rvcFits is not None:
    rvcFits = hf.get_rvc_fits(dataPath, expInd, which_cell, rvcName=rvcName, rvcMod=rvcMod);
    rel_rvc = rvcFits[0]['params'][sfPeak]; # we get 0 dispersion, peak SF
    plt_cons = np.geomspace(all_cons[0], all_cons[-1], 50);
    c50, pk = hf.get_c50(rvcMod, rel_rvc), rvcFits[0]['conGain'][sfPeak];
    c50_emp, c50_eval = hf.c50_empirical(rvcMod, rel_rvc); # determine c50 by optimization, numerical approx.
    if rvcMod == 0:
      rvc_mod = hf.get_rvc_model();
      rvcmodResp = rvc_mod(*rel_rvc, plt_cons);
    else: # i.e. mod=1 or mod=2
      rvcmodResp = hf.naka_rushton(plt_cons, rel_rvc);
    if baseline is not None:
      rvcmodResp = rvcmodResp - baseline; 
    ax[1, 0].plot(plt_cons, rvcmodResp, 'k--', label='rvc fit (c50=%.2f, gain=%0f)' %(c50, pk))
    # and save it
    curr_suppr['c50'] = c50; curr_suppr['conGain'] = pk;
    curr_suppr['c50_emp'] = c50_emp; curr_suppr['c50_emp_eval'] = c50_eval
  else:
    curr_suppr['c50'] = np.nan; curr_suppr['conGain'] = np.nan;
    curr_suppr['c50_emp'] = np.nan; curr_suppr['c50_emp_eval'] = np.nan;

  ax[1, 0].plot(all_cons[v_cons_single], rvcRef, 'k-', marker='o', label='ref. tuning (d0, peak SF)', clip_on=False)
  #         ax[1, 0].set_xscale('log')
  ax[1, 0].set_xlabel('contrast (%)');
  ax[1, 0].set_ylabel('response (spikes/s)')
  ax[1, 0].set_ylim((-5, 1.1*np.nanmax(rvcRef)));
  ax[1, 0].legend(fontsize='x-small');

  # plot the fitted model on each axis
  pred_plt = np.linspace(0, np.nanmax(all_preds), 100);
  if fitz is not None:
    ax[0, 0].plot(pred_plt, myFit(pred_plt, *fitz), 'r--', label='fit')
    ax[0, 1].plot(pred_plt, myFit(pred_plt, *fitz), 'r--', label='fit')

  for d in range(nDisps):
    if d == 0: # we don't care about single gratings!
      dispRats = [];
      continue; 
    v_cons = np.array(val_con_by_disp[d]);
    n_v_cons = len(v_cons);

    # plot split out by each contrast [0,1]
    for c in reversed(range(n_v_cons)):
      v_sfs = hf.get_valid_sfs(S, d, v_cons[c], expInd)
      for s in v_sfs:
        mixResp = respMean[d, s, v_cons[c]];
        allMix.append(mixResp);
        sumResp = predResps[d, s, v_cons[c]];
        allSum.append(sumResp);
  #      print('condition: d(%d), c(%d), sf(%d):: pred(%.2f)|real(%.2f)' % (d, v_cons[c], s, sumResp, mixResp))
        # PLOT in by-disp panel
        if c == 0 and s == v_sfs[0]:
          ax[0, 0].plot(sumResp, mixResp, 'o', color=clrs_d[d-1], label=lbls_d[d], clip_on=False)
        else:
          ax[0, 0].plot(sumResp, mixResp, 'o', color=clrs_d[d-1], clip_on=False)
        # PLOT in by-sf panel
        sfInd = np.where(np.array(v_sfs) == s)[0][0]; # will only be one entry, so just "unpack"
        try:
          if d == 1 and c == 0:
            ax[0, 1].plot(sumResp, mixResp, 'o', color=clrs_sf[sfInd], label=lbls_sf[sfInd], clip_on=False);
          else:
            ax[0, 1].plot(sumResp, mixResp, 'o', color=clrs_sf[sfInd], clip_on=False);
        except:
          pass;
          #pdb.set_trace();
        # plot baseline, if f0...
  #       if baseline is not None:
  #         [ax[0, i].axhline(baseline, linestyle='--', color='k', label='spon. rate') for i in range(2)];


    # plot averaged across all cons/sfs (i.e. average for the whole dispersion) [1,0]
    mixDisp = respMean[d, :, :].flatten();
    sumDisp = predResps[d, :, :].flatten();
    mixDisp, sumDisp = zr_rm_pair(mixDisp, sumDisp, 0.5);
    curr_rats = np.divide(mixDisp, sumDisp)
    curr_mn = geomean(curr_rats); curr_std = np.std(np.log10(curr_rats));
  #  curr_rat = geomean(np.divide(mixDisp, sumDisp));
    ax[2, 0].bar(d, curr_mn, yerr=curr_std, color=clrs_d[d-1]);
    ax[2, 0].set_yscale('log')
    ax[2, 0].set_ylim(0.1, 10);
  #  ax[2, 0].yaxis.set_ticks(minorticks)
    dispRats.append(curr_mn);
  #  ax[2, 0].bar(d, np.mean(np.divide(mixDisp, sumDisp)), color=clrs_d[d-1]);

    # also, let's plot the (signed) error relative to the fit
    if fitz is not None:
      errs = mixDisp - myFit(sumDisp, *fitz);
      ax[3, 0].bar(d, np.mean(errs), yerr=np.std(errs), color=clrs_d[d-1])
      # -- and normalized by the prediction output response
      errs_norm = np.divide(mixDisp - myFit(sumDisp, *fitz), myFit(sumDisp, *fitz));
      ax[4, 0].bar(d, np.mean(errs_norm), yerr=np.std(errs_norm), color=clrs_d[d-1])

    # and set some labels/lines, as needed
    if d == 1:
        ax[2, 0].set_xlabel('dispersion');
        ax[2, 0].set_ylabel('suppression ratio (linear)')
        ax[2, 0].axhline(1, ls='--', color='k')
        ax[3, 0].set_xlabel('dispersion');
        ax[3, 0].set_ylabel('mean (signed) error')
        ax[3, 0].axhline(0, ls='--', color='k')
        ax[4, 0].set_xlabel('dispersion');
        ax[4, 0].set_ylabel('mean (signed) error -- as frac. of fit prediction')
        ax[4, 0].axhline(0, ls='--', color='k')

    curr_suppr['supr_disp'] = dispRats;

  ### plot averaged across all cons/disps
  sfInds = []; sfRats = []; sfRatStd = []; 
  sfErrs = []; sfErrsStd = []; sfErrsInd = []; sfErrsIndStd = []; sfErrsRat = []; sfErrsRatStd = [];
  curr_errNormFactor = [];
  for s in range(len(val_sfs)):
    try: # not all sfs will have legitmate values;
      # only get mixtures (i.e. ignore single gratings)
      mixSf = respMean[1:, val_sfs[s], :].flatten();
      sumSf = predResps[1:, val_sfs[s], :].flatten();
      mixSf, sumSf = zr_rm_pair(mixSf, sumSf, 0.5);
      rats_curr = np.divide(mixSf, sumSf); 
      sfInds.append(s); sfRats.append(geomean(rats_curr)); sfRatStd.append(np.std(np.log10(rats_curr)));

      if fitz is not None:
        #curr_NR = myFit(sumSf, *fitz); # unvarnished
        curr_NR = np.maximum(myFit(sumSf, *fitz), 0.5); # thresholded at 0.5...

        curr_err = mixSf - curr_NR;
        sfErrs.append(np.mean(curr_err));
        sfErrsStd.append(np.std(curr_err))

        curr_errNorm = np.divide(mixSf - curr_NR, mixSf + curr_NR);
        sfErrsInd.append(np.mean(curr_errNorm));
        sfErrsIndStd.append(np.std(curr_errNorm))

        curr_errRat = np.divide(mixSf, curr_NR);
        sfErrsRat.append(np.mean(curr_errRat));
        sfErrsRatStd.append(np.std(curr_errRat));

        curr_normFactors = np.array(curr_NR)
        curr_errNormFactor.append(geomean(curr_normFactors[curr_normFactors>0]));
      else:
        sfErrs.append([]);
        sfErrsStd.append([]);
        sfErrsInd.append([]);
        sfErrsIndStd.append([]);
        sfErrsRat.append([]);
        sfErrsRatStd.append([]);
        curr_errNormFactor.append([]);
    except:
      pass

  # get the offset/scale of the ratio so that we can plot a rescaled/flipped version of
  # the high con/single grat tuning for reference...does the suppression match the response?
  offset, scale = np.nanmax(sfRats), np.nanmax(sfRats) - np.nanmin(sfRats);
  sfRef = hf.nan_rm(respMean[0, val_sfs, -1]); # high contrast tuning
  sfRefShift = offset - scale * (sfRef/np.nanmax(sfRef))
  ax[2,1].scatter(all_sfs[val_sfs][sfInds], sfRats, color=clrs_sf[sfInds], clip_on=False)
  ax[2,1].errorbar(all_sfs[val_sfs][sfInds], sfRats, sfRatStd, color='k', linestyle='-', clip_on=False, label='suppression tuning')
  #         ax[2,1].plot(all_sfs[val_sfs][sfInds], sfRats, 'k-', clip_on=False, label='suppression tuning')
  ax[2,1].plot(all_sfs[val_sfs], sfRefShift, 'k--', label='ref. tuning', clip_on=False)
  ax[2,1].axhline(1, ls='--', color='k')
  ax[2,1].set_xlabel('sf (cpd)')
  ax[2,1].set_xscale('log')
  ax[2,1].set_xlim((0.1, 10));
  #ax[2,1].set_xlim((np.min(all_sfs), np.max(all_sfs)));
  ax[2,1].set_ylabel('suppression ratio');
  ax[2,1].set_yscale('log')
  #ax[2,1].yaxis.set_ticks(minorticks)
  ax[2,1].set_ylim(0.1, 10);        
  ax[2,1].legend(fontsize='x-small');
  curr_suppr['supr_sf'] = sfRats;

  ### residuals from fit of suppression
  if fitz is not None:
    # mean signed error: and labels/plots for the error as f'n of SF
    ax[3,1].axhline(0, ls='--', color='k')
    ax[3,1].set_xlabel('sf (cpd)')
    ax[3,1].set_xscale('log')
    ax[3,1].set_xlim((0.1, 10));
    #ax[3,1].set_xlim((np.min(all_sfs), np.max(all_sfs)));
    ax[3,1].set_ylabel('mean (signed) error');
    ax[3,1].errorbar(all_sfs[val_sfs][sfInds], sfErrs, sfErrsStd, color='k', marker='o', linestyle='-', clip_on=False)
    # -- and normalized by the prediction output response + output respeonse
    val_errs = np.logical_and(~np.isnan(sfErrsRat), np.logical_and(np.array(sfErrsIndStd)>0, np.array(sfErrsIndStd) < 2));
    norm_subset = np.array(sfErrsInd)[val_errs];
    normStd_subset = np.array(sfErrsIndStd)[val_errs];
    ax[4,1].axhline(0, ls='--', color='k')
    ax[4,1].set_xlabel('sf (cpd)')
    ax[4,1].set_xscale('log')
    ax[4,1].set_xlim((0.1, 10));
    #ax[4,1].set_xlim((np.min(all_sfs), np.max(all_sfs)));
    ax[4,1].set_ylim((-1, 1));
    ax[4,1].set_ylabel('error index');
    ax[4,1].errorbar(all_sfs[val_sfs][sfInds][val_errs], norm_subset, normStd_subset, color='k', marker='o', linestyle='-', clip_on=False)
    # -- AND simply the ratio between the mixture response and the mean expected mix response (i.e. Naka-Rushton)
    # --- equivalent to the suppression ratio, but relative to the NR fit rather than perfect linear summation
    val_errs = np.logical_and(~np.isnan(sfErrsRat), np.logical_and(np.array(sfErrsRatStd)>0, np.array(sfErrsRatStd) < 2));
    rat_subset = np.array(sfErrsRat)[val_errs];
    ratStd_subset = np.array(sfErrsRatStd)[val_errs];
    #ratStd_subset = (1/np.log(2))*np.divide(np.array(sfErrsRatStd)[val_errs], rat_subset);
    ax[5,1].scatter(all_sfs[val_sfs][sfInds][val_errs], rat_subset, color=clrs_sf[sfInds][val_errs], clip_on=False)
    ax[5,1].errorbar(all_sfs[val_sfs][sfInds][val_errs], rat_subset, ratStd_subset, color='k', linestyle='-', clip_on=False, label='suppression tuning')
    ax[5,1].axhline(1, ls='--', color='k')
    ax[5,1].set_xlabel('sf (cpd)')
    ax[5,1].set_xscale('log')
    ax[5,1].set_xlim((0.1, 10));
    ax[5,1].set_ylabel('suppression ratio (wrt NR)');
    ax[5,1].set_yscale('log', basey=2)
  #         ax[2,1].yaxis.set_ticks(minorticks)
    ax[5,1].set_ylim(np.power(2.0, -2), np.power(2.0, 2));
    ax[5,1].legend(fontsize='x-small');
    # - compute the variance - and put that value on the plot
    errsRatVar = np.var(np.log2(sfErrsRat)[val_errs]);
    curr_suppr['sfRat_VAR'] = errsRatVar;
    ax[5,1].text(0.1, 2, 'var=%.2f' % errsRatVar);

    # compute the unsigned "area under curve" for the sfErrsInd, and normalize by the octave span of SF values considered
    val_errs = np.logical_and(~np.isnan(sfErrsRat), np.logical_and(np.array(sfErrsIndStd)>0, np.array(sfErrsIndStd) < 2));
    val_x = all_sfs[val_sfs][sfInds][val_errs];
    ind_var = np.var(np.array(sfErrsInd)[val_errs]);
    curr_suppr['sfErrsInd_VAR'] = ind_var;
    # - and put that value on the plot
    ax[4,1].text(0.1, -0.25, 'var=%.3f' % ind_var);
  else:
    curr_suppr['sfErrsInd_VAR'] = np.nan
    curr_suppr['sfRat_VAR'] = np.nan

  #########
  ### NOW, let's evaluate the derivative of the SF tuning curve and get the correlation with the errors
  #########
  mod_sfs = np.geomspace(all_sfs[0], all_sfs[-1], 1000);
  mod_resp = hf.get_descrResp(dfit_curr, mod_sfs, DoGmodel=dMod_num);
  deriv = np.divide(np.diff(mod_resp), np.diff(np.log10(mod_sfs)))
  deriv_norm = np.divide(deriv, np.maximum(np.nanmax(deriv), np.abs(np.nanmin(deriv)))); # make the maximum response 1 (or -1)
  # - then, what indices to evaluate for comparing with sfErr?
  errSfs = all_sfs[val_sfs][sfInds];
  mod_inds = [np.argmin(np.square(mod_sfs-x)) for x in errSfs];
  deriv_norm_eval = deriv_norm[mod_inds];
  # -- plot on [1, 1] (i.e. where the data is)
  ax[1,1].plot(mod_sfs, mod_resp, 'k--', label='fit (g)')
  ax[1,1].legend();
  # Duplicate "twin" the axis to create a second y-axis
  ax2 = ax[1,1].twinx();
  ax2.set_xscale('log'); # have to re-inforce log-scale?
  ax2.set_ylim([-1, 1]); # since the g' is normalized
  # make a plot with different y-axis using second axis object
  ax2.plot(mod_sfs[1:], deriv_norm, '--', color="red", label='g\'');
  ax2.set_ylabel("deriv. (normalized)",color="red")
  ax2.legend();
  sns.despine(ax=ax2, offset=10, right=False);
  # -- and let's plot rescaled and shifted version in [2,1]
  offset, scale = np.nanmax(sfRats), np.nanmax(sfRats) - np.nanmin(sfRats);
  derivShift = offset - scale * (deriv_norm/np.nanmax(deriv_norm));
  ax[2,1].plot(mod_sfs[1:], derivShift, 'r--', label='deriv(ref. tuning)', clip_on=False)
  ax[2,1].legend(fontsize='x-small');
  # - then, normalize the sfErrs/sfErrsInd and compute the correlation coefficient
  if fitz is not None:
    norm_sfErr = np.divide(sfErrs, np.nanmax(np.abs(sfErrs)));
    norm_sfErrInd = np.divide(sfErrsInd, np.nanmax(np.abs(sfErrsInd))); # remember, sfErrsInd is normalized per condition; this is overall
    non_nan = np.logical_and(~np.isnan(norm_sfErr), ~np.isnan(deriv_norm_eval))
    corr_nsf, corr_nsfN = np.corrcoef(deriv_norm_eval[non_nan], norm_sfErr[non_nan])[0,1], np.corrcoef(deriv_norm_eval[non_nan], norm_sfErrInd[non_nan])[0,1]
    curr_suppr['corr_derivWithErr'] = corr_nsf;
    curr_suppr['corr_derivWithErrsInd'] = corr_nsfN;
    ax[3,1].text(0.1, 0.25*np.nanmax(sfErrs), 'corr w/g\' = %.2f' % corr_nsf)
    ax[4,1].text(0.1, 0.25, 'corr w/g\' = %.2f' % corr_nsfN)
  else:
    curr_suppr['corr_derivWithErr'] = np.nan;
    curr_suppr['corr_derivWithErrsInd'] = np.nan;

  # make a polynomial fit
  try:
    hmm = np.polyfit(allSum, allMix, deg=1) # returns [a, b] in ax + b 
  except:
    hmm = [np.nan];
  curr_suppr['supr_index'] = hmm[0];

  for j in range(1):
    for jj in range(nCols):
      ax[j, jj].axis('square')
      ax[j, jj].set_xlabel('prediction: sum(components) (imp/s)');
      ax[j, jj].set_ylabel('mixture response (imp/s)');
      ax[j, jj].plot([0, 1*maxResp], [0, 1*maxResp], 'k--')
      ax[j, jj].set_xlim((-5, maxResp));
      ax[j, jj].set_ylim((-5, 1.1*maxResp));
      ax[j, jj].set_title('Suppression index: %.2f|%.2f' % (hmm[0], rel_c50))
      ax[j, jj].legend(fontsize='x-small');

  fSuper.suptitle('Superposition: %s #%d [%s; f1f0 %.2f; szSupr[dt/md] %.2f/%.2f; oriBW|CV %.2f|%.2f; tfBW %.2f]' % (cellType, which_cell, cellName, f1f0_rat, suprDat, suprMod, oriBW, oriCV, tfBW))

  if fitList is None:
    save_name = 'cell_%03d.pdf' % which_cell
  else:
    save_name = 'cell_%03d_mod%s.pdf' % (which_cell, hf.fitType_suffix(fitType))
  pdfSv = pltSave.PdfPages(str(save_locSuper + save_name));
  pdfSv.savefig(fSuper)
  pdfSv.close();

  #########
  ### Finally, add this "superposition" to the newest 
  #########

  if to_save:

    if fitList is None:
      from datetime import datetime
      suffix = datetime.today().strftime('%y%m%d')
      super_name = 'superposition_analysis_%s.npy' % suffix;
    else:
      super_name = 'superposition_analysis_mod%s.npy' % hf.fitType_suffix(fitType);

    pause_tm = 5*np.random.rand();
    print('sleeping for %d secs (#%d)' % (pause_tm, which_cell));
    time.sleep(pause_tm);

    if os.path.exists(dataPath + super_name):
      suppr_all = hf.np_smart_load(dataPath + super_name);
    else:
      suppr_all = dict();
    suppr_all[which_cell-1] = curr_suppr;
    np.save(dataPath + super_name, suppr_all);
  
  return curr_suppr;