Exemplo n.º 1
0
def run(db_name, save_fig_name=None):
    """If save_fig_name is not None, should be ex. '~/test.png' """
    # Load the raw data
    db = OE.open_db(url=('sqlite:///%s' % db_name))
    id_blocks, = OE.sql(
        'SELECT block.id FROM block WHERE block.name="Raw Data"')
    id_block = id_blocks[0]

    id_recordingpoints, rp_names = OE.sql("SELECT \
        recordingpoint.id, recordingpoint.name \
        FROM recordingpoint \
        WHERE recordingpoint.id_block = :id_block",
                                          id_block=id_block)

    f = plt.figure(figsize=(10, 10))

    # Process each recording point separately
    for n, (id_rp, tt) in enumerate(zip(id_recordingpoints[:16],
                                        rp_names[:16])):
        # Load all signals from all segments with this recording point
        id_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_recordingpoint = :id_recordingpoint',
            id_recordingpoint=id_rp)

        # Average the signal
        avgsig = np.zeros(OE.AnalogSignal().load(id_sigs[0]).signal.shape)
        for id_sig in id_sigs:
            sig = OE.AnalogSignal().load(id_sig)
            avgsig = avgsig + sig.signal
        avgsig = old_div(avgsig, len(id_sigs))

        # Plot the average signal of this recording point
        ax = f.add_subplot(4, 4, n + 1)
        ax.plot(
            old_div(np.arange(len(avgsig)), sig.sampling_rate) * 1000, avgsig)
        #ax.set_ylim((-250, 250))
        ax.set_title(tt)

    if save_fig_name is None:
        plt.show()
    else:
        plt.savefig(save_fig_name)
        plt.close()
Exemplo n.º 2
0
def stuff(filename,
          db_name,
          TIMESTAMPS,
          SHOVE_CHANNELS,
          pre_slice_len,
          post_slice_len,
          db_type='sqlite'):

    # Load the file and file header
    l = ns5.Loader(filename=filename)
    l.load_file()

    # Audio channel numbers
    AUDIO_CHANNELS = l.get_audio_channel_numbers()

    # Open connection to OE db and create a block
    if db_type is 'postgres':
        OE.open_db(url=(
            'postgresql://[email protected]/test'))  # %s' % db_name))
        print('post')
    else:
        OE.open_db(url=('sqlite:///%s' % db_name))
    #OE.open_db(url=('mysql://*****:*****@localhost/%s' % db_name))
    block = OE.Block(name='Raw Data',
                     info='Raw data sliced around trials',
                     fileOrigin=filename)
    id_block = block.save()  # need this later

    # Convert requested slice lengths to samples
    pre_slice_len_samples = int(pre_slice_len * l.header.f_samp)
    post_slice_len_samples = int(post_slice_len * l.header.f_samp)

    # Add RecordingPoint
    ch2rpid = dict()
    for ch in SHOVE_CHANNELS:
        rp = OE.RecordingPoint(name=('RP%d' % ch),
                               id_block=id_block,
                               channel=float(ch))
        rp_id = rp.save()
        ch2rpid[ch] = rp_id

    # Extract each trial as a segment
    for tn, trial_start in enumerate(TIMESTAMPS):
        # Create segment for this trial
        segment = OE.Segment(id_block=id_block,
                             name=('trial%d' % tn),
                             info='raw data loaded from good channels')
        id_segment = segment.save()

        # Create AnalogSignal for each channel
        for chn, ch in enumerate(SHOVE_CHANNELS):
            # Load
            x = np.array(l._get_channel(ch)[trial_start-pre_slice_len_samples:\
                trial_start+post_slice_len_samples])

            # Convert to uV
            x = x * uV_QUANTUM

            # Put in AnalogSignal and save to db
            sig = OE.AnalogSignal(signal=x,
                                  channel=float(ch),
                                  sampling_rate=l.header.f_samp,
                                  t_start=old_div(
                                      (trial_start - pre_slice_len_samples),
                                      l.header.f_samp),
                                  id_segment=id_segment,
                                  id_recordingpoint=ch2rpid[ch],
                                  name=('Channel %d Trial %d' % (ch, tn)))

            # Special processing for audio channels
            if ch == AUDIO_CHANNELS[0]:
                sig.name = ('L Speaker Trial %d' % tn)
            elif ch == AUDIO_CHANNELS[1]:
                sig.name = ('R Speaker Trial %d' % tn)

            # Save signal to database
            sig.save()

        # Handle AUDIO CHANNELS only slightly differently
        for ch in AUDIO_CHANNELS:
            # Load
            x = np.array(l._get_channel(ch)[trial_start-pre_slice_len_samples:\
                trial_start+post_slice_len_samples])

            # Special processing for audio channels
            if ch == AUDIO_CHANNELS[0]:
                sname = ('L Speaker Trial %d' % tn)
            elif ch == AUDIO_CHANNELS[1]:
                sname = ('R Speaker Trial %d' % tn)

            # Put in AnalogSignal and save to db
            sig = OE.AnalogSignal(signal=x,
                                  channel=float(ch),
                                  sampling_rate=l.header.f_samp,
                                  t_start=old_div(
                                      (trial_start - pre_slice_len_samples),
                                      l.header.f_samp),
                                  id_segment=id_segment,
                                  name=sname)

            # Save signal to database
            sig.save()

        # Save segment (with all analogsignals) to db
        # Actually this may be unnecessary
        # Does saving the signals link to the segment automatically?
        segment.save()

    return (id_segment, id_block)
Exemplo n.º 3
0
def run(db_name, CAR=True, smooth_spikes=True):
    """Filters the data for spike extraction.
    
    db_name: Name of the OpenElectrophy db file
    CAR: If True, subtract the common-average of every channel.
    smooth_spikes: If True, add an additional low-pass filtering step to
        the spike filter.
    """
    # Open connection to the database
    OE.open_db(url=('sqlite:///%s' % db_name))

    # Check that I haven't already run
    id_blocks, = OE.sql(
        "SELECT block.id FROM block WHERE block.name='CAR Tetrode Data'")
    if len(id_blocks) > 0:
        print("CAR Tetrode Data already exists, no need to recompute")
        return

    # Find the block
    id_blocks, = OE.sql(
        "SELECT block.id FROM block WHERE block.name='Raw Data'")
    assert (len(id_blocks) == 1)
    id_block = id_blocks[0]
    raw_block = OE.Block().load(id_block)

    # Define spike filter
    # TODO: fix so that doesn't assume all sampling rates the same!
    fixed_sampling_rate = OE.AnalogSignal().load(1).sampling_rate
    FILTER_B, FILTER_A = define_spike_filter(fixed_sampling_rate)

    # If requested, define second spike filter
    if smooth_spikes is True:
        FILTER_B2, FILTER_A2 = define_spike_filter_2(fixed_sampling_rate)

    # Find TETRODE_CHANNELS file in data directory of db
    data_dir = path.split(db_name)[0]
    TETRODE_CHANNELS = get_tetrode_channels(
        path.join(data_dir, 'TETRODE_CHANNELS'))
    N_TET = len(TETRODE_CHANNELS)

    # For convenience, flatten TETRODE_CHANNELS to just get worthwhile channels
    GOOD_CHANNELS = [item for sublist in TETRODE_CHANNELS for item in sublist]

    # Create a new block for referenced data, and save to db.
    car_block = OE.Block(\
        name='CAR Tetrode Data',
        info='Raw neural data, now referenced and ordered by tetrode',
        fileOrigin=db_name)
    id_car_block = car_block.save()

    # Make RecordingPoint for each channel, linked to tetrode number with `group`
    # Also keep track of link between channel and RP with ch2rpid dict
    ch2rpid = dict()
    for tn, ch_list in enumerate(TETRODE_CHANNELS):
        for ch in ch_list:
            rp = OE.RecordingPoint(name=('RP%d' % ch),
                                   id_block=id_car_block,
                                   trodness=len(ch_list),
                                   channel=float(ch),
                                   group=tn)
            rp_id = rp.save()
            ch2rpid[ch] = rp_id

    # Find all segments in the block of raw data
    id_segments, = OE.sql('SELECT segment.id FROM segment ' + \
        'WHERE segment.id_block = :id_block', id_block=id_block)

    # For each segment in this block, load each AnalogSignal listed in
    # TETRODE channels and average
    # to compute CAR. Then subtract from each AnalogSignal.
    for id_segment in id_segments:
        # Create a new segment in the new block with the same name
        old_seg = OE.Segment().load(id_segment)
        car_seg = OE.Segment(
            name=old_seg.name,
            id_block=id_car_block,
        )
        id_car_seg = car_seg.save()

        # Find all AnalogSignals in this segment
        id_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_segment = :id_segment', id_segment=id_segment)

        # Compute average of each
        running_car = 0
        n_summed = 0
        for id_sig in id_sigs:
            sig = OE.AnalogSignal().load(id_sig)
            if sig.channel not in GOOD_CHANNELS:
                continue
            running_car = running_car + sig.signal
            n_summed = n_summed + 1

        # Zero out CAR if CAR is not wanted
        # TODO: eliminate the actual calculation of CAR above in this case
        # For now, just want to avoid weird bugs
        if CAR is False:
            running_car = np.zeros(running_car.shape)

        # Put the CAR into the new block
        # not assigning channel, t_start, sample_rate, maybe more?
        car_sig = OE.AnalogSignal(
            name='CAR',
            signal=old_div(running_car, n_summed),
            info='CAR calculated from good channels for this segment',
            id_segment=id_segment)
        car_sig.save()

        # Put all the substractions in id_car_seg
        for id_sig in id_sigs:
            # Load the raw signal (skip bad channels)
            sig = OE.AnalogSignal().load(id_sig)
            if sig.channel not in GOOD_CHANNELS:
                continue

            # Subtract the CAR
            referenced_signal = sig.signal - car_sig.signal

            # Filter!
            filtered_signal = scipy.signal.filtfilt(FILTER_B, FILTER_A,
                                                    referenced_signal)
            if smooth_spikes is True:
                filtered_signal = scipy.signal.filtfilt(
                    FILTER_B2, FILTER_A2, filtered_signal)

            # Check for infs or nans
            if np.isnan(filtered_signal).any():
                print("ERROR: Filtered signal contains NaN!")
            if np.isinf(filtered_signal).any():
                print("ERROR: Filtered signal contains Inf!")

            # Store in db
            new_sig = OE.AnalogSignal(\
                name=sig.name,
                signal=filtered_signal,
                info='CAR has been subtracted',
                id_segment=id_car_seg,
                id_recordingpoint=ch2rpid[sig.channel],
                channel=sig.channel,
                t_start=sig.t_start,
                sampling_rate=sig.sampling_rate)
            new_sig.save()

        # Finally, copy the audio channel over from the old block
        id_audio_sigs, = OE.sql('SELECT analogsignal.id FROM analogsignal ' + \
            'WHERE analogsignal.id_segment = :id_segment AND ' + \
            "analogsignal.name LIKE '% Speaker %'", id_segment=id_segment)
        for id_audio_sig in id_audio_sigs:
            old_sig = OE.AnalogSignal().load(id_audio_sig)
            OE.AnalogSignal(\
                name=old_sig.name,
                signal=old_sig.signal,
                id_segment=id_car_seg,
                channel=old_sig.channel,
                t_start=old_sig.t_start,
                sampling_rate=old_sig.sampling_rate).save()