Exemple #1
0
    def _setup_recording(self, recording, output_folder):
        from mountainlab_pytools import mdaio
        p = self.params

        if not check_if_installed(IronclustSorter.ironclust_path):
            raise ImportError(IronclustSorter.installation_mesg)

        dataset_dir = (output_folder / 'ironclust_dataset').absolute()
        if not dataset_dir.is_dir():
            dataset_dir.mkdir()

        # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
        se.MdaRecordingExtractor.write_recording(recording=recording,
                                                 save_path=dataset_dir)
        samplerate = recording.get_sampling_frequency()

        if self.debug:
            print('Reading timeseries header...')
        HH = mdaio.readmda_header(str(dataset_dir / 'raw.mda'))
        num_channels = HH.dims[0]
        num_timepoints = HH.dims[1]
        duration_minutes = num_timepoints / samplerate / 60
        if self.debug:
            print(
                'Num. channels = {}, Num. timepoints = {}, duration = {} minutes'
                .format(num_channels, num_timepoints, duration_minutes))

        if self.debug:
            print('Creating .params file...')
        txt = ''
        txt += 'samplerate={}\n'.format(samplerate)
        txt += 'detect_sign={}\n'.format(p['detect_sign'])
        txt += 'adjacency_radius={}\n'.format(p['adjacency_radius'])
        txt += 'detect_threshold={}\n'.format(p['detect_threshold'])
        txt += 'merge_thresh={}\n'.format(p['merge_thresh'])
        txt += 'freq_min={}\n'.format(p['freq_min'])
        txt += 'freq_max={}\n'.format(p['freq_max'])
        txt += 'pc_per_chan={}\n'.format(p['pc_per_chan'])
        txt += 'prm_template_name={}\n'.format(p['prm_template_name'])
        txt += 'fGpu={}\n'.format(p['fGpu'])
        _write_text_file(dataset_dir / 'argfile.txt', txt)
Exemple #2
0
def ironclust(*,
    recording, # Recording object
    tmpdir, # Temporary working directory
    detect_sign=-1, # Polarity of the spikes, -1, 0, or 1
    adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file
    detect_threshold=5, # Threshold for detection
    merge_thresh=.98, # Cluster merging threhold 0..1
    freq_min=300, # Lower frequency limit for band-pass filter
    freq_max=6000, # Upper frequency limit for band-pass filter
    pc_per_chan=3, # Number of pc per channel
    prm_template_name, # Name of the template file
    ironclust_src=None
):      
    if ironclust_src is None:
        ironclust_src=os.getenv('IRONCLUST_SRC',None)
    if not ironclust_src:
        raise Exception('You must either set the IRONCLUST_SRC environment variable, or pass the ironclust_src parameter')
    source_dir=os.path.dirname(os.path.realpath(__file__))

    dataset_dir=tmpdir+'/ironclust_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    si.MdaRecordingExtractor.writeRecording(recording=recording,save_path=dataset_dir)
        
    samplerate=recording.getSamplingFrequency()

    print('Reading timeseries header...')
    HH=mdaio.readmda_header(dataset_dir+'/raw.mda')
    num_channels=HH.dims[0]
    num_timepoints=HH.dims[1]
    duration_minutes=num_timepoints/samplerate/60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes))

    print('Creating .params file...')
    txt=''
    txt+='samplerate={}\n'.format(samplerate)
    txt+='detect_sign={}\n'.format(detect_sign)
    txt+='adjacency_radius={}\n'.format(adjacency_radius)
    txt+='detect_threshold={}\n'.format(detect_threshold)
    txt+='merge_thresh={}\n'.format(merge_thresh)
    txt+='freq_min={}\n'.format(freq_min)
    txt+='freq_max={}\n'.format(freq_max)    
    txt+='pc_per_chan={}\n'.format(pc_per_chan)
    txt+='prm_template_name={}\n'.format(prm_template_name)
    _write_text_file(dataset_dir+'/argfile.txt',txt)
        
    print('Running IronClust...')
    cmd_path = "addpath('{}', '{}/matlab', '{}/mdaio');".format(ironclust_src, ironclust_src, ironclust_src)
    #"p_ironclust('$(tempdir)','$timeseries$','$geom$','$prm$','$firings_true$','$firings_out$','$(argfile)');"
    cmd_call = "p_ironclust('{}', '{}', '{}', '', '', '{}', '{}');"\
        .format(tmpdir, dataset_dir+'/raw.mda', dataset_dir+'/geom.csv', tmpdir+'/firings.mda', dataset_dir+'/argfile.txt')
    cmd='matlab -nosplash -nodisplay -r "{} {} quit;"'.format(cmd_path, cmd_call)
    print(cmd)
    retcode=_run_command_and_print_output(cmd)

    if retcode != 0:
        raise Exception('IronClust returned a non-zero exit code')

    # parse output
    result_fname=tmpdir+'/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: '+ result_fname)
    
    firings=mdaio.readmda(result_fname)
    sorting=si.NumpySortingExtractor()
    sorting.setTimesLabels(firings[1,:],firings[2,:])
    return sorting
Exemple #3
0
def convert_array(*,
                  input,
                  output,
                  format='',
                  format_out='',
                  dimensions='',
                  dtype='',
                  dtype_out='',
                  channels=''):
    """
    Convert a multi-dimensional array between various formats ('.mda', '.npy', '.dat') based on the file extensions of the input/output files

    Parameters
    ----------
    input : INPUT
        Path of input array file (can be repeated for concatenation).
    output : OUTPUT
        Path of the output array file.
        
    format : string
        The format for the input array (mda, npy, dat), or determined from the file extension if empty
    format_out : string
        The format for the output input array (mda, npy, dat), or determined from the file extension if empty
    dimensions : string
        Comma-separated list of dimensions (shape). If empty, it is auto-determined, if possible, by the input array. If second dim is -1 then it will be extrapolated from file size / first dim.
    dtype : string
        The data format for the input array. Choices: int8, int16, int32, uint16, uint32, float32, float64 (possibly float16 in the future).
    dtype_out : string
        The data format for the output array. If empty, the dtype for the input array is used.
    channels : string
        Comma-seperated list of channels to keep in output. Zero-based indexing. Only works for .dat to .mda conversions.
    """
    if isinstance(input, (list, )):
        multifile = True
        inputs = input
        input = inputs[0]
    else:
        multifile = False

    format_in = format
    if not format_in:
        format_in = determine_file_format(file_extension(input), dimensions)
    if not format_out:
        format_out = determine_file_format(file_extension(output), dimensions)
    print('Input/output formats: {}/{}'.format(format_in, format_out))
    ext_in = file_extension(input)

    dims = None

    if (format_in == 'mda') and (dtype == ''):
        header = mdaio.readmda_header(input)
        dtype = header.dt
        dims = header.dims

    if (format_in == 'npy') and (dtype == ''):
        A = np.load(input, mmap_mode='r')
        dtype = npy_dtype_to_string(A.dtype)
        dims = A.shape
        A = 0

    if dimensions:
        dims2 = [int(entry) for entry in dimensions.split(',')]
        if dims:
            if len(dims) != len(dims2):
                raise Exception(
                    'Inconsistent number of dimensions for input array')
            if not np.all(np.array(dims) == np.array(dims2)):
                raise Exception('Inconsistent dimensions for input array')
        dims = dims2

    if not dtype_out:
        dtype_out = dtype

    if not dtype:
        raise Exception('Unable to determine datatype for input array')

    if not dtype_out:
        raise Exception('Unable to determine datatype for output array')

    if (dims[1] == -1) and (dims[0] > 0):
        if ((dtype) and (format_in == 'dat')):
            bits = int(
                dtype[-2:]
            )  # number of bits per entry of dtype, TODO: make this smarter
            if not multifile:
                filebytes = os.stat(input).st_size  # bytes in input file
            else:
                dims1 = np.copy(dims)
                filebytes1 = os.stat(input).st_size  # bytes in input file
                entries1 = int(filebytes1 / (int(bits / 8)))
                dims1[1] = int(entries1 / dims1[0])
                filebytes = sum([os.stat(inp).st_size for inp in inputs])
            entries = int(filebytes / (int(bits / 8)))  # entries in input file
            dims[1] = int(entries / dims[0])  # caclulated second dimension
            if DEBUG:
                print(bits)
                print(filebytes)
                print(int(filebytes / (int(bits / 8))))
                print(dims)
        else:
            raise Exception('Could not infer dimensions')

    if not dims:
        raise Exception('Unable to determine dimensions for input array')

    if not channels:
        channels = range(0, dims[0])
    else:
        channels = np.array([int(entry) for entry in channels.split(',')])

    if DEBUG:
        print(channels)

    print('Using dtype={}, dtype_out={}, dimensions={}'.format(
        dtype, dtype_out, ','.join(str(item) for item in dims)))
    if (format_in == format_out) and ((dtype == dtype_out) or
                                      (dtype_out == '')):
        if multifile and (format_in == 'dat'):
            print('Concatenating Files')
            with open(output, "wb") as outfile:
                for input_file in inputs:
                    with open(input_file, "rb") as inpt:
                        outfile.write(inpt.read())
        elif not multifile:
            print('Simply copying file...')
            shutil.copyfile(input, output)
            print('Done.')
        return True

    if format_out == 'dat' and not multifile:
        if format_in == 'mda':
            H = mdaio.readmda_header(input)
            copy_raw_file_data(input,
                               output,
                               start_byte=H.header_size,
                               num_entries=np.product(dims),
                               dtype=dtype,
                               dtype_out=dtype_out)
            return True
        elif format_in == 'npy':
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(input, mmap_mode='r').astype(dtype=dtype_out,
                                                     order='F',
                                                     copy=False)
            A = A.ravel(order='F')
            A.tofile(output)
            # The following was problematic because of row-major ordering, i think
            #header_size=determine_npy_header_size(input)
            #copy_raw_file_data(input,output,start_byte=header_size,num_entries=np.product(dims),dtype=dtype,dtype_out=dtype_out)
            return True
        elif format_in == 'dat':
            raise Exception('This case not yet implemented.')
        else:
            raise Exception('Unexpected case.')

    elif (format_out == 'mda') or (format_out == 'npy'):
        if format_in == 'dat' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            print(channels)  #DEBUG
            A = np.fromfile(inputs[0], dtype=dtype, count=np.product(dims))
            A = A.reshape(tuple(dims1), order='F')
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = np.fromfile(inputn, dtype=dtype, count=np.product(dims))
                dimsN = np.copy(dims1)
                dimsN[1] = An.size / dims1[0]
                An = An.reshape(tuple(dimsN), order='F')
                An = An[channels, :]
                print(A.shape)
                print(An.shape)
                A = np.concatenate((A, An), axis=1)
            print(A.shape)  #DEBUG
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'dat' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.fromfile(input, dtype=dtype, count=np.product(dims))
            A = A.reshape(tuple(dims), order='F')
            A = A[channels, :]
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'mda' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = mdaio.readmda(input)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'mda' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = mdaio.readmda(inputs[0])
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = mdaio.readmda(inputn)
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'npy' and not multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(input, mmap_mode='r').astype(dtype=dtype,
                                                     order='F',
                                                     copy=False)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        elif format_in == 'npy' and multifile:
            print(
                'Warning: loading entire array into memory. This should be avoided in the future.'
            )
            A = np.load(inputs[0], mmap_mode='r').astype(dtype=dtype,
                                                         order='F',
                                                         copy=False)
            A = A[channels, :]
            for inputn in inputs[1:]:
                An = np.load(inputn, mmap_mode='r').astype(dtype=dtype,
                                                           order='F',
                                                           copy=False)
                An = An[channels, :]
                A = np.concatenate((A, An), axis=0)
            if format_out == 'mda':
                mdaio.writemda(A, output, dtype=dtype_out)
            else:
                mdaio.writenpy(A, output, dtype=dtype_out)
            return True
        else:
            raise Exception('Unexpected case.')
    else:
        raise Exception('Unexpected output format: {}'.format(format_out))

    raise Exception('Unexpected error.')
Exemple #4
0
    def run(self):
        tmpdir=os.environ.get('ML_PROCESSOR_TEMPDIR')
        if not tmpdir:
            raise Exception('Environment variable not set: ML_PROCESSOR_TEMPDIR')
        
        source_dir=os.path.dirname(os.path.realpath(__file__))
        
        ## todo: link rather than copy
        print('Copying timeseries file: {} -> {}'.format(self.timeseries,tmpdir+'/raw.mda'))
        copyfile(self.timeseries,tmpdir+'/raw.mda')
        
        print('Reading timeseries header...')
        HH=mdaio.readmda_header(tmpdir+'/raw.mda')
        num_channels=HH.dims[0]
        num_timepoints=HH.dims[1]
        duration_minutes=num_timepoints/self.samplerate/60
        print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes))
        
        print('Creating .prb file...')
        prb_text=self._read_text_file(source_dir+'/template.prb')
        prb_text=prb_text.replace('$num_channels$','{}'.format(num_channels))
        prb_text=prb_text.replace('$radius$','{}'.format(self.adjacency_radius))
        geom=np.genfromtxt(self.geom, delimiter=',')
        geom_str='{\n'
        for m in range(geom.shape[0]):
            geom_str+='  {}: [{},{}],\n'.format(m,geom[m,0],geom[m,1]) # todo: handle 3d geom
        geom_str+='}'
        prb_text=prb_text.replace('$geometry$','{}'.format(geom_str))
        self._write_text_file(tmpdir+'/geometry.prb',prb_text)
        
        print('Creating .params file...')
        txt=self._read_text_file(source_dir+'/template.params')
        txt=txt.replace('$header_size$','{}'.format(HH.header_size))
        txt=txt.replace('$prb_file$',tmpdir+'/geometry.prb')
        txt=txt.replace('$dtype$',HH.dt)
        txt=txt.replace('$num_channels$','{}'.format(num_channels))
        txt=txt.replace('$samplerate$','{}'.format(self.samplerate))
        txt=txt.replace('$template_width_ms$','{}'.format(self.template_width_ms))
        txt=txt.replace('$spike_thresh$','{}'.format(self.spike_thresh))
        if self.detect_sign>0:
            peaks_str='positive'
        elif self.detect_sign<0:
            peaks_str='negative'
        else:
            peaks_str='both'
        txt=txt.replace('$peaks$',peaks_str)
        txt=txt.replace('$whitening_max_elts$','{}'.format(self.whitening_max_elts))
        txt=txt.replace('$clustering_max_elts$','{}'.format(self.clustering_max_elts))
        self._write_text_file(tmpdir+'/raw.params',txt)
        
        print('Running spyking circus...')
        #num_threads=np.maximum(1,int(os.cpu_count()/2))
        num_threads=1 # for some reason, using more than 1 thread causes an error
        cmd='spyking-circus {} -c {} '.format(tmpdir+'/raw.mda',num_threads)
        print(cmd)
        retcode=self._run_command_and_print_output(cmd)

        if retcode != 0:
            raise Exception('Spyking circus returned a non-zero exit code')

        result_fname=tmpdir+'/raw/raw.result.hdf5'
        if not os.path.exists(result_fname):
            raise Exception('Result file does not exist: '+result_fname)
        
        firings=sc_results_to_firings(result_fname)
        print(firings.shape)
        mdaio.writemda64(firings,self.firings_out)
        
        return True
def separateSpikesInEpochs(data_dir=None,
                           firings_file='firings.curated.mda',
                           timestamp_files=None,
                           write_separated_spikes=True):
    """
    Takes curated spikes from MountainSort and combines this information with spike timestamps to create separate curated spikes for each epoch

    :firings_file: Curated firings file
    :timestamp_files: Spike timestamps file list
    :write_separated_spikes: If the separated spikes should be written back to the data directory.
    :returns: List of spikes for each epoch
    """

    if data_dir is None:
        # Get the firings file
        data_dir = QtHelperUtils.get_directory(
            message="Select Curated firings location")

    separated_tetrodes = []
    curated_firings = []
    merged_curated_firings = []
    tetrode_list = os.listdir(data_dir)
    for tt_dir in tetrode_list:
        try:
            if firings_file in os.listdir(data_dir + '/' + tt_dir):
                curated_firings.append([])
                separated_tetrodes.append(tt_dir)
                firings_file_location = '/'.join(
                    [data_dir, tt_dir, firings_file])
                merged_curated_firings.append(
                    mdaio.readmda(firings_file_location))
                print(MODULE_IDENTIFIER +
                      'Read merged firings file for tetrode %s!' % tt_dir)
            else:
                print(MODULE_IDENTIFIER +
                      'Merged firings %s not  found for tetrode %s!' %
                      (firings_file, tt_dir))
        except (FileNotFoundError, IOError) as err:
            print(MODULE_IDENTIFIER +
                  'Unable to read merged firings file for tetrode %s!' %
                  tt_dir)
            print(err)

    gui_root = Tk()
    gui_root.withdraw()
    timestamp_headers = []
    if timestamp_files is None:
        # Read all the timestamp files
        timestamp_files = filedialog.askopenfilenames(initialdir=DEFAULT_SEARCH_PATH, \
                title="Select all timestamp files", \
                filetypes=(("Timestamps", ".mda"), ("All Files", "*.*")))
    gui_root.destroy()

    for ts_file in timestamp_files:
        timestamp_headers.append(mdaio.readmda_header(ts_file))

    # Now that we have both the timestamp headers and the timestamp files, we
    # can separate spikes out.  It is important here for the timestamp files to
    # be in the same order as the curated firings as that is the only way for
    # us to tell that the firings are being split up correctly.

    print(MODULE_IDENTIFIER + 'Looking at spike timestamps in order')
    print(timestamp_files)

    # First splice up curated spikes into indiviual epochs
    for tt_idx, tt_firings in enumerate(merged_curated_firings):
        for ep_idx, ts_header in enumerate(timestamp_headers):
            if tt_firings is None:
                curated_firings[tt_idx].append(None)
                continue

            n_data_points = ts_header.dims[0]
            print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx) + ': ' +
                  str(n_data_points) + ' samples.')
            last_spike_from_epoch = np.searchsorted(
                tt_firings[1], n_data_points, side='left') - 1

            # If there are no spikes in this epoch, there might still be some in future epochs!
            if last_spike_from_epoch < 0:
                tt_firings[1] = tt_firings[1] - float(n_data_points)
                curated_firings[tt_idx].append(None)
                continue

            last_spike_sample_number = tt_firings[1][last_spike_from_epoch]
            print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': First spike ' + str(tt_firings[1][0])\
                    + ', Last spike ' + str(last_spike_sample_number))
            epoch_spikes = tt_firings[:, :last_spike_from_epoch]
            curated_firings[tt_idx].append(epoch_spikes)

            if last_spike_from_epoch < (len(tt_firings[1]) - 1):
                # Slice the merged curated firing to only have the remaining spikes
                tt_firings = tt_firings[:, last_spike_from_epoch + 1:]
                print(MODULE_IDENTIFIER +
                      'Trimming curated spikes. Aggregate sample start ' +
                      str(tt_firings[1][0]))
                tt_firings[1] = tt_firings[1] - float(n_data_points)
                print(MODULE_IDENTIFIER + 'Sample number trimmed to ' +
                      str(tt_firings[1][0]))
            else:
                print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] +
                      ', Reached end of curated firings at Epoch ' +
                      str(ep_idx))
                tt_firings = None

    print(MODULE_IDENTIFIER +
          'Spikes separated in epochs. Substituting timestamps!')
    # For each epoch replace the sample numbers with the corresponding
    # timestamps. We are going through multiple revisions for this so that we
    # only have to load one timestamp file at a time
    for ep_idx, ts_file in enumerate(timestamp_files):
        epoch_timestamps = mdaio.readmda(ts_file)
        print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx))
        for tt_idx, tt_curated_firings in enumerate(curated_firings):
            if tt_curated_firings[ep_idx] is None:
                continue
            # Need to use the original array because changes to tt_curated_firings do not get copied back
            curated_firings[tt_idx][ep_idx][1] = epoch_timestamps[np.array(
                tt_curated_firings[ep_idx][1], dtype=int)]
            print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': Samples (' + \
                    str(tt_curated_firings[ep_idx][1][0]) + ', ' + str(tt_curated_firings[ep_idx][1][-1]), ')')

    if write_separated_spikes:
        try:
            for tt_idx, tet in enumerate(separated_tetrodes):
                for ep_idx in range(len(timestamp_files)):
                    if curated_firings[tt_idx][ep_idx] is not None:
                        ep_firings_file_name = data_dir + '/' + tet + '/firings-' + \
                                str(ep_idx+1) + '.curated.mda'
                        mdaio.writemda64(curated_firings[tt_idx][ep_idx],
                                         ep_firings_file_name)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                print(MODULE_IDENTIFIER +
                      'Unable to write timestamped firings!')
                print(exception)

    return curated_firings