def _setup_recording(self, recording, output_folder): from mountainlab_pytools import mdaio p = self.params if not check_if_installed(IronclustSorter.ironclust_path): raise ImportError(IronclustSorter.installation_mesg) dataset_dir = (output_folder / 'ironclust_dataset').absolute() if not dataset_dir.is_dir(): dataset_dir.mkdir() # Generate three files in the dataset directory: raw.mda, geom.csv, params.json se.MdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir) samplerate = recording.get_sampling_frequency() if self.debug: print('Reading timeseries header...') HH = mdaio.readmda_header(str(dataset_dir / 'raw.mda')) num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 if self.debug: print( 'Num. channels = {}, Num. timepoints = {}, duration = {} minutes' .format(num_channels, num_timepoints, duration_minutes)) if self.debug: print('Creating .params file...') txt = '' txt += 'samplerate={}\n'.format(samplerate) txt += 'detect_sign={}\n'.format(p['detect_sign']) txt += 'adjacency_radius={}\n'.format(p['adjacency_radius']) txt += 'detect_threshold={}\n'.format(p['detect_threshold']) txt += 'merge_thresh={}\n'.format(p['merge_thresh']) txt += 'freq_min={}\n'.format(p['freq_min']) txt += 'freq_max={}\n'.format(p['freq_max']) txt += 'pc_per_chan={}\n'.format(p['pc_per_chan']) txt += 'prm_template_name={}\n'.format(p['prm_template_name']) txt += 'fGpu={}\n'.format(p['fGpu']) _write_text_file(dataset_dir / 'argfile.txt', txt)
def ironclust(*, recording, # Recording object tmpdir, # Temporary working directory detect_sign=-1, # Polarity of the spikes, -1, 0, or 1 adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file detect_threshold=5, # Threshold for detection merge_thresh=.98, # Cluster merging threhold 0..1 freq_min=300, # Lower frequency limit for band-pass filter freq_max=6000, # Upper frequency limit for band-pass filter pc_per_chan=3, # Number of pc per channel prm_template_name, # Name of the template file ironclust_src=None ): if ironclust_src is None: ironclust_src=os.getenv('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('You must either set the IRONCLUST_SRC environment variable, or pass the ironclust_src parameter') source_dir=os.path.dirname(os.path.realpath(__file__)) dataset_dir=tmpdir+'/ironclust_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json si.MdaRecordingExtractor.writeRecording(recording=recording,save_path=dataset_dir) samplerate=recording.getSamplingFrequency() print('Reading timeseries header...') HH=mdaio.readmda_header(dataset_dir+'/raw.mda') num_channels=HH.dims[0] num_timepoints=HH.dims[1] duration_minutes=num_timepoints/samplerate/60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes)) print('Creating .params file...') txt='' txt+='samplerate={}\n'.format(samplerate) txt+='detect_sign={}\n'.format(detect_sign) txt+='adjacency_radius={}\n'.format(adjacency_radius) txt+='detect_threshold={}\n'.format(detect_threshold) txt+='merge_thresh={}\n'.format(merge_thresh) txt+='freq_min={}\n'.format(freq_min) txt+='freq_max={}\n'.format(freq_max) txt+='pc_per_chan={}\n'.format(pc_per_chan) txt+='prm_template_name={}\n'.format(prm_template_name) _write_text_file(dataset_dir+'/argfile.txt',txt) print('Running IronClust...') cmd_path = "addpath('{}', '{}/matlab', '{}/mdaio');".format(ironclust_src, ironclust_src, ironclust_src) #"p_ironclust('$(tempdir)','$timeseries$','$geom$','$prm$','$firings_true$','$firings_out$','$(argfile)');" cmd_call = "p_ironclust('{}', '{}', '{}', '', '', '{}', '{}');"\ .format(tmpdir, dataset_dir+'/raw.mda', dataset_dir+'/geom.csv', tmpdir+'/firings.mda', dataset_dir+'/argfile.txt') cmd='matlab -nosplash -nodisplay -r "{} {} quit;"'.format(cmd_path, cmd_call) print(cmd) retcode=_run_command_and_print_output(cmd) if retcode != 0: raise Exception('IronClust returned a non-zero exit code') # parse output result_fname=tmpdir+'/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: '+ result_fname) firings=mdaio.readmda(result_fname) sorting=si.NumpySortingExtractor() sorting.setTimesLabels(firings[1,:],firings[2,:]) return sorting
def convert_array(*, input, output, format='', format_out='', dimensions='', dtype='', dtype_out='', channels=''): """ Convert a multi-dimensional array between various formats ('.mda', '.npy', '.dat') based on the file extensions of the input/output files Parameters ---------- input : INPUT Path of input array file (can be repeated for concatenation). output : OUTPUT Path of the output array file. format : string The format for the input array (mda, npy, dat), or determined from the file extension if empty format_out : string The format for the output input array (mda, npy, dat), or determined from the file extension if empty dimensions : string Comma-separated list of dimensions (shape). If empty, it is auto-determined, if possible, by the input array. If second dim is -1 then it will be extrapolated from file size / first dim. dtype : string The data format for the input array. Choices: int8, int16, int32, uint16, uint32, float32, float64 (possibly float16 in the future). dtype_out : string The data format for the output array. If empty, the dtype for the input array is used. channels : string Comma-seperated list of channels to keep in output. Zero-based indexing. Only works for .dat to .mda conversions. """ if isinstance(input, (list, )): multifile = True inputs = input input = inputs[0] else: multifile = False format_in = format if not format_in: format_in = determine_file_format(file_extension(input), dimensions) if not format_out: format_out = determine_file_format(file_extension(output), dimensions) print('Input/output formats: {}/{}'.format(format_in, format_out)) ext_in = file_extension(input) dims = None if (format_in == 'mda') and (dtype == ''): header = mdaio.readmda_header(input) dtype = header.dt dims = header.dims if (format_in == 'npy') and (dtype == ''): A = np.load(input, mmap_mode='r') dtype = npy_dtype_to_string(A.dtype) dims = A.shape A = 0 if dimensions: dims2 = [int(entry) for entry in dimensions.split(',')] if dims: if len(dims) != len(dims2): raise Exception( 'Inconsistent number of dimensions for input array') if not np.all(np.array(dims) == np.array(dims2)): raise Exception('Inconsistent dimensions for input array') dims = dims2 if not dtype_out: dtype_out = dtype if not dtype: raise Exception('Unable to determine datatype for input array') if not dtype_out: raise Exception('Unable to determine datatype for output array') if (dims[1] == -1) and (dims[0] > 0): if ((dtype) and (format_in == 'dat')): bits = int( dtype[-2:] ) # number of bits per entry of dtype, TODO: make this smarter if not multifile: filebytes = os.stat(input).st_size # bytes in input file else: dims1 = np.copy(dims) filebytes1 = os.stat(input).st_size # bytes in input file entries1 = int(filebytes1 / (int(bits / 8))) dims1[1] = int(entries1 / dims1[0]) filebytes = sum([os.stat(inp).st_size for inp in inputs]) entries = int(filebytes / (int(bits / 8))) # entries in input file dims[1] = int(entries / dims[0]) # caclulated second dimension if DEBUG: print(bits) print(filebytes) print(int(filebytes / (int(bits / 8)))) print(dims) else: raise Exception('Could not infer dimensions') if not dims: raise Exception('Unable to determine dimensions for input array') if not channels: channels = range(0, dims[0]) else: channels = np.array([int(entry) for entry in channels.split(',')]) if DEBUG: print(channels) print('Using dtype={}, dtype_out={}, dimensions={}'.format( dtype, dtype_out, ','.join(str(item) for item in dims))) if (format_in == format_out) and ((dtype == dtype_out) or (dtype_out == '')): if multifile and (format_in == 'dat'): print('Concatenating Files') with open(output, "wb") as outfile: for input_file in inputs: with open(input_file, "rb") as inpt: outfile.write(inpt.read()) elif not multifile: print('Simply copying file...') shutil.copyfile(input, output) print('Done.') return True if format_out == 'dat' and not multifile: if format_in == 'mda': H = mdaio.readmda_header(input) copy_raw_file_data(input, output, start_byte=H.header_size, num_entries=np.product(dims), dtype=dtype, dtype_out=dtype_out) return True elif format_in == 'npy': print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.load(input, mmap_mode='r').astype(dtype=dtype_out, order='F', copy=False) A = A.ravel(order='F') A.tofile(output) # The following was problematic because of row-major ordering, i think #header_size=determine_npy_header_size(input) #copy_raw_file_data(input,output,start_byte=header_size,num_entries=np.product(dims),dtype=dtype,dtype_out=dtype_out) return True elif format_in == 'dat': raise Exception('This case not yet implemented.') else: raise Exception('Unexpected case.') elif (format_out == 'mda') or (format_out == 'npy'): if format_in == 'dat' and multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) print(channels) #DEBUG A = np.fromfile(inputs[0], dtype=dtype, count=np.product(dims)) A = A.reshape(tuple(dims1), order='F') A = A[channels, :] for inputn in inputs[1:]: An = np.fromfile(inputn, dtype=dtype, count=np.product(dims)) dimsN = np.copy(dims1) dimsN[1] = An.size / dims1[0] An = An.reshape(tuple(dimsN), order='F') An = An[channels, :] print(A.shape) print(An.shape) A = np.concatenate((A, An), axis=1) print(A.shape) #DEBUG if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'dat' and not multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.fromfile(input, dtype=dtype, count=np.product(dims)) A = A.reshape(tuple(dims), order='F') A = A[channels, :] if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'mda' and not multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = mdaio.readmda(input) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'mda' and multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = mdaio.readmda(inputs[0]) A = A[channels, :] for inputn in inputs[1:]: An = mdaio.readmda(inputn) An = An[channels, :] A = np.concatenate((A, An), axis=0) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'npy' and not multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.load(input, mmap_mode='r').astype(dtype=dtype, order='F', copy=False) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True elif format_in == 'npy' and multifile: print( 'Warning: loading entire array into memory. This should be avoided in the future.' ) A = np.load(inputs[0], mmap_mode='r').astype(dtype=dtype, order='F', copy=False) A = A[channels, :] for inputn in inputs[1:]: An = np.load(inputn, mmap_mode='r').astype(dtype=dtype, order='F', copy=False) An = An[channels, :] A = np.concatenate((A, An), axis=0) if format_out == 'mda': mdaio.writemda(A, output, dtype=dtype_out) else: mdaio.writenpy(A, output, dtype=dtype_out) return True else: raise Exception('Unexpected case.') else: raise Exception('Unexpected output format: {}'.format(format_out)) raise Exception('Unexpected error.')
def run(self): tmpdir=os.environ.get('ML_PROCESSOR_TEMPDIR') if not tmpdir: raise Exception('Environment variable not set: ML_PROCESSOR_TEMPDIR') source_dir=os.path.dirname(os.path.realpath(__file__)) ## todo: link rather than copy print('Copying timeseries file: {} -> {}'.format(self.timeseries,tmpdir+'/raw.mda')) copyfile(self.timeseries,tmpdir+'/raw.mda') print('Reading timeseries header...') HH=mdaio.readmda_header(tmpdir+'/raw.mda') num_channels=HH.dims[0] num_timepoints=HH.dims[1] duration_minutes=num_timepoints/self.samplerate/60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes)) print('Creating .prb file...') prb_text=self._read_text_file(source_dir+'/template.prb') prb_text=prb_text.replace('$num_channels$','{}'.format(num_channels)) prb_text=prb_text.replace('$radius$','{}'.format(self.adjacency_radius)) geom=np.genfromtxt(self.geom, delimiter=',') geom_str='{\n' for m in range(geom.shape[0]): geom_str+=' {}: [{},{}],\n'.format(m,geom[m,0],geom[m,1]) # todo: handle 3d geom geom_str+='}' prb_text=prb_text.replace('$geometry$','{}'.format(geom_str)) self._write_text_file(tmpdir+'/geometry.prb',prb_text) print('Creating .params file...') txt=self._read_text_file(source_dir+'/template.params') txt=txt.replace('$header_size$','{}'.format(HH.header_size)) txt=txt.replace('$prb_file$',tmpdir+'/geometry.prb') txt=txt.replace('$dtype$',HH.dt) txt=txt.replace('$num_channels$','{}'.format(num_channels)) txt=txt.replace('$samplerate$','{}'.format(self.samplerate)) txt=txt.replace('$template_width_ms$','{}'.format(self.template_width_ms)) txt=txt.replace('$spike_thresh$','{}'.format(self.spike_thresh)) if self.detect_sign>0: peaks_str='positive' elif self.detect_sign<0: peaks_str='negative' else: peaks_str='both' txt=txt.replace('$peaks$',peaks_str) txt=txt.replace('$whitening_max_elts$','{}'.format(self.whitening_max_elts)) txt=txt.replace('$clustering_max_elts$','{}'.format(self.clustering_max_elts)) self._write_text_file(tmpdir+'/raw.params',txt) print('Running spyking circus...') #num_threads=np.maximum(1,int(os.cpu_count()/2)) num_threads=1 # for some reason, using more than 1 thread causes an error cmd='spyking-circus {} -c {} '.format(tmpdir+'/raw.mda',num_threads) print(cmd) retcode=self._run_command_and_print_output(cmd) if retcode != 0: raise Exception('Spyking circus returned a non-zero exit code') result_fname=tmpdir+'/raw/raw.result.hdf5' if not os.path.exists(result_fname): raise Exception('Result file does not exist: '+result_fname) firings=sc_results_to_firings(result_fname) print(firings.shape) mdaio.writemda64(firings,self.firings_out) return True
def separateSpikesInEpochs(data_dir=None, firings_file='firings.curated.mda', timestamp_files=None, write_separated_spikes=True): """ Takes curated spikes from MountainSort and combines this information with spike timestamps to create separate curated spikes for each epoch :firings_file: Curated firings file :timestamp_files: Spike timestamps file list :write_separated_spikes: If the separated spikes should be written back to the data directory. :returns: List of spikes for each epoch """ if data_dir is None: # Get the firings file data_dir = QtHelperUtils.get_directory( message="Select Curated firings location") separated_tetrodes = [] curated_firings = [] merged_curated_firings = [] tetrode_list = os.listdir(data_dir) for tt_dir in tetrode_list: try: if firings_file in os.listdir(data_dir + '/' + tt_dir): curated_firings.append([]) separated_tetrodes.append(tt_dir) firings_file_location = '/'.join( [data_dir, tt_dir, firings_file]) merged_curated_firings.append( mdaio.readmda(firings_file_location)) print(MODULE_IDENTIFIER + 'Read merged firings file for tetrode %s!' % tt_dir) else: print(MODULE_IDENTIFIER + 'Merged firings %s not found for tetrode %s!' % (firings_file, tt_dir)) except (FileNotFoundError, IOError) as err: print(MODULE_IDENTIFIER + 'Unable to read merged firings file for tetrode %s!' % tt_dir) print(err) gui_root = Tk() gui_root.withdraw() timestamp_headers = [] if timestamp_files is None: # Read all the timestamp files timestamp_files = filedialog.askopenfilenames(initialdir=DEFAULT_SEARCH_PATH, \ title="Select all timestamp files", \ filetypes=(("Timestamps", ".mda"), ("All Files", "*.*"))) gui_root.destroy() for ts_file in timestamp_files: timestamp_headers.append(mdaio.readmda_header(ts_file)) # Now that we have both the timestamp headers and the timestamp files, we # can separate spikes out. It is important here for the timestamp files to # be in the same order as the curated firings as that is the only way for # us to tell that the firings are being split up correctly. print(MODULE_IDENTIFIER + 'Looking at spike timestamps in order') print(timestamp_files) # First splice up curated spikes into indiviual epochs for tt_idx, tt_firings in enumerate(merged_curated_firings): for ep_idx, ts_header in enumerate(timestamp_headers): if tt_firings is None: curated_firings[tt_idx].append(None) continue n_data_points = ts_header.dims[0] print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx) + ': ' + str(n_data_points) + ' samples.') last_spike_from_epoch = np.searchsorted( tt_firings[1], n_data_points, side='left') - 1 # If there are no spikes in this epoch, there might still be some in future epochs! if last_spike_from_epoch < 0: tt_firings[1] = tt_firings[1] - float(n_data_points) curated_firings[tt_idx].append(None) continue last_spike_sample_number = tt_firings[1][last_spike_from_epoch] print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': First spike ' + str(tt_firings[1][0])\ + ', Last spike ' + str(last_spike_sample_number)) epoch_spikes = tt_firings[:, :last_spike_from_epoch] curated_firings[tt_idx].append(epoch_spikes) if last_spike_from_epoch < (len(tt_firings[1]) - 1): # Slice the merged curated firing to only have the remaining spikes tt_firings = tt_firings[:, last_spike_from_epoch + 1:] print(MODULE_IDENTIFIER + 'Trimming curated spikes. Aggregate sample start ' + str(tt_firings[1][0])) tt_firings[1] = tt_firings[1] - float(n_data_points) print(MODULE_IDENTIFIER + 'Sample number trimmed to ' + str(tt_firings[1][0])) else: print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ', Reached end of curated firings at Epoch ' + str(ep_idx)) tt_firings = None print(MODULE_IDENTIFIER + 'Spikes separated in epochs. Substituting timestamps!') # For each epoch replace the sample numbers with the corresponding # timestamps. We are going through multiple revisions for this so that we # only have to load one timestamp file at a time for ep_idx, ts_file in enumerate(timestamp_files): epoch_timestamps = mdaio.readmda(ts_file) print(MODULE_IDENTIFIER + 'Epoch ' + str(ep_idx)) for tt_idx, tt_curated_firings in enumerate(curated_firings): if tt_curated_firings[ep_idx] is None: continue # Need to use the original array because changes to tt_curated_firings do not get copied back curated_firings[tt_idx][ep_idx][1] = epoch_timestamps[np.array( tt_curated_firings[ep_idx][1], dtype=int)] print(MODULE_IDENTIFIER + separated_tetrodes[tt_idx] + ': Samples (' + \ str(tt_curated_firings[ep_idx][1][0]) + ', ' + str(tt_curated_firings[ep_idx][1][-1]), ')') if write_separated_spikes: try: for tt_idx, tet in enumerate(separated_tetrodes): for ep_idx in range(len(timestamp_files)): if curated_firings[tt_idx][ep_idx] is not None: ep_firings_file_name = data_dir + '/' + tet + '/firings-' + \ str(ep_idx+1) + '.curated.mda' mdaio.writemda64(curated_firings[tt_idx][ep_idx], ep_firings_file_name) except OSError as exception: if exception.errno != errno.EEXIST: print(MODULE_IDENTIFIER + 'Unable to write timestamped firings!') print(exception) return curated_firings