示例#1
0
 def __init__(self, name=None, dir=None, files=None, mode='r', prm={}):
     """`name` must correspond to the basename of the files."""
     self.name = name
     self._dir = dir
     self._mode = mode
     self._files = files
     self._prm = prm
     if self._files is None:
         self._files = open_files(self.name, dir=self._dir, mode=self._mode)
     def _get_filename(file):
         if file is None:
             return None
         else:
             return os.path.realpath(file.filename)
     self._filenames = {type: _get_filename(file)
         for type, file in iteritems(self._files)}
     super(Experiment, self).__init__(self._files)
     self._root = self._node
     
     self.application_data = NodeWrapper(self._root.application_data)
     self.user_data = NodeWrapper(self._root.user_data)
     
     self.channel_groups = self._gen_children('channel_groups', ChannelGroup)
     self.recordings = self._gen_children('recordings', Recording)
     self.event_types = self._gen_children('event_types', EventType)
示例#2
0
def _print_instance(obj, depth=0, name=''):
    # Handle the first element of the list/dict.
    if isinstance(obj, (list, dict)):
        if not obj:
            r = []
            return r
        if isinstance(obj, list):
            sobj = obj[0]
            key = '0'
        elif isinstance(obj, dict):
            key, sobj = next(iteritems(obj))
        if isinstance(sobj, (list, dict, int, long, string_types, np.ndarray, 
                      float)):
            r = []
        else:
            r = [(depth+1, str(key))] + _print_instance(sobj, depth+1)
    # Arrays do not have children.
    elif isinstance(obj, (np.ndarray, tb.EArray)):
        r = []
    # Handle class instances.
    elif hasattr(obj, '__dict__'):
        fields = {k: v 
            for k, v in iteritems(vars(obj)) 
                if not k.startswith('_')}
        r = list(chain(*[_print_instance(fields[n], depth=depth+1, name=str(n)) 
                for n in sorted(iterkeys(fields))]))
    else:
        r = []
    # Add the current object's display string.
    if name:
        if isinstance(obj, tb.EArray):
            s = name + ' [{dtype} {shape}]'.format(dtype=obj.dtype, 
                shape=obj.shape)
        else:
            s = name
        r = [(depth, s)] + r
    return r
示例#3
0
def open_files(name, dir=None, mode=None):
    filenames = get_filenames(name, dir=dir)
    return {
        type: open_file(filename, mode=mode)
        for type, filename in iteritems(filenames)
    }
示例#4
0
def create_kwik(path, experiment_name=None, prm=None, prb=None):
    """Create a KWIK file.
    
    Arguments:
      * path: path to the .kwik file.
      * experiment_name
      * prm: a dictionary representing the contents of the PRM file (used for
        SpikeDetekt)
      * prb: a dictionary with the contents of the PRB file
    
    """
    if experiment_name is None:
        experiment_name = ''
    if prm is None:
        prm = {}
    if prb is None:
        prb = {}

    file = tb.openFile(path, mode='w')

    file.root._f_setAttr('kwik_version', 2)
    file.root._f_setAttr('name', experiment_name)

    file.createGroup('/', 'application_data')

    # Set the SpikeDetekt parameters
    file.createGroup('/application_data', 'spikedetekt')
    for prm_name, prm_value in iteritems(prm):
        file.root.application_data.spikedetekt._f_setAttr(prm_name, prm_value)

    file.createGroup('/', 'user_data')

    # Create channel groups.
    file.createGroup('/', 'channel_groups')
    for igroup, group_info in enumerate(prb.get('channel_groups', [])):
        group = file.createGroup('/channel_groups', str(igroup))
        # group_info: channel, graph, geometry
        group._f_setAttr('name', 'channel_group_{0:d}'.format(igroup))
        group._f_setAttr('adjacency_graph',
                         group_info.get('graph', np.zeros((0, 2))))
        file.createGroup(group, 'application_data')
        file.createGroup(group, 'user_data')

        # Create channels.
        file.createGroup(group, 'channels')
        channels = group_info.get('channels', [])
        for channel_idx in channels:
            # channel is the absolute channel index.
            channel = file.createGroup(group.channels, str(channel_idx))
            channel._f_setAttr('name', 'channel_{0:d}'.format(channel_idx))

            ############### TODO
            channel._f_setAttr('kwd_index', 0)
            channel._f_setAttr('ignored', False)
            channel._f_setAttr('position', group_info.get('geometry', {}). \
                get(channel_idx, None))
            channel._f_setAttr('voltage_gain', 0.)
            channel._f_setAttr('display_threshold', 0.)
            file.createGroup(channel, 'application_data')
            file.createGroup(channel.application_data, 'spikedetekt')
            file.createGroup(channel.application_data, 'klustaviewa')
            file.createGroup(channel, 'user_data')

        # Create spikes.
        spikes = file.createGroup(group, 'spikes')
        file.createEArray(spikes,
                          'time_samples',
                          tb.UInt64Atom(), (0, ),
                          expectedrows=1000000)
        file.createEArray(spikes,
                          'time_fractional',
                          tb.UInt8Atom(), (0, ),
                          expectedrows=1000000)
        file.createEArray(spikes,
                          'recording',
                          tb.UInt16Atom(), (0, ),
                          expectedrows=1000000)
        clusters = file.createGroup(spikes, 'clusters')
        file.createEArray(clusters,
                          'main',
                          tb.UInt32Atom(), (0, ),
                          expectedrows=1000000)
        file.createEArray(clusters,
                          'original',
                          tb.UInt32Atom(), (0, ),
                          expectedrows=1000000)

        fm = file.createGroup(spikes, 'features_masks')
        fm._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/features_masks'. \
            format(igroup))
        wr = file.createGroup(spikes, 'waveforms_raw')
        wr._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_raw'. \
            format(igroup))
        wf = file.createGroup(spikes, 'waveforms_filtered')
        wf._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_filtered'. \
            format(igroup))

        # TODO: add clusters 0, 1, 2, 3 by default

        # Create clusters.
        clusters = file.createGroup(group, 'clusters')
        file.createGroup(clusters, 'main')
        file.createGroup(clusters, 'original')

        # Create cluster groups.
        cluster_groups = file.createGroup(group, 'cluster_groups')
        file.createGroup(cluster_groups, 'main')
        file.createGroup(cluster_groups, 'original')

    # Create recordings.
    file.createGroup('/', 'recordings')

    # Create event types.
    file.createGroup('/', 'event_types')

    file.close()
示例#5
0
def create_kwik(path, experiment_name=None, prm=None, prb=None):
    """Create a KWIK file.
    
    Arguments:
      * path: path to the .kwik file.
      * experiment_name
      * prm: a dictionary representing the contents of the PRM file (used for
        SpikeDetekt)
      * prb: a dictionary with the contents of the PRB file
    
    """
    if experiment_name is None:
        experiment_name = ''
    if prm is None:
        prm = {}
    if prb is None:
        prb = {}
    
    file = tb.openFile(path, mode='w')
    
    file.root._f_setAttr('kwik_version', 2)
    file.root._f_setAttr('name', experiment_name)

    file.createGroup('/', 'application_data')
    
    # Set the SpikeDetekt parameters
    file.createGroup('/application_data', 'spikedetekt')
    for prm_name, prm_value in iteritems(prm):
        file.root.application_data.spikedetekt._f_setAttr(prm_name, prm_value)
    
    file.createGroup('/', 'user_data')
    
    # Create channel groups.
    file.createGroup('/', 'channel_groups')
    for igroup, group_info in enumerate(prb.get('channel_groups', [])):
        group = file.createGroup('/channel_groups', str(igroup))
        # group_info: channel, graph, geometry
        group._f_setAttr('name', 'channel_group_{0:d}'.format(igroup))
        group._f_setAttr('adjacency_graph', group_info.get('graph', np.zeros((0, 2))))
        file.createGroup(group, 'application_data')
        file.createGroup(group, 'user_data')
        
        # Create channels.
        file.createGroup(group, 'channels')
        channels = group_info.get('channels', [])
        for channel_idx in channels:
            # channel is the absolute channel index.
            channel = file.createGroup(group.channels, str(channel_idx))
            channel._f_setAttr('name', 'channel_{0:d}'.format(channel_idx))
            
            ############### TODO
            channel._f_setAttr('kwd_index', 0)
            channel._f_setAttr('ignored', False)
            channel._f_setAttr('position', group_info.get('geometry', {}). \
                get(channel_idx, None))
            channel._f_setAttr('voltage_gain', 0.)
            channel._f_setAttr('display_threshold', 0.)
            file.createGroup(channel, 'application_data')
            file.createGroup(channel.application_data, 'spikedetekt')
            file.createGroup(channel.application_data, 'klustaviewa')
            file.createGroup(channel, 'user_data')
            
        # Create spikes.
        spikes = file.createGroup(group, 'spikes')
        file.createEArray(spikes, 'time_samples', tb.UInt64Atom(), (0,),
                          expectedrows=1000000)
        file.createEArray(spikes, 'time_fractional', tb.UInt8Atom(), (0,),
                          expectedrows=1000000)
        file.createEArray(spikes, 'recording', tb.UInt16Atom(), (0,),
                          expectedrows=1000000)
        clusters = file.createGroup(spikes, 'clusters')
        file.createEArray(clusters, 'main', tb.UInt32Atom(), (0,),
                          expectedrows=1000000)
        file.createEArray(clusters, 'original', tb.UInt32Atom(), (0,),
                          expectedrows=1000000)
        
        fm = file.createGroup(spikes, 'features_masks')
        fm._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/features_masks'. \
            format(igroup))
        wr = file.createGroup(spikes, 'waveforms_raw')
        wr._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_raw'. \
            format(igroup))
        wf = file.createGroup(spikes, 'waveforms_filtered')
        wf._f_setAttr('hdf5_path', '{{kwx}}/channel_groups/{0:d}/waveforms_filtered'. \
            format(igroup))
        
        # TODO: add clusters 0, 1, 2, 3 by default
        
        # Create clusters.
        clusters = file.createGroup(group, 'clusters')
        file.createGroup(clusters, 'main')
        file.createGroup(clusters, 'original')
        
        # Create cluster groups.
        cluster_groups = file.createGroup(group, 'cluster_groups')
        file.createGroup(cluster_groups, 'main')
        file.createGroup(cluster_groups, 'original')
        
    # Create recordings.
    file.createGroup('/', 'recordings')
    
    # Create event types.
    file.createGroup('/', 'event_types')
            
    file.close()
示例#6
0
def open_files(name, dir=None, mode=None):
    filenames = get_filenames(name, dir=dir)
    return {type: open_file(filename, mode=mode) 
            for type, filename in iteritems(filenames)}