예제 #1
0
def add_event_type(fd, id=None, evt=None):
    """fd is returned by `open_files`: it is a dict {type: tb_file_handle}."""
    kwik = fd.get('kwik', None)
    # The KWIK needs to be there.
    assert kwik is not None
    if id is None:
        # If id is None, take the maximum integer index among the existing
        # recording names, + 1.
        event_types = sorted([n._v_name 
                             for n in kwik.listNodes('/event_types')])
        if event_types:
            id = str(max([int(r) for r in event_types if r.isdigit()]) + 1)
        else:
            id = '0'
    event_type = kwik.createGroup('/event_types', id)
    
    kwik.createGroup(event_type, 'user_data')
    
    app = kwik.createGroup(event_type, 'application_data')
    kv = kwik.createGroup(app, 'klustaviewa')
    kv._f_setAttr('color', None)
    
    events = kwik.createGroup(event_type, 'events')
    kwik.createEArray(events, 'time_samples', tb.UInt64Atom(), (0,))
    kwik.createEArray(events, 'recording', tb.UInt16Atom(), (0,))
    kwik.createGroup(events, 'user_data')
def inizialize_dataset():
    tab = tables.open_file(dsOut, mode='w')
    data_shape = (0, sizedb[0], sizedb[1], sizedb[2])
    img_dtype = tables.UInt8Atom()
    label_dtype = tables.UInt64Atom()
    X_storage = tab.create_earray(tab.root, 'X', img_dtype, shape=data_shape)
    Y_storageID = tab.create_earray(tab.root, 'Y_ID', label_dtype, shape=(0, ))
    Y_desc = tab.create_earray(tab.root, 'desc', label_dtype, shape=(0, 6))
    return X_storage, Y_storageID, Y_desc
예제 #3
0
def inizialize_dataset():
    h5 = tables.open_file(dbOut_path, mode='w')
    data_shape = (0, size_embedding)
    img_dtype = tables.Float32Atom()
    label_dtype = tables.UInt64Atom()
    X_storage = h5.create_earray(h5.root, 'X', img_dtype, shape=data_shape)
    Y_storageID = h5.create_earray(h5.root, 'Y_ID', label_dtype, shape=(0, ))
    Y_desc = h5.create_earray(h5.root, 'desc', label_dtype, shape=(0, 6))
    return X_storage, Y_storageID, Y_desc
def inizialize_dataset():
    global X_storage, Y_storageID, desc_storage
    h5 = tables.open_file(db_path, mode='w')
    data_shape = (0, sizedb[0], sizedb[1], sizedb[2])
    img_dtype = tables.UInt8Atom()
    label_dtype = tables.UInt64Atom()
    X_storage = h5.create_earray(h5.root, 'X', img_dtype, shape=data_shape)
    Y_storageID = h5.create_earray(h5.root, 'Y_ID', label_dtype, shape=(0, ))
    desc_storage = h5.create_earray(h5.root, 'desc', label_dtype,
                                    shape=(0, 6))  #video,frame,boundingbox
예제 #5
0
def initfile(h5name, ncsf, q_down, include_times=True):
    """
    initializes a h5 file to store converted data
    """
    adbitvolts = ncsf.header['ADBitVolts']
    timestep = ncsf.timestep

    chname = ncsf.header['AcqEntName']

    h5f = tables.open_file(h5name, 'w')

    h5f.create_group('/', 'data')
    h5f.create_earray('/data', 'rawdata', tables.Int16Atom(), [0])
    h5f.root.data.rawdata.set_attr('ADBitVolts', adbitvolts)
    h5f.root.data.rawdata.set_attr('timestep', timestep)
    h5f.root.data.rawdata.set_attr('Q', q_down)
    h5f.root.data.rawdata.set_attr('AcqEntName', chname)

    if include_times:
        h5f.create_earray('/', 'time', tables.UInt64Atom(), [0])

    return h5f
예제 #6
0
    def _create_table(self, name, example):
        """
        Create a new table within the HDF file, where the tables shape and its
        datatype are determined by *example*.
        """
        type_map = {
            np.dtype(np.float64): tables.Float64Atom(),
            np.dtype(np.float32): tables.Float32Atom(),
            np.dtype(np.int): tables.Int64Atom(),
            np.dtype(np.int8): tables.Int8Atom(),
            np.dtype(np.uint8): tables.UInt8Atom(),
            np.dtype(np.int16): tables.Int16Atom(),
            np.dtype(np.uint16): tables.UInt16Atom(),
            np.dtype(np.int64): tables.UInt64Atom(),
            np.dtype(np.int32): tables.Int32Atom(),
            np.dtype(np.uint32): tables.UInt32Atom(),
            np.dtype(np.bool): tables.BoolAtom(),
        }

        try:
            if type(example) == np.ndarray:
                h5type = type_map[example.dtype]
            elif type(example) == str:
                h5type = tables.VLStringAtom()
        except KeyError:
            raise TypeError(
                "Could not create table %s because of unknown dtype '%s'" %
                (name, example.dtype))  #+ ", of name: " % example.shape)

        if type(example) == np.ndarray:
            h5dim = (0, ) + example.shape

            h5 = self.h5
            filters = tables.Filters(complevel=self.compression_level,
                                     complib='zlib',
                                     shuffle=True)

            nodes = h5.list_nodes(h5.root)

            nmpt = name.replace('.', '/\n')
            nmpt = nmpt.split('\n')

            path = '/'
            for kay in range(len(nmpt) - 1):
                #if not path+nmpt[kay][:-1] in str(nodes): h5.create_group(path,nmpt[kay][:-1])
                try:
                    h5.is_visible_node(path + nmpt[kay][:-1])
                except:
                    h5.create_group(path, nmpt[kay][:-1])
                path += nmpt[kay]

            self.tables[name] = h5.create_earray(path,
                                                 nmpt[-1],
                                                 h5type,
                                                 h5dim,
                                                 filters=filters)

        elif type(example) == str:
            h5 = self.h5
            filters = tables.Filters(complevel=self.compression_level,
                                     complib='zlib',
                                     shuffle=True)

            nodes = h5.list_nodes(h5.root)

            nmpt = name.replace('.', '/\n')
            nmpt = nmpt.split('\n')

            path = '/'
            for kay in range(len(nmpt) - 1):
                #if not path+nmpt[kay][:-1] in str(nodes): h5.create_group(path,nmpt[kay][:-1])
                try:
                    h5.is_visible_node(path + nmpt[kay][:-1])
                except:
                    h5.create_group(path, nmpt[kay][:-1])
                path += nmpt[kay]

            self.tables[name] = h5.create_vlarray(path,
                                                  nmpt[-1],
                                                  h5type,
                                                  filters=filters)

        self.types[name] = type(example)
예제 #7
0
def create_kwik(path, experiment_name=None, prm=None, prb=None, overwrite=True):
    """Create a KWIK file.
    
    Arguments:
      * path: path to the .kwik file.
      * experiment_name
      * prm: a dictionary representing the contents of the PRM file (used for
        SpikeDetekt)
      * prb: a dictionary with the contents of the PRB file
    
    """
    if experiment_name is None:
        experiment_name = ''
    if prm is None:
        prm = {}
    if prb is None:
        prb = {}
    
    if not overwrite and os.path.exists(path):
        return
    
    file = tb.openFile(path, mode='w')
    
    file.root._f_setAttr('kwik_version', 2)
    file.root._f_setAttr('name', experiment_name)

    file.createGroup('/', 'application_data')
    
    # Set the SpikeDetekt parameters
    file.createGroup('/application_data', 'spikedetekt')
    for prm_name, prm_value in iteritems(prm):
        file.root.application_data.spikedetekt._f_setAttr(prm_name, prm_value)
    
    file.createGroup('/', 'user_data')
    
    # Create channel groups.
    file.createGroup('/', 'channel_groups')
    
    for igroup, group_info in prb.iteritems():
        igroup = int(igroup)
        group = file.createGroup('/channel_groups', str(igroup))
        # group_info: channel, graph, geometry
        group._f_setAttr('name', 'channel_group_{0:d}'.format(igroup))
        group._f_setAttr('adjacency_graph', 
            np.array(group_info.get('graph', np.zeros((0, 2))), dtype=np.int32))
        file.createGroup(group, 'application_data')
        file.createGroup(group, 'user_data')
        
        # Create channels.
        file.createGroup(group, 'channels')
        channels = group_info.get('channels', [])
        
        # Add the channel order.
        group._f_setAttr('channel_order', np.array(channels, dtype=np.int32))
        
        for channel_idx in channels:
            # channel is the absolute channel index.
            channel = file.createGroup(group.channels, str(channel_idx))
            channel._f_setAttr('name', 'channel_{0:d}'.format(channel_idx))
            
            channel._f_setAttr('ignored', False)  # "channels" only contains 
                                                  # not-ignored channels here
            pos = group_info.get('geometry', {}). \
                get(channel_idx, None)
            if pos is not None:
                pos = np.array(pos, dtype=np.float32)
            channel._f_setAttr('position', pos)
            channel._f_setAttr('voltage_gain', prm.get('voltage_gain', 0.))
            channel._f_setAttr('display_threshold', 0.)
            file.createGroup(channel, 'application_data')
            file.createGroup(channel.application_data, 'spikedetekt')
            file.createGroup(channel.application_data, 'klustaviewa')
            file.createGroup(channel, 'user_data')
            
        # Create spikes.
        spikes = file.createGroup(group, 'spikes')
        file.createEArray(spikes, 'time_samples', tb.UInt64Atom(), (0,),
                          expectedrows=1000000)
        file.createEArray(spikes, 'time_fractional', tb.UInt8Atom(), (0,),
                          expectedrows=1000000)
        file.createEArray(spikes, 'recording', tb.UInt16Atom(), (0,),
                          expectedrows=1000000)
        clusters = file.createGroup(spikes, 'clusters')
        file.createEArray(clusters, 'main', tb.UInt32Atom(), (0,),
                          expectedrows=1000000)
        file.createEArray(clusters, 'original', tb.UInt32Atom(), (0,),
                          expectedrows=1000000)
        
        fm = file.createGroup(spikes, 'features_masks')
        fm._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/features_masks'. \
            format(igroup))
        wr = file.createGroup(spikes, 'waveforms_raw')
        wr._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_raw'. \
            format(igroup))
        wf = file.createGroup(spikes, 'waveforms_filtered')
        wf._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_filtered'. \
            format(igroup))
        
        # Create clusters.
        clusters = file.createGroup(group, 'clusters')
        file.createGroup(clusters, 'main')
        file.createGroup(clusters, 'original')
        
        # Create cluster groups.
        cluster_groups = file.createGroup(group, 'cluster_groups')
        file.createGroup(cluster_groups, 'main')
        file.createGroup(cluster_groups, 'original')
        
    # Create recordings.
    file.createGroup('/', 'recordings')
    
    # Create event types.
    file.createGroup('/', 'event_types')
            
    file.close()
예제 #8
0
def compress_twix(infile,
                  outfile,
                  remove_os=False,
                  cc_mode=False,
                  ncc=None,
                  cc_tol=0.05,
                  zfp=False,
                  zfp_tol=1e-5,
                  zfp_prec=None,
                  rm_fidnav=False):

    with suppress_stdout_stderr():
        twix = twixtools.read_twix(infile)

    filters = tables.Filters(complevel=5,
                             complib='zlib')  # lossless compression settings
    #filters = None

    mtx = None
    noise_mtx = None
    noise_dmtx = None
    if cc_mode or zfp:
        # # calibrate noise decorrelation matrix for better compression
        # noise = list()
        # for mdb in twix[1]['mdb']:
        #     if mdb.is_flag_set('NOISEADJSCAN'):
        #         noise.append(mdb.data)
        # if len(noise)>0:
        #     noise_dmtx, noise_mtx = calculate_prewhitening(np.asarray(noise).swapaxes(0,1))
        # del(noise)
        pass

    if cc_mode:
        # calibrate coil compression based on last scan in list (image scan)
        # use the calibration coil weights for all data that fits
        cal_data = get_cal_data(twix[-1], remove_os)
        if cc_mode == 'scc' or cc_mode == 'gcc':
            mtx, ncc = calibrate_mtx(cal_data, cc_mode, ncc, cc_tol)
            del (cal_data)
            print('coil compression from %d channels to %d virtual channels' %
                  (mtx.shape[-1], ncc))
        else:
            mtx = calibrate_mtx_bart(cal_data, cc_mode)
            del (cal_data)
            if ncc is None:
                # set default
                ncc = mtx.shape[-1] // 2
            print('coil compression from %d channels to %d virtual channels' %
                  (mtx.shape[-1], ncc))

    t_start = time.time()
    with tables.open_file(outfile, mode="w") as f:
        f.root._v_attrs.original_filename = os.path.basename(infile)
        f.root._v_attrs.cc_mode = cc_mode
        f.root._v_attrs.ncc = ncc
        f.root._v_attrs.zfp = zfp

        if zfp_tol is None:
            f.root._v_attrs.zfp_tol = -1
        else:
            f.root._v_attrs.zfp_tol = zfp_tol
        if zfp_prec is None:
            f.root._v_attrs.zfp_prec = -1
        else:
            f.root._v_attrs.zfp_prec = zfp_prec

        f.create_carray(f.root,
                        "multi_header",
                        obj=np.frombuffer(twix[0].tobytes(), 'S1'),
                        filters=filters)

        if mtx is not None:
            # save mtx for coil compression
            f.create_carray(f.root, "mtx", obj=mtx, filters=filters)
        if noise_dmtx is not None:
            f.create_carray(f.root,
                            "noise_dmtx",
                            obj=noise_dmtx,
                            filters=filters)
            f.create_carray(f.root,
                            "noise_mtx",
                            obj=noise_mtx,
                            filters=filters)

        scanlist = []
        for meas_key, meas in enumerate(twix[1:]):
            scanlist.append("scan%d" % (meas_key))
            grp = f.create_group("/", "scan%d" % (meas_key))
            f.create_carray(grp,
                            "hdr_str",
                            obj=meas['hdr_str'],
                            filters=filters)

            # remove fidnav scans if necessary
            if rm_fidnav:
                for mdb_key, mdb in enumerate(meas['mdb']):
                    if mdb.is_flag_set('noname60'):
                        del (meas['mdb'][mdb_key])

            mdh_count = len(meas['mdb'])

            # create info array with mdh, coil & compression information
            f.create_carray(grp,
                            "info",
                            shape=[mdh_count, datinfo_type.itemsize],
                            atom=tables.UInt8Atom(),
                            filters=filters)

            dt = tables.UInt64Atom(shape=())
            if zfp:
                f.create_vlarray(grp, "DATA", atom=dt, expectedrows=mdh_count)
            else:
                f.create_vlarray(grp,
                                 "DATA",
                                 atom=dt,
                                 filters=filters,
                                 expectedrows=mdh_count)

            syncscans = 0
            for mdb_key, mdb in enumerate(meas['mdb']):
                info = np.zeros(1, dtype=datinfo_type)[0]
                is_syncscan = mdb.is_flag_set('SYNCDATA')
                if rm_fidnav:  # we have to update the scan counters
                    if not is_syncscan:
                        mdb.mdh[
                            'ulScanCounter'] = mdb_key + 1 - syncscans  # scanCounter starts at 1
                    else:
                        syncscans += 1

                # store mdh
                info['mdh_info'] = mdb.mdh

                if is_syncscan or mdb.is_flag_set('ACQEND'):
                    data = np.ascontiguousarray(mdb.data).view('uint64')
                else:
                    restrictions = get_restrictions(mdb.get_flags())
                    if restrictions == 'NO_COILCOMP':
                        data, info['rm_os_active'], _ = reduce_data(
                            mdb.data, mdb.mdh, remove_os, cc_mode=False)
                    else:
                        data, info['rm_os_active'], info[
                            'cc_active'] = reduce_data(mdb.data,
                                                       mdb.mdh,
                                                       remove_os,
                                                       cc_mode=cc_mode,
                                                       mtx=mtx,
                                                       ncc=ncc)
                    data = data.flatten()
                    if zfp:
                        data = pyzfp.compress(data.view('float32'),
                                              tolerance=zfp_tol,
                                              precision=zfp_prec,
                                              parallel=True)
                        data = np.frombuffer(data, dtype='uint64')
                    else:
                        data = data.view('uint64')
                    if len(mdb.channel_hdr) > 0:
                        mdb.channel_hdr[0]['ulScanCounter'] = mdb.mdh[
                            'ulScanCounter']
                        info['coil_info'] = mdb.channel_hdr[0]
                        coil_list = np.asarray(
                            [item['ulChannelId'] for item in mdb.channel_hdr],
                            dtype='uint8')
                        info['coil_list'][:len(coil_list)] = coil_list

                # write data
                grp.DATA.append(data)
                grp.info[mdb_key] = np.frombuffer(info, dtype='uint8')

        f.root._v_attrs.scanlist = scanlist

        # from joblib import Parallel, delayed
        # Parallel(n_jobs=2)(delayed(task)(mdb_key, mdb, is_byte, count, grp, remove_os, zfp, zfp_tol, zfp_prec, mtx) for mdb_key, (mdb, is_byte, count) in enumerate(zip(meas['mdb'], is_bytearray, data_counter)))

    elapsed_time = (time.time() - t_start)
    print("compression finished in %d:%02d:%02d h" %
          (elapsed_time // 3600,
           (elapsed_time % 3600) // 60, elapsed_time % 60))
    print("compression factor = %.2f" %
          (os.path.getsize(infile) / os.path.getsize(outfile)))
    def init_h5file(self):

        file, curr_dir = self.get_new_file_name()

        self.settings.child('acquisition', 'temp_file').setValue(file+'.h5')
        self.h5file = tables.open_file(os.path.join(curr_dir, file+'.h5'), mode='w')
        h5group = self.h5file.root
        h5group._v_attrs['settings'] = customparameter.parameter_to_xml_string(self.settings)
        h5group._v_attrs.type = 'detector'
        h5group._v_attrs['format_name'] = 'timestamps'

        channels_index = [self.channels_enabled[k]['index'] for k in self.channels_enabled.keys() if
                          self.channels_enabled[k]['enabled']]
        self.marker_array = self.h5file.create_earray(self.h5file.root, 'markers', tables.UInt8Atom(), (0,),
                                                      title='markers')
        self.marker_array._v_attrs['data_type'] = '1D'
        self.marker_array._v_attrs['type'] = 'tttr_data'

        self.nanotimes_array = self.h5file.create_earray(self.h5file.root, 'nanotimes', tables.UInt16Atom(), (0,),
                                                         title='nanotimes')
        self.nanotimes_array._v_attrs['data_type'] = '1D'
        self.nanotimes_array._v_attrs['type'] = 'tttr_data'

        self.timestamp_array = self.h5file.create_earray(self.h5file.root, 'timestamps', tables.UInt64Atom(), (0,),
                                                   title='timestamps')
        self.timestamp_array._v_attrs['data_type'] = '1D'
        self.timestamp_array._v_attrs['type'] = 'tttr_data'
예제 #10
0
def make_h5(obj, handle, objname):
    """given a vector or dictionary of objects associated with the set of DHS sites,
    save the objects in the hdf5 format used by other scripts.
    
    Objects currently handled:
        UInt8
        UInt64
        Float64
    """

    if type(obj) == np.ndarray:

        file = tables.openFile(
            '/mnt/lustre/home/anilraj/linspec/cache/dhslocations.h5', 'r')
        locdata = dict([(chr, file.getNode('/' + chr)) for chr in chromosomes])
        locations = [(chr, l) for chr in chromosomes
                     for l in locdata[chr].start[:]]
        file.close()

        data = dict()
        for v, loc in zip(obj, locations):
            try:
                data[loc[0]].append(v)
            except KeyError:
                data[loc[0]] = [v]

        for k, v in data.iteritems():
            data[k] = np.array(v).astype(obj.dtype)

        # selecting an appropriate atom
        if obj.dtype == np.float64:
            atom = tables.Float64Atom()
        elif obj.dtype == np.float32:
            atom = tables.Float32Atom()
        elif obj.dtype == np.int8:
            atom = tables.UInt8Atom()
        elif obj.dtype == np.int64:
            atom = tables.UInt64Atom()

    else:

        data = dict()
        for chr, vals in obj.iteritems():
            data[chr] = vals

        # selecting an appropriate atom
        if obj[chr].dtype == np.float64:
            atom = tables.Float64Atom()
        elif obj[chr].dtype == np.float32:
            atom = tables.Float32Atom
        elif obj[chr].dtype == np.int8:
            atom = tables.UInt8Atom()
        elif obj[chr].dtype == np.int64:
            atom = tables.UInt64Atom()

    filters = tables.Filters(complevel=5, complib='zlib')

    for chr, dat in data.iteritems():
        chrgroup = handle.createGroup(handle.root, chr, chr)
        values = handle.createCArray(chrgroup,
                                     objname,
                                     atom,
                                     dat.shape,
                                     filters=filters)
        values[:] = dat[:]

    return handle
예제 #11
0
def get_PS_waveform(fn_mwk,
                    fn_nev,
                    fn_out,
                    movie_begin_fname=None,
                    n_samples=DEFAULT_SAMPLES_PER_SPK,
                    n_max_spks=DEFAULT_MAX_SPKS,
                    **kwargs):
    """Get waveform data around stimuli presented for later spike sorting.
    This will give completely different output file format.
    NOTE: this function is memory intensive!  Will require approximately
    as much memory as the size of the files."""

    # -- some housekeeping things...
    kwargs['verbose'] = 2
    kwargs['only_new_t'] = True
    t_start0 = kwargs['t_start0']
    t_stop0 = kwargs['t_stop0']

    iid2idx = {}
    idx2iid = []
    ch2idx = {}
    idx2ch = []
    n_spks = 0

    # does "n_spks_lim" reach "n_max_spks"?
    b_warn_max_spks_lim = False
    # list of image presentations without spikes
    l_empty_spks = []

    for info in getspk(fn_mwk, fn_nev=fn_nev, **kwargs):
        # -- get the metadata. this must be called before other clauses
        if info['type'] == 'preamble':
            actvelecs = info['actvelecs']
            t_adjust = info['t_adjust']
            chn_info = info['chn_info']
            n_spks_lim = min(info['n_packets'], n_max_spks)
            print '* n_spks_lim =', n_spks_lim

            for ch in sorted(actvelecs):
                makeavail(ch, ch2idx, idx2ch)

            # Data for snippets ===
            # Msnp: snippet data
            # Msnp_tabs: when it spiked (absolute time)
            # Msnp_ch: which channel ID spiked?
            # Msnp_pos: corresponding file position
            Msnp = np.empty((n_spks_lim, n_samples), dtype='int16')
            Msnp_tabs = np.empty(n_spks_lim, dtype='uint64')
            Msnp_ch = np.empty(n_spks_lim, dtype='uint32')
            Msnp_pos = np.empty(n_spks_lim, dtype='uint64')

            # Data for images ===
            # Mimg: image indices in the order of presentations
            # Mimg_tabs: image onset time (absolute)
            Mimg = []
            Mimg_tabs = []

        # -- do some housekeeping things once per each img
        if info['type'] == 'begin':
            t_abs = info['t_imgonset']
            iid = info['imgid']
            i_img = info['i_img']

            makeavail(iid, iid2idx, idx2iid)
            Mimg.append(iid2idx[iid])
            Mimg_tabs.append(t_abs)
            b_no_spks = True

            # process movie if requested
            if movie_begin_fname is not None:
                raise NotImplementedError('Movies are not supported yet.')

        # -- put actual spiking info
        elif info['type'] == 'spike':
            wav = info['wavinfo']['waveform']
            t_abs = info['t_abs']
            i_ch = ch2idx[info['ch']]
            pos = info['pos']

            Msnp[n_spks] = wav
            Msnp_tabs[n_spks] = t_abs
            Msnp_ch[n_spks] = i_ch
            Msnp_pos[n_spks] = pos
            b_no_spks = False

            n_spks += 1
            if n_spks >= n_spks_lim:
                warnings.warn('n_spks exceedes n_spks_lim! '
                              'Aborting further additions.')
                b_warn_max_spks_lim = True
                break

        elif info['type'] == 'end':
            if not b_no_spks:
                continue
            # if there's no spike at all, list the stim
            warnings.warn('No spikes are there!       ')
            l_empty_spks.append(i_img)

    # -- done!
    # finished calculation....
    Msnp = Msnp[:n_spks]
    Msnp_tabs = Msnp_tabs[:n_spks]
    Msnp_ch = Msnp_ch[:n_spks]
    Msnp_pos = Msnp_pos[:n_spks]

    Mimg = np.array(Mimg, dtype='uint32')
    Mimg_tabs = np.array(Mimg_tabs, dtype='uint64')

    filters = tbl.Filters(complevel=4, complib='blosc')
    t_int16 = tbl.Int16Atom()
    t_uint32 = tbl.UInt32Atom()
    t_uint64 = tbl.UInt64Atom()

    h5o = tbl.openFile(fn_out, 'w')
    CMsnp = h5o.createCArray(h5o.root,
                             'Msnp',
                             t_int16,
                             Msnp.shape,
                             filters=filters)
    CMsnp_tabs = h5o.createCArray(h5o.root,
                                  'Msnp_tabs',
                                  t_uint64,
                                  Msnp_tabs.shape,
                                  filters=filters)
    CMsnp_ch = h5o.createCArray(h5o.root,
                                'Msnp_ch',
                                t_uint32,
                                Msnp_ch.shape,
                                filters=filters)
    CMsnp_pos = h5o.createCArray(h5o.root,
                                 'Msnp_pos',
                                 t_uint64,
                                 Msnp_pos.shape,
                                 filters=filters)

    CMsnp[...] = Msnp
    CMsnp_tabs[...] = Msnp_tabs
    CMsnp_ch[...] = Msnp_ch
    CMsnp_pos[...] = Msnp_pos

    h5o.createArray(h5o.root, 'Mimg', Mimg)
    h5o.createArray(h5o.root, 'Mimg_tabs', Mimg_tabs)

    meta = h5o.createGroup('/', 'meta', 'Metadata')
    h5o.createArray(meta, 't_start0', t_start0)
    h5o.createArray(meta, 't_stop0', t_stop0)
    h5o.createArray(meta, 't_adjust', t_adjust)
    h5o.createArray(meta, 'chn_info_pk', pk.dumps(chn_info))
    h5o.createArray(meta, 'kwargs_pk', pk.dumps(kwargs))

    h5o.createArray(meta, 'idx2iid', idx2iid)
    h5o.createArray(meta, 'iid2idx_pk', pk.dumps(iid2idx))
    h5o.createArray(meta, 'idx2ch', idx2ch)
    h5o.createArray(meta, 'ch2idx_pk', pk.dumps(ch2idx))

    # some error signals
    h5o.createArray(meta, 'b_warn_max_spks_lim', b_warn_max_spks_lim)
    if len(l_empty_spks) > 0:
        h5o.createArray(meta, 'l_empty_spks', l_empty_spks)

    h5o.close()
예제 #12
0
파일: spksort.py 프로젝트: stothe2/maru
def cluster(fn_inp, fn_out, opts):
    config = {}
    config['skimspk'] = {}
    config['skimspk']['tb0'] = SKIMSPK_TB
    config['skimspk']['te0'] = SKIMSPK_TE
    config['extract'] = {}
    config['extract']['nperimg'] = EXTRACT_NPERIMG
    config['extract']['nmax'] = EXTRACT_NMAX

    config['feat'] = {}
    config['feat']['kssort'] = FEAT_KSSORT
    config['feat']['outdim'] = FEAT_OUTDIM

    config['cluster'] = {}
    config['cluster']['metd'] = CLUSTERING_ALG
    config['cluster']['commonp'] = AFFINITYPRP_COMMONP

    config['qc'] = {}
    config['qc']['qc'] = QC
    config['qc']['kwargs'] = {}
    config['qc']['kwargs']['min_snr'] = QC_MINSNR
    config['qc']['kwargs']['ks_plevel'] = QC_KS_PLEVEL
    config['qc']['kwargs']['min_size'] = QC_MINSIZE

    config['nn'] = {}
    config['nn']['nneigh'] = NN_NNEIGH
    config['nn']['radius'] = NN_RADIUS

    n_jobs = NCPU
    reference = None

    # -- process opts
    if 'njobs' in opts:
        n_jobs = int(opts['njobs'])
        print '* n_jobs =', n_jobs

    if 'ref' in opts:
        reference = opts['ref']
        print '* Using the reference file:', reference
        h5r = tbl.openFile(reference)
        config = pk.loads(h5r.root.meta.config_clu_pk.read())

    # TODO: implement other options!!!

    # -- preps
    print '-> Initializing...'
    h5 = tbl.openFile(fn_inp)
    Msnp = h5.root.Msnp.read()
    Msnp_feat = h5.root.Msnp_feat.read()
    Msnp_ch = h5.root.Msnp_ch.read()
    Msnp_pos = h5.root.Msnp_pos.read()
    Msnp_tabs = h5.root.Msnp_tabs.read()
    Msnp_selected = h5.root.Msnp_selected.read()
    Mimg = h5.root.Mimg.read()
    Mimg_tabs = h5.root.Mimg_tabs.read()

    t_adjust = h5.root.meta.t_adjust.read()
    t_start0 = h5.root.meta.t_start0.read()
    t_stop0 = h5.root.meta.t_stop0.read()

    idx2iid = h5.root.meta.idx2iid.read()
    iid2idx_pk = h5.root.meta.iid2idx_pk.read()
    idx2ch = h5.root.meta.idx2ch.read()
    ch2idx_pk = h5.root.meta.ch2idx_pk.read()
    all_chs = range(len(idx2ch))

    if reference is None:
        # -- get training examples...
        print '-> Collecting snippet examples...'
        ibie, iuimg = skim_imgs(Mimg, Mimg_tabs, Msnp_tabs,
                t_adjust, **config['skimspk'])

        clu_feat_train = get_example_spikes(Msnp_feat, Msnp_ch, ibie, all_chs,
                **config['extract'])
        clu_train = get_example_spikes(Msnp, Msnp_ch, ibie, all_chs,
                **config['extract'])

        # -- get feature indices to use...
        print '-> Finding useful axes...'
        outdim = config['feat']['outdim']
        if config['feat']['kssort']:
            Msnp_feat_use = []

            for i_ch in xrange(len(all_chs)):
                # get deviations from Gaussian
                devs, _ = KS_all(clu_feat_train[i_ch])
                # got top-n deviations
                devs = np.argsort(-devs)[:outdim]
                Msnp_feat_use.append(devs)
        else:
            Msnp_feat_use = [range(outdim)] * len(all_chs)
        Msnp_feat_use = np.array(Msnp_feat_use)

        # -- XXX: DEBUG SUPPORT
        __DBG__ = False
        if __DBG__:
            clu_feat_train = clu_feat_train[:4]
            clu_train = clu_train[:4]
            Msnp_feat_use = Msnp_feat_use[:4]
            all_chs = all_chs[:4]

        # -- get clusters...
        print '-> Clustering...'
        clu_labels, clu_ulabels, clu_nclus, clu_centers = \
                find_clusters_par(clu_feat_train,
                Msnp_feat_use, n_jobs=n_jobs,
                **config['cluster'])

        # -- quality control
        if config['qc']['qc']:
            print '-> Run signal quality-based screening...'
            clu_sig_q = quality_meas_par(clu_train, clu_labels)
            quality_ctrl_par(clu_sig_q, clu_labels, clu_ulabels,
                    clu_nclus, clu_centers, n_jobs=n_jobs,
                    **config['qc']['kwargs'])
    else:
        # -- Bypass all and get the pre-computed clustering template
        print '-> Loading reference data...'
        clu_feat_train = pk.loads(h5r.root.clu_pk.clu_feat_train_pk.read())
        clu_train = pk.loads(h5r.root.clu_pk.clu_train_pk.read())
        clu_labels = pk.loads(h5r.root.clu_pk.clu_labels_pk.read())
        clu_ulabels = pk.loads(h5r.root.clu_pk.clu_ulabels_pk.read())
        clu_nclus = pk.loads(h5r.root.clu_pk.clu_nclus_pk.read())
        clu_centers = pk.loads(h5r.root.clu_pk.clu_centers_pk.read())
        clu_sig_q = pk.loads(h5r.root.clu_pk.clu_sig_q_pk.read())
        Msnp_feat_use = h5r.root.Msnp_feat_use.read()
        h5r.close()

    # -- NN search
    print '-> Template matching...'
    Msnp_cid = nearest_neighbor_par(clu_feat_train, Msnp_feat,
            Msnp_feat_use, Msnp_ch, clu_labels, clu_ulabels,
            all_chs, n_jobs=n_jobs, **config['nn'])

    # -- final quality report...
    print '-> Computing the final signal quality report...'
    clu_sig_q = quality_meas_par2(Msnp, Msnp_ch, Msnp_cid,
            all_chs, n_jobs=n_jobs)
    # ... and update clu_ulabels, clu_nclus
    clu_ulabels = [np.array(clu_sig_q_ch.keys(), dtype='int') for
            clu_sig_q_ch in clu_sig_q]
    clu_nclus = [len(clu_sig_q_ch.keys()) for clu_sig_q_ch in clu_sig_q]

    # -- done! write everything...
    print '-> Writing results...'
    filters = tbl.Filters(complevel=4, complib='blosc')
    t_int16 = tbl.Int16Atom()
    t_uint32 = tbl.UInt32Atom()
    t_uint64 = tbl.UInt64Atom()

    h5o = tbl.openFile(fn_out, 'w')
    CMsnp_tabs = h5o.createCArray(h5o.root, 'Msnp_tabs', t_uint64,
            Msnp_tabs.shape, filters=filters)
    CMsnp_ch = h5o.createCArray(h5o.root, 'Msnp_ch', t_uint32,
            Msnp_ch.shape, filters=filters)
    CMsnp_cid = h5o.createCArray(h5o.root, 'Msnp_cid', t_int16,
            Msnp_cid.shape, filters=filters)
    CMsnp_pos = h5o.createCArray(h5o.root, 'Msnp_pos', t_uint64,
            Msnp_pos.shape, filters=filters)
    CMsnp_feat_use = h5o.createCArray(h5o.root, 'Msnp_feat_use', t_uint64,
            Msnp_feat_use.shape, filters=filters)
    CMsnp_selected = h5o.createCArray(h5o.root, 'Msnp_selected',
            t_uint64, Msnp_selected.shape, filters=filters)

    CMsnp_tabs[...] = Msnp_tabs
    CMsnp_ch[...] = Msnp_ch
    CMsnp_cid[...] = Msnp_cid
    CMsnp_pos[...] = Msnp_pos
    CMsnp_feat_use[...] = Msnp_feat_use
    CMsnp_selected[...] = Msnp_selected

    h5o.createArray(h5o.root, 'Mimg', Mimg)
    h5o.createArray(h5o.root, 'Mimg_tabs', Mimg_tabs)

    meta = h5o.createGroup('/', 'meta', 'Metadata')
    h5o.createArray(meta, 't_start0', t_start0)
    h5o.createArray(meta, 't_stop0', t_stop0)
    h5o.createArray(meta, 't_adjust', t_adjust)
    h5o.createArray(meta, 'config_clu_pk', pk.dumps(config))

    h5o.createArray(meta, 'idx2iid', idx2iid)
    h5o.createArray(meta, 'iid2idx_pk', iid2idx_pk)
    h5o.createArray(meta, 'idx2ch', idx2ch)
    h5o.createArray(meta, 'ch2idx_pk', ch2idx_pk)

    h5o.createArray(meta, 'fn_inp', fn_inp)

    clupk = h5o.createGroup('/', 'clu_pk', 'Pickles')
    h5o.createArray(clupk, 'clu_feat_train_pk', pk.dumps(clu_feat_train))
    h5o.createArray(clupk, 'clu_train_pk', pk.dumps(clu_train))
    h5o.createArray(clupk, 'clu_labels_pk', pk.dumps(clu_labels))
    h5o.createArray(clupk, 'clu_ulabels_pk', pk.dumps(clu_ulabels))
    h5o.createArray(clupk, 'clu_nclus_pk', pk.dumps(clu_nclus))
    h5o.createArray(clupk, 'clu_centers_pk', pk.dumps(clu_centers))
    h5o.createArray(clupk, 'clu_sig_q_pk', pk.dumps(clu_sig_q))

    h5o.close()
    h5.close()
예제 #13
0
파일: spksort.py 프로젝트: stothe2/maru
def get_features(fn_inp, fn_out, opts):
    config = {}
    config['rethreshold_mult'] = RETHRESHOLD_MULT
    config['align'] = {}
    config['align']['subsmp'] = ALIGN_SUBSMP
    config['align']['maxdt'] = ALIGN_MAXDT
    config['align']['peakloc'] = ALIGN_PEAKLOC
    config['align']['findbwd'] = ALIGN_FINDBWD
    config['align']['findfwd'] = ALIGN_FINDFWD
    config['align']['cutat'] = ALIGN_CUTAT
    config['align']['outdim'] = ALIGN_OUTDIM
    config['align']['peakfunc'] = ALIGN_PEAKFUNC

    config['feat'] = {}
    config['feat']['metd'] = FEAT_METHOD
    config['feat']['kwargs'] = {'level': FEAT_WAVL_LEV}

    n_jobs = NCPU

    # -- process opts
    if 'njobs' in opts:
        n_jobs = int(opts['njobs'])
        print '* n_jobs =', n_jobs
    # TODO: implement!!!

    # -- preps
    print '-> Initializing...'
    h5 = tbl.openFile(fn_inp)
    Msnp = h5.root.Msnp.read()
    Msnp_ch = h5.root.Msnp_ch.read()
    Msnp_pos = h5.root.Msnp_pos.read()
    Msnp_tabs = h5.root.Msnp_tabs.read()
    Mimg = h5.root.Mimg.read()
    Mimg_tabs = h5.root.Mimg_tabs.read()

    t_adjust = h5.root.meta.t_adjust.read()
    t_start0 = h5.root.meta.t_start0.read()
    t_stop0 = h5.root.meta.t_stop0.read()

    idx2iid = h5.root.meta.idx2iid.read()
    iid2idx_pk = h5.root.meta.iid2idx_pk.read()
    idx2ch = h5.root.meta.idx2ch.read()
    ch2idx_pk = h5.root.meta.ch2idx_pk.read()
    all_chs = range(len(idx2ch))

    # -- re-threshold
    print '-> Re-thresholding...'
    if type(config['rethreshold_mult']) is float or int:
        thr_sel, thrs = rethreshold_by_multiplier(Msnp, Msnp_ch,
                all_chs, config['rethreshold_mult'])

        Msnp = Msnp[thr_sel]
        Msnp_ch = Msnp_ch[thr_sel]
        Msnp_pos = Msnp_pos[thr_sel]
        Msnp_tabs = Msnp_tabs[thr_sel]
        Msnp_selected = np.nonzero(thr_sel)[0]

    else:
        thr_sel = None
        thrs = None
        Msnp_selected = None

    # -- align
    print '-> Aligning...'
    Msnp = par_comp(align_core, Msnp, n_jobs=n_jobs, **config['align'])

    # -- feature extraction
    print '-> Extracting features...'
    if config['feat']['metd'] == 'wavelet':
        Msnp_feat = par_comp(wavelet_core, Msnp, n_jobs=n_jobs,
                **config['feat']['kwargs'])

    elif config['feat']['metd'] != 'pca':
        config['feat']['kwargs'].pop('level')
        raise NotImplementedError('PCA not implemented yet')

    else:
        raise ValueError('Not recognized "feat_metd"')

    # -- done! write everything...
    print '-> Writing results...'
    filters = tbl.Filters(complevel=4, complib='blosc')
    t_int16 = tbl.Int16Atom()
    t_uint32 = tbl.UInt32Atom()
    t_uint64 = tbl.UInt64Atom()
    t_float32 = tbl.Float32Atom()

    h5o = tbl.openFile(fn_out, 'w')
    CMsnp = h5o.createCArray(h5o.root, 'Msnp', t_int16,
            Msnp.shape, filters=filters)
    CMsnp_tabs = h5o.createCArray(h5o.root, 'Msnp_tabs', t_uint64,
            Msnp_tabs.shape, filters=filters)
    CMsnp_ch = h5o.createCArray(h5o.root, 'Msnp_ch', t_uint32,
            Msnp_ch.shape, filters=filters)
    CMsnp_pos = h5o.createCArray(h5o.root, 'Msnp_pos', t_uint64,
            Msnp_pos.shape, filters=filters)
    CMsnp_feat = h5o.createCArray(h5o.root, 'Msnp_feat', t_float32,
            Msnp_feat.shape, filters=filters)
    # TODO: support when thr_sel is None
    CMsnp_selected = h5o.createCArray(h5o.root, 'Msnp_selected',
            t_uint64, Msnp_selected.shape, filters=filters)

    CMsnp[...] = Msnp
    CMsnp_tabs[...] = Msnp_tabs
    CMsnp_ch[...] = Msnp_ch
    CMsnp_pos[...] = Msnp_pos
    CMsnp_feat[...] = Msnp_feat
    CMsnp_selected[...] = Msnp_selected

    h5o.createArray(h5o.root, 'Mimg', Mimg)
    h5o.createArray(h5o.root, 'Mimg_tabs', Mimg_tabs)

    meta = h5o.createGroup('/', 'meta', 'Metadata')
    h5o.createArray(meta, 't_start0', t_start0)
    h5o.createArray(meta, 't_stop0', t_stop0)
    h5o.createArray(meta, 't_adjust', t_adjust)
    h5o.createArray(meta, 'config_feat_pk', pk.dumps(config))

    h5o.createArray(meta, 'idx2iid', idx2iid)
    h5o.createArray(meta, 'iid2idx_pk', iid2idx_pk)
    h5o.createArray(meta, 'idx2ch', idx2ch)
    h5o.createArray(meta, 'ch2idx_pk', ch2idx_pk)

    h5o.createArray(meta, 'thrs', thrs)
    h5o.createArray(meta, 'fn_inp', fn_inp)

    h5o.close()
    h5.close()