def train(directory, config_train, logger_level): """ Train neural networks, DIRECTORY must be a folder containing the output of `yass sort`, CONFIG_TRAIN must be the location of a file with the training parameters """ logging.basicConfig(level=getattr(logging, logger_level)) logger = logging.getLogger(__name__) path_to_spike_train = path.join(directory, 'spike_train.npy') spike_train = np.load(path_to_spike_train) logger.info( 'Loaded spike train with: {:,} spikes and {:,} different IDs'.format( len(spike_train), len(np.unique(spike_train[:, 1])))) path_to_config = path.join(directory, 'config.yaml') CONFIG = Config.from_yaml(path_to_config) CONFIG_TRAIN = load_yaml(config_train) train_neural_networks(CONFIG, CONFIG_TRAIN, spike_train, data_folder=directory)
def test_can_preprocess_without_filtering(path_to_threshold_config): CONFIG = load_yaml(path_to_threshold_config) CONFIG['preprocess']['apply_filter'] = False yass.set_config(CONFIG) standarized_path, standarized_params, whiten_filter = preprocess.run()
def test_can_preprocess_in_parallel(path_to_threshold_config): CONFIG = load_yaml(path_to_threshold_config) CONFIG['resources']['processes'] = 'max' yass.set_config(CONFIG) standarized_path, standarized_params, whiten_filter = preprocess.run()
def __init__(self, path_to_model, input_tensor=None): """ Initializes the attributes for the class NeuralNetDetector. Parameters: ----------- path_to_model: str location of trained neural net autoencoder """ if not path_to_model.endswith('.ckpt'): path_to_model = path_to_model + '.ckpt' self.path_to_model = path_to_model # load parameter of autoencoder path_to_filters_ae = change_extension(path_to_model, 'yaml') self.ae_dict = load_yaml(path_to_filters_ae) n_input = self.ae_dict['n_input'] n_features = self.ae_dict['n_features'] # initialize autoencoder weight self.W_ae = tf.Variable( tf.random_uniform((n_input, n_features), -1.0 / np.sqrt(n_input), 1.0 / np.sqrt(n_input))) # create saver variables self.saver = tf.train.Saver({"W_ae": self.W_ae}) # make score tensorflow tensor from waveform self.score_tf = self._make_graph(input_tensor)
def __init__(self, path_to_ae_model): """ Initializes the attributes for the class NeuralNetDetector. Parameters: ----------- path_to_ae_model: str location of trained neural net autoencoder """ # add locations as attributes self.path_to_ae_model = path_to_ae_model # load parameter of autoencoder path_to_filters_ae = change_extension(path_to_ae_model, 'yaml') self.ae_dict = load_yaml(path_to_filters_ae) n_input = self.ae_dict['n_input'] n_features = self.ae_dict['n_features'] # initialize autoencoder weight self.W_ae = tf.Variable( tf.random_uniform((n_input, n_features), -1.0 / np.sqrt(n_input), 1.0 / np.sqrt(n_input))) # create saver variables self.saver_ae = tf.train.Saver({"W_ae": self.W_ae})
def __init__(self, path_to_triage_model): """ Initializes the attributes for the class NeuralNetDetector. Parameters: ----------- config: configuration file """ self.path_to_triage_model = path_to_triage_model path_to_filters = change_extension(path_to_triage_model, 'yaml') self.filters_dict = load_yaml(path_to_filters) R1 = self.filters_dict['size'] K1, K2 = self.filters_dict['filters'] C = self.filters_dict['n_neighbors'] self.W1 = weight_variable([R1, 1, 1, K1]) self.b1 = bias_variable([K1]) self.W11 = weight_variable([1, 1, K1, K2]) self.b11 = bias_variable([K2]) self.W2 = weight_variable([1, C, K2, 1]) self.b2 = bias_variable([1]) self.saver = tf.train.Saver({ "W1": self.W1, "W11": self.W11, "W2": self.W2, "b1": self.b1, "b11": self.b11, "b2": self.b2 })
def test_can_preprocess_without_filtering(path_to_config, make_tmp_folder): CONFIG = load_yaml(path_to_config) CONFIG['preprocess'] = dict(apply_filter=False) yass.set_config(CONFIG, make_tmp_folder) (standardized_path, standardized_params) = preprocess.run( os.path.join(make_tmp_folder, 'preprocess'))
def test_can_preprocess_in_parallel(path_to_config, make_tmp_folder): CONFIG = load_yaml(path_to_config) CONFIG['resources']['processes'] = 'max' yass.set_config(CONFIG, make_tmp_folder) (standardized_path, standardized_params) = preprocess.run( os.path.join(make_tmp_folder, 'preprocess'))
def test_can_preprocess_without_filtering(path_to_config, make_tmp_folder): CONFIG = load_yaml(path_to_config) CONFIG['preprocess'] = dict(apply_filter=False) yass.set_config(CONFIG, make_tmp_folder) standarized_path, standarized_params, whiten_filter = preprocess.run()
def __init__(self, path_to_detector_model, path_to_ae_model): """ Initializes the attributes for the class NeuralNetDetector. Parameters: ----------- """ self.path_to_detector_model = path_to_detector_model self.path_to_ae_model = path_to_ae_model path_to_filters = change_extension(path_to_detector_model, 'yaml') self.filters_dict = load_yaml(path_to_filters) R1 = self.filters_dict['size'] K1, K2 = self.filters_dict['filters'] C = self.filters_dict['n_neighbors'] self.W1 = weight_variable([R1, 1, 1, K1]) self.b1 = bias_variable([K1]) self.W11 = weight_variable([1, 1, K1, K2]) self.b11 = bias_variable([K2]) self.W2 = weight_variable([1, C, K2, 1]) self.b2 = bias_variable([1]) # output of ae encoding (1st layer) path_to_filters_ae = change_extension(path_to_ae_model, 'yaml') ae_dict = load_yaml(path_to_filters_ae) n_input = ae_dict['n_input'] n_features = ae_dict['n_features'] self.W_ae = tf.Variable( tf.random_uniform((n_input, n_features), -1.0 / np.sqrt(n_input), 1.0 / np.sqrt(n_input))) self.saver_ae = tf.train.Saver({"W_ae": self.W_ae}) self.saver = tf.train.Saver({ "W1": self.W1, "W11": self.W11, "W2": self.W2, "b1": self.b1, "b11": self.b11, "b2": self.b2 })
def load(cls, path_to_model, input_tensor=None): if not path_to_model.endswith('.ckpt'): path_to_model = path_to_model + '.ckpt' # load parameter of autoencoder path_to_params = change_extension(path_to_model, 'yaml') params = load_yaml(path_to_params) return cls(path_to_model, params['waveform_length'], params['n_features'], input_tensor)
def load(cls, path_to_model, threshold, channel_index): if not path_to_model.endswith('.ckpt'): path_to_model = path_to_model+'.ckpt' # load nn parameter files path_to_params = change_extension(path_to_model, 'yaml') params = load_yaml(path_to_params) return cls(path_to_model, params['filters_size'], params['waveform_length'], params['n_neighbors'], threshold, channel_index)
def load(cls, path_to_model, threshold, input_tensor=None, load_test_set=False): """Load a model from a file """ if not path_to_model.endswith('.ckpt'): path_to_model = path_to_model + '.ckpt' # load necessary parameters path_to_params = change_extension(path_to_model, 'yaml') params = load_yaml(path_to_params) return cls(path_to_model=path_to_model, filters_size=params['filters_size'], waveform_length=params['waveform_length'], n_neighbors=params['n_neighbors'], threshold=threshold, input_tensor=input_tensor, load_test_set=load_test_set)
def __init__(self, path_to_detector_model): """ Initializes the attributes for the class NeuralNetDetector. Parameters: ----------- path_to_detector_model: str location of trained neural net detectior """ # add locations as attributes self.path_to_detector_model = path_to_detector_model # load nn parameter files path_to_filters = change_extension(path_to_detector_model, 'yaml') self.filters_dict = load_yaml(path_to_filters) # initialize neural net weights and add as attributes R1 = self.filters_dict['size'] K1, K2 = self.filters_dict['filters'] C = self.filters_dict['n_neighbors'] self.W1 = weight_variable([R1, 1, 1, K1]) self.b1 = bias_variable([K1]) self.W11 = weight_variable([1, 1, K1, K2]) self.b11 = bias_variable([K2]) self.W2 = weight_variable([1, C, K2, 1]) self.b2 = bias_variable([1]) # create saver variables self.saver = tf.train.Saver({ "W1": self.W1, "W11": self.W11, "W2": self.W2, "b1": self.b1, "b11": self.b11, "b2": self.b2 })
def __init__(self, path_to_triage_model): """ Initializes the attributes for the class NeuralNetTriage. Parameters: ----------- path_to_detector_model: str location of trained neural net triage """ # save path to the model as an attribute self.path_to_triage_model = path_to_triage_model # load necessary parameters path_to_filters = change_extension(path_to_triage_model, 'yaml') self.filters_dict = load_yaml(path_to_filters) R1 = self.filters_dict['size'] K1, K2 = self.filters_dict['filters'] C = self.filters_dict['n_neighbors'] # initialize and save nn weights self.W1 = weight_variable([R1, 1, 1, K1]) self.b1 = bias_variable([K1]) self.W11 = weight_variable([1, 1, K1, K2]) self.b11 = bias_variable([K2]) self.W2 = weight_variable([1, C, K2, 1]) self.b2 = bias_variable([1]) # initialize savers self.saver = tf.train.Saver({ "W1": self.W1, "W11": self.W11, "W2": self.W2, "b1": self.b1, "b11": self.b11, "b2": self.b2 })
def params(path_to_config): """ Generate phy's params.py from YASS' config.yaml """ template = load_asset('phy/params.py') config = load_yaml(path_to_config) timestamp = datetime.now().strftime('%B %-d, %Y at %H:%M') dat_path = path.join(config['data']['root_folder'], config['data']['recordings']) n_channels_dat = config['recordings']['n_channels'] dtype = config['recordings']['dtype'] sample_rate = config['recordings']['sampling_rate'] params = template.format(timestamp=timestamp, dat_path=dat_path, n_channels_dat=n_channels_dat, dtype=dtype, offset=0, sample_rate=sample_rate, hp_filtered='True') return params
def make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder, noise_ratio=10, collision_ratio=1, misalign_ratio=1, misalign_ratio2=1, multi=True): """[Description] Parameters ---------- CONFIG: yaml file Configuration file spike_train: numpy.ndarray [number of spikes, 2] Ground truth for training. First column is the spike time, second column is the spike id chosen_templates: list List of chosen templates' id's min_amp: float Minimum value allowed for the maximum absolute amplitude of the isolated spike on its main channel nspikes: int Number of isolated spikes to generate. This is different from the total number of x_detect data_folder: str Folder storing the standarized data (if not exist, run preprocess to automatically generate) noise_ratio: int Ratio of number of noise to isolated spikes. For example, if n_isolated_spike=1000, noise_ratio=5, then n_noise=5000 collision_ratio: int Ratio of number of collisions to isolated spikes. misalign_ratio: int Ratio of number of spatially and temporally misaligned spikes to isolated spikes misalign_ratio2: int Ratio of number of only-spatially misaligned spikes to isolated spikes multi: bool If multi= True, generate training data for multi-channel neural network. Otherwise generate single-channel data Returns ------- x_detect: numpy.ndarray [number of detection training data, temporal length, number of channels] Training data for the detect net. y_detect: numpy.ndarray [number of detection training data] Label for x_detect x_triage: numpy.ndarray [number of triage training data, temporal length, number of channels] Training data for the triage net. y_triage: numpy.ndarray [number of triage training data] Label for x_triage x_ae: numpy.ndarray [number of ae training data, temporal length] Training data for the autoencoder: noisy spikes y_ae: numpy.ndarray [number of ae training data, temporal length] Denoised x_ae """ logger = logging.getLogger(__name__) path_to_data = os.path.join(data_folder, 'standarized.bin') path_to_config = os.path.join(data_folder, 'standarized.yaml') # make sure standarized data already exists if not os.path.exists(path_to_data): raise ValueError( 'Standarized data does not exist in: {}, this is ' 'needed to generate training data, run the ' 'preprocesor first to generate it'.format(path_to_data)) PARAMS = load_yaml(path_to_config) logger.info('Getting templates...') # get templates templates, _ = get_templates( np.hstack((spike_train, np.ones((spike_train.shape[0], 1), 'int32'))), path_to_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates = np.transpose(templates, (2, 1, 0)) logger.info('Got templates ndarray of shape: {}'.format(templates.shape)) # choose good templates (good looking and big enough) templates = choose_templates(templates, chosen_templates) templates_uncropped = np.copy(templates) if templates.shape[0] == 0: raise ValueError("Coulndt find any good templates...") logger.info('Good looking templates of shape: {}'.format(templates.shape)) # align and crop templates templates = crop_templates(templates, CONFIG.spike_size, CONFIG.neigh_channels, CONFIG.geom) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'], CONFIG.recordings.n_channels, PARAMS['data_order'], CONFIG.neigh_channels, CONFIG.geom, templates.shape[1]) # make training data set K = templates.shape[0] R = CONFIG.spike_size amps = np.max(np.abs(templates), axis=1) # make clean augmented spikes nk = int(np.ceil(nspikes / K)) max_amp = np.max(amps) * 1.5 nneigh = templates.shape[2] ################ # clean spikes # ################ x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2])) for k in range(K): tt = templates[k] amp_now = np.max(np.abs(tt)) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] x_clean[k * nk:(k + 1) * nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range ############# # collision # ############# x_collision = np.zeros((x_clean.shape[0] * int(collision_ratio), templates.shape[1], templates.shape[2])) max_shift = 2 * R temporal_shifts = np.random.randint(max_shift * 2, size=x_collision.shape[0]) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 amp_per_data = np.max(x_clean[:, :, 0], axis=1) for j in range(x_collision.shape[0]): shift = temporal_shifts[j] x_collision[j] = np.copy(x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]) idx_candidate = np.where( amp_per_data > np.max(x_collision[j, :, 0]) * 0.3)[0] idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0], size=1)[0]] if multi: x_clean2 = np.copy( x_clean[idx_match] [:, np.random.choice(nneigh, nneigh, replace=False)]) else: x_clean2 = np.copy(x_clean[idx_match]) if shift > 0: x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_collision[j, (-shift):] += x_clean2[:(x_collision.shape[1] + shift)] else: x_collision[j] += x_clean2 ############################################### # temporally and spatially misaligned spikes # ############################################# x_misaligned = np.zeros((x_clean.shape[0] * int(misalign_ratio), templates.shape[1], templates.shape[2])) temporal_shifts = np.random.randint(max_shift * 2, size=x_misaligned.shape[0]) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 for j in range(x_misaligned.shape[0]): shift = temporal_shifts[j] if multi: x_clean2 = np.copy( x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)] [:, :, np.random.choice(nneigh, nneigh, replace=False)]) x_clean2 = np.squeeze(x_clean2) else: x_clean2 = np.copy(x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]) x_clean2 = np.squeeze(x_clean2) if shift > 0: x_misaligned[j, :(x_misaligned.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_misaligned[j, (-shift):] += x_clean2[:(x_misaligned.shape[1] + shift)] else: x_misaligned[j] += x_clean2 ################################ # spatially misaligned spikes # ############################## if multi: x_misaligned2 = np.zeros((x_clean.shape[0] * int(misalign_ratio2), templates.shape[1], templates.shape[2])) for j in range(x_misaligned2.shape[0]): x_misaligned2[j] = np.copy( x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)] [:, :, np.random.choice(nneigh, nneigh, replace=False)]) ######### # noise # ######### # get noise noise = np.random.normal(size=[ x_clean.shape[0] * int(noise_ratio), templates.shape[1], templates.shape[2] ]) for c in range(noise.shape[2]): noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG) reshaped_noise = np.reshape(noise, (-1, noise.shape[2])) noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG), [noise.shape[0], x_clean.shape[1], x_clean.shape[2]]) y_clean = np.ones((x_clean.shape[0])) y_col = np.ones((x_collision.shape[0])) y_misaligned = np.zeros((x_misaligned.shape[0])) if multi: y_misaligned2 = np.zeros((x_misaligned2.shape[0])) y_noise = np.zeros((noise.shape[0])) mid_point = int((x_clean.shape[1] - 1) / 2) # get training set for detection if multi: x = np.concatenate( (x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_collision + noise[np.random.choice( noise.shape[0], x_collision.shape[0], replace=False)], x_misaligned + noise[np.random.choice( noise.shape[0], x_misaligned.shape[0], replace=False)], noise)) x_detect = x[:, (mid_point - R):(mid_point + R + 1), :] y_detect = np.concatenate((y_clean, y_col, y_misaligned, y_noise)) else: x = np.concatenate( (x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_misaligned + noise[np.random.choice( noise.shape[0], x_misaligned.shape[0], replace=False)], noise)) x_detect = x[:, (mid_point - R):(mid_point + R + 1), 0] y_detect = np.concatenate((y_clean, y_misaligned, y_noise)) # get training set for triage if multi: x = np.concatenate(( x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_collision + noise[np.random.choice( noise.shape[0], x_collision.shape[0], replace=False)], x_misaligned2 + noise[np.random.choice( noise.shape[0], x_misaligned2.shape[0], replace=False)], )) x_triage = x[:, (mid_point - R):(mid_point + R + 1), :] y_triage = np.concatenate((y_clean, np.zeros( (x_collision.shape[0])), y_misaligned2)) else: x = np.concatenate(( x_clean + noise[np.random.choice( noise.shape[0], x_clean.shape[0], replace=False)], x_collision + noise[np.random.choice( noise.shape[0], x_collision.shape[0], replace=False)], )) x_triage = x[:, (mid_point - R):(mid_point + R + 1), 0] y_triage = np.concatenate((y_clean, np.zeros((x_collision.shape[0])))) ############### # Autoencoder # ############### n_channels = templates_uncropped.shape[2] templates_ae = crop_templates(templates_uncropped, CONFIG.spike_size, np.ones((n_channels, n_channels), 'int32'), CONFIG.geom) tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1) tt = tt[:, np.ptp(tt, axis=0) > 2] max_amp = np.max(np.ptp(tt, axis=0)) y_ae = np.zeros((nk * tt.shape[1], tt.shape[0])) for k in range(tt.shape[1]): amp_now = np.ptp(tt[:, k]) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] y_ae[k * nk:(k + 1) * nk] = ((tt[:, k] / amp_now)[np.newaxis, :] * amps_range[:, :, 0]) noise = np.random.normal(size=y_ae.shape) noise = np.matmul(noise, temporal_SIG) x_ae = y_ae + noise x_ae = x_ae[:, (mid_point - R):(mid_point + R + 1)] y_ae = y_ae[:, (mid_point - R):(mid_point + R + 1)] return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_params, whiten_filter) = (preprocess.run( if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start ''' ********************************************** ************** DETECT EVENTS ***************** ********************************************** ''' # detect # Cat: This code now runs with open tensorflow calls start = time.time() (spike_index_all) = detect.run(standardized_path, standardized_params, whiten_filter, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) spike_index_clear = None time_detect = time.time() - start ''' ********************************************** ***************** CLUSTER ******************** ********************************************** ''' # cluster start = time.time() path_to_spike_train_cluster = path.join(TMP_FOLDER, 'spike_train_cluster.npy') if os.path.exists(path_to_spike_train_cluster) == False: cluster.run(spike_index_clear, spike_index_all) else: print("\nClustering completed previously...\n\n") spike_train_cluster = np.load(path_to_spike_train_cluster) templates_cluster = np.load( os.path.join(TMP_FOLDER, 'templates_cluster.npy')) time_cluster = time.time() - start #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ", # spike_train_clear.shape, " templates: ", templates.shape) ''' ********************************************** ************** DECONVOLUTION ***************** ********************************************** ''' # run deconvolution start = time.time() spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster, templates_cluster) time_deconvolution = time.time() - start # save spike train path_to_spike_train = path.join(TMP_FOLDER, 'spike_train_post_deconv_post_merge.npy') np.save(path_to_spike_train, spike_train) logger.info('Spike train saved in: {}'.format(path_to_spike_train)) # save template path_to_templates = path.join(TMP_FOLDER, 'templates_post_deconv_post_merge.npy') np.save(path_to_templates, postdeconv_templates) logger.info('Templates saved in: {}'.format(path_to_templates)) ''' ********************************************** ************** POST PROCESSING**************** ********************************************** ''' # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format( config, path_to_config_copy)) # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...'.format( path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...'.format( path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...'.format( path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info( 'Saved all templates scores in {}...'.format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess / total * 100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect / total * 100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster / total * 100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution / total * 100) return spike_train
def _run_pipeline(config, output_file, logger_level='INFO', clean=True, output_dir='tmp/', complete=False): """ Run the entire pipeline given a path to a config file and output path """ # load yass configuration parameters set_config(config) CONFIG = read_config() ROOT_FOLDER = CONFIG.data.root_folder TMP_FOLDER = path.join(ROOT_FOLDER, output_dir) # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # run preprocessor (score, spike_index_clear, spike_index_collision) = preprocess.run(output_directory=output_dir) # run processor (spike_train_clear, templates, spike_index_collision) = process.run(score, spike_index_clear, spike_index_collision, output_directory=output_dir) # run deconvolution spike_train = deconvolute.run(spike_train_clear, templates, spike_index_collision, output_directory=output_dir) # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format( config, path_to_config_copy)) # save templates path_to_templates = path.join(TMP_FOLDER, 'templates.npy') logging.info('Saving templates in {}'.format(path_to_templates)) np.save(path_to_templates, templates) path_to_spike_train = path.join(TMP_FOLDER, output_file) np.save(path_to_spike_train, spike_train) logger.info('Spike train saved in: {}'.format(path_to_spike_train)) # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spikeSize, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_format=PARAMS['data_format']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...'.format( path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...'.format( path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neighChannels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...'.format( path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neighChannels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info( 'Saved all templates scores in {}...'.format(path_to_waveforms))
def export(directory, output_dir): """ Generates phy input files, 'yass sort' (with the `--complete` option) must be run first to generate all the necessary files """ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) TMP_FOLDER = directory PATH_TO_CONFIG = path.join(TMP_FOLDER, 'config.yaml') CONFIG = load_yaml(PATH_TO_CONFIG) ROOT_FOLDER = CONFIG['data']['root_folder'] N_CHANNELS = CONFIG['recordings']['n_channels'] # verify that the tmp/ folder exists, otherwise abort if not os.path.exists(TMP_FOLDER): click.echo("{} directory does not exist, this means you " "haven't run 'yass sort', run it before running " "'yass export' again...".format(TMP_FOLDER)) raise click.Abort() if output_dir is None: PHY_FOLDER = path.join(TMP_FOLDER, 'phy/') else: PHY_FOLDER = output_dir if not os.path.exists(PHY_FOLDER): logger.info('Creating directory: {}'.format(PHY_FOLDER)) os.makedirs(PHY_FOLDER) # TODO: convert data to wide format # generate params.py params = generate.params(PATH_TO_CONFIG) path_to_params = path.join(PHY_FOLDER, 'params.py') with open(path_to_params, 'w') as f: f.write(params) logger.info('Saved {}...'.format(path_to_params)) # channel_positions.npy logger.info('Generating channel_positions.npy') path_to_geom = path.join(ROOT_FOLDER, CONFIG['data']['geometry']) geom = geometry.parse(path_to_geom, N_CHANNELS) path_to_channel_positions = path.join(PHY_FOLDER, 'channel_positions.npy') np.save(path_to_channel_positions, geom) logger.info('Saved {}...'.format(path_to_channel_positions)) # channel_map.npy channel_map = generate.channel_map(N_CHANNELS) path_to_channel_map = path.join(PHY_FOLDER, 'channel_map.npy') np.save(path_to_channel_map, channel_map) logger.info('Saved {}...'.format(path_to_channel_map)) # load spike train path_to_spike_train = path.join(TMP_FOLDER, 'spike_train.npy') logger.info('Loading spike train from {}...'.format(path_to_spike_train)) spike_train = np.load(path_to_spike_train) N_SPIKES, _ = spike_train.shape logger.info('Spike train contains {:,} spikes'.format(N_SPIKES)) # load templates logging.info('Loading previously saved templates...') path_to_templates = path.join(TMP_FOLDER, 'templates.npy') templates = np.load(path_to_templates) _, _, N_TEMPLATES = templates.shape # pc_features_ind.npy path_to_pc_features_ind = path.join(PHY_FOLDER, 'pc_feature_ind.npy') ch_neighbors = geometry.find_channel_neighbors neigh_channels = ch_neighbors(geom, CONFIG['recordings']['spatial_radius']) pc_feature_ind = generate.pc_feature_ind(N_SPIKES, N_TEMPLATES, N_CHANNELS, geom, neigh_channels, spike_train, templates) np.save(path_to_pc_features_ind, pc_feature_ind) # similar_templates.npy path_to_templates = path.join(TMP_FOLDER, 'templates.npy') path_to_similar_templates = path.join(PHY_FOLDER, 'similar_templates.npy') templates = np.load(path_to_templates) similar_templates = generate.similar_templates(templates) np.save(path_to_similar_templates, similar_templates) logger.info('Saved {}...'.format(path_to_similar_templates)) # spike_templates.npy and spike_times.npy path_to_spike_templates = path.join(PHY_FOLDER, 'spike_templates.npy') np.save(path_to_spike_templates, spike_train[:, 1]) logger.info('Saved {}...'.format(path_to_spike_templates)) path_to_spike_times = path.join(PHY_FOLDER, 'spike_times.npy') np.save(path_to_spike_times, spike_train[:, 0]) logger.info('Saved {}...'.format(path_to_spike_times)) # template_feature_ind.npy path_to_template_feature_ind = path.join(PHY_FOLDER, 'template_feature_ind.npy') template_feature_ind = generate.template_feature_ind( N_TEMPLATES, similar_templates) np.save(path_to_template_feature_ind, template_feature_ind) logger.info('Saved {}...'.format(path_to_template_feature_ind)) # template_features.npy templates_score = np.load(path.join(TMP_FOLDER, 'templates_score.npy')) templates_main_channel = np.load( path.join(TMP_FOLDER, 'templates_main_channel.npy')) waveforms_score = np.load(path.join(TMP_FOLDER, 'waveforms_score.npy')) path_to_template_features = path.join(PHY_FOLDER, 'template_features.npy') template_features = generate.template_features( N_SPIKES, N_CHANNELS, N_TEMPLATES, spike_train, templates_main_channel, neigh_channels, geom, templates_score, template_feature_ind, waveforms_score) np.save(path_to_template_features, template_features) logger.info('Saved {}...'.format(path_to_template_features)) # templates.npy path_to_phy_templates = path.join(PHY_FOLDER, 'templates.npy') np.save(path_to_phy_templates, np.transpose(templates, [2, 1, 0])) logging.info( 'Saved phy-compatible templates in {}'.format(path_to_phy_templates)) # templates_ind.npy templates_ind = generate.templates_ind(N_TEMPLATES, N_CHANNELS) path_to_templates_ind = path.join(PHY_FOLDER, 'templates_ind.npy') np.save(path_to_templates_ind, templates_ind) logger.info('Saved {}...'.format(path_to_templates_ind)) # whitening_mat.npy and whitening_mat_inv.npy logging.info('Generating whitening_mat.npy and whitening_mat_inv.npy...') path_to_whitening = path.join(TMP_FOLDER, 'whitening.npy') path_to_whitening_mat = path.join(PHY_FOLDER, 'whitening_mat.npy') shutil.copy2(path_to_whitening, ) logging.info('Saving copy of whitening: {} in {}'.format( path_to_whitening, path_to_whitening_mat)) path_to_whitening_mat_inv = path.join(PHY_FOLDER, 'whitening_mat_inv.npy') whitening = np.load(path_to_whitening) np.save(path_to_whitening_mat_inv, np.linalg.inv(whitening)) logger.info('Saving inverse of whitening matrix in {}...'.format( path_to_whitening_mat_inv)) logging.info('Done.')
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (relative to CONFIG.data.root_folder to store the output data, defaults to tmp/ complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standarized.bin`` - Standarized recordings (from preprocess) * ``standarized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config) CONFIG = read_config() ROOT_FOLDER = CONFIG.data.root_folder TMP_FOLDER = path.join(ROOT_FOLDER, output_dir) # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join(TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger and start coloredlogs logger = logging.getLogger(__name__) coloredlogs.install(logger=logger) if set_zero_seed: logger.warning('Set numpy seed to zero') np.random.seed(0) # print yass version logger.info('YASS version: %s', yass.__version__) # preprocess start = time.time() (standarized_path, standarized_params, channel_index, whiten_filter) = (preprocess .run(output_directory=output_dir, if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start # detect start = time.time() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter, output_directory=output_dir, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) time_detect = time.time() - start # cluster start = time.time() spike_train_clear, tmp_loc, vbParam = cluster.run( score, spike_index_clear, output_directory=output_dir, if_file_exists=CONFIG.cluster.if_file_exists, save_results=CONFIG.cluster.save_results) time_cluster = time.time() - start # get templates start = time.time() (templates, spike_train_clear_after_templates, groups, idx_good_templates) = get_templates.run( spike_train_clear, tmp_loc, output_directory=output_dir, if_file_exists=CONFIG.templates.if_file_exists, save_results=CONFIG.templates.save_results) time_templates = time.time() - start # run deconvolution start = time.time() spike_train = deconvolute.run(spike_index_all, templates, output_directory=output_dir) time_deconvolution = time.time() - start # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format(config, path_to_config_copy)) # TODO: complete flag saves other files needed for integrating phy # with yass, the integration hasn't been completed yet # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...' .format(path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...' .format(path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...' .format(path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info('Saved all templates scores in {}...' .format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_templates + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess/total*100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect/total*100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster/total*100) logger.info('\t Templates: %s (%.2f %%)', human_readable_time(time_templates), time_templates/total*100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution/total*100) return spike_train
def make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder): """[Description] Parameters ---------- Returns ------- """ logger = logging.getLogger(__name__) path_to_data = os.path.join(data_folder, 'standarized.bin') path_to_config = os.path.join(data_folder, 'standarized.yaml') # make sure standarized data already exists if not os.path.exists(path_to_data): raise ValueError( 'Standarized data does not exist in: {}, this is ' 'needed to generate training data, run the ' 'preprocesor first to generate it'.format(path_to_data)) PARAMS = load_yaml(path_to_config) logger.info('Getting templates...') # get templates templates, _ = get_templates(spike_train, path_to_data, CONFIG.spikeSize) templates = np.transpose(templates, (2, 1, 0)) logger.info('Got templates ndarray of shape: {}'.format(templates.shape)) # choose good templates (good looking and big enough) templates = choose_templates(templates, chosen_templates) if templates.shape[0] == 0: raise ValueError("Coulndt find any good templates...") logger.info('Good looking templates of shape: {}'.format(templates.shape)) # align and crop templates templates = crop_templates(templates, CONFIG.spikeSize, CONFIG.neighChannels, CONFIG.geom) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'], CONFIG.recordings.n_channels, PARAMS['data_format'], CONFIG.neighChannels, CONFIG.geom, templates.shape[1]) # make training data set K = templates.shape[0] R = CONFIG.spikeSize amps = np.max(np.abs(templates), axis=1) # make clean augmented spikes nk = int(np.ceil(nspikes / K)) max_amp = np.max(amps) * 1.5 nneigh = templates.shape[2] ################ # clean spikes # ################ x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2])) for k in range(K): tt = templates[k] amp_now = np.max(np.abs(tt)) amps_range = (np.arange(nk) * (max_amp - min_amp) / nk + min_amp)[:, np.newaxis, np.newaxis] x_clean[k * nk:(k + 1) * nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range ############# # collision # ############# x_collision = np.zeros(x_clean.shape) max_shift = 2 * R temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 amp_per_data = np.max(x_clean[:, :, 0], axis=1) for j in range(nk * K): shift = temporal_shifts[j] x_collision[j] = np.copy(x_clean[j]) idx_candidate = np.where(amp_per_data > amp_per_data[j] * 0.3)[0] idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0], size=1)[0]] x_clean2 = np.copy(x_clean[idx_match] [:, np.random.choice(nneigh, nneigh, replace=False)]) if shift > 0: x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_collision[j, (-shift):] += x_clean2[:(x_collision.shape[1] + shift)] else: x_collision[j] += x_clean2 ##################### # misaligned spikes # ##################### x_misaligned = np.zeros(x_clean.shape) temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift temporal_shifts[ temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5 temporal_shifts[ temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6 for j in range(nk * K): shift = temporal_shifts[j] x_clean2 = np.copy( x_clean[j][:, np.random.choice(nneigh, nneigh, replace=False)]) if shift > 0: x_misaligned[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:] elif shift < 0: x_misaligned[j, (-shift):] += x_clean2[:(x_collision.shape[1] + shift)] else: x_misaligned[j] += x_clean2 ######### # noise # ######### # get noise noise = np.random.normal(size=x_clean.shape) for c in range(noise.shape[2]): noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG) reshaped_noise = np.reshape(noise, (-1, noise.shape[2])) noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG), [x_clean.shape[0], x_clean.shape[1], x_clean.shape[2]]) y_clean = np.ones((x_clean.shape[0])) y_col = np.ones((x_clean.shape[0])) y_misalinged = np.zeros((x_clean.shape[0])) y_noise = np.zeros((x_clean.shape[0])) mid_point = int((x_clean.shape[1] - 1) / 2) # get training set for detection x = np.concatenate( (x_clean + noise, x_collision + noise[np.random.permutation(noise.shape[0])], x_misaligned + noise[np.random.permutation(noise.shape[0])], noise)) x_detect = x[:, (mid_point - R):(mid_point + R + 1), :] y_detect = np.concatenate((y_clean, y_col, y_misalinged, y_noise)) # get training set for triage x = np.concatenate(( x_clean + noise, x_collision + noise[np.random.permutation(noise.shape[0])], )) x_triage = x[:, (mid_point - R):(mid_point + R + 1), :] y_triage = np.concatenate((y_clean, np.zeros((x_clean.shape[0])))) # ge training set for auto encoder ae_shift_max = 1 temporal_shifts_ae = np.random.randint( ae_shift_max * 2 + 1, size=x_clean.shape[0]) - ae_shift_max y_ae = np.zeros((x_clean.shape[0], 2 * R + 1)) x_ae = np.zeros((x_clean.shape[0], 2 * R + 1)) for j in range(x_ae.shape[0]): y_ae[j] = x_clean[j, (mid_point - R + temporal_shifts_ae[j]):(mid_point + R + 1 + temporal_shifts_ae[j]), 0] x_ae[j] = x_clean[j, (mid_point - R + temporal_shifts_ae[j]):( mid_point + R + 1 + temporal_shifts_ae[j]), 0] + noise[j, ( mid_point - R + temporal_shifts_ae[j]):( mid_point + R + 1 + temporal_shifts_ae[j]), 0] return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae