def make_image():
    fig=plt.figure(figsize=(1,2))
    ax=fig.add_subplot(111)
    plt.set_cmap('viridis')
    crfile=files_to_plot[0]
    dt=fh.open_pickle(data_path+crfile+'.pck')
    ax.imshow(dt['im_mean'],origin='lower')
    
    plt.sca(ax)
           
    cr_roi=dt['roi'][roi_ind]
    
    cr_roi.display_roi()
    ax.set_ylim(10,70)
    ax.set_xlim(20,90)
    ten_microns_in_pixels=10/microns_per_pixel
    plt.plot([30,30+ten_microns_in_pixels],[20,20],'r')
    
    cr_center=on_target_centers[0]
    circ=plt.Circle((cr_center[0],cr_center[1]),radius=5/microns_per_pixel,edgecolor='k',facecolor='None')
    ax.add_patch(circ)
    cr_center=on_target_centers[1]
    circ=plt.Circle((cr_center[0],cr_center[1]),radius=5/microns_per_pixel,edgecolor='k',facecolor='None')
    ax.add_patch(circ)
    cr_center=off_target_centers[0]
    circ=plt.Circle((cr_center[0],cr_center[1]),radius=5/microns_per_pixel,edgecolor='r',facecolor='None')
    ax.add_patch(circ)
    cr_center=off_target_centers[1]
    circ=plt.Circle((cr_center[0],cr_center[1]),radius=5/microns_per_pixel,edgecolor='r',facecolor='None')
    ax.add_patch(circ)
    fpl.adjust_spines(ax,[])
예제 #2
0
    def plot_permute_hist(self,ax, permuted_dist,**kwargs):
        try:
            crcol=kwargs['color']
        except:
            crcol='k'
        twplt.plot_hist(ax,permuted_dist,norm=True,hst_bnds=[0,np.pi],col=crcol,num_bins=200,linewidth=0.5)

        fpl.adjust_spines(ax,['left','bottom'])
        #ax.set_aspect('equal')
        #ax.plot([0,0.8],[0,0.8],'k--')
        ax.set_ylim([0, 0.075])

        ax.set_xlim([calc.deg_to_rad(0), calc.deg_to_rad(180)])
        xticks=[calc.deg_to_rad(0),calc.deg_to_rad(90),calc.deg_to_rad(180)]
        xticklabels=[0,90,180]
        yticks=[0,0.075]
        yticklabels=[0,0.075]
        ax.get_xaxis().set_ticks(xticks)
        ax.get_xaxis().set_ticklabels(xticklabels,fontsize=6)
        ax.get_yaxis().set_ticks(yticks)
        ax.get_yaxis().set_ticklabels(yticklabels,fontsize=6)
        ax.set_xlabel('heading difference\n($^\circ$)', fontsize=6)
        ax.set_ylabel('probability', fontsize=6)
        ax.xaxis.labelpad = 0
        ax.yaxis.labelpad=-12
예제 #3
0
    def plot_heading_hist(self,ax):
        COLCT=0
        for fly_type_ind,flytype in enumerate(FLY_TYPE_TO_PLOT):
            for stim_type in ['sun','stripe']:
                indt=np.array(self.adt[flytype][stim_type]['mnrad'])

                thresh_inds=np.where(np.array(self.adt[flytype][stim_type]['vec_strength'])>VEC_STRENGTH_THRESH)[0]
                crdt=indt[thresh_inds]

                highinds=np.where(crdt>np.pi)
                crdt[highinds]=-(2*np.pi-crdt[highinds])

                twplt.plot_hist(ax,crdt,hst_bnds=[-np.pi-np.pi/2,np.pi+np.pi/2],num_bins=12,col=COLVLS[COLCT])

                COLCT=COLCT+1

        fpl.adjust_spines(ax,['left','bottom'])

        #crax.set_ylim([-calc.deg_to_rad(20.),calc.deg_to_rad(380.0)])
            #crax.plot(0.22,mnvl_in_rad,'r<')
        #crax.set_xlim([0,0.2])
        ax.set_xlim([-np.pi,np.pi])
        ax.get_xaxis().set_ticks([-np.pi,0,np.pi])
        ax.get_xaxis().set_ticklabels(['-180','0','180'],fontsize=6)
        ax.get_yaxis().set_ticks([0,0.7])
        ax.get_yaxis().set_ticklabels(['0','0.7'],fontsize=6)
        ax.set_xlabel('mean heading',fontsize=6)
        ax.set_ylabel('probability',fontsize=6)
        ax.set_aspect(7)
def plot_sum_for_each_animal(axsum, mn_by_animal):
    kcol = ml.colors.colorConverter.to_rgba('k', alpha=.5)
    rcol = ml.colors.colorConverter.to_rgba('r', alpha=.5)
    axsum.scatter(np.zeros(len(mn_by_animal['on'])),
                  mn_by_animal['on'],
                  s=15,
                  facecolor='none',
                  edgecolor=kcol)

    axsum.scatter(np.ones(len(mn_by_animal['off'])) - .5,
                  mn_by_animal['off'],
                  s=15,
                  facecolor='none',
                  edgecolor=rcol)
    for crind, cr_on in enumerate(mn_by_animal['on']):
        cr_off = mn_by_animal['off'][crind]
        axsum.plot([0, 0.5], [cr_on, cr_off], color='k', linewidth=0.2)

    mn_on = np.mean(mn_by_animal['on'])
    mn_off = np.mean(mn_by_animal['off'])
    rel_stats = scipy.stats.ttest_rel(mn_by_animal['on'], mn_by_animal['off'])
    off_target_rel_to_zero_stats = scipy.stats.ttest_1samp(
        mn_by_animal['off'], 0)

    plt.plot(1.1, mn_on, '<', Markersize=3, MarkerEdgeColor=None, color='k')
    plt.plot(1.1, mn_off, '<', Markersize=3, MarkerEdgeColor=None, color='r')
    axsum.set_ylim(-.25, 1.25)
    axsum.set_aspect(4)
    axsum.set_yticks([-.25, 0, .25, .5, .75])
    axsum.set_xlim(-.2, 1.2)
    fpl.adjust_spines(axsum, ['left'])
    axsum.set_ylim(-.25, 1.25)
    axsum.set_aspect(4)
    axsum.set_yticks([-.25, 0, .25, .5, .75, 1.0, 1.25])
    axsum.set_xlim(-.2, 1.2)
예제 #5
0
def plot_raw(ax, fvldt, norm_flag=True, max_value=0, line_flag=False):
    if norm_flag:
        fvlkey = 'fvl_norm'
    else:
        fvlkey = 'fvl_raw'
    colors = ['k', 'r']
    for crtypeind, crtype in enumerate(fvldt.keys()):

        if max_value:
            fvl = fvldt[crtype][fvlkey] / max_value
        else:
            fvl = fvldt[crtype][fvlkey]
        col = ml.colors.colorConverter.to_rgba(colors[crtypeind], alpha=.5)
        ax.scatter(fvldt[crtype]['depth'],
                   fvl,
                   s=15,
                   facecolor='none',
                   edgecolor=col)

        if line_flag:
            ax.plot(fvldt[crtype]['depth'],
                    fvl,
                    linewidth=0.2,
                    color=colors[crtypeind])

    fpl.adjust_spines(ax, ['bottom', 'left'])
예제 #6
0
def make_traces():
    fig = plt.figure(figsize=(1.5, 1.5))
    gs = GridSpec(2, 2, figure=fig)
    ax = {}
    ax['on'] = []
    ax['off'] = []
    ax['on'].append(fig.add_subplot(gs[0, 0]))
    ax['on'].append(fig.add_subplot(gs[1, 0]))
    ax['off'].append(fig.add_subplot(gs[0, 1]))
    ax['off'].append(fig.add_subplot(gs[1, 1]))
    for crind, crfile in enumerate(files_to_plot):
        dt = fh.open_pickle(data_path + crfile + '.pck')
        pdb.set_trace()
        pre_time_in_frames = int(np.ceil(pre_time_in_sec /
                                         TIME_BETWEEN_FRAMES))
        post_time_in_frames = int(
            np.ceil(post_time_in_sec / TIME_BETWEEN_FRAMES))
        for key in ['on', 'off']:
            if key is 'on':
                col = 'k'
            else:
                col = 'r'
            crax = ax[key][crind]

            st_time = stim_time_in_sec[key][crind]

            st_frame = int(np.ceil(st_time / TIME_BETWEEN_FRAMES) - 1)
            modvl = np.mod(st_time / TIME_BETWEEN_FRAMES,
                           np.floor(st_time / TIME_BETWEEN_FRAMES))
            crax.plot(np.arange(pre_time_in_frames),
                      dt['mn_roi'][roi_ind][st_frame -
                                            pre_time_in_frames:st_frame],
                      color=col,
                      linewidth=0.5)
            stim_duration_in_frames = stim_duration / TIME_BETWEEN_FRAMES
            st_time_in_frames = pre_time_in_frames - 1 + modvl
            crax.plot([
                st_time_in_frames, st_time_in_frames + stim_duration_in_frames
            ], [700, 700],
                      color='c',
                      linewidth=2)

            pre_f = dt['deltaf_vls']['pre_f'][roi_ind][stim_ind[key]]
            pst_f = dt['deltaf_vls']['pst_f'][roi_ind][stim_ind[key]]
            crax.plot(39, pre_f, '<')
            crax.plot(50, pst_f, '<')
            crax.plot(np.arange(post_time_in_frames) + pre_time_in_frames + 3,
                      dt['mn_roi'][roi_ind][st_frame + 3:st_frame + 3 +
                                            post_time_in_frames],
                      color=col,
                      linewidth=0.5)

            fpl.adjust_spines(crax, ['left'])
            crax.set_ylim(300, 900)
            crax.set_yticks([300, 900])
            if crind == 0:
                if key is 'on':
                    crax.plot([10, 10 + 1 / TIME_BETWEEN_FRAMES], [350, 350],
                              'k')
예제 #7
0
def plot_raw_vert(ax, fvldt, norm_flag=True, max_value=0):
    if norm_flag:
        fvlkey = 'fvl_norm'
    else:
        fvlkey = 'fvl_raw'
    colors = ['k', 'r']
    for crtypeind, crtype in enumerate(fvldt.keys()):

        if max_value:
            fvl = fvldt[crtype][fvlkey] / max_value
        else:
            fvl = fvldt[crtype][fvlkey]
        ax.scatter(fvl,
                   fvldt[crtype]['depth'],
                   s=20,
                   facecolor='none',
                   edgecolor=colors[crtypeind])
    fpl.adjust_spines(ax, ['bottom', 'left'])
예제 #8
0
    def set_up_plot_ax(self,crflytype):
        LAYOUTIND=0

        ax={}
        ax['tmdt']=[]
        ax['hst']=[]
        ax['txt']=[]
        for cr_row in np.arange(NROW):
            ax['hst'].append([])

            axrow=np.mod(cr_row,MAXROW)
            if cr_row>MAXROW-1:
                LAYOUTIND=int(np.floor(cr_row/MAXROW))
            if PLOT_EX_TEXT_STR:
                hst_3_str='txt%s'%(str(axrow))
                try:
                    ax['txt'].append(self.layout[crflytype][LAYOUTIND].axes[hst_3_str])
                except:
                    print('text append error')
                fpl.adjust_spines(ax['txt'][-1],['none'])

            for crcol in np.arange(NHISTCOL):

                crhststr='hst%s%s'%(str(axrow),str(crcol))
                try:
                    ax['hst'][cr_row].append(self.layout[crflytype][LAYOUTIND].axes[crhststr])
                except:
                    print('layout error')


            crtmstr='ex%s'%str(axrow)
            try:
                ax['tmdt'].append(self.layout[crflytype][LAYOUTIND].axes[crtmstr])
            except:

                print('layout error')

        return ax
예제 #9
0
    def plot_hist_dir(self,ax):
        COLCT=0
        for fly_type_ind,flytype in enumerate(FLY_TYPE_TO_PLOT):
            for indnum in [0,1,2]:

                indt=np.array(self.adt[flytype][indnum]['mnrad'])

                thresh_inds=np.where(np.array(self.adt[flytype][indnum]['vec_strength'])>VEC_STRENGTH_THRESH)[0]
                crdt=np.cos(indt[thresh_inds])

                ax[indnum].text(0.2,0.2+fly_type_ind*.2,str(np.count_nonzero(~np.isnan(thresh_inds))),color=COLVLS[COLCT])

                twplt.plot_hist(ax[indnum],crdt,hst_bnds=[-1,1],plt_type='cumulative',num_bins=100,col=COLVLS[COLCT])
            COLCT=COLCT+1


        xvl=np.linspace(0,2*np.pi,10000)
        yvl=np.cos(xvl)
        for indnum in [0,1,2]:
            twplt.plot_hist(ax[indnum],yvl,hst_bnds=[-1,1],num_bins=100,plt_type='cumulative',col=[0.5,0.5,0.5])

        for indnum in [0,1,2]:
            crax=ax[indnum]
            fpl.adjust_spines(crax,['left','bottom'])

            #crax.set_ylim([-calc.deg_to_rad(20.),calc.deg_to_rad(380.0)])
                #crax.plot(0.22,mnvl_in_rad,'r<')
            #crax.set_xlim([0,0.2])
            crax.set_xlim([-1,1])
            crax.set_ylim([0,1])
            crax.get_xaxis().set_ticks([-1,0,1])
            crax.get_xaxis().set_ticklabels(['-1','0','1'],fontsize=6)
            crax.get_yaxis().set_ticks([0,0.5,1])
            crax.get_yaxis().set_ticklabels(['0','0.5',1],fontsize=6)
            crax.set_xlabel('cosine of heading',fontsize=6)
            crax.set_ylabel('cumulative probability',fontsize=6)
            crax.set_aspect(2)
        ax[crind].add_patch(circ)
        cr_center = on_target_centers[1]
        circ = plt.Circle((cr_center[0], cr_center[1]),
                          radius=5 / microns_per_pixel,
                          edgecolor='g',
                          facecolor='None')
        ax[crind].add_patch(circ)
    if crind == 1:
        cr_center = off_target_centers[0]
        circ = plt.Circle((cr_center[0], cr_center[1]),
                          radius=5 / microns_per_pixel,
                          edgecolor='c',
                          facecolor='None')
        ax[crind].add_patch(circ)

        cr_center = off_target_centers[1]
        circ = plt.Circle((cr_center[0], cr_center[1]),
                          radius=5 / microns_per_pixel,
                          edgecolor='c',
                          facecolor='None')
        ax[crind].add_patch(circ)

    xvls = np.array(stim_region['xlist'][1])
    yvls = np.array(stim_region['ylist'][1])
    y = np.zeros(np.shape(dt['tifstack']))
    y[yvls, xvls] = 1
    y = np.ma.masked_where(y == 0, y)

    ax[crind].plot([200, 200 + 10 / microns_per_pixel], [170, 170], 'c')
    fpl.adjust_spines(ax[crind], [])
        circ = plt.Circle((cr_center[0], cr_center[1]),
                          radius=5 / microns_per_pixel,
                          edgecolor='c',
                          facecolor='None')
        ax[crind].add_patch(circ)
    if crind == 1:
        cr_center = off_target_centers[0]
        circ = plt.Circle((cr_center[0], cr_center[1]),
                          radius=5 / microns_per_pixel,
                          edgecolor='g',
                          facecolor='None')
        ax[crind].add_patch(circ)

        cr_center = off_target_centers[1]
        circ = plt.Circle((cr_center[0], cr_center[1]),
                          radius=5 / microns_per_pixel,
                          edgecolor='g',
                          facecolor='None')
        ax[crind].add_patch(circ)

    xvls = np.array(stim_region['xlist'][1])
    yvls = np.array(stim_region['ylist'][1])
    y = np.zeros(np.shape(dt['tifstack']))
    y[yvls, xvls] = 1
    y = np.ma.masked_where(y == 0, y)
    ax[crind].set_xlim(80, 220)
    ax[crind].set_ylim(120, 240)
    ax[crind].plot([200, 200 + 10 / microns_per_pixel], [190, 190], 'c')

    fpl.adjust_spines(ax[crind], '')
for crind in np.arange(len(file_names)):
    fig = plt.figure(figsize=(2, 4))
    gs = GridSpec(8, 14, figure=fig)
    ax = fig.add_subplot(gs[0:5, 0:5])
    ax_rthist = fig.add_subplot(gs[5:7, 0:4])
    ax2 = fig.add_subplot(gs[1:3, 6:12])
    ax2_rthist = fig.add_subplot(gs[1:3, 12:])
    ax2_bthist = fig.add_subplot(gs[7:, 6:12])
    plt.set_cmap('viridis')
    dt = util.read_in_tif(data_path + file_names[crind])
    mean_xy = np.mean(dt['tifstack'], axis=0)
    mean_z = np.mean(dt['tifstack'], axis=1)
    imshowobj = ax.imshow(mean_xy)
    #imshowobj.set_clim(20,28)
    fpl.adjust_spines(ax, [])
    ax.set_xlim(600, 1200)
    ax.set_ylim(600, 1200)
    microns_mean_z = np.shape(mean_z)[1] * microns_per_pixel
    aspect_ratio = microns_mean_z / 100.
    imshowobj2 = ax2.imshow(mean_z)
    imshowobj2.set_clim(14.6, 16.5)
    fpl.adjust_spines(ax2, [])
    ax2.set_aspect(aspect_ratio)
    ax2.set_xlim(600, 1200)
    ax2.set_ylim(0, 100)
    mean_meanz = np.mean(mean_z, axis=0)
    ax2_bthist.plot(mean_meanz)
    ax.plot([970, 970 + 20 / microns_per_pixel], [775, 775], 'w')
    ax2.plot([970, 970 + 20 / microns_per_pixel], [20, 20], 'w')
    fpl.adjust_spines(ax2_bthist, [])
예제 #13
0
    def make_raw_plot(self,flyindnum,exp_type_to_plot):
        ADD_VEC_TEXT=True
        #COLNUM=-1

        if self.crdt:


            for cr_fltnum in self.crdt.keys():

                if self.crdt[cr_fltnum]:




                    mnvl_in_rad=self.crdt[cr_fltnum]['mnrad_360']
                    if mnvl_in_rad>np.pi:
                        mnvl_in_rad=-(2*np.pi-mnvl_in_rad)
                    halt_flag=False

                    offset_time=0
                    if cr_fltnum==1:
                        offset_time=self.crdt[cr_fltnum-1]['time_in_min'][-1]
                    elif cr_fltnum>1:
                        offset_time=self.crdt[cr_fltnum-1]['time_in_min'][-1]-TIME_GAP
                    if flyindnum<len(self.axraw[exp_type_to_plot]['tmdt']):
                        try:
                            fpb.plot_motor(self.crdt[cr_fltnum],self.axraw[exp_type_to_plot]['tmdt'][self.crplotnum],plot_vector=False,plot_split=1,plot_start_angle=0,subtract_zero_time=True,offset_time=offset_time,plot_vert_line_at_end=True, halt_flag=halt_flag,center_on_zero_flag=True,withhold_bottom_axis=True)
                        except:
                            print('plot_motor error')
                        self.axraw[exp_type_to_plot]['tmdt'][self.crplotnum].set_xlim([0,15.5])


                        if ADD_VEC_TEXT:
                            crvec=1-self.crdt[cr_fltnum]['circvar_mod360']

                            self.axraw[exp_type_to_plot]['tmdt'][self.crplotnum].text(1.5*5*cr_fltnum,30,str(crvec))
                        #if COLNUM:

                         #   axmotor[crkey][flyindnum][COLNUM].axis('off')
                          #  axhist[crkey][flyindnum][COLNUM].axis('off')
                        mot_deg=calc.center_deg_on_zero(calc.rad_to_deg(self.crdt[cr_fltnum]['mot_rad']))
                        degbins=np.arange(-180,190,10)
                        tst=calc.make_hist_calculations(mot_deg,degbins)
                        deg_per_bin=10

                        tst['rad_per_bin']=deg_per_bin*np.pi/180

                        tst['xrad']=calc.deg_to_rad(degbins)
                        try:
                            crax=self.axraw[exp_type_to_plot]['hst'][self.crplotnum][cr_fltnum]
                        except:
                            print('ax assignment error')

                        crax.step(tst['normhst'],tst['xrad'][0:-1]+tst['rad_per_bin']/2,color='k',linewidth=0.5)
                        if PLOT_EX_TEXT_STR:
                            textax=self.axraw[exp_type_to_plot]['txt'][self.crplotnum]
                            txtfile=self.crdt[cr_fltnum]['fname'].split('/')[-1]
                            textax.text(0,0,txtfile,fontsize=4, rotation='vertical')
                            textax.set_ylim([-12,2])

                        #crax.step(self.crdt[cr_fltnum]['normhst'],self.crdt[cr_fltnum]['xrad'][0:-1]+self.crdt[cr_fltnum]['rad_per_bin']/2,'k',linewidth=0.5)
                        #self.col_num[crkey]=self.col_num[crkey]+1
                        fpl.adjust_spines(crax,[])
                        crax.set_ylim([-calc.deg_to_rad(180.),calc.deg_to_rad(180.0)])
                        crax.plot(0.21,mnvl_in_rad,'r<',clip_on=False)
                        crax.set_xlim([0,0.24])
            self.crplotnum=self.crplotnum+1
예제 #14
0
    def plot_vec_strength_and_dir(self,axvec,axhist,axvec_nonfrontal):

        hst_bnds=[0,1]
        num_bins=10
        COLCT=0
        concat_dt_vec={}
        dt_vec_in_cols={}
        concat_dt_vec_front={}
        concat_dt_head={}
        self.concat_dt_head_in_cols={}
        dir_dt={}
        dir_dt_all={}
        all_dt_dict={}
        self.comb_head_dt={}
        for fly_type_ind,flytype in enumerate(FLY_TYPE_TO_PLOT):
            horiz_pos=0.2
            vert_pos=0.2

            for stim_type in ['sun']:
                frontal_dt={}
                for indnum in [0,1]:
                    #plot vec strength
                    indt=np.array(self.adt[flytype][indnum]['vec_strength'])
                    if indnum==1:
                        if flytype is 'tb_flies':
                            PLOTFLAG=False
                        else:
                            PLOTFLAG=True
                    else:
                        PLOTFLAG=True
                    if PLOTFLAG:
                       vert_offset=fly_type_ind
                       crcol=COLVLS[COLCT]
                       ax=axvec[indnum]
                       self.make_hist_cumulative(ax,indt,horiz_pos,vert_pos,vert_offset,crcol,hst_bnds)

                      

                    #plot position
                    indt=np.array(self.adt[flytype][indnum]['mnrad'])
                    thresh_inds=np.where(np.array(self.adt[flytype][indnum]['vec_strength'])>VEC_STRENGTH_THRESH)[0]
                    low_vec_inds=np.where(np.array(self.adt[flytype][indnum]['vec_strength'])<VEC_STRENGTH_THRESH)[0]
                    #map to 180 degrees
                    inds=np.where(indt[thresh_inds]>np.pi)[0]
                    dir_dt[indnum]=indt[thresh_inds]
                    dir_dt[indnum][inds]=2*np.pi-indt[thresh_inds][inds]
                    dir_dt_all[indnum]=np.array(self.adt[flytype][indnum]['mnrad'])
                    alldt=np.array(self.adt[flytype][indnum]['mnrad'])
                    
                    dt_to_add=np.copy(alldt)
                    dt_to_add[low_vec_inds]=np.nan
                    
                    all_dt_dict[indnum]=dt_to_add

                    alldt_inds=np.where(alldt>np.pi)[0]
                    
                    dir_dt_all[indnum][alldt_inds]=2*np.pi-alldt[alldt_inds]

                    hst_bnds=[0,np.pi]
                    ax=axhist[indnum]
                    self.make_hist_cumulative(ax,dir_dt[indnum],horiz_pos,vert_pos,vert_offset,crcol,hst_bnds)

                 
                    #find position where angle is not frontal according to threshold.
                    ax=axvec_nonfrontal[indnum]
                    nonfrontal_inds_high=np.where(indt>FRONTAL_THRESH)[0]
                    nonfrontal_inds_low=np.where(indt<(2*np.pi-FRONTAL_THRESH))[0]
                    nonfrontal_inds=np.intersect1d(nonfrontal_inds_low,nonfrontal_inds_high)
                    indt=np.array(self.adt[flytype][indnum]['vec_strength'])[nonfrontal_inds]
                    frontal_dt[indnum]=np.reshape(indt,len(indt))
                    hst_bnds=[0,1]
                    self.make_hist_cumulative(ax,indt,horiz_pos,vert_pos,vert_offset,crcol,hst_bnds)
                #combined 1st and 2nd trials
                indnum=2
                #vec strength
                crdt=np.concatenate((np.array(self.adt[flytype][0]['vec_strength']),np.array(self.adt[flytype][1]['vec_strength'])),axis=0)
                crdt_in_cols=np.concatenate((np.array(self.adt[flytype][0]['vec_strength']),np.array(self.adt[flytype][1]['vec_strength'])),axis=1)
                crax=axvec[indnum]
                self.make_hist_cumulative(crax,crdt,horiz_pos,vert_pos,vert_offset,crcol,hst_bnds)
                concat_dt_vec[flytype]=np.reshape(crdt,len(crdt))
                dt_vec_in_cols[flytype]=crdt_in_cols
                crax=axvec_nonfrontal[indnum]

                front_dt_combined=np.concatenate((frontal_dt[0],frontal_dt[1]))
                concat_dt_vec_front[flytype]=front_dt_combined
                self.make_hist_cumulative(crax,front_dt_combined,horiz_pos,vert_pos,vert_offset,crcol,hst_bnds)

             

                crdt=np.concatenate((dir_dt[0],dir_dt[1]))

                self.comb_head_dt[flytype]=np.concatenate((all_dt_dict[0],all_dt_dict[1]))
               
                cr_hd_dt_in_cols=np.concatenate((dir_dt_all[0],dir_dt_all[1]),axis=1)
               
                concat_dt_head[flytype]=crdt
                for colind in [0,1]:
                    low_vec_strength_inds=np.where(crdt_in_cols[:,colind]<VEC_STRENGTH_THRESH)[0]
                    cr_hd_dt_in_cols[low_vec_strength_inds,colind]=np.nan
                
                self.concat_dt_head_in_cols[flytype]=cr_hd_dt_in_cols


                axhist[indnum].text(0.2,fly_type_ind*.1,str(len(crdt)),color=COLVLS[COLCT])
                
                axhist[indnum].plot(np.mean(crdt),0.1,'v',color=COLVLS[COLCT])
                twplt.plot_hist(axhist[indnum],crdt,hst_bnds=[0,np.pi],plt_type='cumulative',num_bins=100,col=COLVLS[COLCT])

            for stim_type in ['stripe']:


                indt=np.array(self.adt[flytype][2]['vec_strength'])
                try:
                    axvec[3].text(0.2,fly_type_ind*.1,str(np.count_nonzero(~np.isnan(indt))),color=COLVLS[COLCT])

                    twplt.plot_hist(axvec[3],indt,hst_bnds=hst_bnds,num_bins=100,plt_type='cumulative',col=COLVLS[COLCT])

                    indt=np.array(self.adt[flytype][2]['mnrad'])
                    thresh_inds=np.where(np.array(self.adt[flytype][2]['vec_strength'])>VEC_STRENGTH_THRESH)[0]
                        #map to 180 degrees

                    inds=np.where(indt[thresh_inds]>np.pi)[0]
                    pltdt=indt[thresh_inds]
                    pltdt=2*np.pi-indt[thresh_inds][inds]
                    axhist[3].text(0.2,0.2+fly_type_ind*.2,str(len(thresh_inds)),color=COLVLS[COLCT])

                    twplt.plot_hist(axhist[3],pltdt,hst_bnds=[0,np.pi],plt_type='cumulative',num_bins=100,col=COLVLS[COLCT])
                except:
                    print('stripe error')
            COLCT=COLCT+1

        FLY_TYPE_TO_PLOT
        
        try:
            if PERMUTE_BY_INDIVIDUALS_NOT_FLIGHTS:
                
                #sd,pctile_all=calc.permute_test_for_mean_diff_sampling_from_rows(dt_vec_in_cols[FLY_TYPE_TO_PLOT[0]], dt_vec_in_cols[FLY_TYPE_TO_PLOT[2]], num_permutations=1000)
                sd,pctile_head=calc.permute_test_for_mean_diff_sampling_from_rows(self.concat_dt_head_in_cols[FLY_TYPE_TO_PLOT[0]], self.concat_dt_head_in_cols[FLY_TYPE_TO_PLOT[1]], num_permutations=1000)
            else:
                 sd,pctile_front=calc.permute_test_for_mean_diff_between_two_groups(concat_dt_vec_front[FLY_TYPE_TO_PLOT[0]], concat_dt_vec_front[FLY_TYPE_TO_PLOT[1]], num_permutations=1000)
                 sd,pctile_head=calc.permute_test_for_mean_diff_between_two_groups(concat_dt_head[FLY_TYPE_TO_PLOT[0]], concat_dt_head[FLY_TYPE_TO_PLOT[1]], num_permutations=1000)
            
            
        except:
            print ('permutation error')
        for ind in [0,1,2,3]:
            fpl.adjust_spines(axvec[ind],['left','bottom'])
            fpl.adjust_spines(axhist[ind],['left','bottom'])
            fpl.adjust_spines(axvec_nonfrontal[ind],['left','bottom'])
            #crax.set_ylim([-calc.deg_to_rad(20.),calc.deg_to_rad(380.0)])
                #crax.plot(0.22,mnvl_in_rad,'r<')
            for crax in [axvec[ind],axvec_nonfrontal[ind]]:
                crax.set_xlim([0,1.0])
                crax.set_ylim([0,1])
                crax.get_xaxis().set_ticks([0,0.25,0.5,0.75,1.0])
                crax.get_xaxis().set_ticklabels(['0','0.25','0.5','0.75'],fontsize=6)
                crax.get_yaxis().set_ticks([0,0.5,1])
                crax.get_yaxis().set_ticklabels(['0','0.5','1'],fontsize=6)
                crax.set_xlabel('vector strength',fontsize=6)
                crax.set_ylabel('probability',fontsize=6)
                crax.set_aspect(1)

            axhist[ind].set_xlim([0,np.pi])
            axhist[ind].set_ylim([0,1])
            axhist[ind].get_xaxis().set_ticks([0,np.pi/2,np.pi])
            axhist[ind].get_xaxis().set_ticklabels(['0','90','180'],fontsize=6)
            axhist[ind].get_yaxis().set_ticks([0,0.5,1])
            axhist[ind].get_yaxis().set_ticklabels(['0','0.5','1'],fontsize=6)
            axhist[ind].set_xlabel('mean heading (deg)',fontsize=6)
            axhist[ind].set_ylabel('probability',fontsize=6)
예제 #15
0
    def plot_correlation(self,ax):

        for flytype in FLY_TYPE_TO_PLOT:
            rad_list=[]
            vec_list=[]
            vec_plot_list=[]
            crax=ax[flytype]['corr_scatter']
            craxhist=ax[flytype]['corr_hist']
            for ind in CORR_PLOT_COLS:


                cr_rad=self.adt[flytype][ind]['mnrad']
                cr_vec=self.adt[flytype][ind]['vec_strength']


                vec_strength=np.reshape(cr_vec,len(cr_vec))
                radvls=np.reshape(cr_rad,len(cr_rad))
                #plt_rad = radvls[~np.isnan(radvls)]
                plt_rad=radvls
                #plt_vec=vec_strength[~np.isnan(radvls)]
                plt_vec=vec_strength
                if center_on_zero_flag:
                    highinds=np.where(plt_rad>np.pi)[0]
                    plt_rad[highinds]=-(2*np.pi-plt_rad[highinds])
                rad_list.append(plt_rad)
                vec_list.append(plt_vec)
            threshinds=np.intersect1d(np.where(vec_list[0]>VEC_STRENGTH_THRESH)[0],np.where(vec_list[1]>VEC_STRENGTH_THRESH)[0])

            vec_plot_list.append(vec_list[0][threshinds])
            vec_plot_list.append(vec_list[1][threshinds])
            #twplt.scatterplot(crax,rad_list[0][threshinds],rad_list[1][threshinds],plot_error_bar=True,dynamic_sizes=vec_plot_list,error_scale_factor=1)
            #twplt.scatterplot(crax,rad_list[0][threshinds],rad_list[1][threshinds]+2*np.pi,plot_error_bar=True,dynamic_sizes=vec_plot_list,error_scale_factor=1)
            if PERMUTE_FLAG:
                for crtype in ['thresh','all']:
                    if crtype is 'thresh':
                        crinds=threshinds
                        crcol='c'
                        ht=.02
                        markht=.01
                    else:
                        crinds=np.arange(len(rad_list[0]))
                        crcol='k'
                        ht=.04
                        markht=0

                    try:
                        return_dict=calc.permute_diffs(np.column_stack((rad_list[0][crinds],rad_list[1][crinds])),max_diff=np.pi)
                        actual_diff=calc.calc_heading_diff(np.column_stack((rad_list[0][crinds],rad_list[1][crinds])),calc_abs_diff=True,max_diff=np.pi)

                        self.plot_permute_hist(craxhist,return_dict['permuted_dist'],color=crcol)

                        craxhist.plot(np.nanmean(actual_diff),markht,marker='v',color=crcol,clip_on=False)

                        try:
                            prctile=stats.percentileofscore(return_dict['permuted_dist'],np.nanmean(actual_diff))
                        except:
                            print('percentile error')
                        craxhist.text(np.pi/6, ht,str(format(prctile/100.,'.2f')),fontsize=6,color=crcol)
                    except:
                        print('hist plot error')
            fpl.adjust_spines(crax,['left','bottom'])
            crax.get_yaxis().set_ticks([-np.pi,0,np.pi,2*np.pi, 3*np.pi])
            crax.get_xaxis().set_ticklabels(['-180','0','180'],fontsize=6)
            crax.get_xaxis().set_ticks([-np.pi,0,np.pi])
            crax.get_yaxis().set_ticklabels(['-180','0','180','360','180'],fontsize=6)
            crax.set_xlabel('first heading',fontsize=6)
            crax.set_ylabel('second heading',fontsize=6)
            crax.set_xlim([-np.pi-.5,np.pi+.5])
            crax.set_ylim([-np.pi-.5,3*np.pi+.5])
            crax.plot([-np.pi,np.pi],[-np.pi,np.pi],'b',linewidth=0.5)
            crax.plot([-np.pi,np.pi],[np.pi,3*np.pi],'b',linewidth=0.5)
            crax.set_aspect(1)
def main():
    mn_by_animal = {}
    for file_ind, crfile in enumerate(summary_data_files_to_load):
        dt = fh.open_pickle(summary_data_location + crfile)
        delta_f = dt['delta_f']

        fvls_array = {}

        for target_key in ['on', 'off']:
            if file_ind == 0:
                mn_by_animal[target_key] = []
            fvls_sum = {}

            fvls_sum['raw'] = []
            fvls_sum['norm'] = []
            fvls_array[target_key] = {}
            for expind, exp_key in enumerate(delta_f.keys()):

                try:
                    fvls_to_add = delta_f[exp_key][target_key]['fvl']
                except:
                    fvls_to_add = delta_f[exp_key][target_key]['fvl_raw']
                norm_fvls_to_add = fvls_to_add / dt['max_value'][exp_key]

                fvls_sum['norm'].append(norm_fvls_to_add)
                if expind > 0:
                    crlen = len(fvls_sum['raw'][0])
                    if len(fvls_to_add) > crlen:
                        pdb.set_trace()
                        fvls_to_add = fvls_to_add[0:crlen]
                fvls_sum['raw'].append(fvls_to_add)

                shape1 = np.shape(fvls_sum['norm'])[0]
            if len(np.shape(fvls_sum['norm'])) == 2:
                shape = np.shape(fvls_sum['norm'])

            fvls_array[target_key]['fvl_norm'] = np.mean(np.array(
                fvls_sum['norm']),
                                                         axis=0)
            try:
                fvls_array[target_key]['fvl_raw'] = np.mean(np.array(
                    fvls_sum['raw']),
                                                            axis=0)
            except:
                pdb.set_trace()
            tmp_depth = delta_f[0][target_key]['depth']
            if ROUND_TO_NEAREST:
                fvls_array[target_key]['depth'] = myround(
                    np.array(tmp_depth), base=ROUND_TO_NEAREST)
            else:
                fvls_array[target_key]['depth'] = np.array(tmp_depth)

            mn_by_animal[target_key].append(
                calc_mean_for_animal_over_depth_range(fvls_array[target_key],
                                                      target_key))

            norm_flag = False

            plt_util.plot_raw(axnorm, fvls_array, norm_flag, line_flag=True)
        fpl.adjust_spines(axnorm, ['left', 'bottom'])
        axnorm.set_ylim(-.25, 0.75)
        axnorm.set_yticks([-.25, 0, 0.25, 0.5, 0.75])
        axnorm.set_xlim(-62, 42)
    plot_sum_for_each_animal(axsum, mn_by_animal)