Example #1
0
def train(directory, config_train, logger_level):
    """
    Train neural networks, DIRECTORY must be a folder containing the
    output of `yass sort`, CONFIG_TRAIN must be the location of a file with
    the training parameters
    """
    logging.basicConfig(level=getattr(logging, logger_level))
    logger = logging.getLogger(__name__)

    path_to_spike_train = path.join(directory, 'spike_train.npy')
    spike_train = np.load(path_to_spike_train)

    logger.info(
        'Loaded spike train with: {:,} spikes and {:,} different IDs'.format(
            len(spike_train), len(np.unique(spike_train[:, 1]))))

    path_to_config = path.join(directory, 'config.yaml')
    CONFIG = Config.from_yaml(path_to_config)

    CONFIG_TRAIN = load_yaml(config_train)

    train_neural_networks(CONFIG,
                          CONFIG_TRAIN,
                          spike_train,
                          data_folder=directory)
Example #2
0
def test_can_preprocess_without_filtering(path_to_threshold_config):
    CONFIG = load_yaml(path_to_threshold_config)
    CONFIG['preprocess']['apply_filter'] = False

    yass.set_config(CONFIG)

    standarized_path, standarized_params, whiten_filter = preprocess.run()
Example #3
0
def test_can_preprocess_in_parallel(path_to_threshold_config):
    CONFIG = load_yaml(path_to_threshold_config)
    CONFIG['resources']['processes'] = 'max'

    yass.set_config(CONFIG)

    standarized_path, standarized_params, whiten_filter = preprocess.run()
Example #4
0
    def __init__(self, path_to_model, input_tensor=None):
        """
        Initializes the attributes for the class NeuralNetDetector.

        Parameters:
        -----------
        path_to_model: str
            location of trained neural net autoencoder
        """
        if not path_to_model.endswith('.ckpt'):
            path_to_model = path_to_model + '.ckpt'

        self.path_to_model = path_to_model

        # load parameter of autoencoder
        path_to_filters_ae = change_extension(path_to_model, 'yaml')
        self.ae_dict = load_yaml(path_to_filters_ae)
        n_input = self.ae_dict['n_input']
        n_features = self.ae_dict['n_features']

        # initialize autoencoder weight
        self.W_ae = tf.Variable(
            tf.random_uniform((n_input, n_features), -1.0 / np.sqrt(n_input),
                              1.0 / np.sqrt(n_input)))

        # create saver variables
        self.saver = tf.train.Saver({"W_ae": self.W_ae})

        # make score tensorflow tensor from waveform
        self.score_tf = self._make_graph(input_tensor)
Example #5
0
    def __init__(self, path_to_ae_model):
        """
        Initializes the attributes for the class NeuralNetDetector.

        Parameters:
        -----------
        path_to_ae_model: str
            location of trained neural net autoencoder
        """

        # add locations as attributes
        self.path_to_ae_model = path_to_ae_model

        # load parameter of autoencoder
        path_to_filters_ae = change_extension(path_to_ae_model, 'yaml')
        self.ae_dict = load_yaml(path_to_filters_ae)
        n_input = self.ae_dict['n_input']
        n_features = self.ae_dict['n_features']

        # initialize autoencoder weight
        self.W_ae = tf.Variable(
            tf.random_uniform((n_input, n_features), -1.0 / np.sqrt(n_input),
                              1.0 / np.sqrt(n_input)))

        # create saver variables
        self.saver_ae = tf.train.Saver({"W_ae": self.W_ae})
Example #6
0
    def __init__(self, path_to_triage_model):
        """
            Initializes the attributes for the class NeuralNetDetector.

            Parameters:
            -----------
            config: configuration file
        """

        self.path_to_triage_model = path_to_triage_model

        path_to_filters = change_extension(path_to_triage_model, 'yaml')
        self.filters_dict = load_yaml(path_to_filters)

        R1 = self.filters_dict['size']
        K1, K2 = self.filters_dict['filters']
        C = self.filters_dict['n_neighbors']

        self.W1 = weight_variable([R1, 1, 1, K1])
        self.b1 = bias_variable([K1])

        self.W11 = weight_variable([1, 1, K1, K2])
        self.b11 = bias_variable([K2])

        self.W2 = weight_variable([1, C, K2, 1])
        self.b2 = bias_variable([1])

        self.saver = tf.train.Saver({
            "W1": self.W1,
            "W11": self.W11,
            "W2": self.W2,
            "b1": self.b1,
            "b11": self.b11,
            "b2": self.b2
        })
Example #7
0
def test_can_preprocess_without_filtering(path_to_config, make_tmp_folder):
    CONFIG = load_yaml(path_to_config)
    CONFIG['preprocess'] = dict(apply_filter=False)

    yass.set_config(CONFIG, make_tmp_folder)

    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))
Example #8
0
def test_can_preprocess_in_parallel(path_to_config, make_tmp_folder):
    CONFIG = load_yaml(path_to_config)
    CONFIG['resources']['processes'] = 'max'

    yass.set_config(CONFIG, make_tmp_folder)

    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))
Example #9
0
def test_can_preprocess_without_filtering(path_to_config,
                                          make_tmp_folder):
    CONFIG = load_yaml(path_to_config)
    CONFIG['preprocess'] = dict(apply_filter=False)

    yass.set_config(CONFIG, make_tmp_folder)

    standarized_path, standarized_params, whiten_filter = preprocess.run()
Example #10
0
    def __init__(self, path_to_detector_model, path_to_ae_model):
        """
            Initializes the attributes for the class NeuralNetDetector.

            Parameters:
            -----------

        """

        self.path_to_detector_model = path_to_detector_model
        self.path_to_ae_model = path_to_ae_model

        path_to_filters = change_extension(path_to_detector_model, 'yaml')
        self.filters_dict = load_yaml(path_to_filters)

        R1 = self.filters_dict['size']
        K1, K2 = self.filters_dict['filters']
        C = self.filters_dict['n_neighbors']

        self.W1 = weight_variable([R1, 1, 1, K1])
        self.b1 = bias_variable([K1])

        self.W11 = weight_variable([1, 1, K1, K2])
        self.b11 = bias_variable([K2])

        self.W2 = weight_variable([1, C, K2, 1])
        self.b2 = bias_variable([1])

        # output of ae encoding (1st layer)
        path_to_filters_ae = change_extension(path_to_ae_model, 'yaml')
        ae_dict = load_yaml(path_to_filters_ae)
        n_input = ae_dict['n_input']
        n_features = ae_dict['n_features']
        self.W_ae = tf.Variable(
            tf.random_uniform((n_input, n_features), -1.0 / np.sqrt(n_input),
                              1.0 / np.sqrt(n_input)))

        self.saver_ae = tf.train.Saver({"W_ae": self.W_ae})
        self.saver = tf.train.Saver({
            "W1": self.W1,
            "W11": self.W11,
            "W2": self.W2,
            "b1": self.b1,
            "b11": self.b11,
            "b2": self.b2
        })
Example #11
0
    def load(cls, path_to_model, input_tensor=None):

        if not path_to_model.endswith('.ckpt'):
            path_to_model = path_to_model + '.ckpt'

        # load parameter of autoencoder
        path_to_params = change_extension(path_to_model, 'yaml')
        params = load_yaml(path_to_params)

        return cls(path_to_model, params['waveform_length'],
                   params['n_features'], input_tensor)
Example #12
0
    def load(cls, path_to_model, threshold, channel_index):

        if not path_to_model.endswith('.ckpt'):
            path_to_model = path_to_model+'.ckpt'

        # load nn parameter files
        path_to_params = change_extension(path_to_model, 'yaml')
        params = load_yaml(path_to_params)

        return cls(path_to_model, params['filters_size'],
                   params['waveform_length'], params['n_neighbors'],
                   threshold, channel_index)
Example #13
0
    def load(cls,
             path_to_model,
             threshold,
             input_tensor=None,
             load_test_set=False):
        """Load a model from a file
        """
        if not path_to_model.endswith('.ckpt'):
            path_to_model = path_to_model + '.ckpt'

        # load necessary parameters
        path_to_params = change_extension(path_to_model, 'yaml')
        params = load_yaml(path_to_params)

        return cls(path_to_model=path_to_model,
                   filters_size=params['filters_size'],
                   waveform_length=params['waveform_length'],
                   n_neighbors=params['n_neighbors'],
                   threshold=threshold,
                   input_tensor=input_tensor,
                   load_test_set=load_test_set)
Example #14
0
    def __init__(self, path_to_detector_model):
        """
        Initializes the attributes for the class NeuralNetDetector.

        Parameters:
        -----------
        path_to_detector_model: str
            location of trained neural net detectior
        """

        # add locations as attributes
        self.path_to_detector_model = path_to_detector_model

        # load nn parameter files
        path_to_filters = change_extension(path_to_detector_model, 'yaml')
        self.filters_dict = load_yaml(path_to_filters)

        # initialize neural net weights and add as attributes
        R1 = self.filters_dict['size']
        K1, K2 = self.filters_dict['filters']
        C = self.filters_dict['n_neighbors']

        self.W1 = weight_variable([R1, 1, 1, K1])
        self.b1 = bias_variable([K1])

        self.W11 = weight_variable([1, 1, K1, K2])
        self.b11 = bias_variable([K2])

        self.W2 = weight_variable([1, C, K2, 1])
        self.b2 = bias_variable([1])

        # create saver variables
        self.saver = tf.train.Saver({
            "W1": self.W1,
            "W11": self.W11,
            "W2": self.W2,
            "b1": self.b1,
            "b11": self.b11,
            "b2": self.b2
        })
Example #15
0
    def __init__(self, path_to_triage_model):
        """
            Initializes the attributes for the class NeuralNetTriage.

            Parameters:
            -----------
            path_to_detector_model: str
                location of trained neural net triage
        """
        # save path to the model as an attribute
        self.path_to_triage_model = path_to_triage_model

        # load necessary parameters
        path_to_filters = change_extension(path_to_triage_model, 'yaml')
        self.filters_dict = load_yaml(path_to_filters)
        R1 = self.filters_dict['size']
        K1, K2 = self.filters_dict['filters']
        C = self.filters_dict['n_neighbors']

        # initialize and save nn weights
        self.W1 = weight_variable([R1, 1, 1, K1])
        self.b1 = bias_variable([K1])

        self.W11 = weight_variable([1, 1, K1, K2])
        self.b11 = bias_variable([K2])

        self.W2 = weight_variable([1, C, K2, 1])
        self.b2 = bias_variable([1])

        # initialize savers
        self.saver = tf.train.Saver({
            "W1": self.W1,
            "W11": self.W11,
            "W2": self.W2,
            "b1": self.b1,
            "b11": self.b11,
            "b2": self.b2
        })
Example #16
0
def params(path_to_config):
    """
    Generate phy's params.py from YASS' config.yaml
    """
    template = load_asset('phy/params.py')
    config = load_yaml(path_to_config)

    timestamp = datetime.now().strftime('%B %-d, %Y at %H:%M')
    dat_path = path.join(config['data']['root_folder'],
                         config['data']['recordings'])
    n_channels_dat = config['recordings']['n_channels']
    dtype = config['recordings']['dtype']
    sample_rate = config['recordings']['sampling_rate']

    params = template.format(timestamp=timestamp,
                             dat_path=dat_path,
                             n_channels_dat=n_channels_dat,
                             dtype=dtype,
                             offset=0,
                             sample_rate=sample_rate,
                             hp_filtered='True')

    return params
Example #17
0
def make_training_data(CONFIG,
                       spike_train,
                       chosen_templates,
                       min_amp,
                       nspikes,
                       data_folder,
                       noise_ratio=10,
                       collision_ratio=1,
                       misalign_ratio=1,
                       misalign_ratio2=1,
                       multi=True):
    """[Description]
    Parameters
    ----------
    CONFIG: yaml file
        Configuration file
    spike_train: numpy.ndarray
        [number of spikes, 2] Ground truth for training. First column is the
        spike time, second column is the spike id
    chosen_templates: list
        List of chosen templates' id's
    min_amp: float
        Minimum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    nspikes: int
        Number of isolated spikes to generate. This is different from the
        total number of x_detect
    data_folder: str
        Folder storing the standarized data (if not exist, run preprocess to
        automatically generate)
    noise_ratio: int
        Ratio of number of noise to isolated spikes. For example, if
        n_isolated_spike=1000, noise_ratio=5, then n_noise=5000
    collision_ratio: int
        Ratio of number of collisions to isolated spikes.
    misalign_ratio: int
        Ratio of number of spatially and temporally misaligned spikes to
        isolated spikes
    misalign_ratio2: int
        Ratio of number of only-spatially misaligned spikes to isolated spikes
    multi: bool
        If multi= True, generate training data for multi-channel neural
        network. Otherwise generate single-channel data

    Returns
    -------
    x_detect: numpy.ndarray
        [number of detection training data, temporal length, number of
        channels] Training data for the detect net.
    y_detect: numpy.ndarray
        [number of detection training data] Label for x_detect

    x_triage: numpy.ndarray
        [number of triage training data, temporal length, number of channels]
        Training data for the triage net.
    y_triage: numpy.ndarray
        [number of triage training data] Label for x_triage


    x_ae: numpy.ndarray
        [number of ae training data, temporal length] Training data for the
        autoencoder: noisy spikes
    y_ae: numpy.ndarray
        [number of ae training data, temporal length] Denoised x_ae

    """

    logger = logging.getLogger(__name__)

    path_to_data = os.path.join(data_folder, 'standarized.bin')
    path_to_config = os.path.join(data_folder, 'standarized.yaml')

    # make sure standarized data already exists
    if not os.path.exists(path_to_data):
        raise ValueError(
            'Standarized data does not exist in: {}, this is '
            'needed to generate training data, run the '
            'preprocesor first to generate it'.format(path_to_data))

    PARAMS = load_yaml(path_to_config)

    logger.info('Getting templates...')

    # get templates
    templates, _ = get_templates(
        np.hstack((spike_train, np.ones((spike_train.shape[0], 1), 'int32'))),
        path_to_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size)

    templates = np.transpose(templates, (2, 1, 0))

    logger.info('Got templates ndarray of shape: {}'.format(templates.shape))

    # choose good templates (good looking and big enough)
    templates = choose_templates(templates, chosen_templates)
    templates_uncropped = np.copy(templates)

    if templates.shape[0] == 0:
        raise ValueError("Coulndt find any good templates...")

    logger.info('Good looking templates of shape: {}'.format(templates.shape))

    # align and crop templates
    templates = crop_templates(templates, CONFIG.spike_size,
                               CONFIG.neigh_channels, CONFIG.geom)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'],
                                          CONFIG.recordings.n_channels,
                                          PARAMS['data_order'],
                                          CONFIG.neigh_channels, CONFIG.geom,
                                          templates.shape[1])

    # make training data set
    K = templates.shape[0]
    R = CONFIG.spike_size
    amps = np.max(np.abs(templates), axis=1)

    # make clean augmented spikes
    nk = int(np.ceil(nspikes / K))
    max_amp = np.max(amps) * 1.5
    nneigh = templates.shape[2]

    ################
    # clean spikes #
    ################
    x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2]))
    for k in range(K):
        tt = templates[k]
        amp_now = np.max(np.abs(tt))
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]
        x_clean[k * nk:(k + 1) *
                nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range

    #############
    # collision #
    #############
    x_collision = np.zeros((x_clean.shape[0] * int(collision_ratio),
                            templates.shape[1], templates.shape[2]))
    max_shift = 2 * R

    temporal_shifts = np.random.randint(max_shift * 2,
                                        size=x_collision.shape[0]) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    amp_per_data = np.max(x_clean[:, :, 0], axis=1)
    for j in range(x_collision.shape[0]):
        shift = temporal_shifts[j]

        x_collision[j] = np.copy(x_clean[np.random.choice(x_clean.shape[0],
                                                          1,
                                                          replace=True)])
        idx_candidate = np.where(
            amp_per_data > np.max(x_collision[j, :, 0]) * 0.3)[0]
        idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0],
                                                    size=1)[0]]
        if multi:
            x_clean2 = np.copy(
                x_clean[idx_match]
                [:, np.random.choice(nneigh, nneigh, replace=False)])
        else:
            x_clean2 = np.copy(x_clean[idx_match])

        if shift > 0:
            x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:]

        elif shift < 0:
            x_collision[j,
                        (-shift):] += x_clean2[:(x_collision.shape[1] + shift)]
        else:
            x_collision[j] += x_clean2

    ###############################################
    # temporally and spatially misaligned spikes #
    #############################################
    x_misaligned = np.zeros((x_clean.shape[0] * int(misalign_ratio),
                             templates.shape[1], templates.shape[2]))

    temporal_shifts = np.random.randint(max_shift * 2,
                                        size=x_misaligned.shape[0]) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    for j in range(x_misaligned.shape[0]):
        shift = temporal_shifts[j]
        if multi:
            x_clean2 = np.copy(
                x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]
                [:, :, np.random.choice(nneigh, nneigh, replace=False)])
            x_clean2 = np.squeeze(x_clean2)
        else:
            x_clean2 = np.copy(x_clean[np.random.choice(x_clean.shape[0],
                                                        1,
                                                        replace=True)])
            x_clean2 = np.squeeze(x_clean2)

        if shift > 0:
            x_misaligned[j, :(x_misaligned.shape[1] -
                              shift)] += x_clean2[shift:]

        elif shift < 0:
            x_misaligned[j, (-shift):] += x_clean2[:(x_misaligned.shape[1] +
                                                     shift)]
        else:
            x_misaligned[j] += x_clean2

    ################################
    # spatially misaligned spikes #
    ##############################
    if multi:
        x_misaligned2 = np.zeros((x_clean.shape[0] * int(misalign_ratio2),
                                  templates.shape[1], templates.shape[2]))
        for j in range(x_misaligned2.shape[0]):
            x_misaligned2[j] = np.copy(
                x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]
                [:, :, np.random.choice(nneigh, nneigh, replace=False)])

    #########
    # noise #
    #########

    # get noise
    noise = np.random.normal(size=[
        x_clean.shape[0] *
        int(noise_ratio), templates.shape[1], templates.shape[2]
    ])
    for c in range(noise.shape[2]):
        noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG)

        reshaped_noise = np.reshape(noise, (-1, noise.shape[2]))
    noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG),
                       [noise.shape[0], x_clean.shape[1], x_clean.shape[2]])

    y_clean = np.ones((x_clean.shape[0]))
    y_col = np.ones((x_collision.shape[0]))
    y_misaligned = np.zeros((x_misaligned.shape[0]))
    if multi:
        y_misaligned2 = np.zeros((x_misaligned2.shape[0]))
    y_noise = np.zeros((noise.shape[0]))

    mid_point = int((x_clean.shape[1] - 1) / 2)

    # get training set for detection
    if multi:
        x = np.concatenate(
            (x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
             x_collision + noise[np.random.choice(
                 noise.shape[0], x_collision.shape[0], replace=False)],
             x_misaligned + noise[np.random.choice(
                 noise.shape[0], x_misaligned.shape[0], replace=False)],
             noise))

        x_detect = x[:, (mid_point - R):(mid_point + R + 1), :]
        y_detect = np.concatenate((y_clean, y_col, y_misaligned, y_noise))
    else:
        x = np.concatenate(
            (x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
             x_misaligned + noise[np.random.choice(
                 noise.shape[0], x_misaligned.shape[0], replace=False)],
             noise))
        x_detect = x[:, (mid_point - R):(mid_point + R + 1), 0]
        y_detect = np.concatenate((y_clean, y_misaligned, y_noise))

    # get training set for triage
    if multi:
        x = np.concatenate((
            x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
            x_collision + noise[np.random.choice(
                noise.shape[0], x_collision.shape[0], replace=False)],
            x_misaligned2 + noise[np.random.choice(
                noise.shape[0], x_misaligned2.shape[0], replace=False)],
        ))
        x_triage = x[:, (mid_point - R):(mid_point + R + 1), :]
        y_triage = np.concatenate((y_clean, np.zeros(
            (x_collision.shape[0])), y_misaligned2))
    else:
        x = np.concatenate((
            x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
            x_collision + noise[np.random.choice(
                noise.shape[0], x_collision.shape[0], replace=False)],
        ))
        x_triage = x[:, (mid_point - R):(mid_point + R + 1), 0]
        y_triage = np.concatenate((y_clean, np.zeros((x_collision.shape[0]))))

    ###############
    # Autoencoder #
    ###############

    n_channels = templates_uncropped.shape[2]
    templates_ae = crop_templates(templates_uncropped, CONFIG.spike_size,
                                  np.ones((n_channels, n_channels), 'int32'),
                                  CONFIG.geom)

    tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1)
    tt = tt[:, np.ptp(tt, axis=0) > 2]
    max_amp = np.max(np.ptp(tt, axis=0))

    y_ae = np.zeros((nk * tt.shape[1], tt.shape[0]))
    for k in range(tt.shape[1]):
        amp_now = np.ptp(tt[:, k])
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]

        y_ae[k * nk:(k + 1) * nk] = ((tt[:, k] / amp_now)[np.newaxis, :] *
                                     amps_range[:, :, 0])

    noise = np.random.normal(size=y_ae.shape)
    noise = np.matmul(noise, temporal_SIG)

    x_ae = y_ae + noise
    x_ae = x_ae[:, (mid_point - R):(mid_point + R + 1)]
    y_ae = y_ae[:, (mid_point - R):(mid_point + R + 1)]

    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
Example #18
0
def run(config,
        logger_level='INFO',
        clean=False,
        output_dir='tmp/',
        complete=False,
        set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    set_config(config, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''
    # preprocess
    start = time.time()
    (standardized_path, standardized_params, whiten_filter) = (preprocess.run(
        if_file_exists=CONFIG.preprocess.if_file_exists))

    time_preprocess = time.time() - start
    ''' **********************************************
        ************** DETECT EVENTS *****************
        **********************************************
    '''
    # detect
    # Cat: This code now runs with open tensorflow calls
    start = time.time()
    (spike_index_all) = detect.run(standardized_path,
                                   standardized_params,
                                   whiten_filter,
                                   if_file_exists=CONFIG.detect.if_file_exists,
                                   save_results=CONFIG.detect.save_results)
    spike_index_clear = None
    time_detect = time.time() - start
    ''' **********************************************
        ***************** CLUSTER ********************
        **********************************************
    '''

    # cluster
    start = time.time()
    path_to_spike_train_cluster = path.join(TMP_FOLDER,
                                            'spike_train_cluster.npy')
    if os.path.exists(path_to_spike_train_cluster) == False:
        cluster.run(spike_index_clear, spike_index_all)
    else:
        print("\nClustering completed previously...\n\n")

    spike_train_cluster = np.load(path_to_spike_train_cluster)
    templates_cluster = np.load(
        os.path.join(TMP_FOLDER, 'templates_cluster.npy'))

    time_cluster = time.time() - start
    #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ",
    #        spike_train_clear.shape, " templates: ", templates.shape)
    ''' **********************************************
        ************** DECONVOLUTION *****************
        **********************************************
    '''

    # run deconvolution
    start = time.time()
    spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster,
                                                       templates_cluster)
    time_deconvolution = time.time() - start

    # save spike train
    path_to_spike_train = path.join(TMP_FOLDER,
                                    'spike_train_post_deconv_post_merge.npy')
    np.save(path_to_spike_train, spike_train)
    logger.info('Spike train saved in: {}'.format(path_to_spike_train))

    # save template
    path_to_templates = path.join(TMP_FOLDER,
                                  'templates_post_deconv_post_merge.npy')
    np.save(path_to_templates, postdeconv_templates)
    logger.info('Templates saved in: {}'.format(path_to_templates))
    ''' **********************************************
        ************** POST PROCESSING****************
        **********************************************
    '''

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')

    if isinstance(config, Mapping):
        with open(path_to_config_copy, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
    else:
        shutil.copy2(config, path_to_config_copy)

    logging.info('Saving copy of config: {} in {}'.format(
        config, path_to_config_copy))

    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spike_size,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_order=PARAMS['data_order'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'.format(
            path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'.format(
            path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'.format(
            path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info(
            'Saved all templates scores in {}...'.format(path_to_waveforms))

    logger.info('Finished YASS execution. Timing summary:')
    total = (time_preprocess + time_detect + time_cluster + time_deconvolution)
    logger.info('\t Preprocess: %s (%.2f %%)',
                human_readable_time(time_preprocess),
                time_preprocess / total * 100)
    logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect),
                time_detect / total * 100)
    logger.info('\t Clustering: %s (%.2f %%)',
                human_readable_time(time_cluster), time_cluster / total * 100)
    logger.info('\t Deconvolution: %s (%.2f %%)',
                human_readable_time(time_deconvolution),
                time_deconvolution / total * 100)

    return spike_train
Example #19
0
def _run_pipeline(config,
                  output_file,
                  logger_level='INFO',
                  clean=True,
                  output_dir='tmp/',
                  complete=False):
    """
    Run the entire pipeline given a path to a config file
    and output path
    """
    # load yass configuration parameters
    set_config(config)
    CONFIG = read_config()
    ROOT_FOLDER = CONFIG.data.root_folder
    TMP_FOLDER = path.join(ROOT_FOLDER, output_dir)

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # run preprocessor
    (score, spike_index_clear,
     spike_index_collision) = preprocess.run(output_directory=output_dir)

    # run processor
    (spike_train_clear, templates,
     spike_index_collision) = process.run(score,
                                          spike_index_clear,
                                          spike_index_collision,
                                          output_directory=output_dir)

    # run deconvolution
    spike_train = deconvolute.run(spike_train_clear,
                                  templates,
                                  spike_index_collision,
                                  output_directory=output_dir)

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')
    shutil.copy2(config, path_to_config_copy)
    logging.info('Saving copy of config: {} in {}'.format(
        config, path_to_config_copy))

    # save templates
    path_to_templates = path.join(TMP_FOLDER, 'templates.npy')
    logging.info('Saving templates in {}'.format(path_to_templates))
    np.save(path_to_templates, templates)

    path_to_spike_train = path.join(TMP_FOLDER, output_file)
    np.save(path_to_spike_train, spike_train)
    logger.info('Spike train saved in: {}'.format(path_to_spike_train))

    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spikeSize,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_format=PARAMS['data_format'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'.format(
            path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'.format(
            path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neighChannels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'.format(
            path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neighChannels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info(
            'Saved all templates scores in {}...'.format(path_to_waveforms))
Example #20
0
def export(directory, output_dir):
    """
    Generates phy input files, 'yass sort' (with the `--complete` option)
    must be run first to generate all the necessary files
    """
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    TMP_FOLDER = directory
    PATH_TO_CONFIG = path.join(TMP_FOLDER, 'config.yaml')
    CONFIG = load_yaml(PATH_TO_CONFIG)
    ROOT_FOLDER = CONFIG['data']['root_folder']
    N_CHANNELS = CONFIG['recordings']['n_channels']

    # verify that the tmp/ folder exists, otherwise abort
    if not os.path.exists(TMP_FOLDER):
        click.echo("{} directory does not exist, this means you "
                   "haven't run 'yass sort', run it before running "
                   "'yass export' again...".format(TMP_FOLDER))
        raise click.Abort()

    if output_dir is None:
        PHY_FOLDER = path.join(TMP_FOLDER, 'phy/')
    else:
        PHY_FOLDER = output_dir

    if not os.path.exists(PHY_FOLDER):
        logger.info('Creating directory: {}'.format(PHY_FOLDER))
        os.makedirs(PHY_FOLDER)

    # TODO: convert data to wide format

    # generate params.py
    params = generate.params(PATH_TO_CONFIG)
    path_to_params = path.join(PHY_FOLDER, 'params.py')

    with open(path_to_params, 'w') as f:
        f.write(params)

    logger.info('Saved {}...'.format(path_to_params))

    # channel_positions.npy
    logger.info('Generating channel_positions.npy')
    path_to_geom = path.join(ROOT_FOLDER, CONFIG['data']['geometry'])
    geom = geometry.parse(path_to_geom, N_CHANNELS)
    path_to_channel_positions = path.join(PHY_FOLDER, 'channel_positions.npy')
    np.save(path_to_channel_positions, geom)
    logger.info('Saved {}...'.format(path_to_channel_positions))

    # channel_map.npy
    channel_map = generate.channel_map(N_CHANNELS)
    path_to_channel_map = path.join(PHY_FOLDER, 'channel_map.npy')
    np.save(path_to_channel_map, channel_map)
    logger.info('Saved {}...'.format(path_to_channel_map))

    # load spike train
    path_to_spike_train = path.join(TMP_FOLDER, 'spike_train.npy')
    logger.info('Loading spike train from {}...'.format(path_to_spike_train))
    spike_train = np.load(path_to_spike_train)
    N_SPIKES, _ = spike_train.shape
    logger.info('Spike train contains {:,} spikes'.format(N_SPIKES))

    # load templates
    logging.info('Loading previously saved templates...')
    path_to_templates = path.join(TMP_FOLDER, 'templates.npy')
    templates = np.load(path_to_templates)
    _, _, N_TEMPLATES = templates.shape

    # pc_features_ind.npy
    path_to_pc_features_ind = path.join(PHY_FOLDER, 'pc_feature_ind.npy')
    ch_neighbors = geometry.find_channel_neighbors
    neigh_channels = ch_neighbors(geom, CONFIG['recordings']['spatial_radius'])

    pc_feature_ind = generate.pc_feature_ind(N_SPIKES, N_TEMPLATES, N_CHANNELS,
                                             geom, neigh_channels, spike_train,
                                             templates)
    np.save(path_to_pc_features_ind, pc_feature_ind)

    # similar_templates.npy
    path_to_templates = path.join(TMP_FOLDER, 'templates.npy')
    path_to_similar_templates = path.join(PHY_FOLDER, 'similar_templates.npy')
    templates = np.load(path_to_templates)
    similar_templates = generate.similar_templates(templates)
    np.save(path_to_similar_templates, similar_templates)
    logger.info('Saved {}...'.format(path_to_similar_templates))

    # spike_templates.npy and spike_times.npy
    path_to_spike_templates = path.join(PHY_FOLDER, 'spike_templates.npy')
    np.save(path_to_spike_templates, spike_train[:, 1])
    logger.info('Saved {}...'.format(path_to_spike_templates))

    path_to_spike_times = path.join(PHY_FOLDER, 'spike_times.npy')
    np.save(path_to_spike_times, spike_train[:, 0])
    logger.info('Saved {}...'.format(path_to_spike_times))

    # template_feature_ind.npy
    path_to_template_feature_ind = path.join(PHY_FOLDER,
                                             'template_feature_ind.npy')
    template_feature_ind = generate.template_feature_ind(
        N_TEMPLATES, similar_templates)
    np.save(path_to_template_feature_ind, template_feature_ind)
    logger.info('Saved {}...'.format(path_to_template_feature_ind))

    # template_features.npy
    templates_score = np.load(path.join(TMP_FOLDER, 'templates_score.npy'))
    templates_main_channel = np.load(
        path.join(TMP_FOLDER, 'templates_main_channel.npy'))
    waveforms_score = np.load(path.join(TMP_FOLDER, 'waveforms_score.npy'))

    path_to_template_features = path.join(PHY_FOLDER, 'template_features.npy')
    template_features = generate.template_features(
        N_SPIKES, N_CHANNELS, N_TEMPLATES, spike_train, templates_main_channel,
        neigh_channels, geom, templates_score, template_feature_ind,
        waveforms_score)
    np.save(path_to_template_features, template_features)
    logger.info('Saved {}...'.format(path_to_template_features))

    # templates.npy
    path_to_phy_templates = path.join(PHY_FOLDER, 'templates.npy')
    np.save(path_to_phy_templates, np.transpose(templates, [2, 1, 0]))
    logging.info(
        'Saved phy-compatible templates in {}'.format(path_to_phy_templates))

    # templates_ind.npy
    templates_ind = generate.templates_ind(N_TEMPLATES, N_CHANNELS)
    path_to_templates_ind = path.join(PHY_FOLDER, 'templates_ind.npy')
    np.save(path_to_templates_ind, templates_ind)
    logger.info('Saved {}...'.format(path_to_templates_ind))

    # whitening_mat.npy and whitening_mat_inv.npy
    logging.info('Generating whitening_mat.npy and whitening_mat_inv.npy...')
    path_to_whitening = path.join(TMP_FOLDER, 'whitening.npy')
    path_to_whitening_mat = path.join(PHY_FOLDER, 'whitening_mat.npy')
    shutil.copy2(path_to_whitening, )
    logging.info('Saving copy of whitening: {} in {}'.format(
        path_to_whitening, path_to_whitening_mat))

    path_to_whitening_mat_inv = path.join(PHY_FOLDER, 'whitening_mat_inv.npy')
    whitening = np.load(path_to_whitening)
    np.save(path_to_whitening_mat_inv, np.linalg.inv(whitening))
    logger.info('Saving inverse of whitening matrix in {}...'.format(
        path_to_whitening_mat_inv))

    logging.info('Done.')
Example #21
0
def run(config, logger_level='INFO', clean=False, output_dir='tmp/',
        complete=False, set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (relative to CONFIG.data.root_folder to store the
        output data, defaults to tmp/

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standarized.bin`` - Standarized recordings (from preprocess)
    * ``standarized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """
    # load yass configuration parameters
    set_config(config)
    CONFIG = read_config()
    ROOT_FOLDER = CONFIG.data.root_folder
    TMP_FOLDER = path.join(ROOT_FOLDER, output_dir)

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(TMP_FOLDER,
                                                               'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger and start coloredlogs
    logger = logging.getLogger(__name__)
    coloredlogs.install(logger=logger)

    if set_zero_seed:
        logger.warning('Set numpy seed to zero')
        np.random.seed(0)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)

    # preprocess
    start = time.time()
    (standarized_path,
     standarized_params,
     channel_index,
     whiten_filter) = (preprocess
                       .run(output_directory=output_dir,
                            if_file_exists=CONFIG.preprocess.if_file_exists))
    time_preprocess = time.time() - start

    # detect
    start = time.time()
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter,
                                   output_directory=output_dir,
                                   if_file_exists=CONFIG.detect.if_file_exists,
                                   save_results=CONFIG.detect.save_results)
    time_detect = time.time() - start

    # cluster
    start = time.time()
    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score,
        spike_index_clear,
        output_directory=output_dir,
        if_file_exists=CONFIG.cluster.if_file_exists,
        save_results=CONFIG.cluster.save_results)
    time_cluster = time.time() - start

    # get templates
    start = time.time()
    (templates,
     spike_train_clear_after_templates,
     groups,
     idx_good_templates) = get_templates.run(
        spike_train_clear, tmp_loc,
        output_directory=output_dir,
        if_file_exists=CONFIG.templates.if_file_exists,
        save_results=CONFIG.templates.save_results)
    time_templates = time.time() - start

    # run deconvolution
    start = time.time()
    spike_train = deconvolute.run(spike_index_all, templates,
                                  output_directory=output_dir)
    time_deconvolution = time.time() - start

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')

    if isinstance(config, Mapping):
        with open(path_to_config_copy, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
    else:
        shutil.copy2(config, path_to_config_copy)

    logging.info('Saving copy of config: {} in {}'.format(config,
                                                          path_to_config_copy))

    # TODO: complete flag saves other files needed for integrating phy
    # with yass, the integration hasn't been completed yet
    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spike_size,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_order=PARAMS['data_order'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'
                    .format(path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'
                    .format(path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'
                    .format(path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info('Saved all templates scores in {}...'
                    .format(path_to_waveforms))

    logger.info('Finished YASS execution. Timing summary:')
    total = (time_preprocess + time_detect + time_cluster + time_templates
             + time_deconvolution)
    logger.info('\t Preprocess: %s (%.2f %%)',
                human_readable_time(time_preprocess),
                time_preprocess/total*100)
    logger.info('\t Detection: %s (%.2f %%)',
                human_readable_time(time_detect),
                time_detect/total*100)
    logger.info('\t Clustering: %s (%.2f %%)',
                human_readable_time(time_cluster),
                time_cluster/total*100)
    logger.info('\t Templates: %s (%.2f %%)',
                human_readable_time(time_templates),
                time_templates/total*100)
    logger.info('\t Deconvolution: %s (%.2f %%)',
                human_readable_time(time_deconvolution),
                time_deconvolution/total*100)

    return spike_train
Example #22
0
def make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes,
                       data_folder):
    """[Description]

    Parameters
    ----------

    Returns
    -------
    """

    logger = logging.getLogger(__name__)

    path_to_data = os.path.join(data_folder, 'standarized.bin')
    path_to_config = os.path.join(data_folder, 'standarized.yaml')

    # make sure standarized data already exists
    if not os.path.exists(path_to_data):
        raise ValueError(
            'Standarized data does not exist in: {}, this is '
            'needed to generate training data, run the '
            'preprocesor first to generate it'.format(path_to_data))

    PARAMS = load_yaml(path_to_config)

    logger.info('Getting templates...')

    # get templates
    templates, _ = get_templates(spike_train, path_to_data, CONFIG.spikeSize)

    templates = np.transpose(templates, (2, 1, 0))

    logger.info('Got templates ndarray of shape: {}'.format(templates.shape))

    # choose good templates (good looking and big enough)
    templates = choose_templates(templates, chosen_templates)

    if templates.shape[0] == 0:
        raise ValueError("Coulndt find any good templates...")

    logger.info('Good looking templates of shape: {}'.format(templates.shape))

    # align and crop templates
    templates = crop_templates(templates, CONFIG.spikeSize,
                               CONFIG.neighChannels, CONFIG.geom)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'],
                                          CONFIG.recordings.n_channels,
                                          PARAMS['data_format'],
                                          CONFIG.neighChannels, CONFIG.geom,
                                          templates.shape[1])

    # make training data set
    K = templates.shape[0]
    R = CONFIG.spikeSize
    amps = np.max(np.abs(templates), axis=1)

    # make clean augmented spikes
    nk = int(np.ceil(nspikes / K))
    max_amp = np.max(amps) * 1.5
    nneigh = templates.shape[2]

    ################
    # clean spikes #
    ################
    x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2]))
    for k in range(K):
        tt = templates[k]
        amp_now = np.max(np.abs(tt))
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]
        x_clean[k * nk:(k + 1) *
                nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range

    #############
    # collision #
    #############
    x_collision = np.zeros(x_clean.shape)
    max_shift = 2 * R

    temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    amp_per_data = np.max(x_clean[:, :, 0], axis=1)
    for j in range(nk * K):
        shift = temporal_shifts[j]

        x_collision[j] = np.copy(x_clean[j])
        idx_candidate = np.where(amp_per_data > amp_per_data[j] * 0.3)[0]
        idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0],
                                                    size=1)[0]]
        x_clean2 = np.copy(x_clean[idx_match]
                           [:,
                            np.random.choice(nneigh, nneigh, replace=False)])

        if shift > 0:
            x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:]

        elif shift < 0:
            x_collision[j,
                        (-shift):] += x_clean2[:(x_collision.shape[1] + shift)]
        else:
            x_collision[j] += x_clean2

    #####################
    # misaligned spikes #
    #####################
    x_misaligned = np.zeros(x_clean.shape)

    temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    for j in range(nk * K):
        shift = temporal_shifts[j]
        x_clean2 = np.copy(
            x_clean[j][:, np.random.choice(nneigh, nneigh, replace=False)])

        if shift > 0:
            x_misaligned[j, :(x_collision.shape[1] -
                              shift)] += x_clean2[shift:]

        elif shift < 0:
            x_misaligned[j, (-shift):] += x_clean2[:(x_collision.shape[1] +
                                                     shift)]
        else:
            x_misaligned[j] += x_clean2

    #########
    # noise #
    #########

    # get noise
    noise = np.random.normal(size=x_clean.shape)
    for c in range(noise.shape[2]):
        noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG)

        reshaped_noise = np.reshape(noise, (-1, noise.shape[2]))
    noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG),
                       [x_clean.shape[0], x_clean.shape[1], x_clean.shape[2]])

    y_clean = np.ones((x_clean.shape[0]))
    y_col = np.ones((x_clean.shape[0]))
    y_misalinged = np.zeros((x_clean.shape[0]))
    y_noise = np.zeros((x_clean.shape[0]))

    mid_point = int((x_clean.shape[1] - 1) / 2)

    # get training set for detection
    x = np.concatenate(
        (x_clean + noise,
         x_collision + noise[np.random.permutation(noise.shape[0])],
         x_misaligned + noise[np.random.permutation(noise.shape[0])], noise))

    x_detect = x[:, (mid_point - R):(mid_point + R + 1), :]
    y_detect = np.concatenate((y_clean, y_col, y_misalinged, y_noise))

    # get training set for triage
    x = np.concatenate((
        x_clean + noise,
        x_collision + noise[np.random.permutation(noise.shape[0])],
    ))
    x_triage = x[:, (mid_point - R):(mid_point + R + 1), :]
    y_triage = np.concatenate((y_clean, np.zeros((x_clean.shape[0]))))

    # ge training set for auto encoder
    ae_shift_max = 1
    temporal_shifts_ae = np.random.randint(
        ae_shift_max * 2 + 1, size=x_clean.shape[0]) - ae_shift_max
    y_ae = np.zeros((x_clean.shape[0], 2 * R + 1))
    x_ae = np.zeros((x_clean.shape[0], 2 * R + 1))
    for j in range(x_ae.shape[0]):
        y_ae[j] = x_clean[j, (mid_point - R +
                              temporal_shifts_ae[j]):(mid_point + R + 1 +
                                                      temporal_shifts_ae[j]),
                          0]
        x_ae[j] = x_clean[j, (mid_point - R + temporal_shifts_ae[j]):(
            mid_point + R + 1 + temporal_shifts_ae[j]), 0] + noise[j, (
                mid_point - R + temporal_shifts_ae[j]):(
                    mid_point + R + 1 + temporal_shifts_ae[j]), 0]

    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae