Example #1
0
    def __init__(self,
                 tmp_dir,
                 spikes_file_csv=None,
                 spikes_file=None,
                 spikes_file_nwb=None,
                 spikes_sort_order=None):
        def _get_path(file_name):
            # Unless file-name is an absolute path then it should be placed in the $OUTPUT_DIR
            if file_name is None:
                return None
            return file_name if os.path.isabs(file_name) else os.path.join(
                tmp_dir, file_name)

        self._csv_fname = _get_path(spikes_file_csv)
        self._h5_fname = _get_path(spikes_file)
        self._nwb_fname = _get_path(spikes_file_nwb)

        self._tmp_dir = tmp_dir
        self._tmp_file_base = 'tmp_spike_times'
        self._spike_labels = os.path.join(self._tmp_dir, self._tmp_file_base)

        self._spike_writer = SpikeTrains(cache_dir=tmp_dir)
        # self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir, mpi_rank=MPI_RANK, mpi_size=N_HOSTS)
        self._spike_writer.delimiter = '\t'
        self._spike_writer.gid_col = 0
        self._spike_writer.time_col = 1
        self._sort_order = sort_order.none if not spikes_sort_order else sort_order_lu[
            spikes_sort_order]

        self._spike_detector = None
Example #2
0
    def __init__(self, spikes_file_csv=None, spikes_file=None, spikes_file_nwb=None, tmp_dir='output',
                 sort_order='node_id'):
        def _get_file_path(file_name):
            if file_name is None or os.path.isabs(file_name):
                return file_name

            else:
                rel_tmp = os.path.realpath(tmp_dir)
                rel_fname = os.path.realpath(file_name)
                if not rel_fname.startswith(rel_tmp):
                    return os.path.join(tmp_dir, file_name)
                else:
                    return file_name

        self._csv_fname = _get_file_path(spikes_file_csv)
        self._save_csv = spikes_file_csv is not None

        self._h5_fname = _get_file_path(spikes_file)
        self._save_h5 = spikes_file is not None

        self._nwb_fname = _get_file_path(spikes_file_nwb)
        self._save_nwb = spikes_file_nwb is not None

        self._tmpdir = tmp_dir

        # self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir)
        self._spike_writer = SpikeTrains(cache_dir=tmp_dir)
        self._sort_order = sort_order_lu[sort_order]
def save_in_mem():
    np.random.seed(1000)
    st = SpikeTrains(adaptor=STMPIBuffer() if size > 1 else STMemoryBuffer())
    for i in range(rank, N, size):
        # if i % 1000 == 0:
        #     print('{} > {}'.format(rank, i))
        st.add_spikes(node_ids=i,
                      population='v1',
                      timestamps=np.random.uniform(
                          0.0,
                          3000.0,
                          size=np.random.randint(n_spikes_avg - n_spikes_std,
                                                 n_spikes_avg + n_spikes_std)))

    n_spikes = st.n_spikes('v1')
    if rank == 0:
        # print('HERE')
        print(
            'finished In Memory Version, saving {} spikes...'.format(n_spikes))
        sys.stdout.flush()

    # write_csv('check_out.csv', st, sort_order=sort_order.by_id)
    start = timer()
    # mem = memory_usage((write_csv, ('check_out.csv', st))) #, {'sort_order': sort_order.by_id}))
    mem = memory_usage((write_sonata, ('check_out.h5', st)))
    run_time = timer() - start
    for r in range(size):
        if rank == r:
            print('rank {} = {} MB, {} seconds'.format(rank, max(mem),
                                                       run_time))
            sys.stdout.flush()
        comm.Barrier()
    comm.Barrier()
def save_on_diskv2():
    np.random.seed(1000)
    st = SpikeTrains(adaptor=STCSVMPIBufferV2(
        cache_dir='tmp_spikes') if size > 1 else STCSVBuffer(
            cache_dir='tmp_spikes'))
    for i in range(rank, N, size):
        st.add_spikes(node_ids=i,
                      population='v1',
                      timestamps=np.random.uniform(
                          0.0,
                          3000.0,
                          size=np.random.randint(n_spikes_avg - n_spikes_std,
                                                 n_spikes_avg + n_spikes_std)))

    if rank == 0:
        print('finished On Disk Version2, saving spikes...')

    start = timer()
    # write_csv_old('check_out1.csv', st, sort_order=sort_order.by_id)
    # mem = memory_usage((write_csv, ('check_out_origv2.csv', st), {'sort_order': sort_order.by_id}))
    mem = memory_usage(
        (write_sonata, ('check_out_origv2.h5',
                        st)))  # , {'sort_order': sort_order.by_id}))
    run_time = timer() - start
    for r in range(size):
        if rank == r:
            print('rank {} = {} MB, {} seconds'.format(rank, max(mem),
                                                       run_time))
            sys.stdout.flush()
        comm.Barrier()
    comm.Barrier()
def add_spikes_mem():
    np.random.seed(1000)
    start = timer()
    st = SpikeTrains(adaptor=STMPIBuffer() if size > 1 else STMemoryBuffer())
    for i in range(rank, N, size):
        st.add_spikes(node_ids=i,
                      population='v1',
                      timestamps=np.random.uniform(
                          0.0,
                          3000.0,
                          size=np.random.randint(n_spikes_avg - n_spikes_std,
                                                 n_spikes_avg + n_spikes_std)))
    comm.Barrier()
    end = timer()

    n_spikes = st.n_spikes('v1')
    if rank == 0:
        print('In Memory took {} seconds to add {} spikes'.format(
            end - start, n_spikes))
    comm.Barrier()
    del st

    adaptor = STMPIBuffer() if size > 1 else STMemoryBuffer()
    # mem = memory_usage((add_spikes_tst, (adaptor, )))
    mem = memory_usage((add_spike_tst, (adaptor, )))
    for r in range(size):
        if rank == r:
            print('rank {} = {} MB'.format(rank, max(mem)))
            sys.stdout.flush()
        comm.Barrier()
    comm.Barrier()
Example #6
0
    def __init__(self,
                 tmp_dir,
                 spikes_file_csv=None,
                 spikes_file=None,
                 spikes_file_nwb=None,
                 spikes_sort_order=None):
        # TODO: Have option to turn off caching spikes to csv.
        def _file_path(file_name):
            if file_name is None:
                return None
            return file_name if os.path.isabs(file_name) else os.path.join(
                tmp_dir, file_name)

        self._csv_fname = _file_path(spikes_file_csv)
        self._save_csv = spikes_file_csv is not None

        self._h5_fname = _file_path(spikes_file)
        self._save_h5 = spikes_file is not None

        self._nwb_fname = _file_path(spikes_file_nwb)
        self._save_nwb = spikes_file_nwb is not None

        self._tmpdir = tmp_dir
        self._sort_order = sort_order_lu.get(spikes_sort_order,
                                             sort_order.none)

        self._spike_writer = SpikeTrains(cache_dir=tmp_dir)
        self._gid_map = None
Example #7
0
def load_spike_trains(file_path):
    cpath = os.path.dirname(os.path.realpath(__file__))
    file_path = os.path.join(cpath, file_path)
    if file_path.endswith('.csv'):
        return SpikeTrains.from_csv(file_path)

    elif file_path.endswith('.h5'):
        return SpikeTrains.from_sonata(file_path)

    elif file_path.endswith('.nwb'):
        return SpikeTrains.from_nwb(file_path)
def add_spike_tst(adaptor):
    np.random.seed(1000)
    st = SpikeTrains(adaptor)
    for node_id in range(rank, N, size):
        for ts in np.random.uniform(0.0,
                                    3000.0,
                                    size=np.random.randint(
                                        n_spikes_avg - n_spikes_std,
                                        n_spikes_avg + n_spikes_std)):
            st.add_spike(node_id=node_id, population='v1', timestamp=ts)
    comm.Barrier()
    return st
Example #9
0
def test_multipop_with_default(path):
    path = full_path(path)
    st = SpikeTrains.from_sonata(path, population='tw')
    assert ('tw' in st.populations and 'lgn' not in st.populations)
    n1_tw_ts = st.get_times(node_id=0, population='tw')
    assert (len(n1_tw_ts) > 0)
    assert (np.all(n1_tw_ts == st.get_times(node_id=0)))
Example #10
0
def run(config_file=None, sim=None, conf=None):
    if config_file is not None:
        conf = bionet.Config.from_json(config_file, validate=True)
        dt = conf['run']['dt']
        n_steps = np.ceil(conf['run']['tstop'] / dt + 1).astype(np.int)
        fbmod = None
    if sim is not None:
        n_steps = sim.n_steps
        dt = sim.dt
        fbmod = sim._sim_mods[[
            isinstance(mod, FeedbackLoop) for mod in sim._sim_mods
        ].index(True)]
    output_dir = conf.output_dir
    print(n_steps, dt)

    spikes_df = pd.read_csv(os.path.join(output_dir, 'spikes.csv'), sep=' ')
    print(spikes_df['node_ids'].unique())
    spike_trains = SpikeTrains.from_sonata(
        os.path.join(output_dir, 'spikes.h5'))

    #plotting
    window_size = 1000
    pops = ['Bladaff', 'PGN', 'PAGaff', 'EUSmn', 'INmminus', 'IND']
    windows = [window_size] * len(pops)
    means = {}
    stdevs = {}
    for pop, win in zip(pops, windows):
        means[pop], stdevs[pop] = plotting_calculator(spike_trains, n_steps,
                                                      dt, win, gids, num, pop)

    plot_figure(means, stdevs, n_steps, dt, tstep=window_size, fbmod=fbmod)
Example #11
0
def plotting_calculator(plotting_dict,
                        window_size,
                        arange1,
                        arange2,
                        arange3=0,
                        multiplier=1):
    # Plot PGN firing rate
    # spikes_df = pd.read_csv('output/spikes.csv', sep=' ')
    spike_trains = SpikeTrains.from_sonata('output/spikes.h5')
    means = np.zeros(plotting_dict['n_steps'])
    stdevs = np.zeros(plotting_dict['n_steps'])
    fr_conv = np.zeros((arange2, plotting_dict['n_steps']))

    for gid in np.arange(arange1, arange3 + arange2):
        spikes = np.zeros(plotting_dict['n_steps'], dtype=np.int)
        if len(spike_trains.get_times(gid)) > 0:
            spikes[(spike_trains.get_times(gid) / plotting_dict['dt']).astype(
                np.int)] = 1
        window = np.ones(window_size)

        frs = np.convolve(spikes, window)

        for n in range(len(means)):
            means[n] += frs[n]
            if arange1 > 0:
                fr_conv[gid % arange1][n] = frs[n]
            else:
                fr_conv[gid][n] = frs[n]

    for n in range(len(means)):
        means[n] /= arange2 * multiplier
        stdevs[n] = np.std(fr_conv[:, n])

    return means, stdevs
Example #12
0
def test_single_populations(path):
    path = full_path(path)
    st = SpikeTrains.from_sonata(path)
    assert (st.populations == ['v1'])
    node0_timestamps = st.get_times(node_id=0, population='v1')

    assert (np.all(st.get_times(node_id=0) == node0_timestamps))
    assert (st.get_times(node_id=0, population='should_not_work') == [])
Example #13
0
def test_old_populations(path):
    path = full_path(path)
    st = SpikeTrains.from_sonata(full_path(path))
    assert (st.populations == [pop_na])
    node0_timestamps = st.get_times(node_id=0, population=pop_na)
    assert (len(node0_timestamps) > 0)
    assert (np.all(st.get_times(node_id=0) == node0_timestamps))
    assert (np.all(
        st.get_times(node_id=0, population='should_still_work') ==
        node0_timestamps))
def add_spikes_diskv2():
    np.random.seed(1000)
    start = timer()
    st = SpikeTrains(adaptor=STCSVMPIBufferV2(
        cache_dir='tmp_spikes') if size > 1 else STCSVBuffer(
            cache_dir='tmp_spikes'))
    for i in range(rank, N, size):
        st.add_spikes(node_ids=i,
                      population='v1',
                      timestamps=np.random.uniform(
                          0.0,
                          3000.0,
                          size=np.random.randint(n_spikes_avg - n_spikes_std,
                                                 n_spikes_avg + n_spikes_std)))
    comm.Barrier()
    end = timer()

    if rank == 0:
        #n_spikes = st.n_spikes()
        print('DiskV2 took {} seconds to add spikes'.format(end - start))
Example #15
0
    def __spike_writer(self):
        # from bmtk.utils.reports.spike_trains import SpikeTrains
        from bmtk.utils.reports.spike_trains.spike_train_buffer import STBufferedWriter as SpikeTrains
        tmpdir = tempfile.mkdtemp()

        spike_trains = SpikeTrains(tmpdir)
        for node_id in range(1000):
            spike_trains.add_spike(node_id, 1.0)

        for node_id in range(1000, 3000):
            spike_trains.add_spikes(node_id, np.linspace(0.0, 2000.0, 1000))

        spike_trains.flush()
Example #16
0
 def save_aff(self, path):
     populations = {
         'Bladaff': '_high_level_neurons',
         'PAGaff': '_pag_neurons'
     }
     for pop_name, node_name in populations.items():
         spiketrains = SpikeTrains(population=pop_name)
         for gid in getattr(self, node_name):
             spiketrains.add_spikes(gid,
                                    self._spike_events[gid],
                                    population=pop_name)
         spiketrains.to_sonata(os.path.join(path, pop_name + '_spikes.h5'))
         spiketrains.to_csv(os.path.join(path, pop_name + '_spikes.csv'))
Example #17
0
    def __init__(self, spikes_file_csv=None, spikes_file=None, spikes_file_nwb=None, tmp_dir='output'):
        def _get_file_path(file_name):
            if file_name is None or os.path.isabs(file_name):
                return file_name

            return os.path.join(tmp_dir, file_name)

        self._csv_fname = _get_file_path(spikes_file_csv)
        self._save_csv = spikes_file_csv is not None

        self._h5_fname = _get_file_path(spikes_file)
        self._save_h5 = spikes_file is not None

        self._nwb_fname = _get_file_path(spikes_file_nwb)
        self._save_nwb = spikes_file_nwb is not None

        self._tmpdir = tmp_dir

        # self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir)
        self._spike_writer = SpikeTrains(cache_dir=tmp_dir)
Example #18
0
def test_multi_populations(path):
    path = full_path(path)
    st = SpikeTrains.from_sonata(path)
    assert ('tw' in st.populations and 'lgn' in st.populations)
    n1_tw_ts = st.get_times(node_id=0, population='tw')
    n1_lgn_ts = st.get_times(node_id=0, population='lgn')

    assert (len(n1_tw_ts) > 0)
    assert (len(n1_lgn_ts) > 0)
    assert (not np.array_equal(n1_tw_ts, n1_lgn_ts)
            )  # (np.any(n1_tw_ts != n1_lgn_ts))
    assert (st.get_times(node_id=0, population='other') == [])
Example #19
0
def test_empty_spikes():
    st = SpikeTrains(adaptor=spike_train_buffer.STMemoryBuffer())
    output_path = full_path('output/tmpspikes.h5')
    st.to_sonata(path=output_path)
    st.close()

    st_empty = SpikeTrains.from_sonata(output_path)
    assert (st_empty.populations == [])
    assert (st_empty.n_spikes() == 0)
    assert (list(st_empty.spikes()) == [])
    os.remove(output_path)
Example #20
0
    def __init__(self,
                 tmp_dir,
                 spikes_file_csv=None,
                 spikes_file=None,
                 spikes_file_nwb=None,
                 spikes_sort_order=None,
                 cache_to_disk=True):
        def _get_path(file_name):
            # Unless file-name is an absolute path then it should be placed in the $OUTPUT_DIR
            if file_name is None:
                return None

            if os.path.isabs(file_name):
                return file_name
            else:
                abs_tmp = os.path.abspath(tmp_dir)
                abs_fname = os.path.abspath(file_name)
                if not abs_fname.startswith(abs_tmp):
                    return os.path.join(tmp_dir, file_name)
                else:
                    return file_name

        self._csv_fname = _get_path(spikes_file_csv)
        self._h5_fname = _get_path(spikes_file)
        self._nwb_fname = _get_path(spikes_file_nwb)

        self._tmp_dir = tmp_dir
        self._tmp_file_base = 'tmp_spike_times'
        self._spike_labels = os.path.join(self._tmp_dir, self._tmp_file_base)

        self._spike_writer = SpikeTrains(cache_dir=tmp_dir,
                                         cache_to_disk=cache_to_disk)
        self._spike_writer.delimiter = '\t'
        self._spike_writer.gid_col = 0
        self._spike_writer.time_col = 1
        self._sort_order = sort_order.none if not spikes_sort_order else sort_order_lu[
            spikes_sort_order]

        self._spike_detector = None
Example #21
0
def test_subset():
    st1 = SpikeTrains()
    st1.add_spikes(node_ids=0,
                   population='V1',
                   timestamps=[0.1, 0.2, 0.3, 0.4])
    st1.add_spikes(node_ids=1, population='V1', timestamps=[1.0])

    st2 = SpikeTrains()
    st2.add_spikes(node_ids=1, population='V1', timestamps=[1.0])
    st2.add_spikes(node_ids=0,
                   population='V1',
                   timestamps=[0.3, 0.2, 0.1, 0.4, 0.5])
    st2.add_spike(node_id=2, population='V1', timestamp=0.5)
    st2.add_spikes(node_ids=0,
                   population='V2',
                   timestamps=np.linspace(0.0, 1.0, 11))

    assert (st1 != st2)
    assert (st1 < st2)
    assert (st2 > st1)
    assert (st1 <= st2)
    assert (st2 >= st1)
Example #22
0
class SpikesMod(object):
    """Module use for saving spikes

    """
    def __init__(self,
                 tmp_dir,
                 spikes_file_csv=None,
                 spikes_file=None,
                 spikes_file_nwb=None,
                 spikes_sort_order=None,
                 cache_to_disk=True):
        def _get_path(file_name):
            # Unless file-name is an absolute path then it should be placed in the $OUTPUT_DIR
            if file_name is None:
                return None

            if os.path.isabs(file_name):
                return file_name
            else:
                abs_tmp = os.path.abspath(tmp_dir)
                abs_fname = os.path.abspath(file_name)
                if not abs_fname.startswith(abs_tmp):
                    return os.path.join(tmp_dir, file_name)
                else:
                    return file_name
            # return file_name if os.path.isabs(file_name) else os.path.join(tmp_dir, file_name)

        self._csv_fname = _get_path(spikes_file_csv)
        self._h5_fname = _get_path(spikes_file)
        self._nwb_fname = _get_path(spikes_file_nwb)

        self._tmp_dir = tmp_dir
        self._tmp_file_base = 'tmp_spike_times'
        self._spike_labels = os.path.join(self._tmp_dir, self._tmp_file_base)

        self._spike_writer = SpikeTrains(cache_dir=tmp_dir,
                                         cache_to_disk=cache_to_disk)
        # self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir, mpi_rank=MPI_RANK, mpi_size=N_HOSTS)
        self._spike_writer.delimiter = '\t'
        self._spike_writer.gid_col = 0
        self._spike_writer.time_col = 1
        self._sort_order = sort_order.none if not spikes_sort_order else sort_order_lu[
            spikes_sort_order]

        self._spike_detector = None

    def initialize(self, sim):
        self._spike_detector = nest.Create(
            "spike_detector", 1, {
                'label': self._spike_labels,
                'withtime': True,
                'withgid': True,
                'to_file': True
            })

        nest.Connect(sim.net.gid_map.gids, self._spike_detector)
        #print(sim.net.gid_map.gids)
        #exit()
        #for pop_name, pop in sim._graph._nestid2nodeid_map.items():
        #    nest.Connect(list(pop.keys()), self._spike_detector)

    def finalize(self, sim):
        # convert NEST gdf files into SONATA spikes/ format
        # TODO: Create a gdf_adaptor in bmtk/utils/reports/spike_trains to improve conversion speed.
        if MPI_RANK == 0:
            for gdf_file in glob.glob(self._spike_labels + '*.gdf'):
                self.__parse_gdf(gdf_file, sim.net.gid_map)
                # self._spike_writer.add_spikes_file(gdf_file)
        io.barrier()

        if self._csv_fname is not None:
            self._spike_writer.to_csv(self._csv_fname,
                                      sort_order=self._sort_order)
            # io.barrier()

        if self._h5_fname is not None:
            # TODO: reimplement with pandas
            self._spike_writer.to_sonata(self._h5_fname,
                                         sort_order=self._sort_order)
            # io.barrier()

        if self._nwb_fname is not None:
            self._spike_writer.to_nwb(self._nwb_fname,
                                      sort_order=self._sort_order)
            # io.barrier()

        self._spike_writer.close()
        self.__clean_gdf_files()

    def __parse_gdf(self, gdf_path, gid_map):
        with open(gdf_path, 'r') as csv_file:
            #print(gdf_path)
            csv_reader = csv.reader(csv_file, delimiter='\t')
            for r in csv_reader:
                #print(r)
                p = gid_map.get_pool_id(int(r[0]))
                self._spike_writer.add_spike(node_id=p.node_id,
                                             timestamp=float(r[1]),
                                             population=p.population)

    def __clean_gdf_files(self):
        if MPI_RANK == 0:
            for gdf_file in glob.glob(self._spike_labels + '*.gdf'):
                os.remove(gdf_file)
Example #23
0
class SpikesGenerator(SimModule):
    def __init__(self, spikes_file_csv=None, spikes_file=None, spikes_file_nwb=None, tmp_dir='output'):
        def _get_file_path(file_name):
            if file_name is None or os.path.isabs(file_name):
                return file_name

            return os.path.join(tmp_dir, file_name)

        self._csv_fname = _get_file_path(spikes_file_csv)
        self._save_csv = spikes_file_csv is not None

        self._h5_fname = _get_file_path(spikes_file)
        self._save_h5 = spikes_file is not None

        self._nwb_fname = _get_file_path(spikes_file_nwb)
        self._save_nwb = spikes_file_nwb is not None

        self._tmpdir = tmp_dir

        # self._spike_writer = SpikeTrainWriter(tmp_dir=tmp_dir)
        self._spike_writer = SpikeTrains(cache_dir=tmp_dir)

    def save(self, sim, cell, times, rates):
        try:
            spike_trains = np.array(f_rate_to_spike_train(times*1000.0, rates, np.random.randint(10000),
                                                          1000.*min(times), 1000.*max(times), 0.1))
        except:
            # convert to milliseconds and hence the multiplication by 1000
            spike_trains = 1000.0*np.array(pg.generate_inhomogenous_poisson(times, rates,
                                                                            seed=np.random.randint(10000)))

        # self._spike_writer.add_spikes(times=spike_trains, gid=gid)
        self._spike_writer.add_spikes(node_ids=cell.gid, timestamps=spike_trains, population=cell.population)


    def finalize(self, sim):
        self._spike_writer.flush()

        if self._save_csv:
            self._spike_writer.to_csv(self._csv_fname)

        if self._save_h5:
            self._spike_writer.to_sonata(self._h5_fname)

        if self._save_nwb:
            self._spike_writer.to_nwb(self._nwb_fname)

        self._spike_writer.close()
Example #24
0
class SpikesMod(SimulatorMod):
    """Module use for saving spikes

    """
    def __init__(self,
                 tmp_dir,
                 spikes_file_csv=None,
                 spikes_file=None,
                 spikes_file_nwb=None,
                 spikes_sort_order=None):
        # TODO: Have option to turn off caching spikes to csv.
        def _file_path(file_name):
            if file_name is None:
                return None
            return file_name if os.path.isabs(file_name) else os.path.join(
                tmp_dir, file_name)

        self._csv_fname = _file_path(spikes_file_csv)
        self._save_csv = spikes_file_csv is not None

        self._h5_fname = _file_path(spikes_file)
        self._save_h5 = spikes_file is not None

        self._nwb_fname = _file_path(spikes_file_nwb)
        self._save_nwb = spikes_file_nwb is not None

        self._tmpdir = tmp_dir
        self._sort_order = sort_order_lu.get(spikes_sort_order,
                                             sort_order.none)

        self._spike_writer = SpikeTrains(cache_dir=tmp_dir)
        self._gid_map = None

    def initialize(self, sim):
        # TODO: since it's possible that other modules may need to access spikes, set_spikes_recordings() should
        # probably be called in the simulator itself.
        sim.set_spikes_recording()
        self._gid_map = sim.net.gid_pool

    def block(self, sim, block_interval):
        # take spikes from Simulator spikes vector and save to the tmp file
        for gid, tVec in sim.spikes_table.items():
            pop_id = self._gid_map.get_pool_id(gid)
            for t in tVec:
                self._spike_writer.add_spike(node_id=pop_id.node_id,
                                             timestamp=t,
                                             population=pop_id.population)

        pc.barrier()  # wait until all ranks have been saved
        sim.set_spikes_recording()  # reset recording vector

    def finalize(self, sim):
        self._spike_writer.flush()
        pc.barrier()

        if self._save_csv:
            self._spike_writer.to_csv(self._csv_fname,
                                      sort_order=self._sort_order)
            pc.barrier()

        if self._save_h5:
            self._spike_writer.to_sonata(self._h5_fname,
                                         sort_order=self._sort_order)
            pc.barrier()

        if self._save_nwb:
            self._spike_writer.to_nwb(self._nwb_fname,
                                      sort_order=self._sort_order)
            pc.barrier()

        self._spike_writer.close()
Example #25
0
def test_equals():
    st1 = SpikeTrains()
    st1.add_spikes(node_ids=0,
                   population='V1',
                   timestamps=[0.1, 0.2, 0.3, 0.4])
    st1.add_spikes(node_ids=1, population='V1', timestamps=[1.0])

    st2 = SpikeTrains()
    st2.add_spikes(node_ids=1, population='V1', timestamps=[1.0])
    st2.add_spikes(node_ids=0,
                   population='V1',
                   timestamps=[0.3, 0.2, 0.1, 0.4])

    assert (st1 == st2)
    assert (st1 <= st2)
    assert (st1 >= st2)
    assert (not st1 != st2)
Example #26
0
    def from_config(cls, configure, graph):
        # load the json file or object
        if isinstance(configure, string_types):
            config = cfg.from_json(configure, validate=True)
        elif isinstance(configure, dict):
            config = configure
        else:
            raise Exception('Could not convert {} (type "{}") to json.'.format(
                configure, type(configure)))

        if 'run' not in config:
            raise Exception(
                'Json file is missing "run" entry. Unable to build Bionetwork.'
            )
        run_dict = config['run']

        # Get network parameters
        # step time (dt) is set in the kernel and should be passed
        overwrite = run_dict[
            'overwrite_output_dir'] if 'overwrite_output_dir' in run_dict else True
        print_time = run_dict[
            'print_time'] if 'print_time' in run_dict else False
        dt = run_dict['dt']  # TODO: make sure dt exists
        tstop = float(config.tstop) / 1000.0
        network = cls(graph, dt=config.dt, tstop=tstop, overwrite=overwrite)

        if 'output_dir' in config['output']:
            network.output_dir = config['output']['output_dir']

        # network.spikes_file = config['output']['spikes_ascii']

        if 'block_run' in run_dict and run_dict['block_run']:
            if 'block_size' not in run_dict:
                raise Exception(
                    '"block_run" is set to True but "block_size" not found.')
            network._block_size = run_dict['block_size']

        if 'duration' in run_dict:
            network.duration = run_dict['duration']

        graph.io.log_info('Building cells.')
        graph.build_nodes()

        graph.io.log_info('Building recurrent connections')
        graph.build_recurrent_edges()

        for sim_input in inputs.from_config(config):
            node_set = graph.get_node_set(sim_input.node_set)
            if sim_input.input_type == 'spikes':
                path = sim_input.params['input_file']
                spikes = SpikeTrains.load(path=path,
                                          file_type=sim_input.module,
                                          **sim_input.params)
                graph.io.log_info(
                    'Build virtual cell stimulations for {}'.format(
                        sim_input.name))
                graph.add_spike_trains(spikes, node_set)
            else:
                graph.io.log_info(
                    'Build virtual cell stimulations for {}'.format(
                        sim_input.name))
                rates = firing_rates.RatesInput(sim_input.params)
                graph.add_rates(rates, node_set)

        # Create the output file
        if 'output' in config:
            out_dict = config['output']

            rates_file = out_dict.get('rates_file', None)
            if rates_file is not None:
                rates_file = rates_file if os.path.isabs(
                    rates_file) else os.path.join(config.output_dir,
                                                  rates_file)
                # create directory if required
                network.rates_file = rates_file
                parent_dir = os.path.dirname(rates_file)
                if not os.path.exists(parent_dir):
                    os.makedirs(parent_dir)

            if 'log_file' in out_dict:
                log_file = out_dict['log_file']
                network.set_logging(log_file)

        # exit()

        # build the cells
        #io.log('Building cells')
        #network.build_cells()

        # Build internal connections
        #if run_dict['connect_internal']:
        #    io.log('Creating recurrent connections')
        #    network.set_recurrent_connections()

        # Build external connections. Set connection to default True and turn off only if explicitly stated.
        # NOTE: It might be better to set to default off?!?! Need to dicuss what would be more intuitive for the users.
        # TODO: ignore case of network name
        '''
        external_network_settings = {name: True for name in graph.external_networks()}
        if 'connect_external' in run_dict:
            external_network_settings.update(run_dict['connect_external'])
        for netname, connect in external_network_settings.items():
            if connect:
                io.log('Setting external connections for {}'.format(netname))
                network.set_external_connections(netname)

        # Build inputs
        if 'input' in config:
            for netinput in config['input']:
                if netinput['type'] == 'external_spikes' and netinput['format'] == 'nwb' and netinput['active']:
                    network.add_spikes_nwb(netinput['source_nodes'], netinput['file'], netinput['trial'])

            io.log_info('Adding stimulations')
            network.make_stims()
        '''

        graph.io.log_info('Network created.')
        return network
Example #27
0
    def from_config(cls, config, network, set_recordings=True):
        simulation_inputs = inputs.from_config(config)

        # Special case for setting synapses to spontaneously (for a given set of pre-synaptic cell-types). Using this
        # input will change the way the network builds cells/connections and thus needs to be set first.
        for sim_input in simulation_inputs:
            if sim_input.input_type == 'syn_activity':
                network.set_spont_syn_activity(
                    precell_filter=sim_input.params['precell_filter'],
                    timestamps=sim_input.params['timestamps']
                )

        # The network must be built before initializing the simulator because
        # gap junctions must be set up before the simulation is initialized.
        network.io.log_info('Building cells.')
        network.build_nodes()

        network.io.log_info('Building recurrent connections')
        network.build_recurrent_edges()

        sim = cls(network=network,
                  dt=config.dt,
                  tstop=config.tstop,
                  v_init=config.v_init,
                  celsius=config.celsius,
                  nsteps_block=config.block_step)

        # TODO: Need to create a gid selector
        for sim_input in inputs.from_config(config):
            try:
                network.get_node_set(sim_input.node_set)
            except:
                print("Parameter node_set must be given in inputs module of simulation_config file. If unsure of what node_set should be, set it to 'all'.")
            node_set = network.get_node_set(sim_input.node_set)
            if sim_input.input_type == 'spikes':
                io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
                path = sim_input.params['input_file']
                spikes = SpikeTrains.load(path=path, file_type=sim_input.module, **sim_input.params)
                network.add_spike_trains(spikes, node_set)

            elif sim_input.module == "FileIClamp":
                sim.attach_file_current_clamp(sim_input.params["input_file"])

            elif sim_input.module == 'IClamp':
                # TODO: Parse from csv file
                try: 
                    len(sim_input.params['amp'])
                except:
                    sim_input.params['amp']=[float(sim_input.params['amp'])]
                if len(sim_input.params['amp'])>1:
                    sim_input.params['amp']=[float(i) for i in sim_input.params['amp']]

                try: 
                    len(sim_input.params['delay'])
                except:
                    sim_input.params['delay']=[float(sim_input.params['delay'])]
                if len(sim_input.params['delay'])>1:
                    sim_input.params['delay']=[float(i) for i in sim_input.params['delay']]
                
                try: 
                    len(sim_input.params['duration'])
                except:
                    sim_input.params['duration']=[float(sim_input.params['duration'])]
                if len(sim_input.params['duration'])>1:
                    sim_input.params['duration']=[float(i) for i in sim_input.params['duration']]
                    
                amplitude = sim_input.params['amp']
                delay = sim_input.params['delay']
                duration = sim_input.params['duration']

                # specificed for location to place iclamp hobj.<section_name>[<section_index>](<section_dist>). The
                # default is hobj.soma[0](0.5), the center of the soma
                section_name = sim_input.params.get('section_name', 'soma')
                section_index = sim_input.params.get('section_index', 0)
                section_dist = sim_input.params.get('section_dist', 0.5)

                # section_name = section_name if isinstance(section_name, (list, tuple)) else [section_name]
                # section_index = section_index if isinstance(section_index, (list, tuple)) else [section_index]
                # section_dist = section_dist if isinstance(section_dist, (list, tuple)) else [section_dist]

                try:
                    sim_input.params['gids']
                except:
                    sim_input.params['gids'] = None
                if sim_input.params['gids'] is not None:
                    gids = sim_input.params['gids']
                else:
                    gids = list(node_set.gids())

                sim.attach_current_clamp(amplitude, delay, duration, gids, section_name, section_index, section_dist)

            elif sim_input.module == "SEClamp":
                try: 
                    len(sim_input.params['amps'])
                except:
                    sim_input.params['amps']=[float(sim_input.params['amps'])]
                
                try: 
                    len(sim_input.params['durations'])
                except:
                    sim_input.params['durations']=[float(sim_input.params['durations'])]
                    
                amplitudes = sim_input.params['amps']
                durations = sim_input.params['durations']
                rs = None

                if "rs" in sim_input.params.keys():
                    try: 
                        len(sim_input.params['rs'])
                    except:
                        sim_input.params['rs']=[float(sim_input.params['rs'])]
                    if len(sim_input.params['rs'])>1:
                        sim_input.params['rs']=[float(i) for i in sim_input.params['rs']]
                    rs = sim_input.params["rs"]
                                   
                try:
                    sim_input.params['gids']
                except:
                    sim_input.params['gids'] = None
                if sim_input.params['gids'] is not None:
                    gids = sim_input.params['gids']
                else:
                    gids = list(node_set.gids())

                sim.attach_se_voltage_clamp(amplitudes, durations, gids, rs)

            elif sim_input.module == 'xstim':
                sim.add_mod(mods.XStimMod(**sim_input.params))

            elif sim_input.module == 'syn_activity':
                pass

            else:
                io.log_exception('Can not parse input format {}'.format(sim_input.name))

        # Parse the "reports" section of the config and load an associated output module for each report
        sim_reports = reports.from_config(config)
        for report in sim_reports:
            if isinstance(report, reports.SpikesReport):
                mod = mods.SpikesMod(**report.params)

            elif report.module == 'netcon_report':
                mod = mods.NetconReport(**report.params)

            elif isinstance(report, reports.MembraneReport):
                if report.params['sections'] == 'soma':
                    mod = mods.SomaReport(**report.params)

                else:
                    mod = mods.MembraneReport(**report.params)
            elif isinstance(report, reports.ClampReport):
                mod = mods.ClampReport(**report.params)

            elif isinstance(report, reports.ECPReport):
                mod = mods.EcpMod(**report.params)
                # Set up the ability for ecp on all relevant cells
                # TODO: According to spec we need to allow a different subset other than only biophysical cells
                for gid, cell in network.cell_type_maps('biophysical').items():
                    cell.setup_ecp()

            elif report.module == 'save_synapses':
                mod = mods.SaveSynapses(**report.params)

            else:
                # TODO: Allow users to register customized modules using pymodules
                io.log_warning('Unrecognized module {}, skipping.'.format(report.module))
                continue

            sim.add_mod(mod)

        return sim
Example #28
0
from bmtk.simulator import bionet
from bmtk.utils.reports.spike_trains import SpikeTrains

config_file = 'config.json'

conf = bionet.Config.from_json(config_file, validate=True)
#conf.build_env()

graph = bionet.BioNetwork.from_config(conf)
graph.build()
node_props = graph.node_properties()
node_ids = {
    k: v.tolist()
    for k, v in node_props['v1'].groupby('pop_name').groups.items()
}
print(node_ids)

st = SpikeTrains.load('output/spikes.h5')
Example #29
0
class SpikesMod(object):
    """Module use for saving spikes

    """
    def __init__(self,
                 tmp_dir,
                 spikes_file_csv=None,
                 spikes_file=None,
                 spikes_file_nwb=None,
                 spikes_sort_order=None,
                 cache_to_disk=True):
        def _get_path(file_name):
            # Unless file-name is an absolute path then it should be placed in the $OUTPUT_DIR
            if file_name is None:
                return None

            if os.path.isabs(file_name):
                return file_name
            else:
                abs_tmp = os.path.abspath(tmp_dir)
                abs_fname = os.path.abspath(file_name)
                if not abs_fname.startswith(abs_tmp):
                    return os.path.join(tmp_dir, file_name)
                else:
                    return file_name

        self._csv_fname = _get_path(spikes_file_csv)
        self._h5_fname = _get_path(spikes_file)
        self._nwb_fname = _get_path(spikes_file_nwb)

        self._tmp_dir = tmp_dir
        self._tmp_file_base = 'tmp_spike_times'
        self._spike_labels = os.path.join(self._tmp_dir, self._tmp_file_base)

        self._spike_writer = SpikeTrains(cache_dir=tmp_dir,
                                         cache_to_disk=cache_to_disk)
        self._spike_writer.delimiter = '\t'
        self._spike_writer.gid_col = 0
        self._spike_writer.time_col = 1
        self._sort_order = sort_order.none if not spikes_sort_order else sort_order_lu[
            spikes_sort_order]

        self._spike_detector = None

    def initialize(self, sim):
        self._spike_detector = create_spike_detector(self._spike_labels)
        nest.Connect(sim.net.gid_map.gids, self._spike_detector)

    def finalize(self, sim):
        # convert NEST gdf files into SONATA spikes/ format
        # TODO: Create a gdf_adaptor in bmtk/utils/reports/spike_trains to improve conversion speed.
        if MPI_RANK == 0:
            gid_map = sim.net.gid_map
            read_spikes_file(spike_trains_writer=self._spike_writer,
                             gid_map=gid_map,
                             label=self._spike_labels)
        io.barrier()

        if self._csv_fname is not None:
            self._spike_writer.to_csv(self._csv_fname,
                                      sort_order=self._sort_order)
            # io.barrier()

        if self._h5_fname is not None:
            # TODO: reimplement with pandas
            self._spike_writer.to_sonata(self._h5_fname,
                                         sort_order=self._sort_order)
            # io.barrier()

        if self._nwb_fname is not None:
            self._spike_writer.to_nwb(self._nwb_fname,
                                      sort_order=self._sort_order)
            # io.barrier()

        self._spike_writer.close()
        self._clean_files()

    def _clean_files(self):
        if MPI_RANK == 0:
            for nest_file in glob.glob(self._spike_labels + '*.' +
                                       NEST_spikes_file_format):
                os.remove(nest_file)
Example #30
0
    def from_config(cls, config, network, set_recordings=True):
        # TODO: convert from json to sonata config if necessary

        #The network must be built before initializing the simulator because
        #gap junctions must be set up before the simulation is initialized.
        network.io.log_info('Building cells.')
        network.build_nodes()

        network.io.log_info('Building recurrent connections')
        network.build_recurrent_edges()

        sim = cls(network=network,
                  dt=config.dt,
                  tstop=config.tstop,
                  v_init=config.v_init,
                  celsius=config.celsius,
                  nsteps_block=config.block_step)

        # TODO: Need to create a gid selector
        for sim_input in inputs.from_config(config):
            try:
                network.get_node_set(sim_input.node_set)
            except:
                print(
                    "Parameter node_set must be given in inputs module of simulation_config file. If unsure of what node_set should be, set it to 'all'."
                )
            node_set = network.get_node_set(sim_input.node_set)
            if sim_input.input_type == 'spikes':
                io.log_info('Building virtual cell stimulations for {}'.format(
                    sim_input.name))
                path = sim_input.params['input_file']
                spikes = SpikeTrains.load(path=path,
                                          file_type=sim_input.module,
                                          **sim_input.params)
                network.add_spike_trains(spikes, node_set)

            elif sim_input.module == "FileIClamp":
                sim.attach_file_current_clamp(sim_input.params["input_file"])

            elif sim_input.module == 'IClamp':
                # TODO: Parse from csv file
                try:
                    len(sim_input.params['amp'])
                except:
                    sim_input.params['amp'] = [float(sim_input.params['amp'])]
                if len(sim_input.params['amp']) > 1:
                    sim_input.params['amp'] = [
                        float(i) for i in sim_input.params['amp']
                    ]

                try:
                    len(sim_input.params['delay'])
                except:
                    sim_input.params['delay'] = [
                        float(sim_input.params['delay'])
                    ]
                if len(sim_input.params['delay']) > 1:
                    sim_input.params['delay'] = [
                        float(i) for i in sim_input.params['delay']
                    ]

                try:
                    len(sim_input.params['duration'])
                except:
                    sim_input.params['duration'] = [
                        float(sim_input.params['duration'])
                    ]
                if len(sim_input.params['duration']) > 1:
                    sim_input.params['duration'] = [
                        float(i) for i in sim_input.params['duration']
                    ]

                amplitude = sim_input.params['amp']
                delay = sim_input.params['delay']
                duration = sim_input.params['duration']

                try:
                    sim_input.params['gids']
                except:
                    sim_input.params['gids'] = None
                if sim_input.params['gids'] is not None:
                    gids = sim_input.params['gids']
                else:
                    gids = list(node_set.gids())

                sim.attach_current_clamp(amplitude, delay, duration, gids)

            elif sim_input.module == "SEClamp":
                try:
                    len(sim_input.params['amps'])
                except:
                    sim_input.params['amps'] = [
                        float(sim_input.params['amps'])
                    ]

                try:
                    len(sim_input.params['durations'])
                except:
                    sim_input.params['durations'] = [
                        float(sim_input.params['durations'])
                    ]

                amplitudes = sim_input.params['amps']
                durations = sim_input.params['durations']
                rs = None

                if "rs" in sim_input.params.keys():
                    try:
                        len(sim_input.params['rs'])
                    except:
                        sim_input.params['rs'] = [
                            float(sim_input.params['rs'])
                        ]
                    if len(sim_input.params['rs']) > 1:
                        sim_input.params['rs'] = [
                            float(i) for i in sim_input.params['rs']
                        ]
                    rs = sim_input.params["rs"]

                try:
                    sim_input.params['gids']
                except:
                    sim_input.params['gids'] = None
                if sim_input.params['gids'] is not None:
                    gids = sim_input.params['gids']
                else:
                    gids = list(node_set.gids())

                sim.attach_se_voltage_clamp(amplitudes, durations, gids, rs)

            elif sim_input.module == 'xstim':
                sim.add_mod(mods.XStimMod(**sim_input.params))

            else:
                io.log_exception('Can not parse input format {}'.format(
                    sim_input.name))

        # Parse the "reports" section of the config and load an associated output module for each report
        sim_reports = reports.from_config(config)
        for report in sim_reports:
            if isinstance(report, reports.SpikesReport):
                mod = mods.SpikesMod(**report.params)

            elif report.module == 'netcon_report':
                mod = mods.NetconReport(**report.params)

            elif isinstance(report, reports.MembraneReport):
                if report.params['sections'] == 'soma':
                    mod = mods.SomaReport(**report.params)

                else:
                    mod = mods.MembraneReport(**report.params)
            elif isinstance(report, reports.ClampReport):
                mod = mods.ClampReport(**report.params)

            elif isinstance(report, reports.ECPReport):
                mod = mods.EcpMod(**report.params)
                # Set up the ability for ecp on all relevant cells
                # TODO: According to spec we need to allow a different subset other than only biophysical cells
                for gid, cell in network.cell_type_maps('biophysical').items():
                    cell.setup_ecp()

            elif report.module == 'save_synapses':
                mod = mods.SaveSynapses(**report.params)

            else:
                # TODO: Allow users to register customized modules using pymodules
                io.log_warning('Unrecognized module {}, skipping.'.format(
                    report.module))
                continue

            sim.add_mod(mod)

        return sim