def triage_wf(self, wf_tf, threshold): """ Run neural net triage on given spike waveforms Parameters: ----------- wf_tf: tf tensor (n_spikes, n_temporal_length, n_neigh) tf tensor that produces spikes waveforms threshold: int threshold used on a probability obtained after nn to determine whether it is a clear spike Returns: ----------- tf tensor (n_spikes,) a boolean tensorflow tensor that produces indices of clear spikes """ # get parameters K1, K2 = self.filters_dict['filters'] # first layer: temporal feature layer1 = tf.nn.relu( conv2d_VALID(tf.expand_dims(wf_tf, -1), self.W1) + self.b1) # second layer: feataure mapping layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11) # third layer: spatial convolution o_layer = conv2d_VALID(layer11, self.W2) + self.b2 # thrshold it return o_layer[:, 0, 0, 0] > np.log(threshold / (1 - threshold))
def _make_training_graph(cls, waveform_length, n_neighbors, vars_dict): """Make graph for training Returns ------- x_tf: tf.tensor Input tensor y_tf: tf.tensor Labels tensor o_layer: tf.tensor Output tensor """ # x and y input tensors x_tf = tf.placeholder("float", [None, waveform_length, n_neighbors]) y_tf = tf.placeholder("float", [None]) input_tf = tf.expand_dims(x_tf, -1) vars_dict, layer11 = (NeuralNetDetector._make_network(input_tf, vars_dict, padding='VALID')) W2 = vars_dict['W2'] b2 = vars_dict['b2'] # third layer: spatial convolution o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2) # sigmoid sigmoid = tf.sigmoid(o_layer) return x_tf, y_tf, o_layer, sigmoid
def make_o_layer_tf_tensors(self, x_tf, channel_index): """ Make a tensorflow tensor that outputs spike index Parameters ----------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels Returns ------- output_tf: tf tensor (n_observations, n_channels) tensorflow tensor that produces spike_index """ # get parameters K1, K2 = self.filters_dict['filters'] nneigh = self.filters_dict['n_neighbors'] # save neighbor channel index self.channel_index = channel_index[:, :nneigh] # Temporal shape of input T = tf.shape(x_tf)[0] # input tensor into CNN x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0) # NN structures # first temporal layer layer1 = tf.nn.relu(conv2d(x_cnn_tf, self.W1) + self.b1) # second temporal layer layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11) # first spatial layer zero_added_layer11 = tf.concat( (tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))), axis=0) temp = tf.transpose(tf.gather(zero_added_layer11, self.channel_index), [0, 2, 3, 1, 4]) temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, nneigh, K2]), self.W2) + self.b2 # output layer # o_layer: [1, temporal, spatial, 1] o_layer = tf.transpose(temp2, [2, 1, 0, 3])[0, :, :, 0] output_tf = tf.sigmoid(o_layer) return output_tf
def get_spikes(self, x_tf, T, nneigh, c_idx, temporal_window, th): """ Detects and indexes spikes from the recording. The recording will be chopped to minibatches if its temporal length exceeds 10000. A spike is detected at [t, c] when the output probability of the neural network detector crosses the detection threshold at time t and channel c. For temporal duplicates within a certain temporal radius, the temporal index corresponding to the largest output probability is assigned. For spatial duplicates within certain neighboring channels, the channel with the highest energy is assigned. Parameters: ----------- X: np.array [number of channels, temporal length] raw recording. Returns: ----------- index: np.array [number of detected spikes, 3] returned indices for spikes. First column corresponds to temporal location; second column corresponds to spatial (channel) location. """ # get parameters K1, K2 = self.filters_dict['filters'] # NN structures layer1 = tf.nn.relu( conv2d(tf.expand_dims(tf.expand_dims(x_tf, -1), 0), self.W1) + self.b1) layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11) zero_added_layer11 = tf.concat( (tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))), axis=0) temp = tf.transpose(tf.gather(zero_added_layer11, c_idx), [0, 2, 3, 1, 4]) temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, nneigh, K2]), self.W2) + self.b2 o_layer = tf.transpose(temp2, [2, 1, 0, 3]) temporal_max = max_pool(o_layer, [1, temporal_window, 1, 1]) local_max_idx = tf.where( tf.logical_and(o_layer[0, :, :, 0] >= temporal_max[0, :, :, 0], o_layer[0, :, :, 0] > np.log(th / (1 - th)))) return local_max_idx
def _make_network(cls, input_tensor, filters_size, waveform_length, n_neighbors): """Mates tensorflow network, from first layer to output layer """ K1, K2 = filters_size # initialize and save nn weights W1 = weight_variable([waveform_length, 1, 1, K1]) b1 = bias_variable([K1]) W11 = weight_variable([1, 1, K1, K2]) b11 = bias_variable([K2]) W2 = weight_variable([1, n_neighbors, K2, 1]) b2 = bias_variable([1]) # first layer: temporal feature layer1 = tf.nn.relu( conv2d_VALID(tf.expand_dims(input_tensor, -1), W1) + b1) # second layer: feataure mapping layer11 = tf.nn.relu(conv2d(layer1, W11) + b11) # third layer: spatial convolution o_layer = conv2d_VALID(layer11, W2) + b2 vars_dict = { "W1": W1, "W11": W11, "W2": W2, "b1": b1, "b11": b11, "b2": b2 } return o_layer, vars_dict
def make_detection_tf_tensors(self, x_tf, channel_index, threshold): """ Make a tensorflow tensor that outputs spike index Parameters ----------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels threshold: int threshold on a probability to determine location of spikes Returns ------- spike_index_tf: tf tensor (n_spikes, 2) tensorflow tensor that produces spike_index """ # get parameters K1, K2 = self.filters_dict['filters'] nneigh = self.filters_dict['n_neighbors'] # save neighbor channel index self.channel_index = channel_index[:, :nneigh] # Temporal shape of input T = tf.shape(x_tf)[0] # input tensor into CNN x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0) # NN structures # first temporal layer layer1 = tf.nn.relu(conv2d(x_cnn_tf, self.W1) + self.b1) # second temporal layer layer11 = tf.nn.relu(conv2d(layer1, self.W11) + self.b11) # first spatial layer zero_added_layer11 = tf.concat( (tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))), axis=0) temp = tf.transpose(tf.gather(zero_added_layer11, self.channel_index), [0, 2, 3, 1, 4]) temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, nneigh, K2]), self.W2) + self.b2 # output layer o_layer = tf.transpose(temp2, [2, 1, 0, 3]) # temporal max temporal_max = max_pool(o_layer, [1, 3, 1, 1]) - 1e-8 # spike index is local maximum crossing a threshold spike_index_tf = tf.cast( tf.where( tf.logical_and( o_layer[0, :, :, 0] >= temporal_max[0, :, :, 0], o_layer[0, :, :, 0] > np.log(threshold / (1 - threshold)))), 'int32') return spike_index_tf
def fit(self, x_train, y_train): """ Trains the neural network detector for spike detection Parameters: ----------- x_train: np.array [number of training data, temporal length, number of channels] augmented training data consisting of isolated spikes, noise and misaligned spikes. y_train: np.array [number of training data] label for x_train. '1' denotes presence of an isolated spike and '0' denotes the presence of a noise data or misaligned spike. path_to_model: string name of the .ckpt to be saved """ ###################### # Loading parameters # ###################### logger = logging.getLogger(__name__) # get parameters n_data, waveform_length_train, n_neighbors_train = x_train.shape if self.waveform_length != waveform_length_train: raise ValueError('waveform length from network ({}) does not ' 'match training data ({})' .format(self.waveform_length, waveform_length_train)) if self.n_neighbors != n_neighbors_train: raise ValueError('number of n_neighbors from network ({}) does ' 'not match training data ({})' .format(self.n_neigh, n_neighbors_train)) #################### # Building network # #################### # x and y input tensors x_tf = tf.placeholder("float", [self.n_batch, self.waveform_length, self.n_neighbors]) y_tf = tf.placeholder("float", [self.n_batch]) input_tf = tf.expand_dims(x_tf, -1) vars_dict, layer11 = (NeuralNetDetector ._make_network(input_tf, self.waveform_length, self.filters_size, self.n_neighbors, padding='VALID')) W2 = vars_dict['W2'] b2 = vars_dict['b2'] # third layer: spatial convolution o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2) ########################## # Optimization objective # ########################## # cross entropy _ = tf.nn.sigmoid_cross_entropy_with_logits(logits=o_layer, labels=y_tf) cross_entropy = tf.reduce_mean(_) weights = tf.trainable_variables() # regularization term l2_regularizer = (tf.contrib.layers .l2_regularizer(scale=self.l2_reg_scale)) regularization = tf.contrib.layers.apply_regularization(l2_regularizer, weights) regularized_loss = cross_entropy + regularization # train step train_step = (tf.train.AdamOptimizer(self.train_step_size) .minimize(regularized_loss)) ############ # Training # ############ # saver saver = tf.train.Saver(vars_dict) logger.debug('Training detector network...') with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) pbar = trange(self.n_iter) for i in pbar: # sample n_batch observations from 0, ..., n_data idx_batch = np.random.choice(n_data, self.n_batch, replace=False) res = sess.run([train_step, regularized_loss], feed_dict={x_tf: x_train[idx_batch], y_tf: y_train[idx_batch]}) if i % 100 == 0: pbar.set_description('Loss: %s' % res[1]) logger.debug('Saving network: %s', self.path_to_model) saver.save(sess, self.path_to_model) # estimate tp and fp with a sample idx_batch = np.random.choice(n_data, self.n_batch, replace=False) output = sess.run(o_layer, feed_dict={x_tf: x_train[idx_batch]}) y_test = y_train[idx_batch] tp = np.mean(output[y_test == 1] > 0) fp = np.mean(output[y_test == 0] > 0) logger.debug('Approximate training true positive rate: ' + str(tp) + ', false positive rate: ' + str(fp)) path_to_params = change_extension(self.path_to_model, 'yaml') logger.debug('Saving network parameters: %s', path_to_params) save_detect_network_params(filters_size=self.filters_size, waveform_length=self.waveform_length, n_neighbors=self.n_neighbors, output_path=path_to_params)
def _make_graph(cls, threshold, channel_index, waveform_length, filters_size, n_neigh): """Build tensorflow graph with input and two output layers Parameters ----------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a placeholder in a case that a channel has less than n_neigh neighboring channels threshold: int threshold on a probability to determine location of spikes Returns ------- spike_index_tf: tf tensor (n_spikes, 2) tensorflow tensor that produces spike_index """ ###################### # Loading parameters # ###################### # TODO: need to ask why we are sending different channel indexes # save neighbor channel index small_channel_index = channel_index[:, :n_neigh] # placeholder for input recording x_tf = tf.placeholder("float", [None, None]) # Temporal shape of input T = tf.shape(x_tf)[0] #################### # Building network # #################### # input tensor into CNN - add one dimension at the beginning and # at the end x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0) vars_dict, layer11 = cls._make_network(x_cnn_tf, waveform_length, filters_size, n_neigh, padding='SAME') W2 = vars_dict['W2'] b2 = vars_dict['b2'] K1, K2 = filters_size # first spatial layer zero_added_layer11 = tf.concat((tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))), axis=0) temp = tf.transpose(tf.gather(zero_added_layer11, small_channel_index), [0, 2, 3, 1, 4]) temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, n_neigh, K2]), W2) + b2 o_layer = tf.transpose(temp2, [2, 1, 0, 3]) ################################ # Output layer transformations # ################################ o_layer_val = tf.squeeze(o_layer) # probability output - just sigmoid of output layer probability_tf = tf.sigmoid(o_layer_val) # spike index output (local maximum crossing a threshold) temporal_max = tf.squeeze(max_pool(o_layer, [1, 3, 1, 1]) - 1e-8) higher_than_max_pool = o_layer_val >= temporal_max higher_than_threshold = (o_layer_val > np.log(threshold / (1 - threshold))) both_higher = tf.logical_and(higher_than_max_pool, higher_than_threshold) index_all = tf.cast(tf.where(both_higher), 'int32') spike_index_tf = cls._remove_edge_spikes(x_tf, index_all, waveform_length) # waveform output from spike index output waveform_tf = cls._make_waveform_tf(x_tf, spike_index_tf, channel_index, waveform_length) return x_tf, spike_index_tf, probability_tf, waveform_tf, vars_dict
def train_triage(x_train, y_train, n_filters, n_iter, n_batch, l2_reg_scale, train_step_size, nn_name): """ Trains the triage network Parameters: ----------- x_train: np.array [number of data, temporal length, number of channels] training data for the triage network. y_train: np.array [number of data] training label for the triage network. nn_name: string name of the .ckpt to be saved. """ # get parameters ndata, R, C = x_train.shape K1, K2 = n_filters # x and y input tensors x_tf = tf.placeholder("float", [n_batch, R, C]) y_tf = tf.placeholder("float", [n_batch]) # first layer: temporal feature W1 = weight_variable([R, 1, 1, K1]) b1 = bias_variable([K1]) layer1 = tf.nn.relu(conv2d_VALID(tf.expand_dims(x_tf, -1), W1) + b1) # second layer: feataure mapping W11 = weight_variable([1, 1, K1, K2]) b11 = bias_variable([K2]) layer11 = tf.nn.relu(conv2d(layer1, W11) + b11) # third layer: spatial convolution W2 = weight_variable([1, C, K2, 1]) b2 = bias_variable([1]) o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2) # cross entropy cross_entropy = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=o_layer, labels=y_tf)) # regularization term weights = tf.trainable_variables() l2_regularizer = tf.contrib.layers.l2_regularizer(scale=l2_reg_scale) regularization_penalty = tf.contrib.layers.apply_regularization( l2_regularizer, weights) regularized_loss = cross_entropy + regularization_penalty # train step train_step = tf.train.AdamOptimizer(train_step_size).minimize( regularized_loss) # saver saver = tf.train.Saver({ "W1": W1, "W11": W11, "W2": W2, "b1": b1, "b11": b11, "b2": b2 }) ############ # training # ############ bar = progressbar.ProgressBar(maxval=n_iter) with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) for i in range(0, n_iter): idx_batch = np.random.choice(ndata, n_batch, replace=False) sess.run(train_step, feed_dict={ x_tf: x_train[idx_batch], y_tf: y_train[idx_batch] }) bar.update(i + 1) saver.save(sess, nn_name) idx_batch = np.random.choice(ndata, n_batch, replace=False) output = sess.run(o_layer, feed_dict={x_tf: x_train[idx_batch]}) y_test = y_train[idx_batch] tp = np.mean(output[y_test == 1] > 0) fp = np.mean(output[y_test == 0] > 0) print('Approximate training true positive rate: ' + str(tp) + ', false positive rate: ' + str(fp)) bar.finish()