def plot_trajectory(dataset, config):
    path = config.path
    keys = get_keys_with_orientation_and_odor(dataset, keys=None)
    

    figure_path = os.path.join(path, config.figure_path)
    save_figure_path = os.path.join(figure_path, 'odor_traces/')
    pdf_name_with_path = os.path.join(save_figure_path, 'body_orientation_trajectories.pdf')
    pp = PdfPages(pdf_name_with_path)

    n_to_plot = 50
    n = -1
    for key in keys:
            
        trajec = dataset.trajecs[key]
        
        try:
            frames = trajec.frames_with_orientation
        except:
            continue
        
        if len(trajec.frames_with_orientation) < 5:
            continue
            
        n += 1
        if n >= n_to_plot:
            break
        print key
        
        fig = plt.figure(figsize=(4,4))
        ax = fig.add_subplot(111)
        ax.set_xlim(-.1, .3)
        ax.set_ylim(-.15, .15)
        ax.set_aspect('equal')
        ax.set_title(key.replace('_', '-'))
            
        ax.plot(trajec.positions[frames[0]-10:frames[-1]+10,0], trajec.positions[frames[0]-10:frames[-1]+10,1], 'black', zorder=-100, linewidth=0.25)
        
        fpl.colorline_with_heading(ax,trajec.positions[frames,0], trajec.positions[frames,1], trajec.odor[frames], orientation=trajec.orientation, colormap='jet', alpha=1, colornorm=[0,.8], size_radius=0.15-np.abs(trajec.positions[frames,2]), size_radius_range=[.02, .02], deg=False, nskip=0, center_point_size=0.01, flip=False)
            
            
        pp.savefig()
        plt.close('all')

    pp.close()
def plot_odor_trace_on_ax(path, config, dataset, keys=None, axis='xy', show_saccades=False, frames_to_show_before_odor='all', frames_to_show_after_odor='all', ax=None, odor_multiplier=1, show_post=True, save=False, frameranges=None):
    # test with '0_9174'
    
    if keys is None: 
        keys = dataset.trajecs.keys()

    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
        
    colormap='jet'
    linewidth = 1
    alpha = 1
    zorder = 1
    norm = (0,100)
    show_start = False
    color_attribute = 'odor'
    artists = None
    height = config.post_center[2]-config.ticks['z'][0]
    
    postx = config.post_center[0]
    posty = config.post_center[1]
    postz = config.post_center[2]
    
    if axis == 'xy':
        axes=[0,1]
        post = patches.Circle(config.post_center[0:2], config.post_radius, color='black')
        depth = 2
    if axis == 'xz':
        axes=[0,2]
        post = patches.Rectangle([postx-1*config.post_radius, config.ticks['z'][0]], config.post_radius*2, height, color='black')
        depth = 1
    if axis == 'yz':
        axes=[1,2]
        post = patches.Rectangle([posty-1*config.post_radius, config.ticks['z'][0]], config.post_radius*2, height, color='black')
        depth = 0
    
    ##########################
    # Plot trajectory
    ##########################
    
    for key in keys:
        trajec = dataset.trajecs[key]
        c = trajec.__getattribute__(color_attribute)
        
        frames_where_odor = np.where(trajec.odor > 10)[0]
        #frames_where_odor = hf.find_continuous_blocks(frames_where_odor, 5, return_longest_only=True)
        
        if 0:
            frames = None
            if frameranges is not None:
                if frameranges.has_key(key):
                    frames = np.arange(frameranges[key][0], frameranges[key][-1])
                    autoframerange = False
                else:
                    autoframerange = True
            else:
                autoframerange = False
            if autoframerange:
                if frames_to_show_before_odor == 'all':
                    frame0 = 0
                else:
                    frame0 = np.min(frames_where_odor) - frames_to_show_before_odor
                    frame0 = np.max([frame0, 0])
                if frames_to_show_after_odor == 'all':
                    frame1 = trajec.length
                else:
                    frame1 = np.argmax(trajec.odor) + frames_to_show_after_odor
                    frame1 = np.min([trajec.length, frame1])
                frames = np.arange(frame0, frame1)
            elif frames is None:
                frames = np.arange(0,trajec.length)
        frames = np.arange(0,trajec.length)
        
        tac.calc_heading_for_axes(trajec, axis=axis)
        orientation = trajec.__getattribute__('heading_smooth_'+axis)
        
        fpl.colorline_with_heading(ax,trajec.positions[frames,axes[0]], trajec.positions[frames,axes[1]], c[frames]*odor_multiplier, orientation=orientation[frames], colormap=colormap, alpha=alpha, colornorm=norm, size_radius=0.15-np.abs(trajec.positions[frames,depth]), size_radius_range=[0.003, .025], deg=False, nskip=2, center_point_size=0.01)
        
        #fpl.colorline_with_heading(ax,trajec.positions[frames,axes[0]], trajec.positions[frames,axes[1]], c[frames]*odor_multiplier, orientation=orientation[frames], colormap=colormap, alpha=alpha, colornorm=norm, size_radius=0.15-np.abs(trajec.positions[frames,depth]), size_radius_range=[0.02, .02], deg=False, nskip=4, center_point_size=0.01, show_centers=False)
        
        if show_start:
            start = patches.Circle( (trajec.positions[frames[0],axes[0]], trajec.positions[frames[0],axes[1]]), radius=0.004, facecolor='green', edgecolor='none', linewidth=0, alpha=1, zorder=zorder+1)
            ax.add_artist(start)
        
        if show_saccades:
            for sac_range in trajec.saccades:
                if sac_range[0] in frames and sac_range[-1] in frames: 
                    middle_saccade_index = int(len(sac_range)/2.)
                    middle_saccade_frame = sac_range[middle_saccade_index]
                    saccade = patches.Circle( (trajec.positions[middle_saccade_frame,axes[0]], trajec.positions[middle_saccade_frame,axes[1]]), radius=0.004, facecolor='red', edgecolor='none', linewidth=0, alpha=1, zorder=zorder+1)
                    ax.add_artist(saccade)
            
        
        
    ############################
    # Add post, make plot pretty
    ############################
    
    if artists is None:
        artists = []
    if show_post:
        artists.append(post)
    if artists is not None:
        for artist in artists:
            ax.add_artist(artist)
            
    if save:
        ax.set_xlim(-.4, 1.)
        ax.set_ylim(-.2,.2)
        fig.savefig('odor_traces.pdf', format='pdf')