Пример #1
0
def launch(task,
           filename,
           nb_cpu,
           nb_gpu,
           use_gpu,
           output=None,
           benchmark=None,
           extension='',
           sim_same_elec=None):

    from circus.shared.parser import CircusParser
    params = CircusParser(filename)

    if task not in ['filtering', 'benchmarking']:
        params.get_data_file()

    module = importlib.import_module('circus.' + task)

    if task == 'benchmarking':
        module.main(params, nb_cpu, nb_gpu, use_gpu, output, benchmark,
                    sim_same_elec)
    elif task in ['converting', 'merging']:
        module.main(params, nb_cpu, nb_gpu, use_gpu, extension)
    else:
        module.main(params, nb_cpu, nb_gpu, use_gpu)
class TestConverting(unittest.TestCase):
    def setUp(self):
        self.all_spikes = None
        self.max_chunk = '100'
        dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name = os.path.join(self.path, 'fitting.dat')
        self.source_dataset = get_dataset(self)
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                       self.file_name, 'fitting', 1)
            mpi_launch('whitening', self.file_name, 2, 0, 'False')
            self.parser = CircusParser(self.file_name)
            self.parser.write('fitting', 'max_chunk', '10')
            mpi_launch('fitting', self.file_name, 2, 0, 'False')
        else:
            self.parser = CircusParser(self.file_name)

    def test_converting_some(self):
        self.parser.write('converting', 'export_pcs', 'some')
        mpi_launch('converting', self.file_name, 1, 0, 'False')
        self.parser.write('converting', 'export_pcs', 'prompt')

    def test_converting_all(self):
        self.parser.write('converting', 'export_pcs', 'all')
        mpi_launch('converting', self.file_name, 2, 0, 'False')
        self.parser.write('converting', 'export_pcs', 'prompt')
Пример #3
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to split results obtained when using N streams
into N individual results files, one per stream
    '''
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for slicing results',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)

    filename = os.path.abspath(args.datafile)
    extension = args.extension
    if extension != '':
        extension = '-' + extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    _ = init_logging(params.logfile)
    logger = logging.getLogger(__name__)
    file_out_suff = params.get('data', 'file_out_suff')

    if params.get('data', 'stream_mode') in ['None', 'none']:
        print_and_log(['No streams in the datafile!'], 'error', logger)
        sys.exit(1)

    data_file = params.get_data_file()
    result = circus.shared.files.get_results(params, extension=extension)
    times = []
    for source in data_file._sources:
        times += [[source.t_start, source.t_stop]]

    sub_results = slice_result(result, times)

    for count, result in enumerate(sub_results):
        keys = ['spiketimes', 'amplitudes']
        mydata = h5py.File(file_out_suff + '.result%s_%d.hdf5' %
                           (extension, count),
                           'w',
                           libver='earliest')
        for key in keys:
            mydata.create_group(key)
            for temp in result[key].keys():
                tmp_path = '%s/%s' % (key, temp)
                mydata.create_dataset(tmp_path, data=result[key][temp])
        mydata.close()
Пример #4
0
 def setUp(self):
     self.all_spikes = None
     self.max_chunk = '100'
     dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
     self.path = os.path.join(dirname, 'synthetic')
     if not os.path.exists(self.path):
         os.makedirs(self.path)
     self.file_name = os.path.join(self.path, 'fitting.dat')
     self.source_dataset = get_dataset(self)
     if not os.path.exists(self.file_name):
         mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                    self.file_name, 'fitting')
         mpi_launch('whitening', self.file_name, 2, 0, 'False')
     self.parser = CircusParser(self.file_name)
 def setUp(self):
     self.all_matches    = None
     self.all_templates  = None
     dirname             = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
     self.path           = os.path.join(dirname, 'synthetic')
     if not os.path.exists(self.path):
         os.makedirs(self.path)
     self.file_name      = os.path.join(self.path, 'smart_search.dat')
     self.source_dataset = get_dataset(self)
     if not os.path.exists(self.file_name):
         mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False', self.file_name, 'smart-search', 1)
         mpi_launch('whitening', self.file_name, 2, 0, 'False')
     self.parser = CircusParser(self.file_name)
     self.parser.write('clustering', 'max_elts', '2000')
 def setUp(self):
     dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
     self.path = os.path.join(dirname, 'synthetic')
     if not os.path.exists(self.path):
         os.makedirs(self.path)
     self.file_name = os.path.join(self.path, 'whitening.dat')
     self.source_dataset = get_dataset(self)
     self.whitening = None
     if not os.path.exists(self.file_name):
         mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                    self.file_name, 'fitting', 1)
     self.params = CircusParser(self.file_name)
     self.params.write('clustering', 'max_elts', '1000')
     self.params.write('whitening', 'spatial', 'True')
     self.params.write('clustering', 'temporal', 'False')
Пример #7
0
 def setUp(self):
     dirname             = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
     self.path           = os.path.join(dirname, 'synthetic')
     if not os.path.exists(self.path):
         os.makedirs(self.path)
     self.file_name      = os.path.join(self.path, 'drifts.dat')
     self.source_dataset = get_dataset(self)
     if not os.path.exists(self.file_name):
         mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False', self.file_name, 'drifts', 1)
         mpi_launch('whitening', self.file_name, 2, 0, 'False')
     self.parser = CircusParser(self.file_name)
class TestSmartSearch(unittest.TestCase):

    def setUp(self):
        self.all_matches    = None
        self.all_templates  = None
        dirname             = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path           = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name      = os.path.join(self.path, 'smart_search.dat')
        self.source_dataset = get_dataset(self)
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False', self.file_name, 'smart-search', 1)
            mpi_launch('whitening', self.file_name, 2, 0, 'False')
        self.parser = CircusParser(self.file_name)
        self.parser.write('clustering', 'max_elts', '2000')

    #def tearDown(self):
    #    data_path = '.'.join(self.file_name.split('.')[:-1])
    #    shutil.rmtree(data_path)

    def test_smart_search_on(self):
        self.parser.write('clustering', 'smart_search', 'True')
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        self.parser.write('clustering', 'smart_search', 'False')
        res = get_performance(self.file_name, 'smart_search_on')

    def test_smart_search_off(self):
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        res = get_performance(self.file_name, 'smart_search_off')
Пример #9
0
class TestValidating(unittest.TestCase):
    def setUp(self):
        self.all_spikes = None
        self.max_chunk = '100'
        dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name = os.path.join(self.path, 'fitting.dat')
        self.source_dataset = get_dataset(self)
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                       self.file_name, 'fitting', 1)
            mpi_launch('whitening', self.file_name, 2, 0, 'False')
        self.parser = CircusParser(self.file_name)
        self.length = self.parser.get_data_file().duration

    def test_validating(self):
        #mpi_launch('fitting', self.file_name, 2, 0, 'False')

        a, b = os.path.splitext(os.path.basename(self.file_name))
        file_name, ext = os.path.splitext(self.file_name)
        file_out = os.path.join(os.path.abspath(file_name), a)
        result_name = os.path.join(file_name, 'injected')
        spikes = {}
        result = h5py.File(os.path.join(result_name, '%s.result.hdf5' % a))
        for key in result.get('spiketimes').keys():
            spikes[key] = result.get('spiketimes/%s' % key)[:]

        juxta_file = file_out + '.juxta.dat'

        f = numpy.memmap(juxta_file,
                         shape=(self.length, 1),
                         dtype=self.parser.get('validating', 'juxta_dtype'),
                         mode='w+')
        f[spikes['temp_9']] = 100
        del f

        mpi_launch('validating', self.file_name, 2, 0, 'False')
class TestWhitening(unittest.TestCase):
    def setUp(self):
        dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name = os.path.join(self.path, 'whitening.dat')
        self.source_dataset = get_dataset(self)
        self.whitening = None
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                       self.file_name, 'fitting', 1)
        self.params = CircusParser(self.file_name)
        self.params.write('clustering', 'max_elts', '1000')
        self.params.write('whitening', 'spatial', 'True')
        self.params.write('clustering', 'temporal', 'False')

    def test_whitening_one_CPU(self):
        mpi_launch('whitening', self.file_name, 1, 0, 'False')
        res = get_performance(self.file_name, 'one_CPU')
        if self.whitening is None:
            self.whitening = res
        assert ((res['spatial'] - self.whitening['spatial'])**2).mean() < 0.1

    def test_whitening_two_CPU(self):
        mpi_launch('whitening', self.file_name, 2, 0, 'False')
        res = get_performance(self.file_name, 'two_CPU')
        if self.whitening is None:
            self.whitening = res
        assert ((res['spatial'] - self.whitening['spatial'])**2).mean() < 0.1

    def test_whitening_safety_time(self):
        self.params.write('clustering', 'safety_time', '5')
        mpi_launch('whitening', self.file_name, 1, 0, 'False')
        self.params.write('clustering', 'safety_time', 'auto')
        res = get_performance(self.file_name, 'safety_time')
        if self.whitening is None:
            self.whitening = res
        assert ((res['spatial'] - self.whitening['spatial'])**2).mean() < 0.1
Пример #11
0
class TestGarbage(unittest.TestCase):
    def setUp(self):
        self.all_spikes = None
        self.max_chunk = '100'
        dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name = os.path.join(self.path, 'fitting.dat')
        self.source_dataset = get_dataset(self)
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                       self.file_name, 'fitting')
            mpi_launch('whitening', self.file_name, 2, 0, 'False')
        self.parser = CircusParser(self.file_name)

    def test_collect_all(self):
        self.parser.write('fitting', 'max_chunk', self.max_chunk)
        self.parser.write('fitting', 'collect_all', 'True')
        mpi_launch('fitting', self.file_name, 1, 0, 'False')
        self.parser.write('fitting', 'max_chunk', 'inf')
        self.parser.write('fitting', 'collect_all', 'False')
        ctruth, cspikes, cgarbage = get_performance(self.file_name)
        assert cgarbage < cspikes
"""Inspect basis waveforms."""
import argparse
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os

from circus.shared.parser import CircusParser

# Parse arguments.
parser = argparse.ArgumentParser(description="Inspect basis waveforms.")
parser.add_argument('datafile', help="data file")
args = parser.parse_args()

# Load parameters.
params = CircusParser(args.datafile)
_ = params.get_data_file()
sampling_rate = params.rate
file_out_suff = params.get('data', 'file_out_suff')
nb_channels = params.getint('data', 'N_e')
nb_time_steps = params.getint('detection', 'N_t')

# Load basis waveforms.
basis_path = "{}.basis.hdf5".format(file_out_suff)
if not os.path.isfile(basis_path):
    raise FileNotFoundError(basis_path)
with h5py.File(basis_path, mode='r', libver='earliest') as basis_file:
    if 'proj' not in basis_file:
        raise RuntimeError(
            "No projection matrix found in {}.".format(basis_path))
    projection_matrix = basis_file['proj'][:]
def main(params, nb_cpu, nb_gpu, use_gpu, file_name, benchmark, sim_same_elec):
    """
    Useful tool to create synthetic datasets for benchmarking.
    
    Arguments
    ---------
    benchmark : {'fitting', 'clustering', 'synchrony', 'pca-validation', 'smart-search', 'drifts'}
        
    """
    if sim_same_elec is None:
        sim_same_elec = 0.8

    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.benchmarking')

    numpy.random.seed(265)
    file_name      = os.path.abspath(file_name)
    data_path      = os.path.dirname(file_name)
    data_suff, ext = os.path.splitext(os.path.basename(file_name))
    file_out, ext  = os.path.splitext(file_name)

    if ext == '':
        ext = '.dat'
        file_name += ext
    
    if ext != '.dat':
        if comm.rank == 0:
            print_and_log(['Benchmarking produces raw files: select a .dat extension'], 'error', logger)
        sys.exit(0)

    if benchmark not in ['fitting', 'clustering', 'synchrony', 'smart-search', 'drifts']:
        if comm.rank == 0:
            print_and_log(['Benchmark need to be in [fitting, clustering, synchrony, smart-search, drifts]'], 'error', logger)
        sys.exit(0)

    # The extension `.p` or `.pkl` or `.pickle` seems more appropriate than `.pic`.
    # see: http://stackoverflow.com/questions/4530111/python-saving-objects-and-using-pickle-extension-of-filename
    # see: https://wiki.python.org/moin/UsingPickle
    def write_benchmark(filename, benchmark, cells, rates, amplitudes, sampling, probe, trends=None):
        """Save benchmark parameters in a file to remember them."""
        import cPickle
        to_write = {'benchmark' : benchmark}
        to_write['cells']      = cells
        to_write['rates']      = rates
        to_write['probe']      = probe
        to_write['amplitudes'] = amplitudes
        to_write['sampling']   = sampling
        if benchmark == 'drifts':
            to_write['drifts'] = trends
        cPickle.dump(to_write, open(filename + '.pic', 'w'))

    # Retrieve some key parameters.
    templates = io.load_data(params, 'templates')
    N_tm = templates.shape[1] // 2
    trends          = None

    # Normalize some variables.
    if benchmark == 'fitting':
        nb_insert       = 25
        n_cells         = numpy.random.random_integers(0, N_tm - 1, nb_insert)
        rate            = nb_insert * [10]
        amplitude       = numpy.linspace(0.5, 5, nb_insert)
    if benchmark == 'clustering':
        n_point         = 5
        n_cells         = numpy.random.random_integers(0, N_tm - 1, n_point ** 2)
        x, y            = numpy.mgrid[0:n_point, 0:n_point]
        rate            = numpy.linspace(0.5, 20, n_point)[x.flatten()]
        amplitude       = numpy.linspace(0.5, 5, n_point)[y.flatten()]
    if benchmark == 'synchrony':
        nb_insert       = 5
        corrcoef        = 0.2
        n_cells         = nb_insert * [numpy.random.random_integers(0, N_tm - 1, 1)[0]]
        rate            = 10. / corrcoef
        amplitude       = 2
    if benchmark == 'pca-validation':
        nb_insert       = 10
        n_cells         = numpy.random.random_integers(0, N_tm - 1, nb_insert)
        rate_min        = 0.5
        rate_max        = 20.0
        rate            = rate_min + (rate_max - rate_min) * numpy.random.random_sample(nb_insert)
        amplitude_min   = 0.5
        amplitude_max   = 5.0
        amplitude       = amplitude_min + (amplitude_max - amplitude_min) * numpy.random.random_sample(nb_insert)
    if benchmark == 'smart-search':
        nb_insert       = 10
        n_cells         = nb_insert*[numpy.random.random_integers(0, templates.shape[1]//2-1, 1)[0]]
        rate            = 1 + 5*numpy.arange(nb_insert)
        amplitude       = 2
    if benchmark == 'drifts':
        n_point         = 5
        n_cells         = numpy.random.random_integers(0, templates.shape[1]//2-1, n_point**2)
        x, y            = numpy.mgrid[0:n_point,0:n_point]
        rate            = 5*numpy.ones(n_point)[x.flatten()]
        amplitude       = numpy.linspace(0.5, 5, n_point)[y.flatten()]
        trends          = numpy.random.randn(n_point**2)

    # Delete the output directory tree if this output directory exists.
    if comm.rank == 0:
        if os.path.exists(file_out):
            shutil.rmtree(file_out)

    # Check and normalize some variables.
    if n_cells is None:
        n_cells    = 1
        cells      = [numpy.random.permutation(numpy.arange(n_cells))[0]]
    elif not numpy.iterable(n_cells):
        cells      = [n_cells]
        n_cells    = 1
    else:
        cells      = n_cells
        n_cells    = len(cells)

    if numpy.iterable(rate):
        assert len(rate) == len(cells), "Should have the same number of rates and cells"
    else:
        rate = [rate] * len(cells)

    if numpy.iterable(amplitude):
        assert len(amplitude) == len(cells), "Should have the same number of amplitudes and cells"
    else:
        amplitude = [amplitude] * len(cells)

    # Retrieve some additional key parameters.
    #params           = detect_memory(params)
    data_file        = params.get_data_file(source=True)
    N_e              = params.getint('data', 'N_e')
    N_total          = params.nb_channels
    hdf5_compress    = params.getboolean('data', 'hdf5_compress')
    nodes, edges     = get_nodes_and_edges(params)
    N_t              = params.getint('detection', 'N_t')
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    N_tm_init             = templates.shape[1]//2
    thresholds            = io.load_data(params, 'thresholds')
    limits                = io.load_data(params, 'limits')
    best_elecs            = io.load_data(params, 'electrodes')
    norms                 = io.load_data(params, 'norm-templates')

    # Create output directory if it does not exist.
    if comm.rank == 0:
        if not os.path.exists(file_out):
            os.makedirs(file_out)

    # Save benchmark parameters in a file to remember them.
    if comm.rank == 0:
        write_benchmark(file_out, benchmark, cells, rate, amplitude,
                        params.rate, params.get('data', 'mapping'), trends)

    # Synchronize all the threads/processes.
    comm.Barrier()

    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    # Retrieve some additional key parameters.
    chunk_size     = params.getint('data', 'chunk_size')
    scalings       = []
    
    params.set('data', 'data_file', file_name)

    data_file_out = params.get_data_file(is_empty=True)
    data_file_out.allocate(shape=data_file.shape)

    # Synchronize all the threads/processes.
    comm.Barrier()

    # For each wanted synthesized cell insert a generated template in the set of
    # existing template.
    for gcount, cell_id in enumerate(cells):
        best_elec   = best_elecs[cell_id]
        indices     = inv_nodes[edges[nodes[best_elec]]]
        count       = 0
        new_indices = []
        all_elecs   = numpy.random.permutation(numpy.arange(N_e))
        reference   = templates[:, cell_id].toarray().reshape(N_e, N_t)
        # Initialize the similarity (i.e. default value).
        similarity = 1.0
        # Find the first eligible template for the wanted synthesized cell.
        while len(new_indices) != len(indices) or (similarity > sim_same_elec): 
            similarity  = 0
            if count == len(all_elecs):
                if comm.rank == 0:
                    print_and_log(["No electrode to move template %d (max similarity is %g)" %(cell_id, similarity)], 'error', logger)
                sys.exit(0)
            else:
                # Get the next shuffled electrode.
                n_elec = all_elecs[count]

                if benchmark not in ['synchrony', 'smart-search']:
                    # Process if the shuffled electrode and the nearest electrode
                    # to the synthesized cell are not identical.
                    local_test = n_elec != best_elec
                else:
                    # Process if the shuffled electrode and the nearest electrode
                    # to the synthesized cell are identical.
                    local_test = n_elec == best_elec

                if local_test:
                    # Shuffle the neighboring electrodes whithout modifying
                    # the nearest electrode to the synthesized cell.
                    new_indices = inv_nodes[edges[nodes[n_elec]]]
                    idx = numpy.where(new_indices != best_elec)[0]
                    new_indices[idx] = numpy.random.permutation(new_indices[idx])

                    if len(new_indices) == len(indices):
                        # Shuffle the templates on the neighboring electrodes.
                        new_temp = numpy.zeros(reference.shape,
                                               dtype=numpy.float32)
                        new_temp[new_indices, :] = reference[indices, :]
                        # Compute the scaling factor which normalize the
                        # shuffled template.
                        gmin = new_temp.min()
                        data = numpy.where(new_temp == gmin)
                        scaling = -thresholds[data[0][0]]/gmin
                        for i in xrange(templates.shape[1]//2):
                            match = templates[:, i].toarray().reshape(N_e, N_t)
                            d = numpy.corrcoef(match.flatten(),
                                               scaling * new_temp.flatten())[0, 1]
                            if d > similarity:
                                similarity = d
                else:
                    new_indices = []
            # Go to the next shuffled electrode.
            count += 1

        #if comm.rank == 0:
        #    print "Template", cell_id, "is shuffled from electrode", best_elec, "to", n_elec, "(max similarity is %g)" %similarity

        N_tm           = templates.shape[1]//2
        to_insert      = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert[new_indices] = scaling*amplitude[gcount]*templates[:, cell_id].toarray().reshape(N_e, N_t)[indices]
        to_insert2     = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert2[new_indices] = scaling*amplitude[gcount]*templates[:, cell_id + N_tm].toarray().reshape(N_e, N_t)[indices]

        ## Insert the selected template.
        
        # Retrieve the number of existing templates in the dataset.
        N_tm           = templates.shape[1]//2

        # Generate the template of the synthesized cell from the selected
        # template, the target amplitude and the rescaling (i.e. threshold of
        # the target electrode).
        to_insert = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert[new_indices] = scaling * amplitude[gcount] * templates[:, cell_id].toarray().reshape(N_e, N_t)[indices]
        to_insert = to_insert.flatten()
        to_insert2 = numpy.zeros(reference.shape, dtype=numpy.float32)
        to_insert2[new_indices] = scaling * amplitude[gcount] * templates[:, cell_id + N_tm].toarray().reshape(N_e, N_t)[indices]
        to_insert2 = to_insert2.flatten()

        # Compute the norm of the generated template.
        mynorm     = numpy.sqrt(numpy.sum(to_insert ** 2) / (N_e * N_t))
        mynorm2    = numpy.sqrt(numpy.sum(to_insert2 ** 2) / (N_e * N_t))

        # Insert the limits of the generated template.
        limits     = numpy.vstack((limits, limits[cell_id]))
        # Insert the best electrode of the generated template.
        best_elecs = numpy.concatenate((best_elecs, [n_elec]))

        # Insert the norm of the generated template (i.e. central component and
        # orthogonal component).
        norms      = numpy.insert(norms, N_tm, mynorm)
        norms      = numpy.insert(norms, 2 * N_tm + 1, mynorm2)
        # Insert the scaling of the generated template.
        scalings  += [scaling]

        # Retrieve the data about the existing templates.
        templates = templates.tocoo()
        xdata     = templates.row
        ydata     = templates.col
        zdata     = templates.data

        # Shift by one the orthogonal components of the existing templates.
        idx       = numpy.where(ydata >= N_tm)[0]
        ydata[idx] += 1

        # Insert the central component of the selected template.
        dx    = to_insert.nonzero()[0].astype(numpy.int32)
        xdata = numpy.concatenate((xdata, dx))
        ydata = numpy.concatenate((ydata, N_tm * numpy.ones(len(dx), dtype=numpy.int32)))
        zdata = numpy.concatenate((zdata, to_insert[dx]))

        # Insert the orthogonal component of the selected template.
        dx    = to_insert2.nonzero()[0].astype(numpy.int32)
        xdata = numpy.concatenate((xdata, dx))
        ydata = numpy.concatenate((ydata, (2 * N_tm + 1) * numpy.ones(len(dx), dtype=numpy.int32)))
        zdata = numpy.concatenate((zdata, to_insert2[dx]))

        # Recontruct the matrix of templates.
        templates = scipy.sparse.csc_matrix((zdata, (xdata, ydata)), shape=(N_e * N_t, 2 * (N_tm + 1)))

    # Remove all the expired data.
    if benchmark == 'pca-validation':
        # Remove all the expired data.
        N_tm_init = 0
        N_tm = templates.shape[1] / 2

        limits = limits[N_tm - nb_insert:, :]
        best_elecs = best_elecs[N_tm - nb_insert:]
        norms = numpy.concatenate((norms[N_tm-nb_insert:N_tm], norms[2*N_tm-nb_insert:2*N_tm]))
        scalings = scalings
        
        templates = templates.tocoo()
        xdata = templates.row
        ydata = templates.col
        zdata = templates.data
        
        idx_cen = numpy.logical_and(N_tm - nb_insert <= ydata, ydata < N_tm)
        idx_cen = numpy.where(idx_cen)[0]
        idx_ort = numpy.logical_and(2 * N_tm - nb_insert <= ydata, ydata < 2 * N_tm)
        idx_ort = numpy.where(idx_ort)[0]
        ydata[idx_cen] = ydata[idx_cen] - (N_tm - nb_insert)
        ydata[idx_ort] = ydata[idx_ort] - 2 * (N_tm - nb_insert)
        idx = numpy.concatenate((idx_cen, idx_ort))
        xdata = xdata[idx]
        ydata = ydata[idx]
        zdata = zdata[idx]
        templates = scipy.sparse.csc_matrix((zdata, (xdata, ydata)), shape=(N_e * N_t, 2 * nb_insert))
        
    # Retrieve the information about the organisation of the chunks of data.
    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # Display informations about the generated benchmark.
    if comm.rank == 0:
        print_and_log(["Generating benchmark data [%s] with %d cells" %(benchmark, n_cells)], 'info', logger)
        purge(file_out, '.data')


    template_shift = params.getint('detection', 'template_shift')
    all_chunks     = numpy.arange(nb_chunks)
    to_process     = all_chunks[numpy.arange(comm.rank, nb_chunks, comm.size)]
    loc_nb_chunks  = len(to_process)
    numpy.random.seed(comm.rank)

    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    # Initialize the progress bar about the generation of the benchmark.
    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    # Open the file for collective I/O.
    #g = myfile.Open(comm, file_name, MPI.MODE_RDWR)
    #g.Set_view(data_offset, data_mpi, data_mpi)
    data_file_out.open(mode='r+')

    # Open the thread/process' files to collect the results.
    spiketimes_filename = os.path.join(file_out, data_suff + '.spiketimes-%d.data' %comm.rank)
    spiketimes_file = open(spiketimes_filename, 'wb')
    amplitude_filename = os.path.join(file_out, data_suff + '.amplitudes-%d.data' %comm.rank)
    amplitudes_file = open(amplitude_filename, 'wb')
    templates_filename = os.path.join(file_out, data_suff + '.templates-%d.data' %comm.rank)
    templates_file = open(templates_filename, 'wb')
    real_amps_filename = os.path.join(file_out, data_suff + '.real_amps-%d.data' %comm.rank)
    real_amps_file = open(real_amps_filename, 'wb')
    voltages_filename = os.path.join(file_out, data_suff + '.voltages-%d.data' %comm.rank)
    voltages_file = open(voltages_filename, 'wb')

    # For each chunk of data associate to the current thread/process generate
    # the new chunk of data (i.e. with considering the added synthesized cells).
    for count, gidx in enumerate(to_explore):

        #if (last_chunk_len > 0) and (gidx == (nb_chunks - 1)):
        #    chunk_len  = last_chunk_len
        #    chunk_size = last_chunk_len // N_total

        result         = {'spiketimes' : [], 'amplitudes' : [], 
                          'templates' : [], 'real_amps' : [],
                          'voltages' : []}
        offset         = gidx * chunk_size
        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes)

        if benchmark == 'pca-validation':
            # Clear the current data chunk.
            local_chunk = numpy.zeros(local_chunk.shape, dtype=local_chunk.dtype)

        # Handle whitening if necessary.
        if do_spatial_whitening:
            local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        if benchmark is 'synchrony':
            # Generate some spike indices (i.e. times) at the given rate for
            # 'synchrony' mode. Each synthesized cell will use a subset of this
            # spike times.
            mips = numpy.random.rand(chunk_size) < rate[0] / float(params.rate)

        # For each synthesized cell generate its spike indices (i.e.times) and
        # add them to the dataset.
        for idx in xrange(len(cells)):
            if benchmark is 'synchrony':
                # Choose a subset of the spike indices generated before. The
                # size of this subset is parameterized by the target correlation
                # coefficients.
                sidx       = numpy.where(mips == True)[0]
                spikes     = numpy.zeros(chunk_size, dtype=numpy.bool)
                spikes[sidx[numpy.random.rand(len(sidx)) < corrcoef]] = True
            else:
                # Generate some spike indices at the given rate.
                spikes     = numpy.random.rand(chunk_size) < rate[idx] / float(params.rate)
            if benchmark == 'drifts':
                amplitudes = numpy.ones(len(spikes)) + trends[idx]*((spikes + offset)/(5*60*float(params.rate)))
            else:
                amplitudes = numpy.ones(len(spikes))
            # Padding with `False` to avoid the insertion of partial spikes at
            # the edges of the signal.
            spikes[:N_t]   = False
            spikes[-N_t:]  = False
            # Find the indices of the spike samples.
            spikes         = numpy.where(spikes == True)[0]
            n_template     = N_tm_init + idx
            loc_template   = templates[:, n_template].toarray().reshape(N_e, N_t)
            first_flat     = loc_template.T.flatten()
            norm_flat      = numpy.sum(first_flat ** 2)
            # For each index (i.e. spike sample location) add the spike to the
            # chunk of data.
            refractory     = int(5 * 1e-3 * params.rate)         
            t_last         = - refractory
            for scount, spike in enumerate(spikes):
                if (spike - t_last) > refractory:
                    local_chunk[spike-template_shift:spike+template_shift+1, :] += amplitudes[scount]*loc_template.T
                    amp        = numpy.dot(local_chunk[spike-template_shift:spike+template_shift+1, :].flatten(), first_flat)
                    amp       /= norm_flat
                    result['real_amps']  += [amp]
                    result['spiketimes'] += [spike + offset]
                    result['amplitudes'] += [(amplitudes[scount], 0)]
                    result['templates']  += [n_template]
                    result['voltages']   += [local_chunk[spike, best_elecs[idx]]]
                    t_last                = spike

        # Write the results into the thread/process' files.
        spikes_to_write     = numpy.array(result['spiketimes'], dtype=numpy.uint32)
        amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
        templates_to_write  = numpy.array(result['templates'], dtype=numpy.int32)
        real_amps_to_write  = numpy.array(result['real_amps'], dtype=numpy.float32)
        voltages_to_write   = numpy.array(result['voltages'], dtype=numpy.float32)

        spiketimes_file.write(spikes_to_write.tostring())   
        amplitudes_file.write(amplitudes_to_write.tostring())
        templates_file.write(templates_to_write.tostring())
        real_amps_file.write(real_amps_to_write.tostring())
        voltages_file.write(voltages_to_write.tostring())

        #print count, 'spikes inserted...'
        #new_chunk    = numpy.zeros((chunk_size, N_total), dtype=numpy.float32)
        #new_chunk[:, nodes] = local_chunk

        # Overwrite the new chunk of data using explicit offset. 
        #new_chunk   = new_chunk.flatten()
        #g.Write_at(gidx * chunk_len, new_chunk)
        data_file_out.set_data(offset, local_chunk)

        # Update the progress bar about the generation of the benchmark.
        
    # Close the thread/process' files.
    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    real_amps_file.flush()
    os.fsync(real_amps_file.fileno())
    real_amps_file.close()

    voltages_file.flush()
    os.fsync(voltages_file.fileno())
    voltages_file.close()


    # Close the file for collective I/O.
    data_file_out.close()
    data_file.close()

    
    # Synchronize all the threads/processes.
    comm.Barrier()

    
    ## Eventually, perform all the administrative tasks.
    ## (i.e. files and folders management).

    file_params = file_out + '.params'

    if comm.rank == 0:
        # Create `injected` directory if it does not exist
        result_path = os.path.join(file_out, 'injected') 
        if not os.path.exists(result_path):
            os.makedirs(result_path)

        # Copy initial configuration file from `<dataset1>.params` to `<dataset2>.params`.
        shutil.copy2(params.get('data', 'data_file_noext') + '.params', file_params)
        new_params = CircusParser(file_name)
        # Copy initial basis file from `<dataset1>/<dataset1>.basis.hdf5` to
        # `<dataset2>/injected/<dataset2>.basis.hdf5.
        shutil.copy2(params.get('data', 'file_out') + '.basis.hdf5',
                     os.path.join(result_path, data_suff + '.basis.hdf5'))


        # Save templates into `<dataset>/<dataset>.templates.hdf5`.
        mydata = h5py.File(os.path.join(file_out, data_suff + '.templates.hdf5'), 'w')
        templates = templates.tocoo()
        if hdf5_compress:
            mydata.create_dataset('temp_x', data=templates.row, compression='gzip')
            mydata.create_dataset('temp_y', data=templates.col, compression='gzip')
            mydata.create_dataset('temp_data', data=templates.data, compression='gzip')
        else:
            mydata.create_dataset('temp_x', data=templates.row)
            mydata.create_dataset('temp_y', data=templates.col)
            mydata.create_dataset('temp_data', data=templates.data)
        mydata.create_dataset('temp_shape', data=numpy.array([N_e, N_t, templates.shape[1]],
                                                             dtype=numpy.int32))
        mydata.create_dataset('limits', data=limits)
        mydata.create_dataset('norms', data=norms)
        mydata.close()

        # Save electrodes into `<dataset>/<dataset>.clusters.hdf5`.
        mydata = h5py.File(os.path.join(file_out, data_suff + '.clusters.hdf5'), 'w')
        mydata.create_dataset('electrodes', data=best_elecs)
        mydata.close()

    comm.Barrier()
    if comm.rank == 0:
        # Gather data from all threads/processes.
        f_next, extension = os.path.splitext(file_name)
        file_out_bis = os.path.join(f_next, os.path.basename(f_next))
        #new_params.set('data', 'file_out', file_out_bis) # Output file without suffix
        #new_params.set('data', 'file_out_suff', file_out_bis  + params.get('data', 'suffix'))
    
        new_params.get_data_file()
        io.collect_data(comm.size, new_params, erase=True, with_real_amps=True, with_voltages=True, benchmark=True)
        # Change some flags in the configuration file.
        new_params.write('whitening', 'temporal', 'False') # Disable temporal filtering
        new_params.write('whitening', 'spatial', 'False') # Disable spatial filtering
        new_params.write('data', 'data_dtype', 'float32') # Set type of the data to float32
        new_params.write('data', 'dtype_offset', 'auto') # Set padding for data to auto
        # Move results from `<dataset>/<dataset>.result.hdf5` to
        # `<dataset>/injected/<dataset>.result.hdf5`.
        
        shutil.move(os.path.join(file_out, data_suff + '.result.hdf5'), os.path.join(result_path, data_suff + '.result.hdf5'))
                
        # Save scalings into `<dataset>/injected/<dataset>.scalings.npy`.
        numpy.save(os.path.join(result_path, data_suff + '.scalings'), scalings)

        file_name_noext, ext = os.path.splitext(file_name)

        # Copy basis from `<dataset>/injected/<dataset>.basis.hdf5` to
        # `<dataset>/<dataset>.basis.hdf5`.
        shutil.copy2(os.path.join(result_path, data_suff + '.basis.hdf5'),
                     os.path.join(file_out, data_suff + '.basis.hdf5'))

        if benchmark not in ['fitting', 'synchrony']:
            # Copy templates from `<dataset>/<dataset>.templates.hdf5` to
            # `<dataset>/injected/<dataset>.templates.hdf5`
            shutil.move(os.path.join(file_out, data_suff + '.templates.hdf5'),
                        os.path.join(result_path, data_suff + '.templates.hdf5'))
Пример #14
0
                    ],
                    help="fitted snippets selection")
parser.add_argument('-n',
                    '--nb-snippets',
                    default=10,
                    type=int,
                    help="number of snippets to select")
args = parser.parse_args()
# # Adjust extension argument.
if args.extension is None:
    args.extension = ""
else:
    args.extension = "-" + args.extension

# Load parameters.
params = CircusParser(args.datafile)
_ = params.get_data_file()
sampling_rate = params.rate
duration = params.data_file.duration

# Load spike times and amplitudes.
results = load_data(params, 'results', extension=args.extension)
template_key = 'temp_{}'.format(args.template_id)
spike_times = results['spiketimes'][template_key][:]
amplitudes = results['amplitudes'][template_key][:, 0]

# Check number of spikes.
nb_spikes = spike_times.size
if nb_spikes == 0:
    warnings.warn("No fitted spikes for template {}.".format(args.template_id),
                  category=UserWarning)
Пример #15
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)

    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    logger = init_logging(params.logfile)
    logger = logging.getLogger(__name__)
    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    gain = data_file.gain
    t_start = data_file.t_start
    file_format = data_file.description

    if file_format not in supported_by_matlab:
        print_and_log([
            "File format %s is not supported by MATLAB. Waveforms disabled" %
            file_format
        ], 'info', logger)

    if numpy.iterable(gain):
        print_and_log(
            ['Multiple gains are not supported, using a default value of 1'],
            'info', logger)
        gain = 1

    file_out_suff = params.get('data', 'file_out_suff')
    if hasattr(data_file, 'data_offset'):
        data_offset = data_file.data_offset
    else:
        data_offset = 0
    probe = params.probe
    if extension != '':
        extension = '-' + extension

    def generate_matlab_mapping(probe):
        p = {}
        positions = []
        nodes = []
        for key in probe['channel_groups'].keys():
            p.update(probe['channel_groups'][key]['geometry'])
            nodes += probe['channel_groups'][key]['channels']
            positions += [
                p[channel]
                for channel in probe['channel_groups'][key]['channels']
            ]
        idx = numpy.argsort(nodes)
        positions = numpy.array(positions)[idx]

        t = tempfile.NamedTemporaryFile().name + '.hdf5'
        cfile = h5py.File(t, 'w')
        to_write = {
            'positions': positions / 10.,
            'permutation': numpy.sort(nodes),
            'nb_total': numpy.array([probe['total_nb_channels']])
        }
        write_datasets(cfile, to_write.keys(), to_write)
        cfile.close()
        return t

    mapping = generate_matlab_mapping(probe)

    if not params.getboolean('data', 'overwrite'):
        filename = params.get('data', 'data_file_no_overwrite')
    else:
        filename = params.get('data', 'data_file')

    apply_patch_for_similarities(params, extension)

    gui_file = pkg_resources.resource_filename(
        'circus', os.path.join('matlab_GUI', 'SortingGUI.m'))
    # Change to the directory of the matlab file
    os.chdir(os.path.abspath(os.path.dirname(gui_file)))

    # Use quotation marks for string arguments
    if file_format not in supported_by_matlab:
        gui_params = [
            params.rate,
            os.path.abspath(file_out_suff),
            '%s.mat' % extension, mapping, 2, t_start
        ]
        is_string = [False, True, True, True, False]

    else:

        gui_params = [
            params.rate,
            os.path.abspath(file_out_suff),
            '%s.mat' % extension, mapping, 2, t_start, data_dtype, data_offset,
            gain, filename
        ]
        is_string = [
            False, True, True, True, False, False, True, False, False, True
        ]

    arguments = ', '.join([
        "'%s'" % arg if s else "%s" % arg
        for arg, s in zip(gui_params, is_string)
    ])
    matlab_command = 'SortingGUI(%s)' % arguments

    print_and_log(["Launching the MATLAB GUI..."], 'info', logger)
    print_and_log([matlab_command], 'debug', logger)

    if params.getboolean('fitting', 'collect_all'):
        print_and_log([
            'You can not view the unfitted spikes with the MATLAB GUI',
            'Please consider using phy if you really would like to see them'
        ], 'info', logger)

    try:
        sys.exit(
            subprocess.call(
                ['matlab', '-nodesktop', '-nosplash', '-r', matlab_command]))
    except Exception:
        print_and_log(
            ["Something wrong with MATLAB. Try circus-gui-python instead?"],
            'error', logger)
        sys.exit(1)
Пример #16
0
def get_label_for_sc_index(sc_index: int,
                           dead_channels: list,
                           ordered_mcs_indices: list,
                           pad_with_zero: bool = False):
    return get_label_for_mcs_index(get_mcs_index(sc_index, dead_channels),
                                   ordered_mcs_indices, pad_with_zero)


# CHANGE PATHS TO ANALYZE DIFFERENT FILES
filtered_filepath = r'/mnt/Data/Albina/data from home/2021-06-29T12-05-34Control_slice2.h5'
result_filepath = r'/mnt/Data/Albina/data from home/2021-06-29T12-05-34Control_slice2/' \
                  r'2021-06-29T12-05-34Control_slice2.clusters.hdf5'

filter_file = h5py.File(filtered_filepath, 'r')
result_file = h5py.File(result_filepath, 'r')
params = CircusParser(filtered_filepath)

filter_file_voltage_traces = filter_file['Data']['Recording_0'][
    'AnalogStream']['Stream_0']['ChannelData']
ch_ids = [
    ch[1] for ch in filter_file['Data']['Recording_0']['AnalogStream']
    ['Stream_0']['InfoChannel']
]
spike_indices = retrieve_spiketimes(result_file)
sorted_spike_indices = [
    np.sort(spike_indices[m]) for m in range(len(spike_indices))
]
sampling_rate = float(params.get('data', 'sampling_rate'))
spiketimes = np.array(spike_indices, dtype=object) / sampling_rate
sorted_spiketimes = [np.sort(spiketimes[i]) for i in range(len(spiketimes))]
spike_differences = [
Пример #17
0
    def set_params_spc(self, main_params, npy_file, output):
        '''
        Set parameters file for Spyking Circus

        Parameters
        ----------
        main_params : dictionary
            Some parameters that are useful for setting:
                -- 'N_t'
                -- 'cut_off'
                -- 'stream_mode'
                -- dead_channels ('grad_idx'/'mag_idx')
                -- 'cc_merge'
        npy_file : pathlib.PosixPath
            The path to the numpy data file in the CIRCUS folder
        output : pathlib.PosixPath
            The directory where the results will be saved. 
            Different for magnetometers and gradiometers
        '''
        from shutil import copyfile
        from circus.shared.parser import CircusParser

        self.params = CircusParser(npy_file, create_folders=False)
        ### data
        self.params.write('data','file_format','numpy')
        self.params.write('data','stream_mode', main_params['stream_mode'])
        self.params.write('data','mapping', str(output.parent / 'meg_306.prb'))
        self.params.write('data','output_dir',str(output))
        self.params.write('data','sampling_rate','1000')
        ### detection
        self.params.write('detection','radius', '6')
        self.params.write('detection','N_t', str(main_params['N_t']))
        #self.params.write('detection','spike_thresh', '6')
        self.params.write('detection','peaks','both')
        self.params.write('detection','alignment','False')
        self.params.write('detection','isolation','False')
        if  (self.sensors == 'mag'):
            grad = '{ 1 : %s}'%main_params['grad_idx']
            self.params.write('detection','dead_channels', grad)
        else:
            mag = '{ 1 : %s}'%main_params['mag_idx']
            self.params.write('detection','dead_channels', mag)
        ### filtering
        filt_param = '{}, {}'.format(main_params['cut_off'][0], 
                                     main_params['cut_off'][1])
        self.params.write('filtering','cut_off',filt_param)
        ### whitening
        self.params.write('whitening','safety_time','auto')
        self.params.write('whitening','max_elts','10000')
        self.params.write('whitening','nb_elts','0.1')
        self.params.write('whitening','spatial','False')
        ### clustering
        self.params.write('clustering','extraction','mean-raw')
        self.params.write('clustering','safety_space','False')
        self.params.write('clustering','safety_time','1')
        self.params.write('clustering','max_elts','10000')
        self.params.write('clustering','nb_elts','0.001')
        self.params.write('clustering','nclus_min','0.0001')
        self.params.write('clustering','smart_search','False')
        self.params.write('clustering','sim_same_elec','1')
        self.params.write('clustering','sensitivity','5')
        self.params.write('clustering','cc_merge', str(main_params['cc_merge']))
        self.params.write('clustering','dispersion','(5, 5)')
        self.params.write('clustering','noise_thr','0.9')
        #self.params.write('clustering','remove_mixture','False')
        self.params.write('clustering','cc_mixtures','0.1')
        self.params.write('clustering','make_plots','png')
        ### fitting
        self.params.write('fitting','chunk_size','60')
        self.params.write('fitting','amp_limits','(0.01,10)')
        self.params.write('fitting','amp_auto','False')
        self.params.write('fitting','collect_all','True')
        ### merging
        self.params.write('merging','cc_overlap','0.4')
        self.params.write('merging','cc_bin','200')
        
        self.params = CircusParser(npy_file, create_folders=False)
        copyfile(self.path_params, output / 'config.param')
Пример #18
0
def get_dataset(self):
    dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
    filename = os.path.join(dirname, 'data')
    if not os.path.exists(filename):
        os.makedirs(filename)
    result = os.path.join(filename, 'data')
    filename = os.path.join(filename, 'data.dat')
    if not os.path.exists(filename):
        print "Generating a synthetic dataset of 4 channels, 1min at 20kHz..."
        sampling_rate = 20000
        N_total = 4
        gain = 0.5
        data = (gain *
                numpy.random.randn(sampling_rate * N_total * 1 * 60)).astype(
                    numpy.float32)
        myfile = open(filename, 'w')
        myfile.write(data.tostring())
        myfile.close()

    src_path = os.path.abspath(os.path.join(dirname, 'snippet'))

    if not os.path.exists(result):
        os.makedirs(result)
        shutil.copy(os.path.join(src_path, 'test.basis.hdf5'),
                    os.path.join(result, 'data.basis.hdf5'))
        shutil.copy(os.path.join(src_path, 'test.templates.hdf5'),
                    os.path.join(result, 'data.templates.hdf5'))
        shutil.copy(os.path.join(src_path, 'test.clusters.hdf5'),
                    os.path.join(result, 'data.clusters.hdf5'))

    config_file = os.path.abspath(
        pkg_resources.resource_filename('circus', 'config.params'))
    file_params = os.path.abspath(filename.replace('.dat', '.params'))
    if not os.path.exists(file_params):

        shutil.copyfile(config_file, file_params)
        probe_file = os.path.join(src_path, 'test.prb')
        parser = CircusParser(filename, mapping=probe_file)
        parser.write('data', 'file_format', 'raw_binary')
        parser.write('data', 'data_offset', '0')
        parser.write('data', 'data_dtype', 'float32')
        parser.write('data', 'sampling_rate', '20000')
        parser.write('whitening', 'temporal', 'False')
        parser.write('data', 'mapping', probe_file)
        parser.write('clustering', 'make_plots', 'png')
        parser.write('clustering', 'nb_repeats', '3')
        parser.write('detection', 'N_t', '3')
        parser.write('clustering', 'smart_search', 'False')
        parser.write('clustering', 'max_elts', '10000')
        parser.write('noedits', 'filter_done', 'True')
        parser.write('clustering', 'extraction', 'median-raw')

    a, b = os.path.splitext(os.path.basename(filename))
    c, d = os.path.splitext(filename)
    file_out = os.path.join(os.path.abspath(c), a)

    return filename
Пример #19
0
                    help="template identifier",
                    dest='template_id')
# parser.add_argument('-s', '--size', choices=['amplitude', 'norm'], help="marker size")
parser.add_argument('-c',
                    '--color',
                    choices=['amplitude', 'norm'],
                    help="marker color")
parser.add_argument('-n',
                    '--nb-snippets-max',
                    default=20,
                    type=int,
                    help="maximum number of snippets")
args = parser.parse_args()

# Load parameters.
params = CircusParser(args.datafile)
_ = params.get_data_file()
sampling_rate = params.rate
# duration = params.data_file.duration
nb_time_steps = params.getint('detection', 'N_t')

# Load spike times.
results = load_data(params, 'results',
                    extension='')  # TODO support other extensions?
results_data = dict()
for key in results['spiketimes'].keys():
    if key[:5] == 'temp_':
        template_key = key
        template_id = int(key[5:])
        if 'spike_times' not in results_data:
            results_data['spike_times'] = dict()
Пример #20
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to concatenate artefacts/dead times before using 
stream mode. Code will look for .dead and .trig files, and 
concatenate them automatically taking care of file offsets
    '''
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    # parser.add_argument('-w', '--window', help='text file with artefact window files',
    #                     default=None)

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)
    # if args.window is None:
    #     window_file = None
    # else:
    #     window_file = os.path.abspath(args.window)

    filename = os.path.abspath(args.datafile)
    params = CircusParser(filename)
    dead_in_ms = params.getboolean('triggers', 'dead_in_ms')
    trig_in_ms = params.getboolean('triggers', 'trig_in_ms')

    if os.path.exists(params.logfile):
        os.remove(params.logfile)

    _ = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    if params.get('data', 'stream_mode') == 'multi-files':
        data_file = params.get_data_file(source=True, has_been_created=False)
        all_times_dead = numpy.zeros((0, 2), dtype=numpy.int64)
        all_times_trig = numpy.zeros((0, 2), dtype=numpy.int64)

        for f in data_file._sources:
            name, ext = os.path.splitext(f.file_name)
            dead_file = f.file_name.replace(ext, '.dead')
            trig_file = f.file_name.replace(ext, '.trig')

            if os.path.exists(dead_file):
                print_and_log(['Found file %s' % dead_file], 'default', logger)
                times = get_dead_times(dead_file, data_file.sampling_rate,
                                       dead_in_ms)
                if times.max() > f.duration or times.min() < 0:
                    print_and_log([
                        'Dead zones larger than duration for file %s' %
                        f.file_name, '-> Clipping automatically'
                    ], 'error', logger)
                    times = numpy.minimum(times, f.duration)
                    times = numpy.maximum(times, 0)
                times += f.t_start
                all_times_dead = numpy.vstack((all_times_dead, times))

            if os.path.exists(trig_file):
                print_and_log(['Found file %s' % trig_file], 'default', logger)

                times = get_trig_times(trig_file, data_file.sampling_rate,
                                       trig_in_ms)
                if times[:, 1].max() > f.duration or times[:, 1].min() < 0:
                    print_and_log([
                        'Triggers larger than duration for file %s' %
                        f.file_name
                    ], 'error', logger)
                    sys.exit(0)
                times[:, 1] += f.t_start
                all_times_trig = numpy.vstack((all_times_trig, times))

        if len(all_times_dead) > 0:
            output_file = os.path.join(os.path.dirname(filename),
                                       'dead_zones.txt')
            print_and_log(['Saving global artefact file in %s' % output_file],
                          'default', logger)
            if dead_in_ms:
                all_times_dead = all_times_dead.astype(
                    numpy.float32) / data_file.sampling_rate
            numpy.savetxt(output_file, all_times_dead)

        if len(all_times_trig) > 0:
            output_file = os.path.join(os.path.dirname(filename),
                                       'triggers.txt')
            print_and_log(['Saving global artefact file in %s' % output_file],
                          'default', logger)
            if trig_in_ms:
                all_times_trig = all_times_trig.astype(
                    numpy.float32) / data_file.sampling_rate
            numpy.savetxt(output_file, all_times_trig)

    elif params.get('data', 'stream_mode') == 'single-file':
        print_and_log(['Not implemented'], 'error', logger)
        sys.exit(0)
    else:
        print_and_log(
            ['You should select a valid stream_mode such as multi-files'],
            'error', logger)
        sys.exit(0)
Пример #21
0
base_filepath = '/home/lisa_ruth/spyking-circus/' \
                '2021-06-22T15-51-27human_slices_Slice1_BL.h5'
filepath = '/home/lisa_ruth/spyking-circus/2021-06-22T15-51-27human_slices_Slice1_BL/' \
           '2021-06-22T15-51-27human_slices_Slice1_BL.clusters.hdf5'

file = h5py.File(filepath, 'r')
basefile = h5py.File(base_filepath, 'r')

sampling_frequency = 1000000 / \
                         basefile['Data']['Recording_0']['AnalogStream']['Stream_0']['InfoChannel']['Tick'][0]
duration_index = basefile['Data']['Recording_0']['AnalogStream']['Stream_0'][
    'ChannelDataTimeStamps'][0][2]
duration = duration_index / sampling_frequency
print(duration)
params = CircusParser(base_filepath)
dead_channels = params.get('detection', 'dead_channels')
if len(dead_channels) > 1:
    dead_channels = [int(s) for s in dead_channels[5:-2].split(',')]

ids = [
    ch[0]
    for ch in basefile['Data/Recording_0/AnalogStream/Stream_0/InfoChannel']
]
labels = [
    ch[4].decode('utf8')
    for ch in basefile['Data/Recording_0/AnalogStream/Stream_0/InfoChannel']
]
same_len_labels = [
    str(label[0]) + '0' + str(label[1]) if len(label) < 3 else label
    for label in labels
class TestClustering(unittest.TestCase):
    def setUp(self):
        self.all_matches = None
        self.all_templates = None
        dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name = os.path.join(self.path, 'clustering.dat')
        self.source_dataset = get_dataset(self)
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                       self.file_name, 'clustering', 1)
            mpi_launch('whitening', self.file_name, 2, 0, 'False')

        self.parser = CircusParser(self.file_name)
        self.parser.write('clustering', 'max_elts', '1000')

    def test_clustering_one_CPU(self):
        mpi_launch('clustering', self.file_name, 1, 0, 'False')
        res = get_performance(self.file_name, 'one_CPU')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]

    def test_clustering_two_CPU(self):
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        res = get_performance(self.file_name, 'two_CPU')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]

    def test_clustering_pca(self):
        self.parser.write('clustering', 'extraction', 'median-pca')
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        self.parser.write('clustering', 'extraction', 'median-raw')
        res = get_performance(self.file_name, 'median-pca')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]

    def test_clustering_nb_passes(self):
        self.parser.write('clustering', 'nb_repeats', '1')
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        self.parser.write('clustering', 'nb_repeats', '3')
        res = get_performance(self.file_name, 'nb_passes')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]

    def test_clustering_sim_same_elec(self):
        self.parser.write('clustering', 'sim_same_elec', '5')
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        self.parser.write('clustering', 'sim_same_elec', '3')
        res = get_performance(self.file_name, 'sim_same_elec')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]

    def test_clustering_cc_merge(self):
        self.parser.write('clustering', 'cc_merge', '0.8')
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        self.parser.write('clustering', 'cc_merge', '0.95')
        res = get_performance(self.file_name, 'cc_merge')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]

    def test_remove_mixtures(self):
        self.parser.write('clustering', 'remove_mixtures', 'False')
        mpi_launch('clustering', self.file_name, 2, 0, 'False')
        self.parser.write('clustering', 'remove_mixtures', 'True')
        res = get_performance(self.file_name, 'cc_merge')
        if self.all_templates is None:
            self.all_templates = res[0]
            self.all_matches = res[1]
file_path_sc_base = r'/home/lisa_ruth/spyking-circus/' \
                    r'2021-02-17T15-15-06FHM3_GS967_BL6_P16_female_400ms_7psi_Slice4_Test1_rdy_4_SC.h5'
file_path_sc_result = r'/home/lisa_ruth/spyking-circus/' \
                      r'2021-02-17T15-15-06FHM3_GS967_BL6_P16_female_400ms_7psi_Slice4_Test1_rdy_4_SC/' \
                      r'2021-02-17T15-15-06FHM3_GS967_BL6_P16_female_400ms_7psi_Slice4_Test1_rdy_4_SC.clusters.hdf5'

base_file = h5py.File(file_path_sc_base, 'r')
result_file = h5py.File(file_path_sc_result, 'r')

fs = 10000.

bf_vt = base_file['scaled']
time = np.arange(0, len(bf_vt[0])/fs, 1/fs)

params = CircusParser(file_path_sc_base)
embed()
spiketimes = retrieve_spiketimes(result_file)
labels = ['A5', 'A6', 'B5', 'B6']

figure, axs = plt.subplots(2, 2, figsize=(12, 9))
figure.subplots_adjust(hspace=0.5)
axs = axs.flat
neo_spiketrains = []
for i, ax in enumerate(axs):
    spike_indices = spiketimes[i]
    scatter_time = np.array(time[spike_indices])
    neo_spiketrains.append(neo.SpikeTrain(times=scatter_time, units='sec', t_stop=300.0))
    scatter_height = np.array(bf_vt[i][spike_indices])
    ax.plot(time, bf_vt[i], zorder=1)
    ax.scatter(scatter_time, scatter_height, color='r', zorder=2)
Пример #24
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)

    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    logger = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    mytest = StrictVersion(phycontrib.__version__) >= StrictVersion("1.0.12")
    if not mytest:
        print_and_log(
            ['You need to update phy-contrib to the latest git version'],
            'error', logger)
        sys.exit(1)

    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    if hasattr(data_file, 'data_offset'):
        data_offset = data_file.data_offset
    else:
        data_offset = 0
    file_format = data_file.description
    file_out_suff = params.get('data', 'file_out_suff')

    if file_format not in supported_by_phy:
        print_and_log([
            "File format %s is not supported by phy. TraceView disabled" %
            file_format
        ], 'info', logger)

    if numpy.iterable(data_file.gain):
        print_and_log(
            ['Multiple gains are not supported, using a default value of 1'],
            'info', logger)
        gain = 1
    else:
        if data_file.gain != 1:
            print_and_log([
                "Gain of %g is not supported by phy. Expecting a scaling mismatch"
                % gain
            ], 'info', logger)
            gain = data_file.gain

    probe = params.probe
    if extension != '':
        extension = '-' + extension
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'

    if not os.path.exists(output_path):
        print_and_log(
            ['Data should be first exported with the converting method!'],
            'error', logger)
    else:

        print_and_log(["Launching the phy GUI..."], 'info', logger)

        gui_params = {}
        if file_format in supported_by_phy:
            gui_params['dat_path'] = params.get('data', 'data_file')
        else:
            gui_params['dat_path'] = ''
        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_dtype
        gui_params['offset'] = data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['hp_filtered'] = True

        os.chdir(output_path)
        create_app()
        controller = TemplateController(**gui_params)
        gui = controller.create_gui()

        gui.show()
        run_app()
        gui.close()
        del gui
Пример #25
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    header += '''Utility to launch the phy GUI and visualize the results. 
[data must be first converted with the converting mode]
    '''
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)
    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    logger = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    if extension != '':
        extension = '-' + extension

    try:
        import traitlets
    except ImportError:
        print_and_log(
            ['The package traitlets required by phy is not installed'],
            'error', logger)
        sys.exit(1)

    try:
        import click
    except ImportError:
        print_and_log(['The package click required by phy is not installed'],
                      'error', logger)
        sys.exit(1)

    try:
        import joblib
    except ImportError:
        print_and_log(['The package joblib required by phy is not installed'],
                      'error', logger)
        sys.exit(1)

    if HAVE_PHYCONTRIB:
        mytest = StrictVersion(
            phycontrib.__version__) >= StrictVersion("1.0.12")
        if not mytest:
            print_and_log(
                ['You need to update phy-contrib to the latest git version'],
                'error', logger)
            sys.exit(1)

        print_and_log([
            'phy-contrib is deprecated, you should upgrade to phy 2.0 and phylib'
        ], 'info', logger)

    if HAVE_PHYLIB:
        try:
            import colorcet
        except ImportError:
            print_and_log(
                ['The package colorcet required by phy is not installed'],
                'error', logger)
            sys.exit(1)

        try:
            import qtconsole
        except ImportError:
            print_and_log(
                ['The package qtconsole required by phy is not installed'],
                'error', logger)
            sys.exit(1)

    if not test_patch_for_similarities(params, extension):
        print_and_log(
            ['You should re-export the data because of a fix in 0.6'], 'error',
            logger)
        continue_anyway = query_yes_no(
            Fore.WHITE + "Continue anyway (results may not be fully correct)?",
            default=None)
        if not continue_anyway:
            sys.exit(1)

    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    if data_file.params.has_key('data_offset'):
        data_offset = data_file.data_offset
    else:
        data_offset = 0

    file_format = data_file.description
    file_out_suff = params.get('data', 'file_out_suff')

    if file_format not in supported_by_phy:
        print_and_log([
            "File format %s is not supported by phy. TraceView disabled" %
            file_format
        ], 'info', logger)

    if numpy.iterable(data_file.gain):
        print_and_log(
            ['Multiple gains are not supported, using a default value of 1'],
            'info', logger)
        gain = 1
    else:
        if data_file.gain != 1:
            print_and_log([
                "Gain of %g is not supported by phy. Expecting a scaling mismatch"
                % data_file.gain
            ], 'info', logger)
            gain = data_file.gain

    probe = params.probe
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'

    if not os.path.exists(output_path):
        print_and_log(
            ['Data should be first exported with the converting method!'],
            'error', logger)
    else:

        print_and_log(["Launching the phy GUI..."], 'info', logger)

        gui_params = {}
        if file_format in supported_by_phy:
            if not params.getboolean('data', 'overwrite'):
                gui_params['dat_path'] = r"%s" % params.get(
                    'data', 'data_file_no_overwrite')
            else:
                if params.get('data', 'stream_mode') == 'multi-files':
                    data_file = params.get_data_file(source=True,
                                                     has_been_created=False)
                    gui_params['dat_path'] = [
                        r"%s" % f for f in data_file.get_file_names()
                    ]
                else:
                    gui_params['dat_path'] = r"%s" % params.get(
                        'data', 'data_file')
        else:
            gui_params['dat_path'] = 'giverandomname.dat'

        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_dtype
        gui_params['offset'] = data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['dir_path'] = output_path
        gui_params['hp_filtered'] = True

        os.chdir(output_path)
        create_app()
        controller = TemplateController(**gui_params)
        gui = controller.create_gui()

        gui.show()
        run_app()
        gui.close()
        del gui
Пример #26
0
"""Inspect thresholds."""
import argparse
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os

from circus.shared.parser import CircusParser

# Parse arguments.
parser = argparse.ArgumentParser(description="Inspect threhsolds.")
parser.add_argument('datafile', help="data file")
args = parser.parse_args()

# Load parameters.
params = CircusParser(args.datafile)
_ = params.get_data_file()
file_out_suff = params.get('data', 'file_out_suff')
nb_channels = params.getint('data', 'N_e')

# Load spatial matrix.
basis_path = "{}.basis.hdf5".format(file_out_suff)
if not os.path.isfile(basis_path):
    raise FileNotFoundError(basis_path)
with h5py.File(basis_path, mode='r', libver='earliest') as basis_file:
    if 'thresholds' not in basis_file:
        raise RuntimeError("No thresholds found in {}".format(basis_path))
    thresholds = basis_file['thresholds'][:]
    assert thresholds.shape == (nb_channels, ), (thresholds.shape, nb_channels)

# Plot thresholds.
from circus.shared.probes import get_nodes_and_edges
from circus.shared.files import get_stas

# Parse arguments.
parser = argparse.ArgumentParser(description="Inspect scalar products.")
parser.add_argument('datafile', help="data file")
parser.add_argument('-t',
                    '--template',
                    default=0,
                    type=int,
                    help="template index",
                    dest='template_id')
args = parser.parse_args()

# Load parameters.
params = CircusParser(args.datafile)
_ = params.get_data_file()
nb_electrodes = params.nb_channels
sampling_rate = params.rate  # Hz
duration = params.data_file.duration / sampling_rate  # s
nb_channels = params.getint('data', 'N_e')
nb_time_steps = params.getint('detection', 'N_t')
nb_snippets = params.getint('clustering', 'nb_snippets')

# Load snippets.
clusters_data = load_clusters_data(params, extension='')
electrode = clusters_data['electrodes'][args.template_id]
local_cluster = clusters_data['local_clusters'][args.template_id]
assert electrode.shape == local_cluster.shape, (electrode.shape,
                                                local_cluster.shape)
times = clusters_data['times'][electrode]
Пример #28
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    parallel_hdf5 = h5py.get_config().mpi
    user_path = pjoin(os.path.expanduser('~'), 'spyking-circus')
    tasks_list = None

    if not os.path.exists(user_path):
        os.makedirs(user_path)

    try:
        import cudamat as cmt
        cmt.init()
        HAVE_CUDA = True
    except Exception:
        HAVE_CUDA = False

    all_steps = [
        'whitening', 'clustering', 'fitting', 'gathering', 'extracting',
        'filtering', 'converting', 'deconverting', 'benchmarking',
        'merging', 'validating', 'thresholding'
    ]

    config_file = os.path.abspath(pkg_resources.resource_filename('circus', 'config.params'))

    header = get_colored_header()
    header += Fore.GREEN + 'Local CPUs    : ' + Fore.CYAN + str(psutil.cpu_count()) + '\n'
    # header += Fore.GREEN + 'GPU detected  : ' + Fore.CYAN + str(HAVE_CUDA) + '\n'
    header += Fore.GREEN + 'Parallel HDF5 : ' + Fore.CYAN + str(parallel_hdf5) + '\n'

    do_upgrade = ''
    if not SHARED_MEMORY:
        do_upgrade = Fore.WHITE + '   [please consider upgrading MPI]'

    header += Fore.GREEN + 'Shared memory : ' + Fore.CYAN + str(SHARED_MEMORY) + do_upgrade + '\n'
    header += '\n'
    header += Fore.GREEN + "##################################################################"
    header += Fore.RESET

    method_help = '''by default, all steps are performed,
but a subset x,y can be done. Steps are:
 - filtering
 - whitening
 - clustering
 - fitting
 - merging [with or without a GUI for meta merging]
 - (extra) converting [export results to phy format]
 - (extra) thresholding [to get MUA activity only]
 - (extra) deconverting [import results from phy format]
 - (extra) gathering [force collection of results]
 - (extra) extracting [get templates from spike times]
 - (extra) benchmarking [with -o and -t]
 - (extra) validating [to compare performance with GT neurons]'''

    parser = argparse.ArgumentParser(description=header,
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file (or a list of commands if batch mode)')
    parser.add_argument('-i', '--info', help='list the file formats supported by SpyKING CIRCUS', action='store_true')
    parser.add_argument('-m', '--method',
                        default='filtering,whitening,clustering,fitting,merging',
                        help=method_help)
    parser.add_argument('-c', '--cpu', type=int, default=max(1, int(psutil.cpu_count()/2)), help='number of CPU')
    # parser.add_argument('-g', '--gpu', type=int, default=0, help='number of GPU')
    parser.add_argument('-H', '--hostfile', help='hostfile for MPI',
                        default=pjoin(user_path, 'circus.hosts'))
    parser.add_argument('-b', '--batch', help='datafile is a list of commands to launch, in a batch mode',
                        action='store_true')
    parser.add_argument('-p', '--preview', help='GUI to display the first second filtered with thresholds',
                        action='store_true')
    parser.add_argument('-r', '--result', help='GUI to display the results on top of raw data',
                        action='store_true')
    parser.add_argument('-s', '--second', type=int, default=0, help='If preview mode, begining of the preview [in s]')
    parser.add_argument('-e', '--extension', help='extension to consider for merging, converting and deconverting',
                        default='None')
    parser.add_argument('-o', '--output', help='output file [for generation of synthetic benchmarks]')
    parser.add_argument('-t', '--type', help='benchmark type',
                        choices=['fitting', 'clustering', 'synchrony'])

    if len(argv) == 0:
        parser.print_help()
        sys.exit(0)

    args = parser.parse_args(argv)

    steps = args.method.split(',')
    for step in steps:
        if step not in all_steps:
            print_error(['The method "%s" is not recognized' % step])
            sys.exit(0)

    # To save some typing later
    nb_gpu = 0
    (nb_cpu, hostfile, batch, preview, result, extension, output, benchmark, info, second) = \
        (args.cpu, args.hostfile, args.batch, args.preview, args.result, args.extension, args.output, args.type, args.info, args.second)
    filename = os.path.abspath(args.datafile)
    real_file = filename

    f_next, extens = os.path.splitext(filename)

    if info:
        if args.datafile.lower() in __supported_data_files__:
            filename = 'tmp'
            if len(__supported_data_files__[args.datafile.lower()].extension) > 0:
                filename += __supported_data_files__[args.datafile.lower()].extension[0]

            __supported_data_files__[args.datafile.lower()](filename, {}, is_empty=True)._display_requirements_()
        else:
            print_and_log([
                '',
                'To get info on any particular file format, do:',
                '>> spyking-circus file_format -i',
                ''
            ], 'default')
            print_and_log(list_all_file_format())
        sys.exit(0)

    if extens == '.params':
        print_error(['You should launch the code on the data file!'])
        sys.exit(0)

    file_params = f_next + '.params'
    if not os.path.exists(file_params) and not batch:
        print(Fore.RED + 'The parameter file %s is not present!' % file_params)
        create_params = query_yes_no(Fore.WHITE + "Do you want SpyKING CIRCUS to create a parameter file?")

        if create_params:
            print(Fore.WHITE + "Creating %s" % file_params)
            print(Fore.WHITE + "Fill it properly before launching the code! (see documentation)")
            print_info(['Keep in mind that filtering is performed on site, so please',
                        'be sure to keep a copy of your data elsewhere'])
            shutil.copyfile(config_file, file_params)
        sys.exit(0)
    elif batch:
        tasks_list = filename

    if not batch:
        file_params = f_next + '.params'

        if not os.path.exists(file_params):
            print_and_log(["%s does not exist" % file_params], 'error')
            sys.exit(0)

        import ConfigParser as configparser
        parser = configparser.ConfigParser()
        myfile = open(file_params, 'r')
        lines = myfile.readlines()
        myfile.close()
        myfile = open(file_params, 'w')
        for l in lines:
            myfile.write(l.replace('\t', ''))
        myfile.close()

        parser.read(file_params)

        for section in CircusParser.__all_sections__:
            if parser.has_section(section):
                for (key, value) in parser.items(section):
                    parser.set(section, key, value.split('#')[0].rstrip())
            else:
                parser.add_section(section)

        try:
            use_output_dir = parser.get('data', 'output_dir') != ''
        except Exception:
            use_output_dir = False

        if use_output_dir:
            path = os.path.abspath(os.path.expanduser(parser.get('data', 'output_dir')))
            file_out = os.path.join(path, os.path.basename(f_next))
            if not os.path.exists(file_out):
                os.makedirs(file_out)
        else:
            file_out = f_next


        logfile = file_out + '.log'
        if os.path.exists(logfile):
            os.remove(logfile)

        logger = init_logging(logfile)
        params = CircusParser(filename)
        data_file = params.get_data_file(source=True, has_been_created=False)
        overwrite = params.getboolean('data', 'overwrite')
        file_format = params.get('data', 'file_format')
        if overwrite:
            support_parallel_write = data_file.parallel_write
            is_writable = data_file.is_writable
        else:
            support_parallel_write = __supported_data_files__['raw_binary'].parallel_write
            is_writable = __supported_data_files__['raw_binary'].is_writable

    if preview:
        print_and_log(['Preview mode, showing only seconds [%d-%d] of the recording' % (second, second+1)], 'info', logger)
        tmp_path_loc = os.path.join(os.path.abspath(params.get('data', 'file_out')), 'tmp')

        if not os.path.exists(tmp_path_loc):
            os.makedirs(tmp_path_loc)

        filename = os.path.join(tmp_path_loc, 'preview.dat')
        f_next, extens = os.path.splitext(filename)
        preview_params = f_next + '.params'
        shutil.copyfile(file_params, preview_params)
        steps = ['filtering', 'whitening']

        chunk_size = int(params.rate)

        data_file.open()
        nb_chunks, _ = data_file.analyze(chunk_size)

        if nb_chunks <= (second + 1):
            print_and_log(['Recording is too short to display seconds [%d-%d]' % (second, second+1)])
            sys.exit(0)
        local_chunk = data_file.get_snippet(int(second*params.rate), int(1.2*chunk_size))
        description = data_file.get_description()
        data_file.close()

        new_params = CircusParser(filename, create_folders=False)

        new_params.write('data', 'chunk_size', '1')
        new_params.write('data', 'file_format', 'raw_binary')
        new_params.write('data', 'data_dtype', 'float32')
        new_params.write('data', 'data_offset', '0')
        new_params.write('data', 'dtype_offset', '0')
        new_params.write('data', 'stream_mode', 'None')
        new_params.write('data', 'overwrite', 'True')
        new_params.write('triggers', 'ignore_times', 'False')
        new_params.write('data', 'sampling_rate', str(params.rate))
        new_params.write('whitening', 'safety_time', '0')
        new_params.write('clustering', 'safety_time', '0')
        new_params.write('whitening', 'chunk_size', '1')
        new_params.write('data', 'preview_path', params.file_params)
        new_params.write('data', 'output_dir', '')

        description['data_dtype'] = 'float32'
        description['dtype_offset'] = 0
        description['data_offset'] = 0
        description['gain'] = 1.
        new_params = CircusParser(filename)
        data_file_out = new_params.get_data_file(is_empty=True, params=description)

        support_parallel_write = data_file_out.parallel_write
        is_writable = data_file_out.is_writable

        data_file_out.allocate(shape=local_chunk.shape, data_dtype=numpy.float32)
        data_file_out.open('r+')
        data_file_out.set_data(0, local_chunk)
        data_file_out.close()

    if tasks_list is not None:
        with open(tasks_list, 'r') as f:
            for line in f:
                if len(line) > 0:
                    subprocess.check_call(['spyking-circus'] + line.replace('\n', '').split(" "))
    else:

        print_and_log(['Config file: %s' % (f_next + '.params')], 'debug', logger)
        print_and_log(['Data file  : %s' % filename], 'debug', logger)

        print(get_colored_header())
        print(Fore.GREEN + "File          : " + Fore.CYAN + real_file)
        if preview:
            print(Fore.GREEN + "Steps         : " + Fore.CYAN + "preview mode")
        elif result:
            print(Fore.GREEN + "Steps         : " + Fore.CYAN + "result mode")
        else:
            print(Fore.GREEN + "Steps         : " + Fore.CYAN + ", ".join(steps))
        # print Fore.GREEN + "GPU detected  : ", Fore.CYAN + str(HAVE_CUDA)
        print(Fore.GREEN + "Number of CPU : " + Fore.CYAN + str(nb_cpu) + "/" + str(psutil.cpu_count()))
        # if HAVE_CUDA:
        #     print Fore.GREEN + "Number of GPU : ", Fore.CYAN + str(nb_gpu)
        print(Fore.GREEN + "Parallel HDF5 : " + Fore.CYAN + str(parallel_hdf5))

        do_upgrade = ''
        use_shared_memory = get_shared_memory_flag(params)
        if not SHARED_MEMORY:
            do_upgrade = Fore.WHITE + '   [please consider upgrading MPI]'

        print(Fore.GREEN + "Shared memory : " + Fore.CYAN + str(use_shared_memory) + do_upgrade)
        print(Fore.GREEN + "Hostfile      : " + Fore.CYAN + hostfile)
        print("")
        print(Fore.GREEN + "##################################################################")
        print("")
        print(Fore.RESET)

        # Launch the subtasks
        subtasks = [('filtering', 'mpirun'),
                    ('whitening', 'mpirun'),
                    ('clustering', 'mpirun'),
                    ('fitting', 'mpirun'),
                    ('extracting', 'mpirun'),
                    ('gathering', 'python'),
                    ('converting', 'mpirun'),
                    ('deconverting', 'mpirun'),
                    ('benchmarking', 'mpirun'),
                    ('merging', 'mpirun'),
                    ('validating', 'mpirun'),
                    ('thresholding', 'mpirun')]

        # if HAVE_CUDA and nb_gpu > 0:
        #     use_gpu = 'True'
        # else:
        use_gpu = 'False'

        time = data_stats(params) / 60.0

        if preview:
            params = new_params

        if nb_cpu < psutil.cpu_count():
            if use_gpu != 'True' and not result:
                print_and_log(['Using only %d out of %d local CPUs available (-c to change)' % (nb_cpu, psutil.cpu_count())], 'info', logger)

        if params.getboolean('detection', 'matched-filter') and not params.getboolean('clustering', 'smart_search'):
            print_and_log(['Smart Search should be activated for matched filtering'], 'info', logger)

        if time > 30 and not params.getboolean('clustering', 'smart_search'):
            print_and_log(['Smart Search should be activated for long recordings'], 'info', logger)

        n_edges = get_averaged_n_edges(params)
        if n_edges > 100 and not params.getboolean('clustering', 'compress'):
            print_and_log(['Template compression is highly recommended based on parameters'], 'info', logger)

        if not result:
            for subtask, command in subtasks:
                if subtask in steps:
                    if command == 'python':
                        # Directly call the launcher
                        try:
                            circus.launch(subtask, filename, nb_cpu, nb_gpu, use_gpu)
                        except:
                            print_and_log(['Step "%s" failed!' % subtask], 'error', logger)
                            sys.exit(0)
                    elif command == 'mpirun':
                        # Use mpirun to make the call
                        mpi_args = gather_mpi_arguments(hostfile, params)
                        one_cpu = False

                        if subtask in ['filtering', 'benchmarking'] and not is_writable:
                            if not preview and overwrite:
                                print_and_log(['The file format %s is read only!' % file_format,
                                               'You should set overwite to False, to create a copy of the data.',
                                               'However, note that if you have streams, informations on times',
                                               'will be discarded'], 'info', logger)
                                sys.exit(0)

                        if subtask in ['filtering'] and not support_parallel_write and (args.cpu > 1):
                            print_and_log(['No parallel writes for %s: only 1 node used for %s' %(file_format, subtask)], 'info', logger)
                            nb_tasks = str(1)
                            one_cpu = True

                        else:
                            if subtask != 'fitting':
                                nb_tasks = str(args.cpu)
                            else:
                                # if use_gpu == 'True':
                                #     nb_tasks = str(args.gpu)
                                # else:
                                nb_tasks = str(args.cpu)

                        if subtask == 'benchmarking':
                            if (output is None) or (benchmark is None):
                                print_and_log(["To generate synthetic datasets, you must provide output and type"], 'error', logger)
                                sys.exit(0)
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask',
                                subtask, filename, str(nb_cpu), str(nb_gpu),
                                use_gpu, output, benchmark
                            ]
                        elif subtask in ['merging', 'converting']:
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask',
                                subtask, filename, str(nb_cpu), str(nb_gpu),
                                use_gpu, extension
                            ]
                        elif subtask in ['deconverting']:
                            nb_tasks = str(1)
                            nb_cpu = 1
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask', subtask,
                                filename, str(nb_cpu), str(nb_gpu), use_gpu,
                                extension
                            ]
                        else:
                            mpi_args += [
                                '-np', nb_tasks, 'spyking-circus-subtask',
                                subtask, filename, str(nb_cpu), str(nb_gpu),
                                use_gpu, str(one_cpu)
                            ]

                        print_and_log(['Launching task %s' % subtask], 'debug', logger)
                        print_and_log(['Command: %s' % str(mpi_args)], 'debug', logger)

                        try:
                            subprocess.check_call(mpi_args)
                        except subprocess.CalledProcessError as e:
                            print_and_log(['Step "%s" failed for reason %s!' % (subtask, e)], 'error', logger)
                            sys.exit(0)

    if preview or result:
        from circus.shared import gui
        import pylab
        try:
            from PyQt5.QtWidgets import QApplication
        except ImportError:
            from matplotlib.backends import qt_compat
            use_pyside = qt_compat.QT_API == qt_compat.QT_API_PYSIDE
            if use_pyside:
                from PySide.QtGui import QApplication
            else:
                from PyQt4.QtGui import QApplication
        app = QApplication([])
        try:
            pylab.style.use('ggplot')
        except Exception:
            pass

        if preview:
            print_and_log(['Launching the preview GUI...'], 'debug', logger)
            mygui = gui.PreviewGUI(new_params)
            shutil.rmtree(tmp_path_loc)
        elif result:
            data_file = params.get_data_file()
            print_and_log(['Launching the result GUI...'], 'debug', logger)
            mygui = gui.PreviewGUI(params, show_fit=True)
        sys.exit(app.exec_())
Пример #29
0
def main(argv=None):

    if argv is None:
        argv = sys.argv[1:]

    header = get_colored_header()
    parser = argparse.ArgumentParser(
        description=header, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('datafile', help='data file')
    parser.add_argument('-e',
                        '--extension',
                        help='extension to consider for visualization',
                        default='')

    if len(argv) == 0:
        parser.print_help()
        sys.exit()

    args = parser.parse_args(argv)

    filename = os.path.abspath(args.datafile)
    extension = args.extension
    params = CircusParser(filename)
    if os.path.exists(params.logfile):
        os.remove(params.logfile)
    logger = init_logging(params.logfile)
    logger = logging.getLogger(__name__)

    mytest = StrictVersion(phycontrib.__version__) >= StrictVersion("1.0.12")
    if not mytest:
        print_and_log(
            ['You need to update phy-contrib to the latest git version'],
            'error', logger)
        sys.exit(1)

    if not test_patch_for_similarities(params, extension):
        print_and_log(
            ['You should re-export the data because of a fix in 0.6'], 'error',
            logger)
        continue_anyway = query_yes_no(
            Fore.WHITE + "Continue anyway (results may not be fully correct)?",
            default=None)
        if not continue_anyway:
            sys.exit(1)

    data_file = params.get_data_file()
    data_dtype = data_file.data_dtype
    if data_file.params.has_key('data_offset'):
        data_offset = data_file.data_offset
    else:
        data_offset = 0

    file_format = data_file.description
    file_out_suff = params.get('data', 'file_out_suff')

    if file_format not in supported_by_phy:
        print_and_log([
            "File format %s is not supported by phy. TraceView disabled" %
            file_format
        ], 'info', logger)

    if numpy.iterable(data_file.gain):
        print_and_log(
            ['Multiple gains are not supported, using a default value of 1'],
            'info', logger)
        gain = 1
    else:
        if data_file.gain != 1:
            print_and_log([
                "Gain of %g is not supported by phy. Expecting a scaling mismatch"
                % data_file.gain
            ], 'info', logger)
            gain = data_file.gain

    probe = params.probe
    if extension != '':
        extension = '-' + extension
    output_path = params.get('data', 'file_out_suff') + extension + '.GUI'

    if not os.path.exists(output_path):
        print_and_log(
            ['Data should be first exported with the converting method!'],
            'error', logger)
    else:

        print_and_log(["Launching the phy GUI..."], 'info', logger)

        gui_params = {}
        if file_format in supported_by_phy:
            if not params.getboolean('data', 'overwrite'):
                gui_params['dat_path'] = params.get('data',
                                                    'data_file_no_overwrite')
            else:
                if params.get('data', 'stream_mode') == 'multi-files':
                    data_file = params.get_data_file(source=True,
                                                     has_been_created=False)
                    gui_params['dat_path'] = ' '.join(
                        data_file.get_file_names())
                else:
                    gui_params['dat_path'] = params.get('data', 'data_file')
        else:
            gui_params['dat_path'] = 'giverandomname.dat'
        gui_params['n_channels_dat'] = params.nb_channels
        gui_params['n_features_per_channel'] = 5
        gui_params['dtype'] = data_dtype
        gui_params['offset'] = data_offset
        gui_params['sample_rate'] = params.rate
        gui_params['dir_path'] = output_path
        gui_params['hp_filtered'] = True

        f = open(os.path.join(output_path, 'params.py'), 'w')
        for key, value in gui_params.items():
            if key in ['dir_path', 'dat_path', 'dtype']:
                f.write('%s = "%s"\n' % (key, value))
            else:
                f.write("%s = %s\n" % (key, value))
        f.close()
        os.chdir(output_path)
        create_app()
        controller = TemplateController(**gui_params)
        gui = controller.create_gui()

        gui.show()
        run_app()
        gui.close()
        del gui
Пример #30
0
class TestFitting(unittest.TestCase):
    def setUp(self):
        self.all_spikes = None
        self.max_chunk = '100'
        dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
        self.path = os.path.join(dirname, 'synthetic')
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        self.file_name = os.path.join(self.path, 'fitting.dat')
        self.source_dataset = get_dataset(self)
        if not os.path.exists(self.file_name):
            mpi_launch('benchmarking', self.source_dataset, 2, 0, 'False',
                       self.file_name, 'fitting', 1)
            mpi_launch('whitening', self.file_name, 2, 0, 'False')
        self.parser = CircusParser(self.file_name)

    def test_fitting_one_CPU(self):
        self.parser.write('fitting', 'max_chunk', self.max_chunk)
        mpi_launch('fitting', self.file_name, 1, 0, 'False')
        self.parser.write('fitting', 'max_chunk', 'inf')
        res = get_performance(self.file_name, 'one_CPU')
        if self.all_spikes is None:
            self.all_spikes = res
        assert numpy.all(self.all_spikes == res)

    def test_fitting_two_CPUs(self):
        self.parser.write('fitting', 'max_chunk', self.max_chunk)
        mpi_launch('fitting', self.file_name, 2, 0, 'False')
        self.parser.write('fitting', 'max_chunk', 'inf')
        res = get_performance(self.file_name, 'two_CPU')
        if self.all_spikes is None:
            self.all_spikes = res
        assert numpy.all(self.all_spikes == res)

    def test_fitting_one_GPU(self):
        HAVE_CUDA = False
        try:
            import cudamat
            HAVE_CUDA = True
        except ImportError:
            pass
        if HAVE_CUDA:
            self.parser.write('fitting', 'max_chunk', self.max_chunk)
            mpi_launch('fitting', self.file_name, 1, 0, 'False')
            self.parser.write('fitting', 'max_chunk', 'inf')
            res = get_performance(self.file_name, 'one_GPU')
            if self.all_spikes is None:
                self.all_spikes = res
            assert numpy.all(self.all_spikes == res)

    def test_fitting_large_chunks(self):
        self.parser.write('fitting', 'chunk_size', '1')
        self.parser.write('fitting', 'max_chunk',
                          str(int(self.max_chunk) // 2))
        mpi_launch('fitting', self.file_name, 2, 0, 'False')
        self.parser.write('fitting', 'max_chunk', 'inf')
        self.parser.write('fitting', 'chunk_size', '0.5')
        res = get_performance(self.file_name, 'large_chunks')
        if self.all_spikes is None:
            self.all_spikes = res
        assert numpy.all(self.all_spikes == res)