def plot_trajectory(trajec):
    threshold_odor = 10
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    
    frames_in_odor = np.where(trajec.odor > threshold_odor)[0]
    odor_blocks = hf.find_continuous_blocks(frames_in_odor, 5, return_longest_only=False)
    
    for block in [odor_blocks[3]]:
        #middle_of_block = int(np.mean(block))
        if len(block) < 5:
            continue
        # find next saccade
        first_sac = None
        #second_sac = None
        #third_sac = None
        for sac in trajec.saccades:
            if sac[0] > block[0]:
                if first_sac is None:
                    first_sac = sac
                    break
                #elif second_sac is None:
                #    if trajec.odor[sac[0]] < threshold_odor:
                #        second_sac = sac
                #elif third_sac is None:
                #    if trajec.odor[sac[0]] < threshold_odor:
                #        third_sac = sac
                #    break
            
                
        if first_sac is not None:
            next_sac = first_sac
            angle_of_saccade = tac.get_angle_of_saccade(trajec, next_sac)
            heading_prior_to_saccade = trajec.heading_smooth[next_sac[0]]
            
            if heading_prior_to_saccade < 0:
                heading_prior_to_saccade += np.pi
            else:
                heading_prior_to_saccade -= np.pi
            
        frame0 = np.max([next_sac[0]-20, 0])
        frame1 = np.min([next_sac[0]+20, trajec.length-1])
        frames = np.arange(frame0, frame1)
        
        ax.plot(trajec.positions[frames,0], trajec.positions[frames,1])
        ax.plot(trajec.positions[frames[0],0], trajec.positions[frames[0],1], '.', color='green')
        #pos_before_sac = trajec.positions[next_sac[0], :]
        #heading_vector = pos_before_sac
        print 'raw heading prior: ', heading_prior_to_saccade*180/np.pi
        print 'raw heading after: ', trajec.heading_smooth[next_sac[-1]]*180/np.pi
        
        print 'raw angle of sac: ', angle_of_saccade*180/np.pi
    
    ax.set_aspect('equal')
    return next_sac
def plot_odor_heading_book(pp, threshold_odor, path, config, dataset, keys=None):

    fig = plt.figure(figsize=(4,4))
    
    ax = fig.add_subplot(111)
    

    saccade_angles_after_odor = []
    heading_at_saccade_initiation = []
    heading_after_saccade = []
    for key in keys:
        trajec = dataset.trajecs[key]
        frames_in_odor = np.where(trajec.odor > threshold_odor)[0]
        odor_blocks = hf.find_continuous_blocks(frames_in_odor, 5, return_longest_only=False)
        
        for block in odor_blocks:
            middle_of_block = int(np.mean(block))
            # find next saccade
            first_sac = None
            second_sac = None
            for sac in trajec.saccades:
                if sac[0] > middle_of_block:
                    if first_sac is None:
                        first_sac = sac
                    elif second_sac is None:
                        if trajec.odor[sac[0]] < threshold_odor:
                            second_sac = sac
                        break
                    
            if first_sac is not None:
                next_sac = first_sac
                angle_of_saccade = tac.get_angle_of_saccade(trajec, next_sac)
                heading_prior_to_saccade = trajec.heading_smooth[next_sac[0]]
                # flip heading
                if heading_prior_to_saccade < 0:
                    heading_prior_to_saccade += np.pi
                else:
                    heading_prior_to_saccade -= np.pi
                # flip saccade angle
                if angle_of_saccade < 0:
                    angle_of_saccade += np.pi
                else:
                    angle_of_saccade -= np.pi
                
                saccade_angles_after_odor.append(angle_of_saccade)
                heading_at_saccade_initiation.append(heading_prior_to_saccade)
                heading_after_saccade.append(heading_prior_to_saccade + angle_of_saccade)
        
    saccade_angles_after_odor = np.array(saccade_angles_after_odor)
    heading_at_saccade_initiation = np.array(heading_at_saccade_initiation)
    heading_after_saccade = np.array(heading_after_saccade)
    
    ax.plot(heading_at_saccade_initiation*180./np.pi, saccade_angles_after_odor*180./np.pi, '.')
    #ax.plot(heading_at_saccade_initiation*180./np.pi, heading_after_saccade*180./np.pi, '.')
    
    xticks = [-180, -90, 0, 90, 180]
    yticks = [-180, -90, 0, 90, 180]
    fpl.adjust_spines(ax, ['left', 'bottom'], xticks=xticks, yticks=yticks)
    ax.set_xlabel('Heading before saccade')
    ax.set_ylabel('Angle of saccade')
    
    title_text = 'Odor: ' + trajec.odor_stimulus.title()
    ax.set_title(title_text)
    
    ax.text(0,-180, 'Upwind', horizontalalignment='center', verticalalignment='top')
    ax.text(90,-180, 'Starboard', horizontalalignment='center', verticalalignment='top')
    ax.text(-90,-180, 'Port', horizontalalignment='center', verticalalignment='top')
    
    ax.text(-180,90, 'Starboard', horizontalalignment='left', verticalalignment='center', rotation='vertical')
    ax.text(-180,-90, 'Port', horizontalalignment='left', verticalalignment='center', rotation='vertical')
    
    pp.savefig()
    plt.close('all')
        

    # angle of saccade histogram
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    fpl.histogram_stack(ax, [saccade_angles_after_odor*180./np.pi], bins=20, bin_width_ratio=0.9, colors=['red'], edgecolor='none', normed=True)

    ax.set_xlabel('Angle of Saccade')
    ax.set_ylabel('Occurences, normalized')
    xticks = [-180, -90, 0, 90, 180]
    fpl.adjust_spines(ax, ['left', 'bottom'], xticks=xticks)

    ax.set_title(title_text)

    pp.savefig()
    plt.close('all')
def plot_odor_heading_book(pp, threshold_odor, path, config, dataset, odor_stimulus, keys=None, axis='xy'):

    fig = plt.figure(figsize=(4,4))
    ax = fig.add_subplot(111)
    

    saccade_angles_after_odor = []
    heading_at_saccade_initiation = []
    odor_at_saccade = []
    saccade_number = []
                
    if 1:
        for key in keys:
            trajec = dataset.trajecs[key]
            
            #if trajec.positions[0,0] < 0.2:
            #    continue
            
            frames_in_odor = np.where(trajec.odor > threshold_odor)[0]
            odor_blocks = hf.find_continuous_blocks(frames_in_odor, 5, return_longest_only=False)
            
            b = 0
            for block in odor_blocks:
                if len(block) < 5:
                    continue
                first_sac = None
                
                if axis == 'xy':
                    saccades = trajec.saccades
                elif axis == 'altitude':
                    saccades = trajec.saccades_z
                
                for sac in saccades:
                    if trajec.positions[sac[0],0] < -0.1 or trajec.positions[sac[0],0] > 0.9:
                        continue
                    if np.abs(trajec.positions[sac[0],1]) > 0.05:
                        continue
                    if trajec.positions[sac[0],2] > 0.05 or trajec.positions[sac[0],2] < -0.01:
                        continue
                        
                    if sac[0] > block[0]:
                        if first_sac is None:
                            if trajec.time_fly[sac[0]] - trajec.time_fly[block[-1]] > 0.5:
                                break
                            first_sac = sac
                            break
                        
                if first_sac is not None:
                    next_sac = first_sac
                    if axis == 'xy':
                        angle_of_saccade = tac.get_angle_of_saccade(trajec, next_sac)
                        heading_prior_to_saccade = trajec.heading_smooth[next_sac[0]]
                    elif axis == 'altitude':
                        angle_of_saccade = tac.get_angle_of_saccade_z(trajec, next_sac)
                        heading_prior_to_saccade = trajec.heading_altitude_smooth[next_sac[0]]
                    
                        
                    saccade_angles_after_odor.append(angle_of_saccade)
                    heading_at_saccade_initiation.append(heading_prior_to_saccade)
                    odor_at_saccade.append(trajec.odor[next_sac[0]])
                    b += 1
                    saccade_number.append(b)
        
    saccade_angles_after_odor = np.array(saccade_angles_after_odor)
    heading_at_saccade_initiation = np.array(heading_at_saccade_initiation)
    odor_at_saccade = np.array(odor_at_saccade)
    saccade_number = np.array(saccade_number)
    
    print odor_stimulus, saccade_angles_after_odor.shape
    
    #ax.plot(heading_at_saccade_initiation*180./np.pi, saccade_angles_after_odor*180./np.pi, '.', markersize=3)
    fpl.scatter(ax, heading_at_saccade_initiation*180./np.pi, saccade_angles_after_odor*180./np.pi, color='black', radius=3, colornorm=[0,5])
    
    xpts = np.linspace(-180,180, 100)
    ax.plot(xpts, -1*xpts, color='red', zorder=-10)
    
    #ax.plot(heading_at_saccade_initiation*180./np.pi, heading_after_saccade*180./np.pi, '.')
    
    xticks = [-180, -90, 0, 90, 180]
    yticks = [-180, -90, 0, 90, 180]
    fpl.adjust_spines(ax, ['left', 'bottom'], xticks=xticks, yticks=yticks)
    ax.set_xlabel('Heading before saccade')
    ax.set_ylabel('Angle of saccade')
    
    title_text = 'Odor: ' + odor_stimulus + ' Visual Stim: ' + trajec.visual_stimulus
    ax.set_title(title_text)
    
    ax.text(0,-180, 'Upwind', horizontalalignment='center', verticalalignment='top')
    ax.text(90,-180, 'Starboard', horizontalalignment='center', verticalalignment='top')
    ax.text(-90,-180, 'Port', horizontalalignment='center', verticalalignment='top')
    
    ax.text(-180,90, 'Starboard', horizontalalignment='left', verticalalignment='center', rotation='vertical')
    ax.text(-180,-90, 'Port', horizontalalignment='left', verticalalignment='center', rotation='vertical')
    
    pp.savefig()
    plt.close('all')
def plot_odor_heading_book(pp, threshold_odor, path, config, dataset, keys=None):

    fig = plt.figure(figsize=(4,4))
    ax = fig.add_subplot(111)
    
    saccades_odor = {'saccade_angles': [], 'heading_prior': []}
    saccades_control = {'saccade_angles': [], 'heading_prior': []}

    for key in keys:
        trajec = dataset.trajecs[key]
        frames_in_odor = np.where(trajec.odor > threshold_odor)[0]
        odor_blocks = hf.find_continuous_blocks(frames_in_odor, 5, return_longest_only=False)
        
        for block in odor_blocks:
            #middle_of_block = int(np.mean(block))
            if len(block) < 5:
                continue
            # find next saccade
            first_sac = None
            #second_sac = None
            #third_sac = None
            for sac in trajec.saccades:
                if trajec.positions[sac[0],0] < -0.15 or trajec.positions[sac[0],0] > 0.9:
                    continue
                    
                if np.abs(trajec.positions[sac[0],1]) > 0.08:
                    continue
                if np.abs(trajec.positions[sac[0],2]) > 0.08:
                    continue
                    
                
                if sac[0] > block[0]:
                    if first_sac is None:
                        first_sac = sac
                        break
                    #elif second_sac is None:
                    #    if trajec.odor[sac[0]] < threshold_odor:
                    #        second_sac = sac
                    #elif third_sac is None:
                    #    if trajec.odor[sac[0]] < threshold_odor:
                    #        third_sac = sac
                    #    break
                
                    
            if first_sac is not None:
                next_sac = first_sac
                angle_of_saccade = tac.get_angle_of_saccade(trajec, next_sac)
                heading_prior_to_saccade = trajec.heading_smooth[next_sac[0]]
                # flip heading
                if heading_prior_to_saccade < 0:
                    heading_prior_to_saccade += np.pi
                else:
                    heading_prior_to_saccade -= np.pi
                
                    
                saccade_angles_after_odor.append(angle_of_saccade)
                heading_at_saccade_initiation.append(heading_prior_to_saccade)
                heading_after_saccade.append(heading_prior_to_saccade + angle_of_saccade)
                speed_at_saccade.append(trajec.speed[next_sac[0]])
        
    saccade_angles_after_odor = np.array(saccade_angles_after_odor)
    heading_at_saccade_initiation = np.array(heading_at_saccade_initiation)
    heading_after_saccade = np.array(heading_after_saccade)
    speed_at_saccade = np.array(speed_at_saccade)
    
    ax.plot(heading_at_saccade_initiation*180./np.pi, saccade_angles_after_odor*180./np.pi, '.', markersize=4)
    #fpl.scatter(ax, heading_at_saccade_initiation*180./np.pi, saccade_angles_after_odor*180./np.pi, color=speed_at_saccade, radius=1)
    #ax.plot(heading_at_saccade_initiation*180./np.pi, heading_after_saccade*180./np.pi, '.')
    
    xticks = [-180, -90, 0, 90, 180]
    yticks = [-180, -90, 0, 90, 180]
    fpl.adjust_spines(ax, ['left', 'bottom'], xticks=xticks, yticks=yticks)
    ax.set_xlabel('Heading before saccade')
    ax.set_ylabel('Angle of saccade')
    
    title_text = 'Odor: ' + trajec.odor_stimulus.title()
    ax.set_title(title_text)
    
    ax.text(0,-180, 'Upwind', horizontalalignment='center', verticalalignment='top')
    ax.text(90,-180, 'Starboard', horizontalalignment='center', verticalalignment='top')
    ax.text(-90,-180, 'Port', horizontalalignment='center', verticalalignment='top')
    
    ax.text(-180,90, 'Starboard', horizontalalignment='left', verticalalignment='center', rotation='vertical')
    ax.text(-180,-90, 'Port', horizontalalignment='left', verticalalignment='center', rotation='vertical')
    
    pp.savefig()
    plt.close('all')
        

    # angle of saccade histogram
    if 0:
        fig = plt.figure(figsize=(4,4))
        ax = fig.add_subplot(111)
        
        fpl.histogram_stack(ax, [np.array(time_to_saccade_cast)], bins=40, bin_width_ratio=0.9, colors=['red'], edgecolor='none', normed=True)

        ax.set_xlabel('Angle of Saccade')
        ax.set_ylabel('Occurences, normalized')
        #xticks = [-180, -90, 0, 90, 180]
        fpl.adjust_spines(ax, ['left', 'bottom'])

        ax.set_title(title_text)

        pp.savefig()
        plt.close('all')