Beispiel #1
0
    def triage_wf(self, wf_tf, threshold):
        """
            Run neural net triage on given spike waveforms

            Parameters:
            -----------
            wf_tf: tf tensor (n_spikes, n_temporal_length, n_neigh)
                tf tensor that produces spikes waveforms

            threshold: int
                threshold used on a probability obtained after nn to determine
                whether it is a clear spike

            Returns:
            -----------
            tf tensor (n_spikes,)
                a boolean tensorflow tensor that produces indices of
                clear spikes
        """
        # get parameters
        K1, K2 = self.filters_dict['filters']

        # first layer: temporal feature
        layer1 = tf.nn.relu(
            conv2d_VALID(tf.expand_dims(wf_tf, -1), self.W1) + self.b1)

        # second layer: feataure mapping
        layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11)

        # third layer: spatial convolution
        o_layer = conv2d_VALID(layer11, self.W2) + self.b2

        # thrshold it
        return o_layer[:, 0, 0, 0] > np.log(threshold / (1 - threshold))
Beispiel #2
0
    def _make_training_graph(cls, waveform_length, n_neighbors, vars_dict):
        """Make graph for training

        Returns
        -------
        x_tf: tf.tensor
            Input tensor
        y_tf: tf.tensor
            Labels tensor
        o_layer: tf.tensor
            Output tensor
        """
        # x and y input tensors
        x_tf = tf.placeholder("float", [None, waveform_length, n_neighbors])
        y_tf = tf.placeholder("float", [None])

        input_tf = tf.expand_dims(x_tf, -1)

        vars_dict, layer11 = (NeuralNetDetector._make_network(input_tf,
                                                              vars_dict,
                                                              padding='VALID'))

        W2 = vars_dict['W2']
        b2 = vars_dict['b2']

        # third layer: spatial convolution
        o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2)

        # sigmoid
        sigmoid = tf.sigmoid(o_layer)

        return x_tf, y_tf, o_layer, sigmoid
Beispiel #3
0
    def make_o_layer_tf_tensors(self, x_tf, channel_index):
        """
        Make a tensorflow tensor that outputs spike index

        Parameters
        -----------
        x_tf: tf.tensors (n_observations, n_channels)
            placeholder of recording for running tensorflow

        channel_index: np.array (n_channels, n_neigh)
            Each row indexes its neighboring channels.
            For example, channel_index[c] is the index of
            neighboring channels (including itself)
            If any value is equal to n_channels, it is nothing but
            a space holder in a case that a channel has less than
            n_neigh neighboring channels

        Returns
        -------
        output_tf: tf tensor (n_observations, n_channels)
            tensorflow tensor that produces spike_index
        """

        # get parameters
        K1, K2 = self.filters_dict['filters']
        nneigh = self.filters_dict['n_neighbors']

        # save neighbor channel index
        self.channel_index = channel_index[:, :nneigh]

        # Temporal shape of input
        T = tf.shape(x_tf)[0]

        # input tensor into CNN
        x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0)

        # NN structures
        # first temporal layer
        layer1 = tf.nn.relu(conv2d(x_cnn_tf, self.W1) + self.b1)

        # second temporal layer
        layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11)

        # first spatial layer
        zero_added_layer11 = tf.concat(
            (tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))),
            axis=0)
        temp = tf.transpose(tf.gather(zero_added_layer11, self.channel_index),
                            [0, 2, 3, 1, 4])
        temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, nneigh, K2]),
                             self.W2) + self.b2

        # output layer
        # o_layer: [1, temporal, spatial, 1]
        o_layer = tf.transpose(temp2, [2, 1, 0, 3])[0, :, :, 0]
        output_tf = tf.sigmoid(o_layer)

        return output_tf
Beispiel #4
0
    def get_spikes(self, x_tf, T, nneigh, c_idx, temporal_window, th):
        """
            Detects and indexes spikes from the recording. The recording will
            be chopped to minibatches if its temporal length
            exceeds 10000. A spike is detected at [t, c] when the output
            probability of the neural network detector crosses
            the detection threshold at time t and channel c. For temporal
            duplicates within a certain temporal radius,
            the temporal index corresponding to the largest output probability
            is assigned. For spatial duplicates within
            certain neighboring channels, the channel with the highest energy
            is assigned.

            Parameters:
            -----------
            X: np.array
                [number of channels, temporal length] raw recording.

            Returns:
            -----------
            index: np.array
                [number of detected spikes, 3] returned indices for spikes.
                First column corresponds to temporal location;
                second column corresponds to spatial (channel) location.

        """
        # get parameters
        K1, K2 = self.filters_dict['filters']

        # NN structures
        layer1 = tf.nn.relu(
            conv2d(tf.expand_dims(tf.expand_dims(x_tf, -1), 0), self.W1) +
            self.b1)
        layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11)
        zero_added_layer11 = tf.concat(
            (tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))),
            axis=0)
        temp = tf.transpose(tf.gather(zero_added_layer11, c_idx),
                            [0, 2, 3, 1, 4])
        temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, nneigh, K2]),
                             self.W2) + self.b2
        o_layer = tf.transpose(temp2, [2, 1, 0, 3])

        temporal_max = max_pool(o_layer, [1, temporal_window, 1, 1])
        local_max_idx = tf.where(
            tf.logical_and(o_layer[0, :, :, 0] >= temporal_max[0, :, :, 0],
                           o_layer[0, :, :, 0] > np.log(th / (1 - th))))

        return local_max_idx
Beispiel #5
0
    def _make_network(cls, input_tensor, filters_size, waveform_length,
                      n_neighbors):
        """Mates tensorflow network, from first layer to output layer
        """
        K1, K2 = filters_size

        # initialize and save nn weights
        W1 = weight_variable([waveform_length, 1, 1, K1])
        b1 = bias_variable([K1])

        W11 = weight_variable([1, 1, K1, K2])
        b11 = bias_variable([K2])

        W2 = weight_variable([1, n_neighbors, K2, 1])
        b2 = bias_variable([1])

        # first layer: temporal feature
        layer1 = tf.nn.relu(
            conv2d_VALID(tf.expand_dims(input_tensor, -1), W1) + b1)

        # second layer: feataure mapping
        layer11 = tf.nn.relu(conv2d(layer1, W11) + b11)

        # third layer: spatial convolution
        o_layer = conv2d_VALID(layer11, W2) + b2

        vars_dict = {
            "W1": W1,
            "W11": W11,
            "W2": W2,
            "b1": b1,
            "b11": b11,
            "b2": b2
        }

        return o_layer, vars_dict
Beispiel #6
0
    def make_detection_tf_tensors(self, x_tf, channel_index, threshold):
        """
        Make a tensorflow tensor that outputs spike index

        Parameters
        -----------
        x_tf: tf.tensors (n_observations, n_channels)
            placeholder of recording for running tensorflow

        channel_index: np.array (n_channels, n_neigh)
            Each row indexes its neighboring channels.
            For example, channel_index[c] is the index of
            neighboring channels (including itself)
            If any value is equal to n_channels, it is nothing but
            a space holder in a case that a channel has less than
            n_neigh neighboring channels

        threshold: int
            threshold on a probability to determine
            location of spikes

        Returns
        -------
        spike_index_tf: tf tensor (n_spikes, 2)
            tensorflow tensor that produces spike_index
        """

        # get parameters
        K1, K2 = self.filters_dict['filters']
        nneigh = self.filters_dict['n_neighbors']

        # save neighbor channel index
        self.channel_index = channel_index[:, :nneigh]

        # Temporal shape of input
        T = tf.shape(x_tf)[0]

        # input tensor into CNN
        x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0)

        # NN structures
        # first temporal layer
        layer1 = tf.nn.relu(conv2d(x_cnn_tf, self.W1) + self.b1)

        # second temporal layer
        layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11)

        # first spatial layer
        zero_added_layer11 = tf.concat(
            (tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))),
            axis=0)
        temp = tf.transpose(tf.gather(zero_added_layer11, self.channel_index),
                            [0, 2, 3, 1, 4])
        temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, nneigh, K2]),
                             self.W2) + self.b2

        # output layer
        o_layer = tf.transpose(temp2, [2, 1, 0, 3])

        # temporal max
        temporal_max = max_pool(o_layer, [1, 3, 1, 1]) - 1e-8

        # spike index is local maximum crossing a threshold
        spike_index_tf = tf.cast(
            tf.where(
                tf.logical_and(
                    o_layer[0, :, :, 0] >= temporal_max[0, :, :, 0],
                    o_layer[0, :, :, 0] > np.log(threshold /
                                                 (1 - threshold)))), 'int32')

        return spike_index_tf
Beispiel #7
0
    def fit(self, x_train, y_train):
        """
        Trains the neural network detector for spike detection

        Parameters:
        -----------
        x_train: np.array
            [number of training data, temporal length, number of channels]
            augmented training data consisting of
            isolated spikes, noise and misaligned spikes.
        y_train: np.array
            [number of training data] label for x_train. '1' denotes presence
            of an isolated spike and '0' denotes
            the presence of a noise data or misaligned spike.
        path_to_model: string
            name of the .ckpt to be saved
        """
        ######################
        # Loading parameters #
        ######################

        logger = logging.getLogger(__name__)

        # get parameters
        n_data, waveform_length_train, n_neighbors_train = x_train.shape

        if self.waveform_length != waveform_length_train:
            raise ValueError('waveform length from network ({}) does not '
                             'match training data ({})'
                             .format(self.waveform_length,
                                     waveform_length_train))

        if self.n_neighbors != n_neighbors_train:
            raise ValueError('number of n_neighbors from network ({}) does '
                             'not match training data ({})'
                             .format(self.n_neigh,
                                     n_neighbors_train))

        ####################
        # Building network #
        ####################

        # x and y input tensors
        x_tf = tf.placeholder("float", [self.n_batch, self.waveform_length,
                                        self.n_neighbors])
        y_tf = tf.placeholder("float", [self.n_batch])

        input_tf = tf.expand_dims(x_tf, -1)

        vars_dict, layer11 = (NeuralNetDetector
                              ._make_network(input_tf,
                                             self.waveform_length,
                                             self.filters_size,
                                             self.n_neighbors,
                                             padding='VALID'))

        W2 = vars_dict['W2']
        b2 = vars_dict['b2']

        # third layer: spatial convolution
        o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2)

        ##########################
        # Optimization objective #
        ##########################

        # cross entropy
        _ = tf.nn.sigmoid_cross_entropy_with_logits(logits=o_layer,
                                                    labels=y_tf)
        cross_entropy = tf.reduce_mean(_)

        weights = tf.trainable_variables()

        # regularization term
        l2_regularizer = (tf.contrib.layers
                          .l2_regularizer(scale=self.l2_reg_scale))

        regularization = tf.contrib.layers.apply_regularization(l2_regularizer,
                                                                weights)

        regularized_loss = cross_entropy + regularization

        # train step
        train_step = (tf.train.AdamOptimizer(self.train_step_size)
                        .minimize(regularized_loss))

        ############
        # Training #
        ############

        # saver
        saver = tf.train.Saver(vars_dict)
        logger.debug('Training detector network...')

        with tf.Session() as sess:

            init_op = tf.global_variables_initializer()
            sess.run(init_op)

            pbar = trange(self.n_iter)

            for i in pbar:

                # sample n_batch observations from 0, ..., n_data
                idx_batch = np.random.choice(n_data, self.n_batch,
                                             replace=False)

                res = sess.run([train_step, regularized_loss],
                               feed_dict={x_tf: x_train[idx_batch],
                                          y_tf: y_train[idx_batch]})

                if i % 100 == 0:
                    pbar.set_description('Loss: %s' % res[1])

            logger.debug('Saving network: %s', self.path_to_model)
            saver.save(sess, self.path_to_model)

            # estimate tp and fp with a sample
            idx_batch = np.random.choice(n_data, self.n_batch, replace=False)

            output = sess.run(o_layer, feed_dict={x_tf: x_train[idx_batch]})
            y_test = y_train[idx_batch]

            tp = np.mean(output[y_test == 1] > 0)
            fp = np.mean(output[y_test == 0] > 0)

            logger.debug('Approximate training true positive rate: '
                         + str(tp) + ', false positive rate: ' + str(fp))

        path_to_params = change_extension(self.path_to_model, 'yaml')

        logger.debug('Saving network parameters: %s', path_to_params)
        save_detect_network_params(filters_size=self.filters_size,
                                   waveform_length=self.waveform_length,
                                   n_neighbors=self.n_neighbors,
                                   output_path=path_to_params)
Beispiel #8
0
    def _make_graph(cls, threshold, channel_index, waveform_length,
                    filters_size, n_neigh):
        """Build tensorflow graph with input and two output layers

        Parameters
        -----------
        x_tf: tf.tensors (n_observations, n_channels)
            placeholder of recording for running tensorflow

        channel_index: np.array (n_channels, n_neigh)
            Each row indexes its neighboring channels. For example,
            channel_index[c] is the index of neighboring channels (including
            itself) If any value is equal to n_channels, it is nothing but
            a placeholder in a case that a channel has less than n_neigh
            neighboring channels

        threshold: int
            threshold on a probability to determine
            location of spikes

        Returns
        -------
        spike_index_tf: tf tensor (n_spikes, 2)
            tensorflow tensor that produces spike_index
        """
        ######################
        # Loading parameters #
        ######################

        # TODO: need to ask why we are sending different channel indexes
        # save neighbor channel index
        small_channel_index = channel_index[:, :n_neigh]

        # placeholder for input recording
        x_tf = tf.placeholder("float", [None, None])

        # Temporal shape of input
        T = tf.shape(x_tf)[0]

        ####################
        # Building network #
        ####################

        # input tensor into CNN - add one dimension at the beginning and
        # at the end
        x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0)

        vars_dict, layer11 = cls._make_network(x_cnn_tf,
                                               waveform_length,
                                               filters_size,
                                               n_neigh,
                                               padding='SAME')
        W2 = vars_dict['W2']
        b2 = vars_dict['b2']

        K1, K2 = filters_size

        # first spatial layer
        zero_added_layer11 = tf.concat((tf.transpose(layer11, [2, 0, 1, 3]),
                                        tf.zeros((1, 1, T, K2))),
                                       axis=0)

        temp = tf.transpose(tf.gather(zero_added_layer11, small_channel_index),
                            [0, 2, 3, 1, 4])

        temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, n_neigh, K2]), W2) + b2

        o_layer = tf.transpose(temp2, [2, 1, 0, 3])

        ################################
        # Output layer transformations #
        ################################

        o_layer_val = tf.squeeze(o_layer)

        # probability output - just sigmoid of output layer
        probability_tf = tf.sigmoid(o_layer_val)

        # spike index output (local maximum crossing a threshold)
        temporal_max = tf.squeeze(max_pool(o_layer, [1, 3, 1, 1]) - 1e-8)

        higher_than_max_pool = o_layer_val >= temporal_max

        higher_than_threshold = (o_layer_val >
                                 np.log(threshold / (1 - threshold)))

        both_higher = tf.logical_and(higher_than_max_pool,
                                     higher_than_threshold)

        index_all = tf.cast(tf.where(both_higher), 'int32')

        spike_index_tf = cls._remove_edge_spikes(x_tf, index_all,
                                                 waveform_length)

        # waveform output from spike index output
        waveform_tf = cls._make_waveform_tf(x_tf, spike_index_tf,
                                            channel_index, waveform_length)

        return x_tf, spike_index_tf, probability_tf, waveform_tf, vars_dict
Beispiel #9
0
def train_triage(x_train, y_train, n_filters, n_iter, n_batch, l2_reg_scale,
                 train_step_size, nn_name):
    """
        Trains the triage network

        Parameters:
        -----------
        x_train: np.array
            [number of data, temporal length, number of channels] training data
            for the triage network.
        y_train: np.array
            [number of data] training label for the triage network.
        nn_name: string
            name of the .ckpt to be saved.
    """
    # get parameters
    ndata, R, C = x_train.shape
    K1, K2 = n_filters

    # x and y input tensors
    x_tf = tf.placeholder("float", [n_batch, R, C])
    y_tf = tf.placeholder("float", [n_batch])

    # first layer: temporal feature
    W1 = weight_variable([R, 1, 1, K1])
    b1 = bias_variable([K1])
    layer1 = tf.nn.relu(conv2d_VALID(tf.expand_dims(x_tf, -1), W1) + b1)

    # second layer: feataure mapping
    W11 = weight_variable([1, 1, K1, K2])
    b11 = bias_variable([K2])
    layer11 = tf.nn.relu(conv2d(layer1, W11) + b11)

    # third layer: spatial convolution
    W2 = weight_variable([1, C, K2, 1])
    b2 = bias_variable([1])
    o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2)

    # cross entropy
    cross_entropy = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=o_layer, labels=y_tf))

    # regularization term
    weights = tf.trainable_variables()
    l2_regularizer = tf.contrib.layers.l2_regularizer(scale=l2_reg_scale)
    regularization_penalty = tf.contrib.layers.apply_regularization(
        l2_regularizer, weights)
    regularized_loss = cross_entropy + regularization_penalty

    # train step
    train_step = tf.train.AdamOptimizer(train_step_size).minimize(
        regularized_loss)

    # saver
    saver = tf.train.Saver({
        "W1": W1,
        "W11": W11,
        "W2": W2,
        "b1": b1,
        "b11": b11,
        "b2": b2
    })

    ############
    # training #
    ############

    bar = progressbar.ProgressBar(maxval=n_iter)
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        for i in range(0, n_iter):
            idx_batch = np.random.choice(ndata, n_batch, replace=False)
            sess.run(train_step,
                     feed_dict={
                         x_tf: x_train[idx_batch],
                         y_tf: y_train[idx_batch]
                     })
            bar.update(i + 1)
        saver.save(sess, nn_name)

        idx_batch = np.random.choice(ndata, n_batch, replace=False)
        output = sess.run(o_layer, feed_dict={x_tf: x_train[idx_batch]})
        y_test = y_train[idx_batch]
        tp = np.mean(output[y_test == 1] > 0)
        fp = np.mean(output[y_test == 0] > 0)

        print('Approximate training true positive rate: ' + str(tp) +
              ', false positive rate: ' + str(fp))
    bar.finish()