def add_event_type(fd, id=None, evt=None): """fd is returned by `open_files`: it is a dict {type: tb_file_handle}.""" kwik = fd.get('kwik', None) # The KWIK needs to be there. assert kwik is not None if id is None: # If id is None, take the maximum integer index among the existing # recording names, + 1. event_types = sorted([n._v_name for n in kwik.listNodes('/event_types')]) if event_types: id = str(max([int(r) for r in event_types if r.isdigit()]) + 1) else: id = '0' event_type = kwik.createGroup('/event_types', id) kwik.createGroup(event_type, 'user_data') app = kwik.createGroup(event_type, 'application_data') kv = kwik.createGroup(app, 'klustaviewa') kv._f_setAttr('color', None) events = kwik.createGroup(event_type, 'events') kwik.createEArray(events, 'time_samples', tb.UInt64Atom(), (0,)) kwik.createEArray(events, 'recording', tb.UInt16Atom(), (0,)) kwik.createGroup(events, 'user_data')
def inizialize_dataset(): tab = tables.open_file(dsOut, mode='w') data_shape = (0, sizedb[0], sizedb[1], sizedb[2]) img_dtype = tables.UInt8Atom() label_dtype = tables.UInt64Atom() X_storage = tab.create_earray(tab.root, 'X', img_dtype, shape=data_shape) Y_storageID = tab.create_earray(tab.root, 'Y_ID', label_dtype, shape=(0, )) Y_desc = tab.create_earray(tab.root, 'desc', label_dtype, shape=(0, 6)) return X_storage, Y_storageID, Y_desc
def inizialize_dataset(): h5 = tables.open_file(dbOut_path, mode='w') data_shape = (0, size_embedding) img_dtype = tables.Float32Atom() label_dtype = tables.UInt64Atom() X_storage = h5.create_earray(h5.root, 'X', img_dtype, shape=data_shape) Y_storageID = h5.create_earray(h5.root, 'Y_ID', label_dtype, shape=(0, )) Y_desc = h5.create_earray(h5.root, 'desc', label_dtype, shape=(0, 6)) return X_storage, Y_storageID, Y_desc
def inizialize_dataset(): global X_storage, Y_storageID, desc_storage h5 = tables.open_file(db_path, mode='w') data_shape = (0, sizedb[0], sizedb[1], sizedb[2]) img_dtype = tables.UInt8Atom() label_dtype = tables.UInt64Atom() X_storage = h5.create_earray(h5.root, 'X', img_dtype, shape=data_shape) Y_storageID = h5.create_earray(h5.root, 'Y_ID', label_dtype, shape=(0, )) desc_storage = h5.create_earray(h5.root, 'desc', label_dtype, shape=(0, 6)) #video,frame,boundingbox
def initfile(h5name, ncsf, q_down, include_times=True): """ initializes a h5 file to store converted data """ adbitvolts = ncsf.header['ADBitVolts'] timestep = ncsf.timestep chname = ncsf.header['AcqEntName'] h5f = tables.open_file(h5name, 'w') h5f.create_group('/', 'data') h5f.create_earray('/data', 'rawdata', tables.Int16Atom(), [0]) h5f.root.data.rawdata.set_attr('ADBitVolts', adbitvolts) h5f.root.data.rawdata.set_attr('timestep', timestep) h5f.root.data.rawdata.set_attr('Q', q_down) h5f.root.data.rawdata.set_attr('AcqEntName', chname) if include_times: h5f.create_earray('/', 'time', tables.UInt64Atom(), [0]) return h5f
def _create_table(self, name, example): """ Create a new table within the HDF file, where the tables shape and its datatype are determined by *example*. """ type_map = { np.dtype(np.float64): tables.Float64Atom(), np.dtype(np.float32): tables.Float32Atom(), np.dtype(np.int): tables.Int64Atom(), np.dtype(np.int8): tables.Int8Atom(), np.dtype(np.uint8): tables.UInt8Atom(), np.dtype(np.int16): tables.Int16Atom(), np.dtype(np.uint16): tables.UInt16Atom(), np.dtype(np.int64): tables.UInt64Atom(), np.dtype(np.int32): tables.Int32Atom(), np.dtype(np.uint32): tables.UInt32Atom(), np.dtype(np.bool): tables.BoolAtom(), } try: if type(example) == np.ndarray: h5type = type_map[example.dtype] elif type(example) == str: h5type = tables.VLStringAtom() except KeyError: raise TypeError( "Could not create table %s because of unknown dtype '%s'" % (name, example.dtype)) #+ ", of name: " % example.shape) if type(example) == np.ndarray: h5dim = (0, ) + example.shape h5 = self.h5 filters = tables.Filters(complevel=self.compression_level, complib='zlib', shuffle=True) nodes = h5.list_nodes(h5.root) nmpt = name.replace('.', '/\n') nmpt = nmpt.split('\n') path = '/' for kay in range(len(nmpt) - 1): #if not path+nmpt[kay][:-1] in str(nodes): h5.create_group(path,nmpt[kay][:-1]) try: h5.is_visible_node(path + nmpt[kay][:-1]) except: h5.create_group(path, nmpt[kay][:-1]) path += nmpt[kay] self.tables[name] = h5.create_earray(path, nmpt[-1], h5type, h5dim, filters=filters) elif type(example) == str: h5 = self.h5 filters = tables.Filters(complevel=self.compression_level, complib='zlib', shuffle=True) nodes = h5.list_nodes(h5.root) nmpt = name.replace('.', '/\n') nmpt = nmpt.split('\n') path = '/' for kay in range(len(nmpt) - 1): #if not path+nmpt[kay][:-1] in str(nodes): h5.create_group(path,nmpt[kay][:-1]) try: h5.is_visible_node(path + nmpt[kay][:-1]) except: h5.create_group(path, nmpt[kay][:-1]) path += nmpt[kay] self.tables[name] = h5.create_vlarray(path, nmpt[-1], h5type, filters=filters) self.types[name] = type(example)
def create_kwik(path, experiment_name=None, prm=None, prb=None, overwrite=True): """Create a KWIK file. Arguments: * path: path to the .kwik file. * experiment_name * prm: a dictionary representing the contents of the PRM file (used for SpikeDetekt) * prb: a dictionary with the contents of the PRB file """ if experiment_name is None: experiment_name = '' if prm is None: prm = {} if prb is None: prb = {} if not overwrite and os.path.exists(path): return file = tb.openFile(path, mode='w') file.root._f_setAttr('kwik_version', 2) file.root._f_setAttr('name', experiment_name) file.createGroup('/', 'application_data') # Set the SpikeDetekt parameters file.createGroup('/application_data', 'spikedetekt') for prm_name, prm_value in iteritems(prm): file.root.application_data.spikedetekt._f_setAttr(prm_name, prm_value) file.createGroup('/', 'user_data') # Create channel groups. file.createGroup('/', 'channel_groups') for igroup, group_info in prb.iteritems(): igroup = int(igroup) group = file.createGroup('/channel_groups', str(igroup)) # group_info: channel, graph, geometry group._f_setAttr('name', 'channel_group_{0:d}'.format(igroup)) group._f_setAttr('adjacency_graph', np.array(group_info.get('graph', np.zeros((0, 2))), dtype=np.int32)) file.createGroup(group, 'application_data') file.createGroup(group, 'user_data') # Create channels. file.createGroup(group, 'channels') channels = group_info.get('channels', []) # Add the channel order. group._f_setAttr('channel_order', np.array(channels, dtype=np.int32)) for channel_idx in channels: # channel is the absolute channel index. channel = file.createGroup(group.channels, str(channel_idx)) channel._f_setAttr('name', 'channel_{0:d}'.format(channel_idx)) channel._f_setAttr('ignored', False) # "channels" only contains # not-ignored channels here pos = group_info.get('geometry', {}). \ get(channel_idx, None) if pos is not None: pos = np.array(pos, dtype=np.float32) channel._f_setAttr('position', pos) channel._f_setAttr('voltage_gain', prm.get('voltage_gain', 0.)) channel._f_setAttr('display_threshold', 0.) file.createGroup(channel, 'application_data') file.createGroup(channel.application_data, 'spikedetekt') file.createGroup(channel.application_data, 'klustaviewa') file.createGroup(channel, 'user_data') # Create spikes. spikes = file.createGroup(group, 'spikes') file.createEArray(spikes, 'time_samples', tb.UInt64Atom(), (0,), expectedrows=1000000) file.createEArray(spikes, 'time_fractional', tb.UInt8Atom(), (0,), expectedrows=1000000) file.createEArray(spikes, 'recording', tb.UInt16Atom(), (0,), expectedrows=1000000) clusters = file.createGroup(spikes, 'clusters') file.createEArray(clusters, 'main', tb.UInt32Atom(), (0,), expectedrows=1000000) file.createEArray(clusters, 'original', tb.UInt32Atom(), (0,), expectedrows=1000000) fm = file.createGroup(spikes, 'features_masks') fm._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/features_masks'. \ format(igroup)) wr = file.createGroup(spikes, 'waveforms_raw') wr._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_raw'. \ format(igroup)) wf = file.createGroup(spikes, 'waveforms_filtered') wf._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_filtered'. \ format(igroup)) # Create clusters. clusters = file.createGroup(group, 'clusters') file.createGroup(clusters, 'main') file.createGroup(clusters, 'original') # Create cluster groups. cluster_groups = file.createGroup(group, 'cluster_groups') file.createGroup(cluster_groups, 'main') file.createGroup(cluster_groups, 'original') # Create recordings. file.createGroup('/', 'recordings') # Create event types. file.createGroup('/', 'event_types') file.close()
def compress_twix(infile, outfile, remove_os=False, cc_mode=False, ncc=None, cc_tol=0.05, zfp=False, zfp_tol=1e-5, zfp_prec=None, rm_fidnav=False): with suppress_stdout_stderr(): twix = twixtools.read_twix(infile) filters = tables.Filters(complevel=5, complib='zlib') # lossless compression settings #filters = None mtx = None noise_mtx = None noise_dmtx = None if cc_mode or zfp: # # calibrate noise decorrelation matrix for better compression # noise = list() # for mdb in twix[1]['mdb']: # if mdb.is_flag_set('NOISEADJSCAN'): # noise.append(mdb.data) # if len(noise)>0: # noise_dmtx, noise_mtx = calculate_prewhitening(np.asarray(noise).swapaxes(0,1)) # del(noise) pass if cc_mode: # calibrate coil compression based on last scan in list (image scan) # use the calibration coil weights for all data that fits cal_data = get_cal_data(twix[-1], remove_os) if cc_mode == 'scc' or cc_mode == 'gcc': mtx, ncc = calibrate_mtx(cal_data, cc_mode, ncc, cc_tol) del (cal_data) print('coil compression from %d channels to %d virtual channels' % (mtx.shape[-1], ncc)) else: mtx = calibrate_mtx_bart(cal_data, cc_mode) del (cal_data) if ncc is None: # set default ncc = mtx.shape[-1] // 2 print('coil compression from %d channels to %d virtual channels' % (mtx.shape[-1], ncc)) t_start = time.time() with tables.open_file(outfile, mode="w") as f: f.root._v_attrs.original_filename = os.path.basename(infile) f.root._v_attrs.cc_mode = cc_mode f.root._v_attrs.ncc = ncc f.root._v_attrs.zfp = zfp if zfp_tol is None: f.root._v_attrs.zfp_tol = -1 else: f.root._v_attrs.zfp_tol = zfp_tol if zfp_prec is None: f.root._v_attrs.zfp_prec = -1 else: f.root._v_attrs.zfp_prec = zfp_prec f.create_carray(f.root, "multi_header", obj=np.frombuffer(twix[0].tobytes(), 'S1'), filters=filters) if mtx is not None: # save mtx for coil compression f.create_carray(f.root, "mtx", obj=mtx, filters=filters) if noise_dmtx is not None: f.create_carray(f.root, "noise_dmtx", obj=noise_dmtx, filters=filters) f.create_carray(f.root, "noise_mtx", obj=noise_mtx, filters=filters) scanlist = [] for meas_key, meas in enumerate(twix[1:]): scanlist.append("scan%d" % (meas_key)) grp = f.create_group("/", "scan%d" % (meas_key)) f.create_carray(grp, "hdr_str", obj=meas['hdr_str'], filters=filters) # remove fidnav scans if necessary if rm_fidnav: for mdb_key, mdb in enumerate(meas['mdb']): if mdb.is_flag_set('noname60'): del (meas['mdb'][mdb_key]) mdh_count = len(meas['mdb']) # create info array with mdh, coil & compression information f.create_carray(grp, "info", shape=[mdh_count, datinfo_type.itemsize], atom=tables.UInt8Atom(), filters=filters) dt = tables.UInt64Atom(shape=()) if zfp: f.create_vlarray(grp, "DATA", atom=dt, expectedrows=mdh_count) else: f.create_vlarray(grp, "DATA", atom=dt, filters=filters, expectedrows=mdh_count) syncscans = 0 for mdb_key, mdb in enumerate(meas['mdb']): info = np.zeros(1, dtype=datinfo_type)[0] is_syncscan = mdb.is_flag_set('SYNCDATA') if rm_fidnav: # we have to update the scan counters if not is_syncscan: mdb.mdh[ 'ulScanCounter'] = mdb_key + 1 - syncscans # scanCounter starts at 1 else: syncscans += 1 # store mdh info['mdh_info'] = mdb.mdh if is_syncscan or mdb.is_flag_set('ACQEND'): data = np.ascontiguousarray(mdb.data).view('uint64') else: restrictions = get_restrictions(mdb.get_flags()) if restrictions == 'NO_COILCOMP': data, info['rm_os_active'], _ = reduce_data( mdb.data, mdb.mdh, remove_os, cc_mode=False) else: data, info['rm_os_active'], info[ 'cc_active'] = reduce_data(mdb.data, mdb.mdh, remove_os, cc_mode=cc_mode, mtx=mtx, ncc=ncc) data = data.flatten() if zfp: data = pyzfp.compress(data.view('float32'), tolerance=zfp_tol, precision=zfp_prec, parallel=True) data = np.frombuffer(data, dtype='uint64') else: data = data.view('uint64') if len(mdb.channel_hdr) > 0: mdb.channel_hdr[0]['ulScanCounter'] = mdb.mdh[ 'ulScanCounter'] info['coil_info'] = mdb.channel_hdr[0] coil_list = np.asarray( [item['ulChannelId'] for item in mdb.channel_hdr], dtype='uint8') info['coil_list'][:len(coil_list)] = coil_list # write data grp.DATA.append(data) grp.info[mdb_key] = np.frombuffer(info, dtype='uint8') f.root._v_attrs.scanlist = scanlist # from joblib import Parallel, delayed # Parallel(n_jobs=2)(delayed(task)(mdb_key, mdb, is_byte, count, grp, remove_os, zfp, zfp_tol, zfp_prec, mtx) for mdb_key, (mdb, is_byte, count) in enumerate(zip(meas['mdb'], is_bytearray, data_counter))) elapsed_time = (time.time() - t_start) print("compression finished in %d:%02d:%02d h" % (elapsed_time // 3600, (elapsed_time % 3600) // 60, elapsed_time % 60)) print("compression factor = %.2f" % (os.path.getsize(infile) / os.path.getsize(outfile)))
def init_h5file(self): file, curr_dir = self.get_new_file_name() self.settings.child('acquisition', 'temp_file').setValue(file+'.h5') self.h5file = tables.open_file(os.path.join(curr_dir, file+'.h5'), mode='w') h5group = self.h5file.root h5group._v_attrs['settings'] = customparameter.parameter_to_xml_string(self.settings) h5group._v_attrs.type = 'detector' h5group._v_attrs['format_name'] = 'timestamps' channels_index = [self.channels_enabled[k]['index'] for k in self.channels_enabled.keys() if self.channels_enabled[k]['enabled']] self.marker_array = self.h5file.create_earray(self.h5file.root, 'markers', tables.UInt8Atom(), (0,), title='markers') self.marker_array._v_attrs['data_type'] = '1D' self.marker_array._v_attrs['type'] = 'tttr_data' self.nanotimes_array = self.h5file.create_earray(self.h5file.root, 'nanotimes', tables.UInt16Atom(), (0,), title='nanotimes') self.nanotimes_array._v_attrs['data_type'] = '1D' self.nanotimes_array._v_attrs['type'] = 'tttr_data' self.timestamp_array = self.h5file.create_earray(self.h5file.root, 'timestamps', tables.UInt64Atom(), (0,), title='timestamps') self.timestamp_array._v_attrs['data_type'] = '1D' self.timestamp_array._v_attrs['type'] = 'tttr_data'
def make_h5(obj, handle, objname): """given a vector or dictionary of objects associated with the set of DHS sites, save the objects in the hdf5 format used by other scripts. Objects currently handled: UInt8 UInt64 Float64 """ if type(obj) == np.ndarray: file = tables.openFile( '/mnt/lustre/home/anilraj/linspec/cache/dhslocations.h5', 'r') locdata = dict([(chr, file.getNode('/' + chr)) for chr in chromosomes]) locations = [(chr, l) for chr in chromosomes for l in locdata[chr].start[:]] file.close() data = dict() for v, loc in zip(obj, locations): try: data[loc[0]].append(v) except KeyError: data[loc[0]] = [v] for k, v in data.iteritems(): data[k] = np.array(v).astype(obj.dtype) # selecting an appropriate atom if obj.dtype == np.float64: atom = tables.Float64Atom() elif obj.dtype == np.float32: atom = tables.Float32Atom() elif obj.dtype == np.int8: atom = tables.UInt8Atom() elif obj.dtype == np.int64: atom = tables.UInt64Atom() else: data = dict() for chr, vals in obj.iteritems(): data[chr] = vals # selecting an appropriate atom if obj[chr].dtype == np.float64: atom = tables.Float64Atom() elif obj[chr].dtype == np.float32: atom = tables.Float32Atom elif obj[chr].dtype == np.int8: atom = tables.UInt8Atom() elif obj[chr].dtype == np.int64: atom = tables.UInt64Atom() filters = tables.Filters(complevel=5, complib='zlib') for chr, dat in data.iteritems(): chrgroup = handle.createGroup(handle.root, chr, chr) values = handle.createCArray(chrgroup, objname, atom, dat.shape, filters=filters) values[:] = dat[:] return handle
def get_PS_waveform(fn_mwk, fn_nev, fn_out, movie_begin_fname=None, n_samples=DEFAULT_SAMPLES_PER_SPK, n_max_spks=DEFAULT_MAX_SPKS, **kwargs): """Get waveform data around stimuli presented for later spike sorting. This will give completely different output file format. NOTE: this function is memory intensive! Will require approximately as much memory as the size of the files.""" # -- some housekeeping things... kwargs['verbose'] = 2 kwargs['only_new_t'] = True t_start0 = kwargs['t_start0'] t_stop0 = kwargs['t_stop0'] iid2idx = {} idx2iid = [] ch2idx = {} idx2ch = [] n_spks = 0 # does "n_spks_lim" reach "n_max_spks"? b_warn_max_spks_lim = False # list of image presentations without spikes l_empty_spks = [] for info in getspk(fn_mwk, fn_nev=fn_nev, **kwargs): # -- get the metadata. this must be called before other clauses if info['type'] == 'preamble': actvelecs = info['actvelecs'] t_adjust = info['t_adjust'] chn_info = info['chn_info'] n_spks_lim = min(info['n_packets'], n_max_spks) print '* n_spks_lim =', n_spks_lim for ch in sorted(actvelecs): makeavail(ch, ch2idx, idx2ch) # Data for snippets === # Msnp: snippet data # Msnp_tabs: when it spiked (absolute time) # Msnp_ch: which channel ID spiked? # Msnp_pos: corresponding file position Msnp = np.empty((n_spks_lim, n_samples), dtype='int16') Msnp_tabs = np.empty(n_spks_lim, dtype='uint64') Msnp_ch = np.empty(n_spks_lim, dtype='uint32') Msnp_pos = np.empty(n_spks_lim, dtype='uint64') # Data for images === # Mimg: image indices in the order of presentations # Mimg_tabs: image onset time (absolute) Mimg = [] Mimg_tabs = [] # -- do some housekeeping things once per each img if info['type'] == 'begin': t_abs = info['t_imgonset'] iid = info['imgid'] i_img = info['i_img'] makeavail(iid, iid2idx, idx2iid) Mimg.append(iid2idx[iid]) Mimg_tabs.append(t_abs) b_no_spks = True # process movie if requested if movie_begin_fname is not None: raise NotImplementedError('Movies are not supported yet.') # -- put actual spiking info elif info['type'] == 'spike': wav = info['wavinfo']['waveform'] t_abs = info['t_abs'] i_ch = ch2idx[info['ch']] pos = info['pos'] Msnp[n_spks] = wav Msnp_tabs[n_spks] = t_abs Msnp_ch[n_spks] = i_ch Msnp_pos[n_spks] = pos b_no_spks = False n_spks += 1 if n_spks >= n_spks_lim: warnings.warn('n_spks exceedes n_spks_lim! ' 'Aborting further additions.') b_warn_max_spks_lim = True break elif info['type'] == 'end': if not b_no_spks: continue # if there's no spike at all, list the stim warnings.warn('No spikes are there! ') l_empty_spks.append(i_img) # -- done! # finished calculation.... Msnp = Msnp[:n_spks] Msnp_tabs = Msnp_tabs[:n_spks] Msnp_ch = Msnp_ch[:n_spks] Msnp_pos = Msnp_pos[:n_spks] Mimg = np.array(Mimg, dtype='uint32') Mimg_tabs = np.array(Mimg_tabs, dtype='uint64') filters = tbl.Filters(complevel=4, complib='blosc') t_int16 = tbl.Int16Atom() t_uint32 = tbl.UInt32Atom() t_uint64 = tbl.UInt64Atom() h5o = tbl.openFile(fn_out, 'w') CMsnp = h5o.createCArray(h5o.root, 'Msnp', t_int16, Msnp.shape, filters=filters) CMsnp_tabs = h5o.createCArray(h5o.root, 'Msnp_tabs', t_uint64, Msnp_tabs.shape, filters=filters) CMsnp_ch = h5o.createCArray(h5o.root, 'Msnp_ch', t_uint32, Msnp_ch.shape, filters=filters) CMsnp_pos = h5o.createCArray(h5o.root, 'Msnp_pos', t_uint64, Msnp_pos.shape, filters=filters) CMsnp[...] = Msnp CMsnp_tabs[...] = Msnp_tabs CMsnp_ch[...] = Msnp_ch CMsnp_pos[...] = Msnp_pos h5o.createArray(h5o.root, 'Mimg', Mimg) h5o.createArray(h5o.root, 'Mimg_tabs', Mimg_tabs) meta = h5o.createGroup('/', 'meta', 'Metadata') h5o.createArray(meta, 't_start0', t_start0) h5o.createArray(meta, 't_stop0', t_stop0) h5o.createArray(meta, 't_adjust', t_adjust) h5o.createArray(meta, 'chn_info_pk', pk.dumps(chn_info)) h5o.createArray(meta, 'kwargs_pk', pk.dumps(kwargs)) h5o.createArray(meta, 'idx2iid', idx2iid) h5o.createArray(meta, 'iid2idx_pk', pk.dumps(iid2idx)) h5o.createArray(meta, 'idx2ch', idx2ch) h5o.createArray(meta, 'ch2idx_pk', pk.dumps(ch2idx)) # some error signals h5o.createArray(meta, 'b_warn_max_spks_lim', b_warn_max_spks_lim) if len(l_empty_spks) > 0: h5o.createArray(meta, 'l_empty_spks', l_empty_spks) h5o.close()
def cluster(fn_inp, fn_out, opts): config = {} config['skimspk'] = {} config['skimspk']['tb0'] = SKIMSPK_TB config['skimspk']['te0'] = SKIMSPK_TE config['extract'] = {} config['extract']['nperimg'] = EXTRACT_NPERIMG config['extract']['nmax'] = EXTRACT_NMAX config['feat'] = {} config['feat']['kssort'] = FEAT_KSSORT config['feat']['outdim'] = FEAT_OUTDIM config['cluster'] = {} config['cluster']['metd'] = CLUSTERING_ALG config['cluster']['commonp'] = AFFINITYPRP_COMMONP config['qc'] = {} config['qc']['qc'] = QC config['qc']['kwargs'] = {} config['qc']['kwargs']['min_snr'] = QC_MINSNR config['qc']['kwargs']['ks_plevel'] = QC_KS_PLEVEL config['qc']['kwargs']['min_size'] = QC_MINSIZE config['nn'] = {} config['nn']['nneigh'] = NN_NNEIGH config['nn']['radius'] = NN_RADIUS n_jobs = NCPU reference = None # -- process opts if 'njobs' in opts: n_jobs = int(opts['njobs']) print '* n_jobs =', n_jobs if 'ref' in opts: reference = opts['ref'] print '* Using the reference file:', reference h5r = tbl.openFile(reference) config = pk.loads(h5r.root.meta.config_clu_pk.read()) # TODO: implement other options!!! # -- preps print '-> Initializing...' h5 = tbl.openFile(fn_inp) Msnp = h5.root.Msnp.read() Msnp_feat = h5.root.Msnp_feat.read() Msnp_ch = h5.root.Msnp_ch.read() Msnp_pos = h5.root.Msnp_pos.read() Msnp_tabs = h5.root.Msnp_tabs.read() Msnp_selected = h5.root.Msnp_selected.read() Mimg = h5.root.Mimg.read() Mimg_tabs = h5.root.Mimg_tabs.read() t_adjust = h5.root.meta.t_adjust.read() t_start0 = h5.root.meta.t_start0.read() t_stop0 = h5.root.meta.t_stop0.read() idx2iid = h5.root.meta.idx2iid.read() iid2idx_pk = h5.root.meta.iid2idx_pk.read() idx2ch = h5.root.meta.idx2ch.read() ch2idx_pk = h5.root.meta.ch2idx_pk.read() all_chs = range(len(idx2ch)) if reference is None: # -- get training examples... print '-> Collecting snippet examples...' ibie, iuimg = skim_imgs(Mimg, Mimg_tabs, Msnp_tabs, t_adjust, **config['skimspk']) clu_feat_train = get_example_spikes(Msnp_feat, Msnp_ch, ibie, all_chs, **config['extract']) clu_train = get_example_spikes(Msnp, Msnp_ch, ibie, all_chs, **config['extract']) # -- get feature indices to use... print '-> Finding useful axes...' outdim = config['feat']['outdim'] if config['feat']['kssort']: Msnp_feat_use = [] for i_ch in xrange(len(all_chs)): # get deviations from Gaussian devs, _ = KS_all(clu_feat_train[i_ch]) # got top-n deviations devs = np.argsort(-devs)[:outdim] Msnp_feat_use.append(devs) else: Msnp_feat_use = [range(outdim)] * len(all_chs) Msnp_feat_use = np.array(Msnp_feat_use) # -- XXX: DEBUG SUPPORT __DBG__ = False if __DBG__: clu_feat_train = clu_feat_train[:4] clu_train = clu_train[:4] Msnp_feat_use = Msnp_feat_use[:4] all_chs = all_chs[:4] # -- get clusters... print '-> Clustering...' clu_labels, clu_ulabels, clu_nclus, clu_centers = \ find_clusters_par(clu_feat_train, Msnp_feat_use, n_jobs=n_jobs, **config['cluster']) # -- quality control if config['qc']['qc']: print '-> Run signal quality-based screening...' clu_sig_q = quality_meas_par(clu_train, clu_labels) quality_ctrl_par(clu_sig_q, clu_labels, clu_ulabels, clu_nclus, clu_centers, n_jobs=n_jobs, **config['qc']['kwargs']) else: # -- Bypass all and get the pre-computed clustering template print '-> Loading reference data...' clu_feat_train = pk.loads(h5r.root.clu_pk.clu_feat_train_pk.read()) clu_train = pk.loads(h5r.root.clu_pk.clu_train_pk.read()) clu_labels = pk.loads(h5r.root.clu_pk.clu_labels_pk.read()) clu_ulabels = pk.loads(h5r.root.clu_pk.clu_ulabels_pk.read()) clu_nclus = pk.loads(h5r.root.clu_pk.clu_nclus_pk.read()) clu_centers = pk.loads(h5r.root.clu_pk.clu_centers_pk.read()) clu_sig_q = pk.loads(h5r.root.clu_pk.clu_sig_q_pk.read()) Msnp_feat_use = h5r.root.Msnp_feat_use.read() h5r.close() # -- NN search print '-> Template matching...' Msnp_cid = nearest_neighbor_par(clu_feat_train, Msnp_feat, Msnp_feat_use, Msnp_ch, clu_labels, clu_ulabels, all_chs, n_jobs=n_jobs, **config['nn']) # -- final quality report... print '-> Computing the final signal quality report...' clu_sig_q = quality_meas_par2(Msnp, Msnp_ch, Msnp_cid, all_chs, n_jobs=n_jobs) # ... and update clu_ulabels, clu_nclus clu_ulabels = [np.array(clu_sig_q_ch.keys(), dtype='int') for clu_sig_q_ch in clu_sig_q] clu_nclus = [len(clu_sig_q_ch.keys()) for clu_sig_q_ch in clu_sig_q] # -- done! write everything... print '-> Writing results...' filters = tbl.Filters(complevel=4, complib='blosc') t_int16 = tbl.Int16Atom() t_uint32 = tbl.UInt32Atom() t_uint64 = tbl.UInt64Atom() h5o = tbl.openFile(fn_out, 'w') CMsnp_tabs = h5o.createCArray(h5o.root, 'Msnp_tabs', t_uint64, Msnp_tabs.shape, filters=filters) CMsnp_ch = h5o.createCArray(h5o.root, 'Msnp_ch', t_uint32, Msnp_ch.shape, filters=filters) CMsnp_cid = h5o.createCArray(h5o.root, 'Msnp_cid', t_int16, Msnp_cid.shape, filters=filters) CMsnp_pos = h5o.createCArray(h5o.root, 'Msnp_pos', t_uint64, Msnp_pos.shape, filters=filters) CMsnp_feat_use = h5o.createCArray(h5o.root, 'Msnp_feat_use', t_uint64, Msnp_feat_use.shape, filters=filters) CMsnp_selected = h5o.createCArray(h5o.root, 'Msnp_selected', t_uint64, Msnp_selected.shape, filters=filters) CMsnp_tabs[...] = Msnp_tabs CMsnp_ch[...] = Msnp_ch CMsnp_cid[...] = Msnp_cid CMsnp_pos[...] = Msnp_pos CMsnp_feat_use[...] = Msnp_feat_use CMsnp_selected[...] = Msnp_selected h5o.createArray(h5o.root, 'Mimg', Mimg) h5o.createArray(h5o.root, 'Mimg_tabs', Mimg_tabs) meta = h5o.createGroup('/', 'meta', 'Metadata') h5o.createArray(meta, 't_start0', t_start0) h5o.createArray(meta, 't_stop0', t_stop0) h5o.createArray(meta, 't_adjust', t_adjust) h5o.createArray(meta, 'config_clu_pk', pk.dumps(config)) h5o.createArray(meta, 'idx2iid', idx2iid) h5o.createArray(meta, 'iid2idx_pk', iid2idx_pk) h5o.createArray(meta, 'idx2ch', idx2ch) h5o.createArray(meta, 'ch2idx_pk', ch2idx_pk) h5o.createArray(meta, 'fn_inp', fn_inp) clupk = h5o.createGroup('/', 'clu_pk', 'Pickles') h5o.createArray(clupk, 'clu_feat_train_pk', pk.dumps(clu_feat_train)) h5o.createArray(clupk, 'clu_train_pk', pk.dumps(clu_train)) h5o.createArray(clupk, 'clu_labels_pk', pk.dumps(clu_labels)) h5o.createArray(clupk, 'clu_ulabels_pk', pk.dumps(clu_ulabels)) h5o.createArray(clupk, 'clu_nclus_pk', pk.dumps(clu_nclus)) h5o.createArray(clupk, 'clu_centers_pk', pk.dumps(clu_centers)) h5o.createArray(clupk, 'clu_sig_q_pk', pk.dumps(clu_sig_q)) h5o.close() h5.close()
def get_features(fn_inp, fn_out, opts): config = {} config['rethreshold_mult'] = RETHRESHOLD_MULT config['align'] = {} config['align']['subsmp'] = ALIGN_SUBSMP config['align']['maxdt'] = ALIGN_MAXDT config['align']['peakloc'] = ALIGN_PEAKLOC config['align']['findbwd'] = ALIGN_FINDBWD config['align']['findfwd'] = ALIGN_FINDFWD config['align']['cutat'] = ALIGN_CUTAT config['align']['outdim'] = ALIGN_OUTDIM config['align']['peakfunc'] = ALIGN_PEAKFUNC config['feat'] = {} config['feat']['metd'] = FEAT_METHOD config['feat']['kwargs'] = {'level': FEAT_WAVL_LEV} n_jobs = NCPU # -- process opts if 'njobs' in opts: n_jobs = int(opts['njobs']) print '* n_jobs =', n_jobs # TODO: implement!!! # -- preps print '-> Initializing...' h5 = tbl.openFile(fn_inp) Msnp = h5.root.Msnp.read() Msnp_ch = h5.root.Msnp_ch.read() Msnp_pos = h5.root.Msnp_pos.read() Msnp_tabs = h5.root.Msnp_tabs.read() Mimg = h5.root.Mimg.read() Mimg_tabs = h5.root.Mimg_tabs.read() t_adjust = h5.root.meta.t_adjust.read() t_start0 = h5.root.meta.t_start0.read() t_stop0 = h5.root.meta.t_stop0.read() idx2iid = h5.root.meta.idx2iid.read() iid2idx_pk = h5.root.meta.iid2idx_pk.read() idx2ch = h5.root.meta.idx2ch.read() ch2idx_pk = h5.root.meta.ch2idx_pk.read() all_chs = range(len(idx2ch)) # -- re-threshold print '-> Re-thresholding...' if type(config['rethreshold_mult']) is float or int: thr_sel, thrs = rethreshold_by_multiplier(Msnp, Msnp_ch, all_chs, config['rethreshold_mult']) Msnp = Msnp[thr_sel] Msnp_ch = Msnp_ch[thr_sel] Msnp_pos = Msnp_pos[thr_sel] Msnp_tabs = Msnp_tabs[thr_sel] Msnp_selected = np.nonzero(thr_sel)[0] else: thr_sel = None thrs = None Msnp_selected = None # -- align print '-> Aligning...' Msnp = par_comp(align_core, Msnp, n_jobs=n_jobs, **config['align']) # -- feature extraction print '-> Extracting features...' if config['feat']['metd'] == 'wavelet': Msnp_feat = par_comp(wavelet_core, Msnp, n_jobs=n_jobs, **config['feat']['kwargs']) elif config['feat']['metd'] != 'pca': config['feat']['kwargs'].pop('level') raise NotImplementedError('PCA not implemented yet') else: raise ValueError('Not recognized "feat_metd"') # -- done! write everything... print '-> Writing results...' filters = tbl.Filters(complevel=4, complib='blosc') t_int16 = tbl.Int16Atom() t_uint32 = tbl.UInt32Atom() t_uint64 = tbl.UInt64Atom() t_float32 = tbl.Float32Atom() h5o = tbl.openFile(fn_out, 'w') CMsnp = h5o.createCArray(h5o.root, 'Msnp', t_int16, Msnp.shape, filters=filters) CMsnp_tabs = h5o.createCArray(h5o.root, 'Msnp_tabs', t_uint64, Msnp_tabs.shape, filters=filters) CMsnp_ch = h5o.createCArray(h5o.root, 'Msnp_ch', t_uint32, Msnp_ch.shape, filters=filters) CMsnp_pos = h5o.createCArray(h5o.root, 'Msnp_pos', t_uint64, Msnp_pos.shape, filters=filters) CMsnp_feat = h5o.createCArray(h5o.root, 'Msnp_feat', t_float32, Msnp_feat.shape, filters=filters) # TODO: support when thr_sel is None CMsnp_selected = h5o.createCArray(h5o.root, 'Msnp_selected', t_uint64, Msnp_selected.shape, filters=filters) CMsnp[...] = Msnp CMsnp_tabs[...] = Msnp_tabs CMsnp_ch[...] = Msnp_ch CMsnp_pos[...] = Msnp_pos CMsnp_feat[...] = Msnp_feat CMsnp_selected[...] = Msnp_selected h5o.createArray(h5o.root, 'Mimg', Mimg) h5o.createArray(h5o.root, 'Mimg_tabs', Mimg_tabs) meta = h5o.createGroup('/', 'meta', 'Metadata') h5o.createArray(meta, 't_start0', t_start0) h5o.createArray(meta, 't_stop0', t_stop0) h5o.createArray(meta, 't_adjust', t_adjust) h5o.createArray(meta, 'config_feat_pk', pk.dumps(config)) h5o.createArray(meta, 'idx2iid', idx2iid) h5o.createArray(meta, 'iid2idx_pk', iid2idx_pk) h5o.createArray(meta, 'idx2ch', idx2ch) h5o.createArray(meta, 'ch2idx_pk', ch2idx_pk) h5o.createArray(meta, 'thrs', thrs) h5o.createArray(meta, 'fn_inp', fn_inp) h5o.close() h5.close()