def test_can_make_noise(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) templates = crop_and_align_templates(templates_uncropped, CONFIG.spike_size, CONFIG.neigh_channels, CONFIG.geom) spatial_SIG, temporal_SIG = noise_cov(path_to_standarized_data, CONFIG.neigh_channels, CONFIG.geom, templates.shape[1]) x_clean = make_clean(templates, min_amp=2, max_amp=10, nk=100) make_noise(x_clean, noise_ratio=10, templates=templates, spatial_SIG=spatial_SIG, temporal_SIG=temporal_SIG)
def get_noise_covariance(reader, save_dir, CONFIG, chunk=None): if not os.path.exists(save_dir): os.mkdir(save_dir) fname_spatial_sig = os.path.join(save_dir, 'spatial_sig.npy') fname_temporal_sig = os.path.join(save_dir, 'temporal_sig.npy') if os.path.exists(fname_spatial_sig) and os.path.exists(fname_temporal_sig): return fname_spatial_sig, fname_temporal_sig # only need subset of channels n_neigh_channels = np.sum(CONFIG.neigh_channels, axis=1) c_chosen = np.where(n_neigh_channels == np.max(n_neigh_channels))[0][0] channels = np.where(CONFIG.neigh_channels[c_chosen])[0] if chunk is None: rec = reader.read_data(reader.start, reader.end, channels) else: rec = reader.read_data(chunk[0], chunk[1], channels) spatial_SIG, temporal_SIG = noise_cov( rec, temporal_size=reader.spike_size, window_size=reader.spike_size, sample_size=1000, threshold=3.0) np.save(fname_spatial_sig, spatial_SIG) np.save(fname_temporal_sig, temporal_SIG) return fname_spatial_sig, fname_temporal_sig
def test_can_compute_noise_cov(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) noise_cov(path_to_standarized_data, CONFIG.neigh_channels, CONFIG.geom, templates_uncropped.shape[1])
def test_can_compute_noise_cov(path_to_tests, path_to_standarized_data): recordings = RecordingsReader(path_to_standarized_data, loader='array')._data spatial_SIG, temporal_SIG = noise_cov(recordings, temporal_size=10, sample_size=100, threshold=3.0, window_size=10)
def test_can_estimate_temporal_and_spatial_sig(path_to_standarized_data): recordings = RecordingsReader(path_to_standarized_data, loader='array')._data (spatial_SIG, temporal_SIG) = noise.noise_cov(recordings, temporal_size=40, sample_size=1000, threshold=3.0, window_size=10) # check no nans assert (~np.isnan(spatial_SIG)).all() assert (~np.isnan(temporal_SIG)).all()
def make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder): """[Description] Parameters ---------- Returns ------- """ logger = logging.getLogger(__name__) path_to_data = os.path.join(data_folder, 'standarized.bin') path_to_config = os.path.join(data_folder, 'standarized.yaml') # make sure standarized data already exists if not os.path.exists(path_to_data): raise ValueError( 'Standarized data does not exist in: {}, this is ' 'needed to generate training data, run the ' 'preprocesor first to generate it'.format(path_to_data)) PARAMS = load_yaml(path_to_config) logger.info('Getting templates...') # get templates templates, _ = get_templates(spike_train, path_to_data, CONFIG.spikeSize) templates = np.transpose(templates, (2, 1, 0)) logger.info('Got templates ndarray of shape: {}'.format(templates.shape)) # choose good templates (good looking and big enough) templates = choose_templates(templates, chosen_templates) if templates.shape[0] == 0: raise ValueError("Coulndt find any good templates...") logger.info('Good looking templates of shape: {}'.format(templates.shape)) # align and crop templates templates = crop_templates(templates, CONFIG.spikeSize, CONFIG.neighChannels, CONFIG.geom) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'], CONFIG.recordings.n_channels, PARAMS['data_format'], CONFIG.neighChannels, CONFIG.geom, templates.shape[1]) # make training data set K = templates.shape[0] R = CONFIG.spikeSize amps = np.max(np.abs(templates), axis=1) # make clean augmented spikes nk = int(np.ceil(nspikes / K)) max_amp = np.max(amps) * 1.5 nneigh = templates.shape[2] ################ # clean spikes # ################ x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2])) for k in range(K): tt = templates[k] amp_now = np.max(np.abs(tt)) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] x_clean[k * nk:(k + 1) * nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range ############# # collision # ############# x_collision = np.zeros(x_clean.shape) max_shift = 2 * R temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 amp_per_data = np.max(x_clean[:, :, 0], axis=1) for j in range(nk * K): shift = temporal_shifts[j] x_collision[j] = np.copy(x_clean[j]) idx_candidate = np.where(amp_per_data > amp_per_data[j] * 0.3)[0] idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0], size=1)[0]] x_clean2 = np.copy(x_clean[idx_match] [:, np.random.choice(nneigh, nneigh, replace=False)]) if shift > 0: x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_collision[j, (-shift):] += x_clean2[:(x_collision.shape[1] + shift)] else: x_collision[j] += x_clean2 ##################### # misaligned spikes # ##################### x_misaligned = np.zeros(x_clean.shape) temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 for j in range(nk * K): shift = temporal_shifts[j] x_clean2 = np.copy( x_clean[j][:, np.random.choice(nneigh, nneigh, replace=False)]) if shift > 0: x_misaligned[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_misaligned[j, (-shift):] += x_clean2[:(x_collision.shape[1] + shift)] else: x_misaligned[j] += x_clean2 ######### # noise # ######### # get noise noise = np.random.normal(size=x_clean.shape) for c in range(noise.shape[2]): noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG) reshaped_noise = np.reshape(noise, (-1, noise.shape[2])) noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG), [x_clean.shape[0], x_clean.shape[1], x_clean.shape[2]]) y_clean = np.ones((x_clean.shape[0])) y_col = np.ones((x_clean.shape[0])) y_misalinged = np.zeros((x_clean.shape[0])) y_noise = np.zeros((x_clean.shape[0])) mid_point = int((x_clean.shape[1] - 1) / 2) # get training set for detection x = np.concatenate( (x_clean + noise, x_collision + noise[np.random.permutation(noise.shape[0])], x_misaligned + noise[np.random.permutation(noise.shape[0])], noise)) x_detect = x[:, (mid_point - R):(mid_point + R + 1), :] y_detect = np.concatenate((y_clean, y_col, y_misalinged, y_noise)) # get training set for triage x = np.concatenate(( x_clean + noise, x_collision + noise[np.random.permutation(noise.shape[0])], )) x_triage = x[:, (mid_point - R):(mid_point + R + 1), :] y_triage = np.concatenate((y_clean, np.zeros((x_clean.shape[0])))) # ge training set for auto encoder ae_shift_max = 1 temporal_shifts_ae = np.random.randint( ae_shift_max * 2 + 1, size=x_clean.shape[0]) - ae_shift_max y_ae = np.zeros((x_clean.shape[0], 2 * R + 1)) x_ae = np.zeros((x_clean.shape[0], 2 * R + 1)) for j in range(x_ae.shape[0]): y_ae[j] = x_clean[j, (mid_point - R + temporal_shifts_ae[j]):(mid_point + R + 1 + temporal_shifts_ae[j]), 0] x_ae[j] = x_clean[j, (mid_point - R + temporal_shifts_ae[j]):( mid_point + R + 1 + temporal_shifts_ae[j]), 0] + noise[j, ( mid_point - R + temporal_shifts_ae[j]):( mid_point + R + 1 + temporal_shifts_ae[j]), 0] return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
def training_data(CONFIG, templates_uncropped, min_amp, max_amp, n_isolated_spikes, path_to_standarized, noise_ratio=10, collision_ratio=1, misalign_ratio=1, misalign_ratio2=1, multi_channel=True, return_metadata=False): """Makes training sets for detector, triage and autoencoder Parameters ---------- CONFIG: yaml file Configuration file min_amp: float Minimum value allowed for the maximum absolute amplitude of the isolated spike on its main channel max_amp: float Maximum value allowed for the maximum absolute amplitude of the isolated spike on its main channel n_isolated_spikes: int Number of isolated spikes to generate. This is different from the total number of x_detect path_to_standarized: str Folder storing the standarized data (if not exist, run preprocess to automatically generate) noise_ratio: int Ratio of number of noise to isolated spikes. For example, if n_isolated_spike=1000, noise_ratio=5, then n_noise=5000 collision_ratio: int Ratio of number of collisions to isolated spikes. misalign_ratio: int Ratio of number of spatially and temporally misaligned spikes to isolated spikes misalign_ratio2: int Ratio of number of only-spatially misaligned spikes to isolated spikes multi_channel: bool If True, generate training data for multi-channel neural network. Otherwise generate single-channel data Returns ------- x_detect: numpy.ndarray [number of detection training data, temporal length, number of channels] Training data for the detect net. y_detect: numpy.ndarray [number of detection training data] Label for x_detect x_triage: numpy.ndarray [number of triage training data, temporal length, number of channels] Training data for the triage net. y_triage: numpy.ndarray [number of triage training data] Label for x_triage x_ae: numpy.ndarray [number of ae training data, temporal length] Training data for the autoencoder: noisy spikes y_ae: numpy.ndarray [number of ae training data, temporal length] Denoised x_ae Notes ----- * Detection training data * Multi channel * Positive examples: Clean spikes + noise, Collided spikes + noise * Negative examples: Temporally misaligned spikes + noise, Noise * Triage training data * Multi channel * Positive examples: Clean spikes + noise * Negative examples: Collided spikes + noise """ # FIXME: should we add collided spikes with the first spike non-centered # tod the detection training set? logger = logging.getLogger(__name__) # STEP1: Load recordings data, and select one channel and random (with the # right number of neighbors, then swap the channels so the first one # corresponds to the selected channel, then the nearest neighbor, then the # second nearest and so on... this is only used for estimating noise # structure # ##### FIXME: this needs to be removed, the user should already # pass data with the desired channels rec = RecordingsReader(path_to_standarized, loader='array') channel_n_neighbors = np.sum(CONFIG.neigh_channels, 0) max_neighbors = np.max(channel_n_neighbors) channels_with_max_neighbors = np.where(channel_n_neighbors == max_neighbors)[0] logger.debug('The following channels have %i neighbors: %s', max_neighbors, channels_with_max_neighbors) # reference channel: channel with max number of neighbors channel_selected = np.random.choice(channels_with_max_neighbors) logger.debug('Selected channel %i', channel_selected) # neighbors for the reference channel channel_neighbors = np.where(CONFIG.neigh_channels[channel_selected])[0] # ordered neighbors for reference channel channel_idx, _ = order_channels_by_distance(channel_selected, channel_neighbors, CONFIG.geom) # read the selected channels rec = rec[:, channel_idx] # ##### FIXME:end of section to be removed # STEP 2: load templates processor = TemplatesProcessor(templates_uncropped) # swap channels, first channel is main channel, then nearest neighbor # and so on, only keep neigh_channels templates = (processor.crop_spatially(CONFIG.neigh_channels, CONFIG.geom) .values) # TODO: remove, this data can be obtained from other variables K, _, n_channels = templates_uncropped.shape # make training data set R = CONFIG.spike_size logger.debug('Output will be of size %s', 2 * R + 1) # make clean augmented spikes nk = int(np.ceil(n_isolated_spikes/K)) max_shift = 2*R # make spikes from templates x_templates = util.make_from_templates(templates, min_amp, max_amp, nk) # make collided spikes - max shift is set to R since 2 * R + 1 will be # the final dimension for the spikes. one of the spikes is kept with the # main channel, the other one is shifted and channels are changed x_collision = util.make_collided(x_templates, collision_ratio, multi_channel, max_shift=R, min_shift=5, return_metadata=return_metadata) # make misaligned spikes x_temporally_misaligned = util.make_temporally_misaligned( x_templates, misalign_ratio, multi_channel=multi_channel, max_shift=max_shift) # now spatially misalign those x_misaligned = util.make_spatially_misaligned(x_temporally_misaligned, n_per_spike=misalign_ratio2) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(rec, temporal_size=templates.shape[1], window_size=templates.shape[1], sample_size=1000, threshold=3.0) # make noise n_noise = int(x_templates.shape[0] * noise_ratio) noise = util.make_noise(n_noise, spatial_SIG, temporal_SIG) # make labels y_clean_1 = np.ones((x_templates.shape[0])) y_collision_1 = np.ones((x_collision.shape[0])) y_misaligned_0 = np.zeros((x_misaligned.shape[0])) y_noise_0 = np.zeros((noise.shape[0])) y_collision_0 = np.zeros((x_collision.shape[0])) mid_point = int((x_templates.shape[1]-1)/2) MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1) # TODO: replace _make_noisy for new function x_templates_noisy = util._make_noisy(x_templates, noise) x_collision_noisy = util._make_noisy(x_collision, noise) x_misaligned_noisy = util._make_noisy(x_misaligned, noise) ############# # Detection # ############# if multi_channel: x = yarr.concatenate((x_templates_noisy, x_collision_noisy, x_misaligned_noisy, noise)) x_detect = x[:, MID_POINT_IDX, :] y_detect = np.concatenate((y_clean_1, y_collision_1, y_misaligned_0, y_noise_0)) else: x = yarr.concatenate((x_templates_noisy, x_misaligned_noisy, noise)) x_detect = x[:, MID_POINT_IDX, 0] y_detect = yarr.concatenate((y_clean_1, y_misaligned_0, y_noise_0)) ########## # Triage # ########## if multi_channel: x = yarr.concatenate((x_templates_noisy, x_collision_noisy)) x_triage = x[:, MID_POINT_IDX, :] y_triage = yarr.concatenate((y_clean_1, y_collision_0)) else: x = yarr.concatenate((x_templates_noisy, x_collision_noisy,)) x_triage = x[:, MID_POINT_IDX, 0] y_triage = yarr.concatenate((y_clean_1, y_collision_0)) ############### # Autoencoder # ############### # # TODO: need to abstract this part of the code, create a separate # # function and document it # neighbors_ae = np.ones((n_channels, n_channels), 'int32') # templates_ae = crop_and_align_templates(templates_uncropped, # CONFIG.spike_size, # neighbors_ae, # CONFIG.geom) # tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1) # tt = tt[:, np.ptp(tt, axis=0) > 2] # max_amp = np.max(np.ptp(tt, axis=0)) # y_ae = np.zeros((nk*tt.shape[1], tt.shape[0])) # for k in range(tt.shape[1]): # amp_now = np.ptp(tt[:, k]) # amps_range = (np.arange(nk)*(max_amp-min_amp) # / nk+min_amp)[:, np.newaxis, np.newaxis] # y_ae[k*nk:(k+1)*nk] = ((tt[:, k]/amp_now)[np.newaxis, :] # * amps_range[:, :, 0]) # noise_ae = np.random.normal(size=y_ae.shape) # noise_ae = np.matmul(noise_ae, temporal_SIG) # x_ae = y_ae + noise_ae # x_ae = x_ae[:, MID_POINT_IDX] # y_ae = y_ae[:, MID_POINT_IDX] x_ae = None y_ae = None # FIXME: y_ae is no longer used, autoencoder was replaced by PCA return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
def make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder, noise_ratio=10, collision_ratio=1, misalign_ratio=1, misalign_ratio2=1, multi=True): """[Description] Parameters ---------- CONFIG: yaml file Configuration file spike_train: numpy.ndarray [number of spikes, 2] Ground truth for training. First column is the spike time, second column is the spike id chosen_templates: list List of chosen templates' id's min_amp: float Minimum value allowed for the maximum absolute amplitude of the isolated spike on its main channel nspikes: int Number of isolated spikes to generate. This is different from the total number of x_detect data_folder: str Folder storing the standarized data (if not exist, run preprocess to automatically generate) noise_ratio: int Ratio of number of noise to isolated spikes. For example, if n_isolated_spike=1000, noise_ratio=5, then n_noise=5000 collision_ratio: int Ratio of number of collisions to isolated spikes. misalign_ratio: int Ratio of number of spatially and temporally misaligned spikes to isolated spikes misalign_ratio2: int Ratio of number of only-spatially misaligned spikes to isolated spikes multi: bool If multi= True, generate training data for multi-channel neural network. Otherwise generate single-channel data Returns ------- x_detect: numpy.ndarray [number of detection training data, temporal length, number of channels] Training data for the detect net. y_detect: numpy.ndarray [number of detection training data] Label for x_detect x_triage: numpy.ndarray [number of triage training data, temporal length, number of channels] Training data for the triage net. y_triage: numpy.ndarray [number of triage training data] Label for x_triage x_ae: numpy.ndarray [number of ae training data, temporal length] Training data for the autoencoder: noisy spikes y_ae: numpy.ndarray [number of ae training data, temporal length] Denoised x_ae """ logger = logging.getLogger(__name__) path_to_data = os.path.join(data_folder, 'standarized.bin') path_to_config = os.path.join(data_folder, 'standarized.yaml') # make sure standarized data already exists if not os.path.exists(path_to_data): raise ValueError( 'Standarized data does not exist in: {}, this is ' 'needed to generate training data, run the ' 'preprocesor first to generate it'.format(path_to_data)) PARAMS = load_yaml(path_to_config) logger.info('Getting templates...') # get templates templates, _ = get_templates( np.hstack((spike_train, np.ones((spike_train.shape[0], 1), 'int32'))), path_to_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates = np.transpose(templates, (2, 1, 0)) logger.info('Got templates ndarray of shape: {}'.format(templates.shape)) # choose good templates (good looking and big enough) templates = choose_templates(templates, chosen_templates) templates_uncropped = np.copy(templates) if templates.shape[0] == 0: raise ValueError("Coulndt find any good templates...") logger.info('Good looking templates of shape: {}'.format(templates.shape)) # align and crop templates templates = crop_templates(templates, CONFIG.spike_size, CONFIG.neigh_channels, CONFIG.geom) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'], CONFIG.recordings.n_channels, PARAMS['data_order'], CONFIG.neigh_channels, CONFIG.geom, templates.shape[1]) # make training data set K = templates.shape[0] R = CONFIG.spike_size amps = np.max(np.abs(templates), axis=1) # make clean augmented spikes nk = int(np.ceil(nspikes / K)) max_amp = np.max(amps) * 1.5 nneigh = templates.shape[2] ################ # clean spikes # ################ x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2])) for k in range(K): tt = templates[k] amp_now = np.max(np.abs(tt)) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] x_clean[k * nk:(k + 1) * nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range ############# # collision # ############# x_collision = np.zeros((x_clean.shape[0] * int(collision_ratio), templates.shape[1], templates.shape[2])) max_shift = 2 * R temporal_shifts = np.random.randint(max_shift * 2, size=x_collision.shape[0]) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 amp_per_data = np.max(x_clean[:, :, 0], axis=1) for j in range(x_collision.shape[0]): shift = temporal_shifts[j] x_collision[j] = np.copy(x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]) idx_candidate = np.where( amp_per_data > np.max(x_collision[j, :, 0]) * 0.3)[0] idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0], size=1)[0]] if multi: x_clean2 = np.copy( x_clean[idx_match] [:, np.random.choice(nneigh, nneigh, replace=False)]) else: x_clean2 = np.copy(x_clean[idx_match]) if shift > 0: x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_collision[j, (-shift):] += x_clean2[:(x_collision.shape[1] + shift)] else: x_collision[j] += x_clean2 ############################################### # temporally and spatially misaligned spikes # ############################################# x_misaligned = np.zeros((x_clean.shape[0] * int(misalign_ratio), templates.shape[1], templates.shape[2])) temporal_shifts = np.random.randint(max_shift * 2, size=x_misaligned.shape[0]) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 for j in range(x_misaligned.shape[0]): shift = temporal_shifts[j] if multi: x_clean2 = np.copy( x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)] [:, :, np.random.choice(nneigh, nneigh, replace=False)]) x_clean2 = np.squeeze(x_clean2) else: x_clean2 = np.copy(x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]) x_clean2 = np.squeeze(x_clean2) if shift > 0: x_misaligned[j, :(x_misaligned.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_misaligned[j, (-shift):] += x_clean2[:(x_misaligned.shape[1] + shift)] else: x_misaligned[j] += x_clean2 ################################ # spatially misaligned spikes # ############################## if multi: x_misaligned2 = np.zeros((x_clean.shape[0] * int(misalign_ratio2), templates.shape[1], templates.shape[2])) for j in range(x_misaligned2.shape[0]): x_misaligned2[j] = np.copy( x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)] [:, :, np.random.choice(nneigh, nneigh, replace=False)]) ######### # noise # ######### # get noise noise = np.random.normal(size=[ x_clean.shape[0] * int(noise_ratio), templates.shape[1], templates.shape[2] ]) for c in range(noise.shape[2]): noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG) reshaped_noise = np.reshape(noise, (-1, noise.shape[2])) noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG), [noise.shape[0], x_clean.shape[1], x_clean.shape[2]]) y_clean = np.ones((x_clean.shape[0])) y_col = np.ones((x_collision.shape[0])) y_misaligned = np.zeros((x_misaligned.shape[0])) if multi: y_misaligned2 = np.zeros((x_misaligned2.shape[0])) y_noise = np.zeros((noise.shape[0])) mid_point = int((x_clean.shape[1] - 1) / 2) # get training set for detection if multi: x = np.concatenate( (x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_collision + noise[np.random.choice( noise.shape[0], x_collision.shape[0], replace=False)], x_misaligned + noise[np.random.choice( noise.shape[0], x_misaligned.shape[0], replace=False)], noise)) x_detect = x[:, (mid_point - R):(mid_point + R + 1), :] y_detect = np.concatenate((y_clean, y_col, y_misaligned, y_noise)) else: x = np.concatenate( (x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_misaligned + noise[np.random.choice( noise.shape[0], x_misaligned.shape[0], replace=False)], noise)) x_detect = x[:, (mid_point - R):(mid_point + R + 1), 0] y_detect = np.concatenate((y_clean, y_misaligned, y_noise)) # get training set for triage if multi: x = np.concatenate(( x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_collision + noise[np.random.choice( noise.shape[0], x_collision.shape[0], replace=False)], x_misaligned2 + noise[np.random.choice( noise.shape[0], x_misaligned2.shape[0], replace=False)], )) x_triage = x[:, (mid_point - R):(mid_point + R + 1), :] y_triage = np.concatenate((y_clean, np.zeros( (x_collision.shape[0])), y_misaligned2)) else: x = np.concatenate(( x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_collision + noise[np.random.choice( noise.shape[0], x_collision.shape[0], replace=False)], )) x_triage = x[:, (mid_point - R):(mid_point + R + 1), 0] y_triage = np.concatenate((y_clean, np.zeros((x_collision.shape[0])))) ############### # Autoencoder # ############### n_channels = templates_uncropped.shape[2] templates_ae = crop_templates(templates_uncropped, CONFIG.spike_size, np.ones((n_channels, n_channels), 'int32'), CONFIG.geom) tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1) tt = tt[:, np.ptp(tt, axis=0) > 2] max_amp = np.max(np.ptp(tt, axis=0)) y_ae = np.zeros((nk * tt.shape[1], tt.shape[0])) for k in range(tt.shape[1]): amp_now = np.ptp(tt[:, k]) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] y_ae[k * nk:(k + 1) * nk] = ((tt[:, k] / amp_now)[np.newaxis, :] * amps_range[:, :, 0]) noise = np.random.normal(size=y_ae.shape) noise = np.matmul(noise, temporal_SIG) x_ae = y_ae + noise x_ae = x_ae[:, (mid_point - R):(mid_point + R + 1)] y_ae = y_ae[:, (mid_point - R):(mid_point + R + 1)] return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
def make_training_data(CONFIG, spike_train, chosen_templates_indexes, min_amp, nspikes, data_folder, noise_ratio=10, collision_ratio=1, misalign_ratio=1, misalign_ratio2=1, multi_channel=True): """Makes training sets for detector, triage and autoencoder Parameters ---------- CONFIG: yaml file Configuration file spike_train: numpy.ndarray [number of spikes, 2] Ground truth for training. First column is the spike time, second column is the spike id chosen_templates_indexes: list List of chosen templates' id's min_amp: float Minimum value allowed for the maximum absolute amplitude of the isolated spike on its main channel nspikes: int Number of isolated spikes to generate. This is different from the total number of x_detect data_folder: str Folder storing the standarized data (if not exist, run preprocess to automatically generate) noise_ratio: int Ratio of number of noise to isolated spikes. For example, if n_isolated_spike=1000, noise_ratio=5, then n_noise=5000 collision_ratio: int Ratio of number of collisions to isolated spikes. misalign_ratio: int Ratio of number of spatially and temporally misaligned spikes to isolated spikes misalign_ratio2: int Ratio of number of only-spatially misaligned spikes to isolated spikes multi_channel: bool If True, generate training data for multi-channel neural network. Otherwise generate single-channel data Returns ------- x_detect: numpy.ndarray [number of detection training data, temporal length, number of channels] Training data for the detect net. y_detect: numpy.ndarray [number of detection training data] Label for x_detect x_triage: numpy.ndarray [number of triage training data, temporal length, number of channels] Training data for the triage net. y_triage: numpy.ndarray [number of triage training data] Label for x_triage x_ae: numpy.ndarray [number of ae training data, temporal length] Training data for the autoencoder: noisy spikes y_ae: numpy.ndarray [number of ae training data, temporal length] Denoised x_ae Notes ----- * Detection training data * Multi channel * Positive examples: Clean spikes + noise, Collided spikes + noise * Negative examples: Temporally misaligned spikes + noise, Noise * Triage training data * Multi channel * Positive examples: Clean spikes + noise * Negative examples: Collided spikes + noise, spatially misaligned spikes + noise """ logger = logging.getLogger(__name__) path_to_data = os.path.join(data_folder, 'preprocess', 'standarized.bin') n_spikes, _ = spike_train.shape # make sure standarized data already exists if not os.path.exists(path_to_data): raise ValueError( 'Standarized data does not exist in: {}, this is ' 'needed to generate training data, run the ' 'preprocesor first to generate it'.format(path_to_data)) logger.info('Getting templates...') # add weight of one to every spike weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) # get templates templates_uncropped, _ = get_templates(weighted_spike_train, path_to_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) K, _, n_channels = templates_uncropped.shape logger.info('Got templates ndarray of shape: {}'.format( templates_uncropped.shape)) # choose good templates (user selected and amplitude above threshold) # TODO: maybe the minimum_amplitude parameter should be selected by the # user templates_uncropped = choose_templates(templates_uncropped, chosen_templates_indexes, minimum_amplitude=4) if templates_uncropped.shape[0] == 0: raise ValueError("Coulndt find any good templates...") logger.info('Good looking templates of shape: {}'.format( templates_uncropped.shape)) templates = crop_and_align_templates(templates_uncropped, CONFIG.spike_size, CONFIG.neigh_channels, CONFIG.geom) # make training data set R = CONFIG.spike_size amps = np.max(np.abs(templates), axis=1) # make clean augmented spikes nk = int(np.ceil(nspikes / K)) max_amp = np.max(amps) * 1.5 nneigh = templates.shape[2] max_shift = 2 * R # make clean spikes x_clean = make_clean(templates, min_amp, max_amp, nk) # make collided spikes x_collision = make_collided(x_clean, collision_ratio, templates, R, multi_channel, nneigh) # make misaligned spikes (x_temporally_misaligned, x_spatially_misaligned) = make_misaligned(x_clean, templates, max_shift, misalign_ratio, misalign_ratio2, multi_channel, nneigh) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(path_to_data, CONFIG.neigh_channels, CONFIG.geom, templates.shape[1]) # make noise noise = make_noise(x_clean, noise_ratio, templates, spatial_SIG, temporal_SIG) # make labels y_clean_1 = np.ones((x_clean.shape[0])) y_collision_1 = np.ones((x_collision.shape[0])) y_misaligned_0 = np.zeros((x_temporally_misaligned.shape[0])) y_noise_0 = np.zeros((noise.shape[0])) y_collision_0 = np.zeros((x_collision.shape[0])) if multi_channel: y_misaligned2_0 = np.zeros((x_spatially_misaligned.shape[0])) mid_point = int((x_clean.shape[1] - 1) / 2) MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1) x_clean_noisy = make_noisy(x_clean, noise) x_collision_noisy = make_noisy(x_collision, noise) x_temporally_misaligned_noisy = make_noisy(x_temporally_misaligned, noise) x_spatially_misaligned_noisy = make_noisy(x_spatially_misaligned, noise) ############# # Detection # ############# if multi_channel: x = np.concatenate((x_clean_noisy, x_collision_noisy, x_temporally_misaligned_noisy, noise)) x_detect = x[:, MID_POINT_IDX, :] y_detect = np.concatenate( (y_clean_1, y_collision_1, y_misaligned_0, y_noise_0)) else: x = np.concatenate( (x_clean_noisy, x_temporally_misaligned_noisy, noise)) x_detect = x[:, MID_POINT_IDX, 0] y_detect = np.concatenate((y_clean_1, y_misaligned_0, y_noise_0)) ########## # Triage # ########## if multi_channel: x = np.concatenate( (x_clean_noisy, x_collision_noisy, x_spatially_misaligned_noisy)) x_triage = x[:, MID_POINT_IDX, :] y_triage = np.concatenate((y_clean_1, y_collision_0, y_misaligned2_0)) else: x = np.concatenate(( x_clean_noisy, x_collision_noisy, )) x_triage = x[:, MID_POINT_IDX, 0] y_triage = np.concatenate((y_clean_1, y_collision_0)) ############### # Autoencoder # ############### # TODO: need to abstract this part of the code, create a separate # function and document it neighbors_ae = np.ones((n_channels, n_channels), 'int32') templates_ae = crop_and_align_templates(templates_uncropped, CONFIG.spike_size, neighbors_ae, CONFIG.geom) tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1) tt = tt[:, np.ptp(tt, axis=0) > 2] max_amp = np.max(np.ptp(tt, axis=0)) y_ae = np.zeros((nk * tt.shape[1], tt.shape[0])) for k in range(tt.shape[1]): amp_now = np.ptp(tt[:, k]) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] y_ae[k * nk:(k + 1) * nk] = ((tt[:, k] / amp_now)[np.newaxis, :] * amps_range[:, :, 0]) noise_ae = np.random.normal(size=y_ae.shape) noise_ae = np.matmul(noise_ae, temporal_SIG) x_ae = y_ae + noise_ae x_ae = x_ae[:, MID_POINT_IDX] y_ae = y_ae[:, MID_POINT_IDX] return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae