示例#1
0
    def __init__(self, kwik_path, config_dir=None, **kwargs):
        super(KwikController, self).__init__()
        kwik_path = op.realpath(kwik_path)
        _backup(kwik_path)
        self.model = KwikModel(kwik_path, **kwargs)
        m = self.model
        self.channel_vertical_order = np.argsort(m.channel_positions[:, 1])
        self.distance_max = _get_distance_max(self.model.channel_positions)
        self.cache_dir = op.join(op.dirname(kwik_path), '.phy')
        cg = kwargs.get('channel_group', None)
        if cg is not None:
            self.cache_dir = op.join(self.cache_dir, str(cg))
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir

        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)
示例#2
0
    def set_kwik_file(self, kpath):
        """! @brief Defines the corresponding kwik file

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        self.kwik_model = KwikModel(kpath)
        self.name = self.kwik_model.name
        self.kpath = kpath
示例#3
0
    def __init__(self, kpath=None, name=None):

        self.kwik_model = None
        self.name = name
        self.kpath = None
        if (kpath is not None):
            self.kwik_model = KwikModel(kpath)
            self.kpath = kpath
            if (name is None):
                self.name = self.kwik_model.name
            print("Created class on = %s !" % kpath)
        else:
            print("It still with no path:(")
示例#4
0
 def __init__(self, path, channel_group=None, clustering=None):
     path = op.realpath(op.expanduser(path))
     _backup(path)
     self.path = path
     self.cache_dir = op.join(op.dirname(path), '.phy',
                              str(clustering or 'main'),
                              str(channel_group or 'default'),
                              )
     self.model = KwikModel(path,
                            channel_group=channel_group,
                            clustering=None,
                            )
     super(KwikController, self).__init__()
示例#5
0
 def describe(path, channel_group=0, clustering='main'):
     """Describe a Kwik file."""
     KwikModel(
         path,
         channel_group=channel_group,
         clustering=clustering,
     ).describe()
def autoAnalysis(kwik_fullpath):
	model = KwikModel(kwik_fullpath)
	kwikPath,kwikFile = os.path.split(kwik_fullpath) 
	allsamples = []
	allspikes = []
	for j in model.channel_groups: # for statement walks through different channel_groups (i.e., those on different shanks)
		model.channel_group = j

		samples = model.spike_samples
		spikes = model.spike_clusters
		
		for i,n in enumerate(spikes):
			allsamples.append(samples[i])
			allspikes.append(spikes[i])

	alltimes = np.array(allsamples)/model.sample_rate
	allspikes = np.array(allspikes)
	
	
	#plt.figure(figsize=(15,5))
	#plt.plot(alltimes, allspikes, '|',mew=.5,color=[.5,.5,.5])
	#plt.savefig(kwikPath+'\\allSpikes.png')
	
	[di, ai] = liic.load_intan_input_channels()
	intan_trigger = di['0'][:]
	intan_camera = di['1'][:]
	
	intan_transitions = np.where(intan_trigger[:-1] != intan_trigger[1:])[0]
	
	adjustmentAmount = (intan_transitions[0])
	
	[matlab_trigger, matlab_triggerCMD, matlab_brush] = loadMatlabFile(adjustmentAmount)
	
	plt.figure(figsize=(15,10))
	plt.subplot(211)
	plt.plot(alltimes, allspikes, '|',color =[.5,.5,.5],markersize=5,mew=.4)
	plt.ylabel('unit')
	topxlim = plt.xlim()
	plt.xlim(topxlim)
	plt.subplot(212)
	plt.plot(np.arange(0,len(matlab_brush)/20000,1/20000),matlab_brush,color = [.5,.5,.5],linewidth=.5)
	plt.savefig(kwikPath+'\\brushSpikes.png')
示例#7
0
 def __init__(self, path, channel_group=None, clustering=None):
     path = op.realpath(op.expanduser(path))
     _backup(path)
     self.path = path
     # The cache directory depends on the filename, clustering, and
     # channel_group to avoid conflicts.
     self.cache_dir = op.join(op.dirname(path), '.phy',
                              op.splitext(op.basename(path))[0],
                              str(clustering or 'main'),
                              str(channel_group or 0),
                              )
     self.model = KwikModel(path,
                            channel_group=channel_group,
                            clustering=None,
                            )
     super(KwikController, self).__init__()
示例#8
0
class KwikFile:
    """!  @brief Model for Kwik file, strongly based on KwikModel from phy project

    The main purpose of this class is provide an abstraction for kwik files provided by phy project. The current version contains a basic set of fundamental methods used in kwik file      
    @author: Nivaldo A P de Vasconcelos
    @date: 2018.Feb.02
    """

    #get_path
    def __init__(self, kpath=None, name=None):

        self.kwik_model = None
        self.name = name
        self.kpath = None
        if (kpath is not None):
            self.kwik_model = KwikModel(kpath)
            self.kpath = kpath
            if (name is None):
                self.name = self.kwik_model.name
            print("Created class on = %s !" % kpath)
        else:
            print("It still with no path:(")

    def get_name(self):
        """! @brief Returns the found in name field in kwik file.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        return (self.name)

    def set_kwik_file(self, kpath):
        """! @brief Defines the corresponding kwik file

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        self.kwik_model = KwikModel(kpath)
        self.name = self.kwik_model.name
        self.kpath = kpath

    def sampling_rate(self):
        """! @brief Returns the sampling rate used during the recordings 

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        return (self.kwik_model.sample_rate)

    def shank(self):
        """! @brief Returns the shank/population's id used to group the recordings.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        return (self.kwik_model.name)

    def get_spike_samples(self):
        """! @brief Returns the spike's samples on the recordings.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """

        return (self.kwik_model.spike_samples)

    def get_spike_clusters(self):
        """! @brief Returns the corresponding spike's clusters on the recordings.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        return (self.kwik_model.spike_clusters)

    def describe(self):
        """! @brief Describes the kwik file

        It calls the describe method in KwikModel

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        self.kwik_model.describe()

    def close(self):
        """! @brief Closes the corresponding kwik model

        It calls the close method in KwikModel

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        self.kwik_model.close()

    def list_of_groups(self):
        """! @brief Returns the list of groups found in kwik file

        The result has a list's form.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """

        lgroups = list(self.groups().values())
        lgroups = list(set(lgroups))

        return (lgroups)

    def list_of_non_noisy_groups(self):
        """! @brief Returns the list of groups found in kwik file which are not called noise

        The result has a list's form.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """

        lgroups = list(self.groups().values())
        lgroups = list(set(lgroups) - set([
            'noise',
        ]))
        return (lgroups)

    def all_clusters(self):
        """! @brief Returns the list of all clusters in kwik file

        The result has a list's form.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """

        llabels = list(self.groups().keys())
        llabels = list(set(llabels))

        return (llabels)

    def groups(self):
        """! @brief Returns a dict with cluster label and its respective group

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        if not (isinstance(self.kwik_model, KwikModel)):
            raise ValueError("There is no KwikModel assigned for this object.")
        return (self.kwik_model.cluster_groups)

    def clusters(self, group_name=None):
        """! @brief Returns the list of clusters on kwik file

        It can be used to get the list of clusters for a given group by pproviding
        this information the group_name.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02
        """
        if (group_name is None):
            return (self.all_clusters())

        if not (group_name in self.list_of_groups()):
            raise ValueError("\nThis group was not found in kwik file: %s\n" %
                             group_name)
        group = self.groups()

        clusters = []
        for c in self.all_clusters():
            if (group[c] == group_name):
                clusters.append(c)
        clusters.sort()
        return (clusters)

    def all_spikes_on_groups(self, group_names):
        """! @brief Returns the all spike samples within a list of groups

        Usually the clusters are organized in groups. Ex: noise, mua, sua,
        unsorted This method returns, in a single list of spike samples, all
        spikes found in a lists of groups (group_names). 

        Parameters:
        group_names: list of group names, where the spikes will be searched. 

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02 
        """

        spikes = []
        all_spikes = self.get_spike_samples()
        all_labels = self.get_spike_clusters()

        if not (isinstance(group_names, list)):
            raise ValueError("\nThe argument must be a list.")

        for group_name in group_names:
            if not (group_name in self.list_of_groups()):
                raise ValueError(
                    "\nThis group was not found in kwik file: %s\n" %
                    group_name)
            for c in self.clusters(group_name=group_name):
                spikes = spikes + list(all_spikes[all_labels == c])
        spikes.sort()
        return (spikes)

    def all_spike_id_on_groups(self, group_names):
        """! @brief Returns the all spike id within a list of groups

        Usually the clusters are organized in groups. Ex: noise, mua, sua,
        unsorted This method returns, in a single list of spike samples, all
        spikes found in a lists of groups (group_names). 

        Parameters:
        group_names: list of group names, where the spike ids will be searched. 

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Jun.05
        """

        spk_id = []
        all_spk_id = self.kwik_model.spike_ids
        all_labels = self.get_spike_clusters()

        if not (isinstance(group_names, list)):
            raise ValueError("\nThe argument must be a list.")

        for group_name in group_names:
            if not (group_name in self.list_of_groups()):
                raise ValueError(
                    "\nThis group was not found in kwik file: %s\n" %
                    group_name)
            for c in self.clusters(group_name=group_name):
                spk_id = spk_id + list(all_spk_id[all_labels == c])
        spk_id.sort()
        return (spk_id)

    def all_spike_id_on_cluster(self, cluster_id):

        if not (cluster_id in self.all_clusters()):
            raise ValueError(
                "\nThis cluster was not found in kwik file: %s\n" % cluster_id)
        all_spk_id = self.kwik_model.spike_ids
        all_labels = self.get_spike_clusters()
        spk_id = list(all_spk_id[all_labels == cluster_id])
        spk_id.sort()
        return (spk_id)

    def group_firing_rate(self, group_names=None, a=None, b=None):
        """! @brief Returns firing rate in a given set of groups found in kwik file.

        Usually, the clusters are organized in groups. Ex: noise, mua, sua,
        unsorted. This method returns, in a doubled dictionary, the firing rate
        for each cluster, organized by groups.

        Parameters: 
        group_names: list of group names, where the spikes will be
        searched. When this input is 'None' all groups are taken. The resulting
        dictionary has the first keys as groups, and the second keys as the
        respective cluster id's, whereas the value, is the corresponding firing
        rate within [a,b].


        Please refer to the method cluster_firing_rate in order to get more 
        details about the firing calculation.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02 
        """
        if not (isinstance(group_names, list)) and not (group_names is None):
            raise ValueError("\nThe argument must be a list or a None.")
        spk = dict()
        if group_names is None:
            group_names = self.list_of_non_noisy_groups()
        for group_name in group_names:
            if not (group_name in self.list_of_groups()):
                raise ValueError(
                    "\nThis group was not found in kwik file: %s\n" %
                    group_name)
            spk[group_name] = dict()
            for c in self.clusters(group_name=group_name):
                spk[group_name][c] = self.cluster_firing_rate(c, a=a, b=b)
        return (spk)

    def cluster_firing_rate(self, cluster_id, a=None, b=None):
        """! @brief Returns firing rate in a given cluster_id found in the kwik file

        In the kwik file, a cluster stores the spike times sorted for a given neuronal
        unit. The firing rate here is calculated by dividing the number of spike times
        by the number of seconds of the time period definedd by [a,b]. 
        If a is 'None' a is assingned to zero; if b is 'None', it is assigned to the time
        of the last spike within the cluster.

        Parameters:
        cluster_id: id which identifies the cluster.
        a,b: limits of the time period where the firing rate must be calculated.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02 
        """
        sr = self.sampling_rate()
        spikes = np.array(self.spikes_on_cluster(cluster_id)) / sr
        if a is None:
            a = 0
        if b is None:
            b = spikes[-1]
        if (a == b):
            raise ValueError("\nThe limits of the time interval are equal\n")
        piece = spikes[(spikes >= a)]
        piece = piece[piece <= b]
        return (len(piece) / (b - a))

    def spikes_on_cluster(self, cluster_id):
        """! @brief Returns the all spike samples within a single cluster

        Parameters:
        cluster_id: id used to indentify the cluster.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02 
        """

        if not (cluster_id in self.all_clusters()):
            raise ValueError(
                "\nThis cluster was not found in kwik file: %s\n" % cluster_id)
        all_spikes = self.get_spike_samples()
        all_labels = self.get_spike_clusters()
        spikes = list(all_spikes[all_labels == cluster_id])
        spikes.sort()
        return (spikes)

    def group_firing_rate_to_dataframe(self, group_names=None, a=None, b=None):
        """! @brief Exports the group's firing rate into a pandas dataframe


        Usually, the clusters are organized in groups. Ex: noise, mua, sua,
        unsorted. This method returns, in a pandas dataframe, which contains the
        following information for each unit: 'shank', 'group', 'label', and 'fr';

        
        Parameters:
        group_names: list of group names, where the spikes will be
        searched. When this input is 'None' all groups are taken. The resulting
        dictionary has the first keys as groups, and the second keys as the
        respective cluster id's, whereas the value, is the corresponding firing
        rate within [a,b].

        Please refer to the method cluster_firing_rate in order to get more 
        details about the firing calculation.

        Author: Nivaldo A P de Vasconcelos
        Date: 2018.Feb.02 
        """
        d = self.group_firing_rate(group_names=group_names, a=a, b=b)

        shank_id = self.name
        group_names = d.keys()
        data = []
        for group_name in group_names:
            for label in d[group_name].keys():
                fr = d[group_name][label]
                data.append({
                    "shank_id": shank_id,
                    "group": group_name,
                    "label": label,
                    "fr": fr
                })
        return (pd.DataFrame(data))
示例#9
0
class KwikController(Controller):
    gui_name = 'KwikGUI'

    def __init__(self, path, channel_group=None, clustering=None):
        path = op.realpath(op.expanduser(path))
        _backup(path)
        self.path = path
        self.cache_dir = op.join(
            op.dirname(path),
            '.phy',
            str(clustering or 'main'),
            str(channel_group or 'default'),
        )
        self.model = KwikModel(
            path,
            channel_group=channel_group,
            clustering=None,
        )
        super(KwikController, self).__init__()

    def _init_data(self):
        m = self.model
        self.spike_times = m.spike_times

        self.spike_clusters = m.spike_clusters
        self.cluster_groups = m.cluster_groups
        self.cluster_ids = m.cluster_ids

        self.channel_positions = m.channel_positions
        self.n_samples_waveforms = m.n_samples_waveforms
        self.n_channels = m.n_channels
        self.n_features_per_channel = m.n_features_per_channel
        self.sample_rate = m.sample_rate
        self.duration = m.duration

        self.all_masks = m.all_masks
        self.all_waveforms = m.all_waveforms
        self.all_features = m.all_features
        # WARNING: m.all_traces contains the dead channels, m.traces doesn't.
        # Also, m.traces has the reordered channels as per the prb file.
        self.all_traces = m.traces

    def create_gui(self, config_dir=None):
        """Create the kwik GUI."""
        f = super(KwikController, self).create_gui
        gui = f(
            name=self.gui_name,
            subtitle=self.path,
            config_dir=config_dir,
        )

        @self.manual_clustering.actions.add
        def recluster():
            """Relaunch KlustaKwik on the selected clusters."""
            # Selected clusters.
            cluster_ids = self.manual_clustering.selected
            spike_ids = self.selector.select_spikes(cluster_ids)
            logger.info("Running KlustaKwik on %d spikes.", len(spike_ids))
            spike_clusters, metadata = cluster(
                self.model,
                spike_ids,
                num_starting_clusters=10,
            )
            self.manual_clustering.split(spike_ids, spike_clusters)

        # Save.
        @gui.connect_
        def on_request_save(spike_clusters, groups):
            groups = {c: g.title() for c, g in groups.items()}
            self.model.save(spike_clusters, groups)

        return gui
示例#10
0
def kwik_describe(path, channel_group=None, clustering=None):
    """Describe a template dataset."""
    assert path
    KwikModel(path, channel_group=channel_group,
              clustering=clustering).describe()
示例#11
0
class KwikController(Controller):
    gui_name = 'KwikGUI'

    def __init__(self, path, channel_group=None, clustering=None):
        path = op.realpath(op.expanduser(path))
        _backup(path)
        self.path = path
        # The cache directory depends on the filename, clustering, and
        # channel_group to avoid conflicts.
        self.cache_dir = op.join(op.dirname(path), '.phy',
                                 op.splitext(op.basename(path))[0],
                                 str(clustering or 'main'),
                                 str(channel_group or 0),
                                 )
        self.model = KwikModel(path,
                               channel_group=channel_group,
                               clustering=None,
                               )
        super(KwikController, self).__init__()

    def _init_data(self):
        m = self.model
        self.spike_times = m.spike_times

        self.spike_clusters = m.spike_clusters
        self.cluster_groups = m.cluster_groups
        self.cluster_ids = m.cluster_ids

        self.channel_positions = m.channel_positions
        self.channel_order = m.channel_order
        self.n_samples_waveforms = m.n_samples_waveforms
        self.n_channels = m.n_channels
        self.n_features_per_channel = m.n_features_per_channel
        self.sample_rate = m.sample_rate
        self.duration = m.duration

        self.all_masks = m.all_masks
        self.all_waveforms = m.all_waveforms
        self.all_features = m.all_features
        # WARNING: m.all_traces contains the dead channels, m.traces doesn't.
        # Also, m.traces has the reordered channels as per the prb file.
        self.all_traces = m.traces

    def create_gui(self, config_dir=None):
        """Create the kwik GUI."""
        f = super(KwikController, self).create_gui
        gui = f(name=self.gui_name,
                subtitle=self.path,
                config_dir=config_dir,
                )

        @self.manual_clustering.actions.add
        def recluster():
            """Relaunch KlustaKwik on the selected clusters."""
            # Selected clusters.
            cluster_ids = self.manual_clustering.selected
            spike_ids = self.selector.select_spikes(cluster_ids)
            logger.info("Running KlustaKwik on %d spikes.", len(spike_ids))

            # Run KK2 in a temporary directory to avoid side effects.
            with TemporaryDirectory() as tempdir:
                spike_clusters, metadata = cluster(self.model,
                                                   spike_ids,
                                                   num_starting_clusters=10,
                                                   tempdir=tempdir,
                                                   )
            self.manual_clustering.split(spike_ids, spike_clusters)

        # Save.
        @gui.connect_
        def on_request_save(spike_clusters, groups):
            groups = {c: g.title() for c, g in groups.items()}
            self.model.save(spike_clusters, groups)

        return gui
示例#12
0
class KwikController(EventEmitter):
    gui_name = 'KwikGUI'

    n_spikes_waveforms = 100
    batch_size_waveforms = 10

    n_spikes_features = 10000
    n_spikes_amplitudes = 10000

    n_spikes_close_clusters = 100
    n_closest_channels = 16

    def __init__(self, kwik_path, config_dir=None, **kwargs):
        super(KwikController, self).__init__()
        kwik_path = op.realpath(kwik_path)
        _backup(kwik_path)
        self.model = KwikModel(kwik_path, **kwargs)
        m = self.model
        self.channel_vertical_order = np.argsort(m.channel_positions[:, 1])
        self.distance_max = _get_distance_max(self.model.channel_positions)
        self.cache_dir = op.join(op.dirname(kwik_path), '.phy')
        cg = kwargs.get('channel_group', None)
        if cg is not None:
            self.cache_dir = op.join(self.cache_dir, str(cg))
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir

        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)

    # Internal methods
    # -------------------------------------------------------------------------

    def _set_cache(self):
        memcached = (
            'get_best_channels',
            'get_probe_depth',
            '_get_mean_masks',
            '_get_mean_waveforms',
        )
        cached = (
            '_get_waveforms',
            '_get_features',
            '_get_masks',
        )
        _cache_methods(self, memcached, cached)

    def _set_supervisor(self):
        # Load the new cluster id.
        new_cluster_id = self.context.load('new_cluster_id'). \
            get('new_cluster_id', None)
        cluster_groups = self.model.cluster_groups
        supervisor = Supervisor(
            self.model.spike_clusters,
            similarity=self.similarity,
            cluster_groups=cluster_groups,
            new_cluster_id=new_cluster_id,
            context=self.context,
        )

        @supervisor.connect
        def on_create_cluster_views():

            supervisor.add_column(self.get_best_channel, name='channel')
            supervisor.add_column(self.get_probe_depth, name='depth')

            @supervisor.actions.add
            def recluster():
                """Relaunch KlustaKwik on the selected clusters."""
                # Selected clusters.
                cluster_ids = supervisor.selected
                spike_ids = self.selector.select_spikes(cluster_ids)
                logger.info("Running KlustaKwik on %d spikes.", len(spike_ids))

                # Run KK2 in a temporary directory to avoid side effects.
                n = 10
                with TemporaryDirectory() as tempdir:
                    spike_clusters, metadata = cluster(
                        self.model,
                        spike_ids,
                        num_starting_clusters=n,
                        tempdir=tempdir,
                    )
                self.supervisor.split(spike_ids, spike_clusters)

        # Save.
        @supervisor.connect
        def on_request_save(spike_clusters, groups, *labels):
            """Save the modified data."""
            groups = {c: g.title() for c, g in groups.items()}
            self.model.save(spike_clusters, groups)

        return supervisor

    def _set_selector(self):
        def spikes_per_cluster(cluster_id):
            return self.supervisor.clustering.spikes_per_cluster[cluster_id]

        return Selector(spikes_per_cluster)

    def _add_view(self, gui, view):
        view.attach(gui)
        self.emit('add_view', gui, view)
        return view

    # Model methods
    # -------------------------------------------------------------------------

    def get_best_channel(self, cluster_id):
        return self.get_best_channels(cluster_id)[0]

    def get_best_channels(self, cluster_id):
        """Only used in the trace view."""
        mm = self._get_mean_masks(cluster_id)
        channel_ids = np.argsort(mm)[::-1]
        ind = mm[channel_ids] > .1
        if np.sum(ind) > 0:
            channel_ids = channel_ids[ind]
        else:
            channel_ids = channel_ids[:4]
        return channel_ids

    def get_cluster_position(self, cluster_id):
        channel_id = self.get_best_channel(cluster_id)
        return self.model.channel_positions[channel_id]

    def get_probe_depth(self, cluster_id):
        return self.get_cluster_position(cluster_id)[1]

    def similarity(self, cluster_id):
        """Return the list of similar clusters to a given cluster."""

        pos_i = self.get_cluster_position(cluster_id)
        assert len(pos_i) == 2

        def _sim_ij(cj):
            """Distance between channel position of clusters i and j."""
            pos_j = self.get_cluster_position(cj)
            assert len(pos_j) == 2
            d = np.sqrt(np.sum((pos_j - pos_i)**2))
            return self.distance_max - d

        out = [(cj, _sim_ij(cj))
               for cj in self.supervisor.clustering.cluster_ids]
        return sorted(out, key=itemgetter(1), reverse=True)

    # Waveforms
    # -------------------------------------------------------------------------

    def _get_masks(self, cluster_id):
        spike_ids = self.selector.select_spikes(
            [cluster_id],
            self.n_spikes_waveforms,
            self.batch_size_waveforms,
        )
        return self.model.all_masks[spike_ids]

    def _get_mean_masks(self, cluster_id):
        return np.mean(self._get_masks(cluster_id), axis=0)

    def _get_waveforms(self, cluster_id):
        """Return a selection of waveforms for a cluster."""
        pos = self.model.channel_positions
        spike_ids = self.selector.select_spikes(
            [cluster_id],
            self.n_spikes_waveforms,
            self.batch_size_waveforms,
        )
        data = self.model.all_waveforms[spike_ids]
        mm = self._get_mean_masks(cluster_id)
        mw = np.mean(data, axis=0)
        amp = get_waveform_amplitude(mm, mw)
        masks = self._get_masks(cluster_id)
        # Find the best channels.
        channel_ids = np.argsort(amp)[::-1]
        return Bunch(
            data=data[..., channel_ids],
            channel_ids=channel_ids,
            channel_positions=pos[channel_ids],
            masks=masks[:, channel_ids],
        )

    def _get_mean_waveforms(self, cluster_id):
        b = self._get_waveforms(cluster_id).copy()
        b.data = np.mean(b.data, axis=0)[np.newaxis, ...]
        b.masks = np.mean(b.masks, axis=0)[np.newaxis, ...]**.1
        b['alpha'] = 1.
        return b

    def add_waveform_view(self, gui):
        v = WaveformView(waveforms=self._get_waveforms, )
        v = self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='m')
        def toggle_mean_waveforms():
            f, g = self._get_waveforms, self._get_mean_waveforms
            v.waveforms = f if v.waveforms == g else g
            v.on_select()

        return v

    # Features
    # -------------------------------------------------------------------------

    def _get_spike_ids(self, cluster_id=None, load_all=None):
        nsf = self.n_spikes_features
        if cluster_id is None:
            # Background points.
            ns = self.model.n_spikes
            return np.arange(0, ns, max(1, ns // nsf))
        else:
            # Load all spikes from the cluster if load_all is True.
            n = nsf if not load_all else None
            return self.selector.select_spikes([cluster_id], n)

    def _get_spike_times(self, cluster_id=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        return Bunch(data=self.model.spike_times[spike_ids],
                     lim=(0., self.model.duration))

    def _get_features(self, cluster_id=None, channel_ids=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        # Use the best channels only if a cluster is specified and
        # channels are not specified.
        if cluster_id is not None and channel_ids is None:
            channel_ids = self.get_best_channels(cluster_id)
        f = self.model.all_features[spike_ids][:, channel_ids]
        m = self.model.all_masks[spike_ids][:, channel_ids]
        return Bunch(
            data=f,
            masks=m,
            spike_ids=spike_ids,
            channel_ids=channel_ids,
        )

    def add_feature_view(self, gui):
        v = FeatureView(features=self._get_features,
                        attributes={'time': self._get_spike_times})
        return self._add_view(gui, v)

    # Traces
    # -------------------------------------------------------------------------

    def _get_traces(self, interval):
        """Get traces and spike waveforms."""
        ns = self.model.n_samples_waveforms
        m = self.model
        c = self.channel_vertical_order

        traces_interval = select_traces(m.traces,
                                        interval,
                                        sample_rate=m.sample_rate)
        # Reorder vertically.
        traces_interval = traces_interval[:, c]

        def gbc(cluster_id):
            ch = self.get_best_channels(cluster_id)
            return ch

        out = Bunch(data=traces_interval)
        out.waveforms = []
        for b in _iter_spike_waveforms(
                interval=interval,
                traces_interval=traces_interval,
                model=self.model,
                supervisor=self.supervisor,
                color_selector=self.color_selector,
                n_samples_waveforms=ns,
                get_best_channels=gbc,
                show_all_spikes=self._show_all_spikes,
        ):
            b.channel_labels = m.channel_order[b.channel_ids]
            out.waveforms.append(b)
        return out

    def _jump_to_spike(self, view, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        m = self.model
        cluster_ids = self.supervisor.selected
        if len(cluster_ids) == 0:
            return
        spc = self.supervisor.clustering.spikes_per_cluster
        spike_ids = spc[cluster_ids[0]]
        spike_times = m.spike_times[spike_ids]
        ind = np.searchsorted(spike_times, view.time)
        n = len(spike_times)
        view.go_to(spike_times[(ind + delta) % n])

    def add_trace_view(self, gui):
        m = self.model
        v = TraceView(
            traces=self._get_traces,
            n_channels=m.n_channels,
            sample_rate=m.sample_rate,
            duration=m.duration,
            channel_vertical_order=self.channel_vertical_order,
        )
        self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='alt+pgdown')
        def go_to_next_spike():
            """Jump to the next spike from the first selected cluster."""
            self._jump_to_spike(v, +1)

        @v.actions.add(shortcut='alt+pgup')
        def go_to_previous_spike():
            """Jump to the previous spike from the first selected cluster."""
            self._jump_to_spike(v, -1)

        v.actions.separator()

        @v.actions.add(shortcut='alt+s')
        def toggle_highlighted_spikes():
            """Toggle between showing all spikes or selected spikes."""
            self._show_all_spikes = not self._show_all_spikes
            v.set_interval(force_update=True)

        @gui.connect_
        def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
            # Select the corresponding cluster.
            self.supervisor.select([cluster_id])
            # Update the trace view.
            v.on_select([cluster_id], force_update=True)

        return v

    # Correlograms
    # -------------------------------------------------------------------------

    def _get_correlograms(self, cluster_ids, bin_size, window_size):
        spike_ids = self.selector.select_spikes(cluster_ids, 100000)
        st = self.model.spike_times[spike_ids]
        sc = self.supervisor.clustering.spike_clusters[spike_ids]
        return correlograms(
            st,
            sc,
            sample_rate=self.model.sample_rate,
            cluster_ids=cluster_ids,
            bin_size=bin_size,
            window_size=window_size,
        )

    def add_correlogram_view(self, gui):
        m = self.model
        v = CorrelogramView(
            correlograms=self._get_correlograms,
            sample_rate=m.sample_rate,
        )
        return self._add_view(gui, v)

    # GUI
    # -------------------------------------------------------------------------

    def create_gui(self, **kwargs):
        gui = GUI(name=self.gui_name,
                  subtitle=self.model.kwik_path,
                  config_dir=self.config_dir,
                  **kwargs)

        self.supervisor.attach(gui)

        self.add_waveform_view(gui)
        if self.model.traces is not None:
            self.add_trace_view(gui)
        self.add_feature_view(gui)
        self.add_correlogram_view(gui)

        # Save the memcache when closing the GUI.
        @gui.connect_
        def on_close():
            self.context.save_memcache()

        self.emit('gui_ready', gui)

        return gui
示例#13
0
文件: gui.py 项目: ablot/phy-contrib
class KwikController(Controller):
    gui_name = 'KwikGUI'

    def __init__(self, path, channel_group=None, clustering=None):
        path = op.realpath(op.expanduser(path))
        _backup(path)
        self.path = path
        self.cache_dir = op.join(op.dirname(path), '.phy',
                                 str(clustering or 'main'),
                                 str(channel_group or 'default'),
                                 )
        self.model = KwikModel(path,
                               channel_group=channel_group,
                               clustering=None,
                               )
        super(KwikController, self).__init__()

    def _init_data(self):
        m = self.model
        self.spike_times = m.spike_times

        self.spike_clusters = m.spike_clusters
        self.cluster_groups = m.cluster_groups
        self.cluster_ids = m.cluster_ids

        self.channel_positions = m.channel_positions
        self.n_samples_waveforms = m.n_samples_waveforms
        self.n_channels = m.n_channels
        self.n_features_per_channel = m.n_features_per_channel
        self.sample_rate = m.sample_rate
        self.duration = m.duration

        self.all_masks = m.all_masks
        self.all_waveforms = m.all_waveforms
        self.all_features = m.all_features
        self.all_traces = m.all_traces

    def create_gui(self, config_dir=None):
        """Create the kwik GUI."""
        f = super(KwikController, self).create_gui
        gui = f(name=self.gui_name,
                subtitle=self.path,
                config_dir=config_dir,
                )

        @self.manual_clustering.actions.add
        def recluster():
            """Relaunch KlustaKwik on the selected clusters."""
            # Selected clusters.
            cluster_ids = self.manual_clustering.selected
            spike_ids = self.selector.select_spikes(cluster_ids)
            logger.info("Running KlustaKwik on %d spikes.", len(spike_ids))
            spike_clusters, metadata = cluster(self.model,
                                               spike_ids,
                                               num_starting_clusters=10,
                                               )
            self.manual_clustering.split(spike_ids, spike_clusters)

        # Save.
        @gui.connect_
        def on_request_save(spike_clusters, groups):
            groups = {c: g.title() for c, g in groups.items()}
            self.model.save(spike_clusters, groups)

        return gui