Example #1
0
def run(db_name, save_fig_name=None):
    """If save_fig_name is not None, should be ex. '~/test.png' """
    # Load the raw data
    db = OE.open_db(url=('sqlite:///%s' % db_name))
    id_blocks, = OE.sql(
        'SELECT block.id FROM block WHERE block.name="Raw Data"')
    id_block = id_blocks[0]

    id_recordingpoints, rp_names = OE.sql("SELECT \
        recordingpoint.id, recordingpoint.name \
        FROM recordingpoint \
        WHERE recordingpoint.id_block = :id_block",
                                          id_block=id_block)

    f = plt.figure(figsize=(10, 10))

    # Process each recording point separately
    for n, (id_rp, tt) in enumerate(zip(id_recordingpoints[:16],
                                        rp_names[:16])):
        # Load all signals from all segments with this recording point
        id_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_recordingpoint = :id_recordingpoint',
            id_recordingpoint=id_rp)

        # Average the signal
        avgsig = np.zeros(OE.AnalogSignal().load(id_sigs[0]).signal.shape)
        for id_sig in id_sigs:
            sig = OE.AnalogSignal().load(id_sig)
            avgsig = avgsig + sig.signal
        avgsig = old_div(avgsig, len(id_sigs))

        # Plot the average signal of this recording point
        ax = f.add_subplot(4, 4, n + 1)
        ax.plot(
            old_div(np.arange(len(avgsig)), sig.sampling_rate) * 1000, avgsig)
        #ax.set_ylim((-250, 250))
        ax.set_title(tt)

    if save_fig_name is None:
        plt.show()
    else:
        plt.savefig(save_fig_name)
        plt.close()
def select_segments_by_trial_number(id_block, trial_numbers):    
    # Load all segments from block
    id_segs, info_segs = OE.sql('select segment.id, segment.info from segment \
        where segment.id_block = %d' % id_block)
    
    # Extract b_trial_numbers in more useful format by strippin leading 'B'
    b_trial_numbers = info_segs.astype(int) #[int(info) for info in info_segs] 
    
    # Find id_segs that match these trials
    return id_segs[np.in1d(b_trial_numbers, trial_numbers)]
Example #3
0
def run(db_name, save_fig_name=None):
    """If save_fig_name is not None, should be ex. '~/test.png' """
    # Load the raw data
    db = OE.open_db(url=('sqlite:///%s' % db_name))
    id_blocks, = OE.sql('SELECT block.id FROM block WHERE block.name="Raw Data"')
    id_block = id_blocks[0]

    id_recordingpoints, rp_names = OE.sql("SELECT \
        recordingpoint.id, recordingpoint.name \
        FROM recordingpoint \
        WHERE recordingpoint.id_block = :id_block", id_block=id_block)

    f = plt.figure(figsize=(10,10))
    
    # Process each recording point separately
    for n, (id_rp,tt) in enumerate(zip(id_recordingpoints[:16], rp_names[:16])):
        # Load all signals from all segments with this recording point
        id_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_recordingpoint = :id_recordingpoint',
            id_recordingpoint=id_rp)
        
        # Average the signal
        avgsig = np.zeros(OE.AnalogSignal().load(id_sigs[0]).signal.shape)
        for id_sig in id_sigs: 
            sig = OE.AnalogSignal().load(id_sig)
            avgsig = avgsig + sig.signal
        avgsig = avgsig / len(id_sigs)

        # Plot the average signal of this recording point
        ax = f.add_subplot(4,4,n+1)
        ax.plot(np.arange(len(avgsig)) / sig.sampling_rate * 1000, avgsig)
        #ax.set_ylim((-250, 250))
        ax.set_title(tt)

    if save_fig_name is None:
        plt.show()
    else:
        plt.savefig(save_fig_name)
        plt.close()
def select_audio_signals_by_stimulus_number(id_block, stim_num, 
    TRIALS_INFO, side='L'):
    # Find trials with stimulus number stim_num
    keep_id_segs = select_segments_by_trial_number(id_block, 
        trial_numbers=TRIALS_INFO['TRIAL_NUMBER'][\
        TRIALS_INFO['STIM_NUMBER'] == stim_num])

    # Load all analogsignals of this channel
    id_sigs, id_segs = OE.sql('select analogsignal.id, analogsignal.id_segment \
        from analogsignal where analogsignal.name like "' \
        + side + ' Speaker %"')
    
    # Grab analog signals where id_segs matches keep_id_segs    
    keep_id_sigs = id_sigs[np.in1d(id_segs.astype(int), 
        keep_id_segs.astype(int))]
    speaker_traces = [OE.AnalogSignal().load(id_sig).signal \
        for id_sig in keep_id_sigs]
    
    return np.array(speaker_traces)
def execute(control_params):
    # Load TRIALS_INFO
    bcl = bcontrol.Bcontrol_Loader(filename=control_params['behavior_filename'],
        v2_behavior=True)
    bcl.load()
    TRIALS_INFO = bcl.data['TRIALS_INFO']
    
    # Open database
    OE.open_db('sqlite:///%s' % control_params['db_name'])    
    id_blocks, = OE.sql('select block.id from block where block.name = "Raw Data"')
    id_block = id_blocks[0]

    pre_stim_len = int(control_params['pre_slice'] * 30000.)
    stim_len = int(.250 * 30000.)
    f1 = plt.figure(); f2 = plt.figure();
    l_sums = dict(); r_sums = dict();
    for sn in np.unique(TRIALS_INFO['STIM_NUMBER']):
        # Get all signals with certain stim number
        l_speaker_traces = select_audio_signals_by_stimulus_number(id_block,
            sn, TRIALS_INFO, 'L')
        r_speaker_traces = select_audio_signals_by_stimulus_number(id_block,
            sn, TRIALS_INFO, 'R')
        
        if sn == 6:
            return l_speaker_traces, r_speaker_traces
        
        ax = f1.add_subplot(3, 4, sn)
        ax.plot(l_speaker_traces[:, pre_stim_len + np.arange(-30, 30)].transpose())
        ax.set_title('L %d' % sn)
        
        ax = f2.add_subplot(3, 4, sn)
        ax.plot(r_speaker_traces[:, pre_stim_len + np.arange(-30, 30)].transpose())
        ax.set_title('R %d' % sn)
        
        slices = l_speaker_traces[:, pre_stim_len:pre_stim_len+stim_len]
        l_sums[sn] = 10*np.log10((slices.astype(np.float) ** 2).sum(axis=1))
        
        slices = r_speaker_traces[:, pre_stim_len:pre_stim_len+stim_len]
        r_sums[sn] = 10*np.log10((slices.astype(np.float) ** 2).sum(axis=1))
    
    plt.show()
    
    # Now plot powers
    plt.figure()
    plt.subplot(131)
    for sn in [1,2,3,4]:
        plt.plot(l_sums[sn], r_sums[sn], '.')
    plt.xlabel('left'); plt.ylabel('right')
    plt.legend(['lo', 'hi', 'le', 'ri'], loc='best')
    plt.title('Pure')
    
    plt.subplot(132)
    for sn in [5,6,7,8]:
        plt.plot(l_sums[sn], r_sums[sn], '.')
    plt.xlabel('left'); plt.ylabel('right')
    plt.legend(['le-hi', 'ri-hi', 'le-lo', 'ri-lo'], loc='best')
    plt.title('PB')
    
    plt.subplot(133)
    for sn in [9,10,11,12]:
        plt.plot(l_sums[sn], r_sums[sn], '.')
    plt.xlabel('left'); plt.ylabel('right')
    plt.legend(['le-hi', 'ri-hi', 'le-lo', 'ri-lo'], loc='best')
    plt.title('LB')
    plt.show()
Example #6
0
# Grabs spike times from db that were calculated from within OE

import OpenElectrophy as OE
import numpy as np
import matplotlib.pyplot as plt

db_name = '/home/chris/Public/20110401_CR13A_audresp_data/0327_002/datafile_CR_CR13A_110327_002.db'
#db_name = '/home/chris/Public/20110401_CR13A_audresp_data/0403_002/datafile_CR_CR13A_110403_002.db'
#db_name = '/home/chris/Public/20110401_CR13A_audresp_data/0329_002/datafile_CR_CR13A_110329_002.db'
OE.open_db(url=('sqlite:///%s' % db_name))

# Load neurons
id_block = OE.sql('select block.id from block where block.name = \
    "CAR Tetrode Data"')[0][0]
id_neurons, = OE.sql('select neuron.id from neuron where neuron.id_block = \
    :id_block',
                     id_block=id_block)

plt.figure()
bigger_spiketimes = np.array([])
for id_neuron in id_neurons:
    n = OE.Neuron().load(id_neuron)

    # Grab spike times from all trials (segments)
    big_spiketimes = np.concatenate(\
        [spiketrain.spike_times - spiketrain.t_start \
        for spiketrain in n._spiketrains])
    bigger_spiketimes = np.concatenate([bigger_spiketimes, big_spiketimes])

    # Compute histogram
    nh, x = np.histogram(big_spiketimes, bins=100)
Example #7
0
def run(db_name, CAR=True, smooth_spikes=True):
    """Filters the data for spike extraction.
    
    db_name: Name of the OpenElectrophy db file
    CAR: If True, subtract the common-average of every channel.
    smooth_spikes: If True, add an additional low-pass filtering step to
        the spike filter.
    """
    # Open connection to the database
    OE.open_db(url=("sqlite:///%s" % db_name))

    # Check that I haven't already run
    id_blocks, = OE.sql("SELECT block.id FROM block WHERE block.name='CAR Tetrode Data'")
    if len(id_blocks) > 0:
        print "CAR Tetrode Data already exists, no need to recompute"
        return

    # Find the block
    id_blocks, = OE.sql("SELECT block.id FROM block WHERE block.name='Raw Data'")
    assert len(id_blocks) == 1
    id_block = id_blocks[0]
    raw_block = OE.Block().load(id_block)

    # Define spike filter
    # TODO: fix so that doesn't assume all sampling rates the same!
    fixed_sampling_rate = OE.AnalogSignal().load(1).sampling_rate
    FILTER_B, FILTER_A = define_spike_filter(fixed_sampling_rate)

    # If requested, define second spike filter
    if smooth_spikes is True:
        FILTER_B2, FILTER_A2 = define_spike_filter_2(fixed_sampling_rate)

    # Find TETRODE_CHANNELS file in data directory of db
    data_dir = path.split(db_name)[0]
    TETRODE_CHANNELS = get_tetrode_channels(path.join(data_dir, "TETRODE_CHANNELS"))
    N_TET = len(TETRODE_CHANNELS)

    # For convenience, flatten TETRODE_CHANNELS to just get worthwhile channels
    GOOD_CHANNELS = [item for sublist in TETRODE_CHANNELS for item in sublist]

    # Create a new block for referenced data, and save to db.
    car_block = OE.Block(
        name="CAR Tetrode Data", info="Raw neural data, now referenced and ordered by tetrode", fileOrigin=db_name
    )
    id_car_block = car_block.save()

    # Make RecordingPoint for each channel, linked to tetrode number with `group`
    # Also keep track of link between channel and RP with ch2rpid dict
    ch2rpid = dict()
    for tn, ch_list in enumerate(TETRODE_CHANNELS):
        for ch in ch_list:
            rp = OE.RecordingPoint(
                name=("RP%d" % ch), id_block=id_car_block, trodness=len(ch_list), channel=float(ch), group=tn
            )
            rp_id = rp.save()
            ch2rpid[ch] = rp_id

    # Find all segments in the block of raw data
    id_segments, = OE.sql("SELECT segment.id FROM segment " + "WHERE segment.id_block = :id_block", id_block=id_block)

    # For each segment in this block, load each AnalogSignal listed in
    # TETRODE channels and average
    # to compute CAR. Then subtract from each AnalogSignal.
    for id_segment in id_segments:
        # Create a new segment in the new block with the same name
        old_seg = OE.Segment().load(id_segment)
        car_seg = OE.Segment(name=old_seg.name, id_block=id_car_block)
        id_car_seg = car_seg.save()

        # Find all AnalogSignals in this segment
        id_sigs, = OE.sql(
            "SELECT analogsignal.id FROM analogsignal " + "WHERE analogsignal.id_segment = :id_segment",
            id_segment=id_segment,
        )

        # Compute average of each
        running_car = 0
        n_summed = 0
        for id_sig in id_sigs:
            sig = OE.AnalogSignal().load(id_sig)
            if sig.channel not in GOOD_CHANNELS:
                continue
            running_car = running_car + sig.signal
            n_summed = n_summed + 1

        # Zero out CAR if CAR is not wanted
        # TODO: eliminate the actual calculation of CAR above in this case
        # For now, just want to avoid weird bugs
        if CAR is False:
            running_car = np.zeros(running_car.shape)

        # Put the CAR into the new block
        # not assigning channel, t_start, sample_rate, maybe more?
        car_sig = OE.AnalogSignal(
            name="CAR",
            signal=running_car / n_summed,
            info="CAR calculated from good channels for this segment",
            id_segment=id_segment,
        )
        car_sig.save()

        # Put all the substractions in id_car_seg
        for id_sig in id_sigs:
            # Load the raw signal (skip bad channels)
            sig = OE.AnalogSignal().load(id_sig)
            if sig.channel not in GOOD_CHANNELS:
                continue

            # Subtract the CAR
            referenced_signal = sig.signal - car_sig.signal

            # Filter!
            filtered_signal = scipy.signal.filtfilt(FILTER_B, FILTER_A, referenced_signal)
            if smooth_spikes is True:
                filtered_signal = scipy.signal.filtfilt(FILTER_B2, FILTER_A2, filtered_signal)

            # Check for infs or nans
            if np.isnan(filtered_signal).any():
                print "ERROR: Filtered signal contains NaN!"
            if np.isinf(filtered_signal).any():
                print "ERROR: Filtered signal contains Inf!"

            # Store in db
            new_sig = OE.AnalogSignal(
                name=sig.name,
                signal=filtered_signal,
                info="CAR has been subtracted",
                id_segment=id_car_seg,
                id_recordingpoint=ch2rpid[sig.channel],
                channel=sig.channel,
                t_start=sig.t_start,
                sampling_rate=sig.sampling_rate,
            )
            new_sig.save()

        # Finally, copy the audio channel over from the old block
        id_audio_sigs, = OE.sql(
            "SELECT analogsignal.id FROM analogsignal "
            + "WHERE analogsignal.id_segment = :id_segment AND "
            + "analogsignal.name LIKE '% Speaker %'",
            id_segment=id_segment,
        )
        for id_audio_sig in id_audio_sigs:
            old_sig = OE.AnalogSignal().load(id_audio_sig)
            OE.AnalogSignal(
                name=old_sig.name,
                signal=old_sig.signal,
                id_segment=id_car_seg,
                channel=old_sig.channel,
                t_start=old_sig.t_start,
                sampling_rate=old_sig.sampling_rate,
            ).save()
Example #8
0
def run(control_params, auto_validate=True, v2_behavior=False):
    # Location of data
    data_dir = control_params['data_dir']

    # Location of the Bcontrol file
    bdata_filename = control_params['behavior_filename']

    # Location of TIMESTAMPS
    timestamps_filename = os.path.join(data_dir, 'TIMESTAMPS')

    # Location of OE db
    db_filename = control_params['db_name']

    # Load timestamps calculated from audio onsets in ns5 file
    ns5_times = np.loadtxt(timestamps_filename, dtype=np.int)

    # Load bcontrol data (will also validate)
    bcl = bcontrol.Bcontrol_Loader(filename=bdata_filename,
                                   auto_validate=auto_validate,
                                   v2_behavior=v2_behavior)
    bcl.load()

    # Grab timestamps from behavior file
    b_onsets = bcl.data['onsets']

    # Try to convert this stuff into the format expected by the syncer
    class fake_bcl(object):
        def __init__(self, onsets):
            self.audio_onsets = onsets

    class fake_rdl(object):
        def __init__(self, onsets):
            self.audio_onsets = onsets

    # Convert into desired format, also throwing away first behavior onset
    # We need to correct for this later.
    fb = fake_bcl(b_onsets[1:])
    fr = fake_rdl(ns5_times)

    # Sync. Will write CORR files to disk.
    # Also produces bs.map_n_to_b_masked and vice versa for trial mapping
    bs = DataSession.BehavingSyncer()
    bs.sync(fb, fr, force_run=True)

    # Put trial numbers into OE db
    db = OE.open_db('sqlite:///%s' % db_filename)

    # Each segment in the db is named trial%d, corresponding to the
    # ordinal TIMESTAMP, which means neural trial time.
    # We want to mark it with the behavioral trial number.
    # For now, put the behavioral trial number into Segment.info
    # TODO: Put the skip-1 behavior into the syncer so we don't have to
    # use the trick. Then we can use map_n_to_b_masked without fear.
    # Note that the 1010 data is NOT missing the first trial.
    # Double check that the neural TIMESTAMP matches the value in peh.
    # Also, add the check_audio_waveforms functionality here so that it's
    # all done at once.
    id_segs, name_segs = OE.sql('select segment.id, segment.name from segment')
    for id_seg, name_seg in zip(id_segs, name_segs):
        # Extract neural trial number from name_seg
        n_trial = int(re.search('trial(\d+)', name_seg).group(1))

        # Convert to behavioral trial number
        # We use the 'trial_number' field of TRIALS_INFO
        # IE the original Matlab numbering of the trial
        # Here we correct for the dropped first trial.
        try:
            b_trial = bcl.data['TRIALS_INFO']['TRIAL_NUMBER'][\
                bs.map_n_to_b_masked[n_trial] + 1]
        except IndexError:
            # masked trial
            if n_trial == 0:
                print("WARNING: Assuming this is the dropped first trial")
                b_trial = bcl.data['TRIALS_INFO']['TRIAL_NUMBER'][0]
            else:
                print("WARNING: can't find trial")
                b_trial = -99

        # Store behavioral trial number in the info field
        seg = OE.Segment().load(id_seg)
        seg.info = '%d' % b_trial
        seg.save()
Example #9
0
def run(db_name, CAR=True, smooth_spikes=True):
    """Filters the data for spike extraction.
    
    db_name: Name of the OpenElectrophy db file
    CAR: If True, subtract the common-average of every channel.
    smooth_spikes: If True, add an additional low-pass filtering step to
        the spike filter.
    """
    # Open connection to the database
    OE.open_db(url=('sqlite:///%s' % db_name))

    # Check that I haven't already run
    id_blocks, = OE.sql(
        "SELECT block.id FROM block WHERE block.name='CAR Tetrode Data'")
    if len(id_blocks) > 0:
        print("CAR Tetrode Data already exists, no need to recompute")
        return

    # Find the block
    id_blocks, = OE.sql(
        "SELECT block.id FROM block WHERE block.name='Raw Data'")
    assert (len(id_blocks) == 1)
    id_block = id_blocks[0]
    raw_block = OE.Block().load(id_block)

    # Define spike filter
    # TODO: fix so that doesn't assume all sampling rates the same!
    fixed_sampling_rate = OE.AnalogSignal().load(1).sampling_rate
    FILTER_B, FILTER_A = define_spike_filter(fixed_sampling_rate)

    # If requested, define second spike filter
    if smooth_spikes is True:
        FILTER_B2, FILTER_A2 = define_spike_filter_2(fixed_sampling_rate)

    # Find TETRODE_CHANNELS file in data directory of db
    data_dir = path.split(db_name)[0]
    TETRODE_CHANNELS = get_tetrode_channels(
        path.join(data_dir, 'TETRODE_CHANNELS'))
    N_TET = len(TETRODE_CHANNELS)

    # For convenience, flatten TETRODE_CHANNELS to just get worthwhile channels
    GOOD_CHANNELS = [item for sublist in TETRODE_CHANNELS for item in sublist]

    # Create a new block for referenced data, and save to db.
    car_block = OE.Block(\
        name='CAR Tetrode Data',
        info='Raw neural data, now referenced and ordered by tetrode',
        fileOrigin=db_name)
    id_car_block = car_block.save()

    # Make RecordingPoint for each channel, linked to tetrode number with `group`
    # Also keep track of link between channel and RP with ch2rpid dict
    ch2rpid = dict()
    for tn, ch_list in enumerate(TETRODE_CHANNELS):
        for ch in ch_list:
            rp = OE.RecordingPoint(name=('RP%d' % ch),
                                   id_block=id_car_block,
                                   trodness=len(ch_list),
                                   channel=float(ch),
                                   group=tn)
            rp_id = rp.save()
            ch2rpid[ch] = rp_id

    # Find all segments in the block of raw data
    id_segments, = OE.sql('SELECT segment.id FROM segment ' + \
        'WHERE segment.id_block = :id_block', id_block=id_block)

    # For each segment in this block, load each AnalogSignal listed in
    # TETRODE channels and average
    # to compute CAR. Then subtract from each AnalogSignal.
    for id_segment in id_segments:
        # Create a new segment in the new block with the same name
        old_seg = OE.Segment().load(id_segment)
        car_seg = OE.Segment(
            name=old_seg.name,
            id_block=id_car_block,
        )
        id_car_seg = car_seg.save()

        # Find all AnalogSignals in this segment
        id_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_segment = :id_segment', id_segment=id_segment)

        # Compute average of each
        running_car = 0
        n_summed = 0
        for id_sig in id_sigs:
            sig = OE.AnalogSignal().load(id_sig)
            if sig.channel not in GOOD_CHANNELS:
                continue
            running_car = running_car + sig.signal
            n_summed = n_summed + 1

        # Zero out CAR if CAR is not wanted
        # TODO: eliminate the actual calculation of CAR above in this case
        # For now, just want to avoid weird bugs
        if CAR is False:
            running_car = np.zeros(running_car.shape)

        # Put the CAR into the new block
        # not assigning channel, t_start, sample_rate, maybe more?
        car_sig = OE.AnalogSignal(
            name='CAR',
            signal=old_div(running_car, n_summed),
            info='CAR calculated from good channels for this segment',
            id_segment=id_segment)
        car_sig.save()

        # Put all the substractions in id_car_seg
        for id_sig in id_sigs:
            # Load the raw signal (skip bad channels)
            sig = OE.AnalogSignal().load(id_sig)
            if sig.channel not in GOOD_CHANNELS:
                continue

            # Subtract the CAR
            referenced_signal = sig.signal - car_sig.signal

            # Filter!
            filtered_signal = scipy.signal.filtfilt(FILTER_B, FILTER_A,
                                                    referenced_signal)
            if smooth_spikes is True:
                filtered_signal = scipy.signal.filtfilt(
                    FILTER_B2, FILTER_A2, filtered_signal)

            # Check for infs or nans
            if np.isnan(filtered_signal).any():
                print("ERROR: Filtered signal contains NaN!")
            if np.isinf(filtered_signal).any():
                print("ERROR: Filtered signal contains Inf!")

            # Store in db
            new_sig = OE.AnalogSignal(\
                name=sig.name,
                signal=filtered_signal,
                info='CAR has been subtracted',
                id_segment=id_car_seg,
                id_recordingpoint=ch2rpid[sig.channel],
                channel=sig.channel,
                t_start=sig.t_start,
                sampling_rate=sig.sampling_rate)
            new_sig.save()

        # Finally, copy the audio channel over from the old block
        id_audio_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_segment = :id_segment AND ' + \
            "analogsignal.name LIKE '% Speaker %'", id_segment=id_segment)
        for id_audio_sig in id_audio_sigs:
            old_sig = OE.AnalogSignal().load(id_audio_sig)
            OE.AnalogSignal(\
                name=old_sig.name,
                signal=old_sig.signal,
                id_segment=id_car_seg,
                channel=old_sig.channel,
                t_start=old_sig.t_start,
                sampling_rate=old_sig.sampling_rate).save()
Example #10
0
def get_tetrode_block_id():
    id_blocks, = OE.sql('select block.id from block where \
        block.name = "Spike-filtered Data"')
    return id_blocks[0]
Example #11
0
def get_tetrode_block_id():
    id_blocks, = OE.sql('select block.id from block where \
        block.name = "Spike-filtered Data"')
    return id_blocks[0]
Example #12
0
# Grabs spike times from db that were calculated from within OE

import OpenElectrophy as OE
import numpy as np
import matplotlib.pyplot as plt

db_name = '/home/chris/Public/20110401_CR13A_audresp_data/0327_002/datafile_CR_CR13A_110327_002.db'
#db_name = '/home/chris/Public/20110401_CR13A_audresp_data/0403_002/datafile_CR_CR13A_110403_002.db'
#db_name = '/home/chris/Public/20110401_CR13A_audresp_data/0329_002/datafile_CR_CR13A_110329_002.db'
OE.open_db(url=('sqlite:///%s' % db_name))    

# Load neurons
id_block = OE.sql('select block.id from block where block.name = \
    "CAR Tetrode Data"')[0][0]
id_neurons, = OE.sql('select neuron.id from neuron where neuron.id_block = \
    :id_block', id_block=id_block)

plt.figure()
bigger_spiketimes = np.array([])
for id_neuron in id_neurons:
    n = OE.Neuron().load(id_neuron)
    
    # Grab spike times from all trials (segments)
    big_spiketimes = np.concatenate(\
        [spiketrain.spike_times - spiketrain.t_start \
        for spiketrain in n._spiketrains])
    bigger_spiketimes = np.concatenate([bigger_spiketimes, big_spiketimes])
    
    # Compute histogram
    nh, x = np.histogram(big_spiketimes, bins=100)
    x = np.diff(x) + x[:-1]
def run(control_params, auto_validate=True, v2_behavior=False):
    # Location of data
    data_dir = control_params['data_dir']

    # Location of the Bcontrol file
    bdata_filename = control_params['behavior_filename']

    # Location of TIMESTAMPS
    timestamps_filename = os.path.join(data_dir, 'TIMESTAMPS')

    # Location of OE db
    db_filename = control_params['db_name']

    # Load timestamps calculated from audio onsets in ns5 file
    ns5_times = np.loadtxt(timestamps_filename, dtype=np.int)

    # Load bcontrol data (will also validate)
    bcl = bcontrol.Bcontrol_Loader(filename=bdata_filename,
        auto_validate=auto_validate, v2_behavior=v2_behavior)
    bcl.load()

    # Grab timestamps from behavior file
    b_onsets = bcl.data['onsets']

    # Try to convert this stuff into the format expected by the syncer
    class fake_bcl:
        def __init__(self, onsets):
            self.audio_onsets = onsets
    class fake_rdl:
        def __init__(self, onsets):
            self.audio_onsets = onsets

    # Convert into desired format, also throwing away first behavior onset
    # We need to correct for this later.
    fb = fake_bcl(b_onsets[1:])
    fr = fake_rdl(ns5_times)

    # Sync. Will write CORR files to disk.
    # Also produces bs.map_n_to_b_masked and vice versa for trial mapping
    bs = DataSession.BehavingSyncer()
    bs.sync(fb, fr, force_run=True)

    # Put trial numbers into OE db
    db = OE.open_db('sqlite:///%s' % db_filename)

    # Each segment in the db is named trial%d, corresponding to the
    # ordinal TIMESTAMP, which means neural trial time.
    # We want to mark it with the behavioral trial number.
    # For now, put the behavioral trial number into Segment.info
    # TODO: Put the skip-1 behavior into the syncer so we don't have to
    # use the trick. Then we can use map_n_to_b_masked without fear.
    # Note that the 1010 data is NOT missing the first trial.
    # Double check that the neural TIMESTAMP matches the value in peh.
    # Also, add the check_audio_waveforms functionality here so that it's
    # all done at once.
    id_segs, name_segs = OE.sql('select segment.id, segment.name from segment')
    for id_seg, name_seg in zip(id_segs, name_segs):
        # Extract neural trial number from name_seg
        n_trial = int(re.search('trial(\d+)', name_seg).group(1))
        
        # Convert to behavioral trial number
        # We use the 'trial_number' field of TRIALS_INFO
        # IE the original Matlab numbering of the trial
        # Here we correct for the dropped first trial.
        try:
            b_trial = bcl.data['TRIALS_INFO']['TRIAL_NUMBER'][\
                bs.map_n_to_b_masked[n_trial] + 1]
        except IndexError:
            # masked trial
            if n_trial == 0:
                print "WARNING: Assuming this is the dropped first trial"
                b_trial = bcl.data['TRIALS_INFO']['TRIAL_NUMBER'][0]
            else:
                print "WARNING: can't find trial"
                b_trial = -99
        
        # Store behavioral trial number in the info field
        seg = OE.Segment().load(id_seg)
        seg.info = '%d' % b_trial
        seg.save()