Example #1
0
def s_cone_weights(d, celllist, celldat, params):
    '''
    '''

    lm_midgets = an.compute_s_dist_cone_weight(d, celldat, celllist, params)

    # plotting routines    
    ax, fig1 = pf.get_axes(1, 1, nticks=[4, 5], return_fig=True)
    ax[0].plot(lm_midgets[:, 0], lm_midgets[:, 1], 'ko')

    ax[0].set_xlabel('distance from S-cone (arcmin)')
    ax[0].set_ylabel('S / (L+M+S)')

    # histogram of same data
    ax2, fig2 = pf.get_axes(1, 1, nticks=[4, 5], return_fig=True)
    ax2[0].spines['bottom'].set_smart_bounds(False)

    count, bins = np.histogram(lm_midgets[:, 1], bins=15)
    count = count / count.sum() * 100
    bins, count = pf.histOutline(count, bins)
    ax2[0].plot(bins, count, 'k-')

    ax2[0].set_xlim([0, 1])
    ax2[0].set_ylabel('% of cells')
    ax2[0].set_xlabel('S / (L+M+S)')

    # Save plots
    savedir = util.get_save_dirname(params, check_randomized=True)
    fig1.savefig(savedir + 's_weight_scatter.eps', edgecolor='none')
    fig2.savefig(savedir + 's_weight_hist.eps', edgecolor='none')
Example #2
0
def HueScaling(cm, lPeak=559):
    '''
    '''
    hues = cm.getHueScalingData(
                ConeRatio={'fracLvM': 0.70, 's': 0.05, },
                maxSens={'l': lPeak, 'm': 530.0, 's': 417.0, })
    
    ax = pf.get_axes()[0]
    ax.plot(hues['lambdas'], hues['red'], 'r')
    ax.plot(hues['lambdas'], hues['green'], 'g') 
    ax.plot(hues['lambdas'], hues['blue'], 'b') 
    ax.plot(hues['lambdas'], hues['yellow'], 'y')  

    ax.set_xlabel('wavelength (nm)')
    ax.set_ylabel('percentage')
    ax.set_xlim([390, 750])

    plt.show()
Example #3
0
def s_cone_dist_analysis(rg, results, model_name):
    '''
    '''
    # checkout proximity to s cone and rg metric
    ax, fig = pf.get_axes(1, 1, nticks=[3, 3], return_fig=True)
    ax = ax[0]
    inds = results[:, 4] < 0.4 # eliminate S-cone center cells
    ax.plot(results[inds, 4], np.abs(rg[inds]), 'ko')
    
    ax.set_xlabel('S-cone weight')
    ax.set_ylabel('absolute value rg response')

    # See if there is a relationship between color names and distance to S-cone
    X = sm.add_constant(results[inds, 4], prepend=True)
    OLSmodel = sm.OLS(rg[inds], X)
    res = OLSmodel.fit()
    print '\n\n\n'
    print res.rsquared

    savedir = util.get_save_dirname(params, check_randomized=True)
    fig.savefig(savedir + 's_cone_distance.svg', edgecolor='none')    
Example #4
0
def fit_linear_model(rg, results, params, opponent=True):
    # Switch depending upon opponent option
    if not opponent:
        predictors = results[:, [2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
                                19, 20, 21, 22, 23, 24, 25, 26, 27]] 
    else:
        # organize in opponent fashion
        N = 2 # center cone and nearest x
        predictors = np.zeros((len(results[:, 0]), N + 1))
        _inds = [2, 10, 13, 16, 19, 22, 25, 28, 31]
        for i in range(N + 1):
            predictors[:, i] = ((results[:, _inds[i]] +
                                 results[:, _inds[i] + 2])
                                - results[:, _inds[i] + 1])
    # cone weights in L, M, S order
    inds = [2, 10, 13]
    X = sm.add_constant(predictors, prepend=True)
    OLSmodel = sm.OLS(rg, X) #GLM(rg, X)
    res = OLSmodel.fit()
    #print res.rsquared
    print '\n\n\n'
    print (res.summary())

    linear_reg = True
    ridge_reg = False

    ## Linear regression on training data
    clfs = []
    if linear_reg:
        clfs.append(LinearRegression())
    if ridge_reg:
        clfs.append(Ridge())

    ax, fig = pf.get_axes(1, 1, nticks=[3, 3], return_fig=True)
    ax = ax[0]
    Ntrials = 100
    mae = np.zeros((Ntrials, 1))
    for i in range(Ntrials):
        x_train, x_test, y_train, y_test = model_selection.train_test_split(
            predictors, rg, test_size=0.15)

        for clf in clfs:

            clf.fit(x_train, y_train) 
            mae[i] = mean_absolute_error(y_test, clf.predict(x_test))    

            ax.plot(y_test, clf.predict(x_test), 'ko', alpha=0.5)

    print '\n\n'
    print mae.mean(), mae.std()

    ax.set_xlabel('observed')
    ax.set_ylabel('predicted')
    ax.set_aspect('equal')

    if y_train.min() < 0:
        ax.plot([-1, 1], [-1, 1], 'k-')
        ax.set_ylim([-1, 1])
        ax.set_xlim([-1, 1])
    else:
        ax.plot([0, 1], [0, 1], 'k-')
    savedir = util.get_save_dirname(params, check_randomized=True)
    fig.savefig(params + 'cone_inputs_model_error.svg',
                edgecolor='none')
Example #5
0
def classify_analysis(d, params, purity_thresh=0.0, nseeds=10):
    '''
    '''
    # --- params --- #
    rg_metric = False
    background = 'white'
    cmdscaling = True
    # kernel options=linear, rbf, poly
    kernel = 'linear'

    # MDScaling options
    dims = [0, 1]
    test_size = 0.15
    # if kernel is set to linear only first entry of 'C' is used, gamma ignored
    param_grid = {'C': [1e9],
                  'gamma': [0.0001, 0.001], }
    # -------------- #
    human_subjects = ['wt', 'bps']
    if params['model_name'].lower() not in human_subjects:
        raise('Model must be a human subject with psychopysics data')

    # no need to run a bunch of times since lms is easy to classify    
    if not params['color_cats_switch']:
        nseeds = 1 

    # get some info about the cones
    nn_dat = util.get_nn_dat(params['model_name'])
    celllist = util.get_cell_list(d)
    ncells = len(celllist)
    
    # put responses into a matrix for easy processing
    data_matrix = an.get_data_matrix(d, params['cell_type'])

    # compute distance matrix
    corrmat = an.compute_corr_matrix(data_matrix)

    if cmdscaling:
        # compute the classical multi-dimensional scaling
        config_mat, eigen = dat.cmdscale(corrmat)
    else:
        pca = PCA(n_components=len(dims)).fit(data_matrix)
        config_mat = pca.transform(data_matrix)

    # get the location and cone type of each cone
    xy_lms = an.get_cone_xy_lms(d, nn_dat, celllist)

    # threshold rg that separates red, white, green
    if params['color_cats_switch']:
        # get response
        cone_contrast=params['cone_contrast']
        r = an.response(d, params)
        output = an.associate_cone_color_resp(r, nn_dat, celllist, params['model_name'], 
                                              bkgd=background, 
                                              randomized=params['randomized'])
        stim_cone_ids = output[:, -1]
        stim_cone_inds = np.zeros((1, len(stim_cone_ids)), dtype='int')
        for cone in range(len(stim_cone_ids)):
            stim_cone_inds[0, cone] = np.where(nn_dat[:, 0] == 
                                               stim_cone_ids[cone])[0]

        if rg_metric:
            # break rg into three categories: red, green, white
            rg, by, high_purity = an.get_rgby_from_naming(output[:, 5:10], 
                                                          purity_thresh)
            rgby_thresh = 0.5
            red = rg < -rgby_thresh
            green = rg > rgby_thresh
            blue = by < -rgby_thresh
            yellow = by > rgby_thresh
            color_categories = (yellow * 4 + blue * 3 + red * 2 + green).T[0]

        else: # use dom response category
            max_cat = np.argmax(output[:, 5:10], axis=1)
            red = max_cat == 1
            green = max_cat == 2
            blue = max_cat == 3
            yellow = max_cat == 4
            color_categories = (yellow * 4 + blue * 3 + red * 2 + green).T

        class_cats = color_categories.copy()
        config_mat = config_mat[stim_cone_inds, :][0]
        data_matrix = data_matrix[stim_cone_inds, :][0]
        ncells = len(class_cats)
    else:
        class_cats = xy_lms[:, 2]

    # SVM Classify
    print 'running SVM'
    print '\tCMDScaling=' + str(cmdscaling)
    print '\tkernel=' + str(kernel)
    print '\tRGmetric=' + str(rg_metric)
    print '\tbackground=' + background

    target_names, class_cats = get_target_names_categories(params['color_cats_switch'], 
                                                           class_cats)
    clf, report = an.svm_classify(data_matrix, class_cats, param_grid, target_names, 
                          cmdscaling, dims=dims, display_verbose=True, 
                          rand_seed=2264235, Nseeds=nseeds, test_size=test_size,
                          kernel=kernel)
    print report

    # --------------------------------------------------- #
    print 'plotting results from SVM'

    # undo category shift of plotting 
    if background == 'blue' and params['color_cats_switch']:
        class_cats[class_cats > 0] += 1

    # need to order corrmat based on color category
    sort_inds = np.argsort(class_cats)
    sort_data_matrix = data_matrix[sort_inds, :]
    sort_corrmat = an.compute_corr_matrix(sort_data_matrix)
    # plot correlation matrix
    ax, fig1 = pf.get_axes(1, 1, nticks=[3, 3], return_fig=True)    
    ax[0].imshow(sort_corrmat)
    # change axes to indicate location of category boundaries

    # plot MDS configuration matrix in 2 and 3dim
    ax, fig2 = pf.get_axes(1, 1, nticks=[3, 3], return_fig=True)
    # add decision function
    xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 200), np.linspace(-1.5, 1.5, 200))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    ax[0].contourf(xx, yy, -Z, cmap=plt.cm.Paired, alpha=0.8)
                         
    for cone in range(ncells):
        if (class_cats[cone] == 0 and params['color_cats_switch'] is False 
            or class_cats[cone] == 3):
            color = [0, 0, 1]
        elif class_cats[cone] == 0 and params['color_cats_switch'] is True:  
            color = [0.7, 0.7, 0.7]
        elif class_cats[cone] == 1:
            color = [0, 1, 0]
        elif class_cats[cone] == 2:
            color = [1, 0, 0]
        elif class_cats[cone] == 4:
            color = [0.8, 0.8, 0] # yellow
        else:
            raise TypeError('category type must be int [0, 4]')

        ax[0].plot(config_mat[cone, 0], config_mat[cone, 1], 'o', markersize=8,
                   alpha=0.8, color=color, markeredgecolor='k')


    # save output
    savedir = util.get_save_dirname(params, check_randomized=True)
    # save txt file
    fhandle = open(savedir + 'classification_report.txt', 'w')
    fhandle.write(report)
    fhandle.close()

    # save figs
    #fig1.savefig(savename + '_corr_matrix.eps', edgecolor='none')
    fig2.savefig(savedir + 'low_dim_rep.eps', edgecolor='none')
    #fig3.savefig(savename + '_3dplot.eps', edgecolor='none')

    if cmdscaling:
        ax, fig4 = pf.get_axes(1, 1, nticks=[3, 3], return_fig=True)
        ax[0].plot(eigen, 'ko')
        fig4.savefig(savedir + 'eigenvals.eps', edgecolor='none')
    plt.show(block=params['block_plots'])
Example #6
0
def plotModel(cm, plotModel=True, plotCurveFamily=False,
              plotUniqueHues=False, savefigs=False, 
              fracLvM=0.25, SHOW=True, OD=None, age=None,
              maxSens=None):
    """Plot cone spectral sensitivies and first stage predictions.
    """
    if maxSens is None:
        maxSens = {'l': 559.0, 'm': 530.0, 's': 417.0, }

    if plotCurveFamily:
        model = cm.colorModel(age=age)
        model.genModel(ConeRatio={'fracLvM': fracLvM, 's': 0.05, },
            maxSens=maxSens, OD=OD)

        FirstStage = model.returnFirstStage()   
        SecondStage = model.returnSecondStage()
        
        fig = plt.figure(figsize=(8.5, 8))
        fig.set_tight_layout(True)
        ax1 = fig.add_subplot(211)
        ax2 = fig.add_subplot(212)
        pf.AxisFormat()

        pf.TufteAxis(ax1, ['left', ], Nticks=[5, 5])
        pf.TufteAxis(ax2, ['left', 'bottom'], Nticks=[5, 5])

        ax1.plot(FirstStage['lambdas'], 
                 np.zeros((len(FirstStage['lambdas']))), 'k', linewidth=1.0)
        ax2.plot(FirstStage['lambdas'], 
                 np.zeros((len(FirstStage['lambdas']))), 'k', linewidth=1.0)
        
        sortedlist = []
        for key in SecondStage['percent']:
            sortedlist.append(SecondStage['percent'][key])
            #print SecondStage['percent'][key]
        sortedlist = sorted(sortedlist, key=itemgetter('probSurround'), 
                            reverse=True)
        thresh = sortedlist[15]['probSurround']

        for i in SecondStage['lmsV_L']:
            if i % 2 == 0 or SecondStage['percent'][i][
                    'probSurround'] >= thresh:
                if SecondStage['percent'][i]['probSurround'] >= thresh:
                    print SecondStage['percent'][i]
                    ax1.plot(FirstStage['lambdas'], 
                            SecondStage['lmsV_M'][i][1],
                            c=(1,0,0), linewidth=1, alpha=0.25)
                    ax2.plot(FirstStage['lambdas'], 
                            SecondStage['lmsV_L'][i][1],
                            c=(0,0,1), linewidth=1, alpha=0.25)
                else:
                    ax1.plot(FirstStage['lambdas'], 
                            SecondStage['lmsV_M'][i][1],
                            c=(0,0,0), linewidth=1, alpha=0.10)
                    ax2.plot(FirstStage['lambdas'], 
                            SecondStage['lmsV_L'][i][1],
                            c=(0,0,0), linewidth=1, alpha=0.10)
                
        ax1.set_ylim([-0.4, 0.4])

        ax1.set_xlim([FirstStage['wavelen']['startWave'],
                      FirstStage['wavelen']['endWave']])
        ax2.set_xlim([FirstStage['wavelen']['startWave'],
                      FirstStage['wavelen']['endWave']])

        ax1.set_ylabel('sensitivity')
        ax2.set_ylabel('sensitivity')
        ax2.set_xlabel('wavelength (nm)')
        
        if savefigs:
            plt.savefig('familyLMS_' + str(int(fracLvM * 100)) + 'L.eps')
        plt.show()

    if plotModel:

        model = cm.colorModel(age=age)
        ax = pf.get_axes(1, 1, nticks=[5, 4])[0]
        ax.spines['left'].set_smart_bounds(False)
        
        style = ['-', '--', '-.']
        for i, LvM in enumerate(np.array([0.0, 0.2, 0.4]) + fracLvM):
            model.genModel(ConeRatio={'fracLvM': LvM, 's': 0.05, },
                           maxSens=maxSens, OD=OD)
            FirstStage = model.returnFirstStage() 
            ThirdStage = model.returnThirdStage()  
        
            # get colors
            blue = ThirdStage['lCenter'].clip(0, 1000)
            yellow = ThirdStage['lCenter'].clip(-1000, 0)
            red = ThirdStage['mCenter'].clip(0, 1000)
            green = ThirdStage['mCenter'].clip(-1000, 0)

            # plot
            ax.plot(FirstStage['lambdas'], blue,
                    'b' + style[i], label=str(int(LvM * 100)) + "%L")
            ax.plot(FirstStage['lambdas'], yellow,
                    'y' + style[i])
            ax.plot(FirstStage['lambdas'], green, 'g' + style[i])
            ax.plot(FirstStage['lambdas'], red, 'r' + style[i])

            # add black line at zero
            ax.plot(FirstStage['lambdas'],
                    np.zeros((len(FirstStage['lambdas']))), 'k', linewidth=2.0)

        ax.set_ylim([-0.28, 0.28])
        ax.set_xlim([FirstStage['wavelen']['startWave'],
                         FirstStage['wavelen']['endWave']])

        ax.legend(loc='upper right', fontsize=18)

        ax.set_ylabel('sensitivity')
        ax.set_xlabel('wavelength (nm)')
        
        if savefigs:
            plt.savefig('percent_L.eps')
            
        plt.show()      
    
    if plotUniqueHues:
        model = cm.colorModel(age=age)
        ax = pf.get_axes(1, 1, nticks=[4, 5])[0]
        ax.spines['bottom'].set_smart_bounds(False)

        style = ['-', '--', '-.']
        i = 0
        for lPeak in [559.0, 557.25, 555.5]:

            model.genModel(maxSens={'l': lPeak, 'm': 530.0, 's': 417.0, }, 
                OD=OD)
            model.findUniqueHues()

            UniqueHues = model.returnUniqueHues()

            ax.plot(UniqueHues['LMratio'], UniqueHues['green'],
                    'g' + style[i], label=str(int(lPeak)))
            ax.plot(UniqueHues['LMratio'], UniqueHues['blue'],
                    'b' + style[i], label=str(int(lPeak)))
            ax.plot(UniqueHues['LMratio'], UniqueHues['yellow'],
                    'y' + style[i], label=str(int(lPeak)))
            i += 1

        ax.set_xlim([20, 100])
        ax.set_ylim([460, 600])

        ax.set_ylabel('wavelength (nm)')
        ax.set_xlabel('% L')

        if savefigs:
            plt.savefig('unique_hues.eps')
        if SHOW:
            plt.show()
        else:
            return ax