Example #1
0
def write_firings_file(channels, times, labels, fname):
    L = len(channels)
    X = np.zeros((3, L), dtype='float64')
    X[0, :] = channels
    X[1, :] = times
    X[2, :] = labels
    mdaio.writemda64(X, fname)
def test_compute_templates():
    M, N, K, T, L = 5, 1000, 6, 50, 100
    X = np.random.rand(M, N)
    mdaio.writemda32(X, 'tmp.mda')
    F = np.zeros((3, L))
    F[1, :] = 1 + np.random.randint(N, size=(1, L))
    F[2, :] = 1 + np.random.randint(K, size=(1, L))
    mdaio.writemda64(F, 'tmp2.mda')
    ret = compute_templates(timeseries='tmp.mda',
                            firings='tmp2.mda',
                            templates_out='tmp3.mda',
                            clip_size=T)
    assert (ret)
    templates0 = mdaio.readmda('tmp3.mda')
    assert (templates0.shape == (M, T, K))
    return True
Example #3
0
def create_label_map(*,
                     metrics,
                     label_map_out,
                     firing_rate_thresh=.05,
                     isolation_thresh=0.95,
                     noise_overlap_thresh=.03,
                     peak_snr_thresh=1.5):
    """
    Generate a label map based on the metrics file, where labels being mapped to zero are to be removed.

    Parameters
    ----------
    metrics : INPUT
        Path of metrics json file to be used for generating the label map
    label_map_out : OUTPUT
        Path to mda file where the second column is the present label, and the first column is the new label
        ...
    firing_rate_thresh : float64
        (Optional) firing rate must be above this
    isolation_thresh : float64
        (Optional) isolation must be above this
    noise_overlap_thresh : float64
        (Optional) noise_overlap_thresh must be below this
    peak_snr_thresh : float64
        (Optional) peak snr must be above this
    """
    #TODO: Way to pass in logic or thresholds flexibly

    label_map = []

    #Load json
    with open(metrics) as metrics_json:
        metrics_data = json.load(metrics_json)

    #Iterate through all clusters
    for idx in range(len(metrics_data['clusters'])):
        if metrics_data['clusters'][idx]['metrics']['firing_rate'] < firing_rate_thresh or \
            metrics_data['clusters'][idx]['metrics']['isolation'] < isolation_thresh or \
            metrics_data['clusters'][idx]['metrics']['noise_overlap'] > noise_overlap_thresh or \
            metrics_data['clusters'][idx]['metrics']['peak_snr'] < peak_snr_thresh:
            #Map to zero (mask out)
            label_map.append([0, metrics_data['clusters'][idx]['label']])
        elif metrics_data['clusters'][idx]['metrics'][
                'bursting_parent']:  #Check if burst parent exists
            label_map.append([
                metrics_data['clusters'][idx]['metrics']['bursting_parent'],
                metrics_data['clusters'][idx]['label']
            ])
        else:
            label_map.append([
                metrics_data['clusters'][idx]['label'],
                metrics_data['clusters'][idx]['label']
            ])  # otherwise, map to itself!

    #Writeout
    return mdaio.writemda64(np.array(label_map), label_map_out)
Example #4
0
def synthesize_random_firings(*,
                              firings_out,
                              K=20,
                              samplerate=30000,
                              duration=60):
    """
    Synthesize random waveforms for use in creating a synthetic timeseries dataset

    Parameters
    ----------
    firings_out : OUTPUT
        Path to output firings mda file. 3xL, L is the number of events, second row are timestamps, third row are integer unit labels
    
    K : int
        (Optional) number of simulated units
    samplerate : double
        (Optional) sampling frequency in Hz
    duration : double
        (Optional) duration of the simulated acquisition in seconds
    """
    firing_rates = 3 * np.ones((K))
    refr = 4

    N = np.int64(duration * samplerate)

    # events/sec * sec/timepoint * N
    populations = np.ceil(firing_rates / samplerate * N).astype('int')
    times = np.zeros(0)
    labels = np.zeros(0)
    for k in range(1, K + 1):
        refr_timepoints = refr / 1000 * samplerate

        times0 = np.random.rand(populations[k - 1]) * (N - 1) + 1

        ## make an interesting autocorrelogram shape
        times0 = np.hstack(
            (times0, times0 +
             rand_distr2(refr_timepoints, refr_timepoints * 20, times0.size)))
        times0 = times0[np.random.choice(times0.size, int(times0.size / 2))]
        times0 = times0[np.where((0 <= times0) & (times0 < N))]

        times0 = enforce_refractory_period(times0, refr_timepoints)
        times = np.hstack((times, times0))
        labels = np.hstack((labels, k * np.ones(times0.shape)))

    sort_inds = np.argsort(times)
    times = times[sort_inds]
    labels = labels[sort_inds]

    firings = np.zeros((3, times.size), dtype=np.float64)
    firings[1, :] = times
    firings[2, :] = labels
    return mdaio.writemda64(firings, firings_out)
Example #5
0
def apply_label_map(*, firings, label_map, firings_out):
    """
    Apply a label map to a given firings, including masking and merging

    Parameters
    ----------
    firings : INPUT
        Path of input firings mda file
    label_map : INPUT
        Path of input label map mda file [base 1, mapping to zero removes from firings]
    firings_out : OUTPUT
        ...
    """
    firings = mdaio.readmda(firings)
    label_map = mdaio.readmda(label_map)
    label_map = np.reshape(label_map, (-1, 2))
    label_map = label_map[np.argsort(label_map[:,
                                               0])]  # Assure input is sorted

    #Propagate merge pairs to lowest label number
    for idx, label in enumerate(label_map[:, 1]):
        # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :)
        #label_map[np.isin(label_map[:,0],label),0] = label_map[idx,0] # Input should be sorted
        label_map[np.where(label_map[:, 0] == label)[0],
                  0] = label_map[idx, 0]  # Input should be sorted

    #Apply label map
    for label_pair in range(label_map.shape[0]):
        # jfm changed on 12/8/17 because isin() is not isin() older versions of numpy. :)
        #firings[2, np.isin(firings[2, :], label_map[label_pair, 1])] = label_map[label_pair,0]
        firings[2, np.where(
            firings[2, :] == label_map[label_pair,
                                       1])[0]] = label_map[label_pair, 0]

    #Mask out all labels mapped to zero
    firings = firings[:, firings[2, :] != 0]

    #Write remapped firings
    return mdaio.writemda64(firings, firings_out)
Example #6
0
    def run(self):
        tmpdir=os.environ.get('ML_PROCESSOR_TEMPDIR')
        if not tmpdir:
            raise Exception('Environment variable not set: ML_PROCESSOR_TEMPDIR')
        
        source_dir=os.path.dirname(os.path.realpath(__file__))
        
        ## todo: link rather than copy
        print('Copying timeseries file: {} -> {}'.format(self.timeseries,tmpdir+'/raw.mda'))
        copyfile(self.timeseries,tmpdir+'/raw.mda')
        
        print('Reading timeseries header...')
        HH=mdaio.readmda_header(tmpdir+'/raw.mda')
        num_channels=HH.dims[0]
        num_timepoints=HH.dims[1]
        duration_minutes=num_timepoints/self.samplerate/60
        print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes))
        
        print('Creating .prb file...')
        prb_text=self._read_text_file(source_dir+'/template.prb')
        prb_text=prb_text.replace('$num_channels$','{}'.format(num_channels))
        prb_text=prb_text.replace('$radius$','{}'.format(self.adjacency_radius))
        geom=np.genfromtxt(self.geom, delimiter=',')
        geom_str='{\n'
        for m in range(geom.shape[0]):
            geom_str+='  {}: [{},{}],\n'.format(m,geom[m,0],geom[m,1]) # todo: handle 3d geom
        geom_str+='}'
        prb_text=prb_text.replace('$geometry$','{}'.format(geom_str))
        self._write_text_file(tmpdir+'/geometry.prb',prb_text)
        
        print('Creating .params file...')
        txt=self._read_text_file(source_dir+'/template.params')
        txt=txt.replace('$header_size$','{}'.format(HH.header_size))
        txt=txt.replace('$prb_file$',tmpdir+'/geometry.prb')
        txt=txt.replace('$dtype$',HH.dt)
        txt=txt.replace('$num_channels$','{}'.format(num_channels))
        txt=txt.replace('$samplerate$','{}'.format(self.samplerate))
        txt=txt.replace('$template_width_ms$','{}'.format(self.template_width_ms))
        txt=txt.replace('$spike_thresh$','{}'.format(self.spike_thresh))
        if self.detect_sign>0:
            peaks_str='positive'
        elif self.detect_sign<0:
            peaks_str='negative'
        else:
            peaks_str='both'
        txt=txt.replace('$peaks$',peaks_str)
        txt=txt.replace('$whitening_max_elts$','{}'.format(self.whitening_max_elts))
        txt=txt.replace('$clustering_max_elts$','{}'.format(self.clustering_max_elts))
        self._write_text_file(tmpdir+'/raw.params',txt)
        
        print('Running spyking circus...')
        #num_threads=np.maximum(1,int(os.cpu_count()/2))
        num_threads=1 # for some reason, using more than 1 thread causes an error
        cmd='spyking-circus {} -c {} '.format(tmpdir+'/raw.mda',num_threads)
        print(cmd)
        retcode=self._run_command_and_print_output(cmd)

        if retcode != 0:
            raise Exception('Spyking circus returned a non-zero exit code')

        result_fname=tmpdir+'/raw/raw.result.hdf5'
        if not os.path.exists(result_fname):
            raise Exception('Result file does not exist: '+result_fname)
        
        firings=sc_results_to_firings(result_fname)
        print(firings.shape)
        mdaio.writemda64(firings,self.firings_out)
        
        return True
def separateSpikesInEpochs(data_dir=None,
                           firings_file='firings.curated.mda',
                           timestamp_files=None,
                           write_separated_spikes=True):
    """
    Takes curated spikes from MountainSort and combines this information with spike timestamps to create separate curated spikes for each epoch

    :firings_file: Curated firings file
    :timestamp_files: Spike timestamps file list
    :write_separated_spikes: If the separated spikes should be written back to the data directory.
    :returns: List of spikes for each epoch
    """

    if data_dir is None:
        # Get the firings file
        data_dir = QtHelperUtils.get_directory(
            message="Select Curated firings location")

    separated_tetrodes = []
    curated_firings = []
    merged_curated_firings = []
    tetrode_list = os.listdir(data_dir)
    for tt_dir in tetrode_list:
        try:
            if firings_file in os.listdir(data_dir + '/' + tt_dir):
                curated_firings.append([])
                separated_tetrodes.append(tt_dir)
                firings_file_location = '/'.join(
                    [data_dir, tt_dir, firings_file])
                merged_curated_firings.append(
                    mdaio.readmda(firings_file_location))
                print(MODULE_IDENTIFIER +
                      'Read merged firings file for tetrode %s!' % tt_dir)
            else:
                print(MODULE_IDENTIFIER +
                      'Merged firings %s not  found for tetrode %s!' %
                      (firings_file, tt_dir))
        except (FileNotFoundError, IOError) as err:
            print(MODULE_IDENTIFIER +
                  'Unable to read merged firings file for tetrode %s!' %
                  tt_dir)
            print(err)

    gui_root = Tk()
    gui_root.withdraw()
    timestamp_headers = []
    if timestamp_files is None:
        # Read all the timestamp files
        timestamp_files = filedialog.askopenfilenames(initialdir=DEFAULT_SEARCH_PATH, \
                title="Select all timestamp files", \
                filetypes=(("Timestamps", ".mda"), ("All Files", "*.*")))
    gui_root.destroy()

    for ts_file in timestamp_files:
        timestamp_headers.append(mdaio.readmda_header(ts_file))

    # Now that we have both the timestamp headers and the timestamp files, we
    # can separate spikes out.  It is important here for the timestamp files to
    # be in the same order as the curated firings as that is the only way for
    # us to tell that the firings are being split up correctly.

    print(MODULE_IDENTIFIER + 'Looking at spike timestamps in order')
    print(timestamp_files)

    # First splice up curated spikes into indiviual epochs
    for tt_idx, tt_firings in enumerate(merged_curated_firings):
        for ep_idx, ts_header in enumerate(timestamp_headers):
            if tt_firings is None:
                curated_firings[tt_idx].append(None)
                continue

            n_data_points = ts_header.dims[0]
            print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx) + ': ' +
                  str(n_data_points) + ' samples.')
            last_spike_from_epoch = np.searchsorted(
                tt_firings[1], n_data_points, side='left') - 1

            # If there are no spikes in this epoch, there might still be some in future epochs!
            if last_spike_from_epoch < 0:
                tt_firings[1] = tt_firings[1] - float(n_data_points)
                curated_firings[tt_idx].append(None)
                continue

            last_spike_sample_number = tt_firings[1][last_spike_from_epoch]
            print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': First spike ' + str(tt_firings[1][0])\
                    + ', Last spike ' + str(last_spike_sample_number))
            epoch_spikes = tt_firings[:, :last_spike_from_epoch]
            curated_firings[tt_idx].append(epoch_spikes)

            if last_spike_from_epoch < (len(tt_firings[1]) - 1):
                # Slice the merged curated firing to only have the remaining spikes
                tt_firings = tt_firings[:, last_spike_from_epoch + 1:]
                print(MODULE_IDENTIFIER +
                      'Trimming curated spikes. Aggregate sample start ' +
                      str(tt_firings[1][0]))
                tt_firings[1] = tt_firings[1] - float(n_data_points)
                print(MODULE_IDENTIFIER + 'Sample number trimmed to ' +
                      str(tt_firings[1][0]))
            else:
                print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] +
                      ', Reached end of curated firings at Epoch ' +
                      str(ep_idx))
                tt_firings = None

    print(MODULE_IDENTIFIER +
          'Spikes separated in epochs. Substituting timestamps!')
    # For each epoch replace the sample numbers with the corresponding
    # timestamps. We are going through multiple revisions for this so that we
    # only have to load one timestamp file at a time
    for ep_idx, ts_file in enumerate(timestamp_files):
        epoch_timestamps = mdaio.readmda(ts_file)
        print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx))
        for tt_idx, tt_curated_firings in enumerate(curated_firings):
            if tt_curated_firings[ep_idx] is None:
                continue
            # Need to use the original array because changes to tt_curated_firings do not get copied back
            curated_firings[tt_idx][ep_idx][1] = epoch_timestamps[np.array(
                tt_curated_firings[ep_idx][1], dtype=int)]
            print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': Samples (' + \
                    str(tt_curated_firings[ep_idx][1][0]) + ', ' + str(tt_curated_firings[ep_idx][1][-1]), ')')

    if write_separated_spikes:
        try:
            for tt_idx, tet in enumerate(separated_tetrodes):
                for ep_idx in range(len(timestamp_files)):
                    if curated_firings[tt_idx][ep_idx] is not None:
                        ep_firings_file_name = data_dir + '/' + tet + '/firings-' + \
                                str(ep_idx+1) + '.curated.mda'
                        mdaio.writemda64(curated_firings[tt_idx][ep_idx],
                                         ep_firings_file_name)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                print(MODULE_IDENTIFIER +
                      'Unable to write timestamped firings!')
                print(exception)

    return curated_firings