Ejemplo n.º 1
0
def save_data(sim_type, output_dir):
    save_file = 'expected/sim_output_{}.h5'.format(sim_type)

    sample_data = h5py.File(save_file, 'w')

    # spikes data
    input_spikes = SpikesFile(os.path.join(output_dir, 'spikes.h5'))
    spikes_df = input_spikes.to_dataframe()
    sample_data.create_dataset('/spikes/gids',
                               data=np.array(spikes_df['gids']))
    sample_data.create_dataset('/spikes/timestamps',
                               data=np.array(spikes_df['timestamps']))
    sample_data['/spikes/gids'].attrs['sorting'] = 'time'

    # soma data
    soma_reports = CellVarsFile(os.path.join(output_dir, 'soma_vars.h5'))
    soma_grp = sample_data.create_group('/soma')
    soma_grp.create_dataset('mapping/time', data=[2500.0, 3000.0, 0.1])
    soma_grp.create_dataset('mapping/gids', data=[3])
    for var_name in soma_reports.variables:
        soma_grp.create_dataset(var_name,
                                data=soma_reports.data(var_name,
                                                       gid=3,
                                                       time_window=(2500.0,
                                                                    3000.0)))

    # compartmental report
    ecp_file = h5py.File(os.path.join(output_dir, 'ecp.h5'), 'r')
    ecp_grp = sample_data.create_group('/ecp')
    ecp_grp.create_dataset('data', data=ecp_file['data'][:, 0])
    ecp_grp.create_dataset('channel_id', data=ecp_file['channel_id'])
Ejemplo n.º 2
0
def plot_report(config_file=None,
                report_file=None,
                report_name=None,
                variables=None,
                gids=None):
    if report_file is None:
        report_name, report_file = _get_cell_report(config_file, report_name)

    var_report = CellVarsFile(report_file)
    variables = listify(
        variables) if variables is not None else var_report.variables
    gids = listify(gids) if gids is not None else var_report.gids
    time_steps = var_report.time_trace

    def __units_str(var):
        units = var_report.units(var)
        if units == CellVarsFile.UNITS_UNKNOWN:
            units = missing_units.get(var, '')
        return '({})'.format(units) if units else ''

    n_plots = len(variables)
    if n_plots > 1:
        # If more than one variale to plot do so in different subplots
        f, axarr = plt.subplots(n_plots, 1)
        for i, var in enumerate(variables):
            for gid in gids:
                axarr[i].plot(time_steps,
                              var_report.data(gid=gid, var_name=var),
                              label='gid {}'.format(gid))

            axarr[i].legend()
            axarr[i].set_ylabel('{} {}'.format(var, __units_str(var)))
            if i < n_plots - 1:
                axarr[i].set_xticklabels([])

        axarr[i].set_xlabel('time (ms)')

    elif n_plots == 1:
        # For plotting a single variable
        plt.figure()
        for gid in gids:
            plt.plot(time_steps,
                     var_report.data(gid=0, var_name=variables[0]),
                     label='gid {}'.format(gid))
        plt.ylabel('{} {}'.format(variables[0], __units_str(variables[0])))
        plt.xlabel('time (ms)')

    else:
        return

    plt.show()

    #for gid in gids:
    #    plt.plot(times, var_report.data(gid=0, var_name='v'), label='gid {}'.format(gid))
    '''
Ejemplo n.º 3
0
def get_variable_report(config_file=None, report_file=None, report_name=None, variable=None, gid=None):
    """Returns variable reports for specified gids
    Function will return the report for a specific cell's variable.
    """
    if report_file is None:
        report_name, report_file = _get_cell_report(config_file, report_name)

    var_report = CellVarsFile(report_file)
    time_steps = var_report.time_trace

    return var_report.data(gid=gid, var_name=variable), time_steps
Ejemplo n.º 4
0
def plot_vars(file_names, cell_var='v', gid_list=[], t_min=None, t_max=None):
    """Plots variable traces for a SONATA h5 file. If multiple spike files are specified will do a side-by-side
    comparsion for each gid.

    :param file_names: list of cell_var file names
    :param cell_var: cell variable to plot trace
    :param gid_list: used to set what gid/subplots to show (if empty list just plot all possible gids)
    """
    # convert to list if single spike file passed in
    file_names = [file_names] if not isinstance(file_names, (tuple, list)) else file_names
    assert(len(file_names) > 0)

    # Use bmtk to parse the cell-var files
    cell_var_files = []
    for fn in file_names:
        try:
            cell_var_files.append(CellVarsFile(fn, h5_root="/report/internal"))
        except KeyError:
            cell_var_files.append(CellVarsFile(fn))

    # get first spike file and properties
    cvf_base = cell_var_files[0]
    xlim = [t_min or cvf_base.time_trace[0], t_max or cvf_base.time_trace[-1]]  # Use the same x-axis across all subplots
    gid_list = cvf_base.gids if not gid_list else gid_list  # if gid_list is None just get all gids in first file
    n_cells = len(cvf_base.gids)

    fig, ax = plt.subplots(n_cells, 1, figsize=(10, 10))
    for subplot, gid in enumerate(gid_list):
        for i, cvf in enumerate(cell_var_files):
            # plot all traces
            ax[subplot].plot(cvf.time_trace, cvf.data(gid, cell_var), label=file_names[i])

        ax[subplot].yaxis.set_label_position("right")
        ax[subplot].set_ylabel('gid {}'.format(gid), fontsize='xx-small')
        ax[subplot].set_xlim(xlim)
        if subplot + 1 < n_cells:
            # remove x-axis labels on all but the last plot
            ax[subplot].set_xticklabels([])
        else:
            # Use the last plot to get the legend
            handles, labels = ax[subplot].get_legend_handles_labels()
            fig.legend(handles, labels, loc='upper right')

    plt.show()
Ejemplo n.º 5
0
def save_data(sim_type, conn_type, output_dir):
    """Saves the expected results"""
    from bmtk import __version__ as bmtk_version
    from neuron import h
    import platform

    save_file = 'expected/sim_output_{}.h5'.format(sim_type)

    sample_data = h5py.File(save_file, 'w')
    root_group = sample_data['/']
    root_group.attrs['bmtk'] = bmtk_version
    root_group.attrs['date'] = str(datetime.datetime.now())
    root_group.attrs['python'] = '{}.{}'.format(*sys.version_info[0:2])
    root_group.attrs['NEURON'] = h.nrnversion()
    root_group.attrs['system'] = platform.system()
    root_group.attrs['arch'] = platform.machine()

    # spikes data
    input_spikes = SpikesFile(os.path.join(output_dir, 'spikes.h5'))
    spikes_df = input_spikes.to_dataframe()
    sample_data.create_dataset('/spikes/gids',
                               data=np.array(spikes_df['gids']))
    sample_data.create_dataset('/spikes/timestamps',
                               data=np.array(spikes_df['timestamps']))
    sample_data['/spikes/gids'].attrs['sorting'] = 'time'

    # soma data
    soma_reports = CellVarsFile(os.path.join(output_dir, 'soma_vars.h5'))
    soma_grp = sample_data.create_group('/soma')
    soma_grp.create_dataset('mapping/time', data=[2500.0, 3000.0, 0.1])
    soma_grp.create_dataset('mapping/gids',
                            data=soma_reports.h5['mapping/gids'])  # data=[3])
    soma_grp.create_dataset('mapping/element_id',
                            data=soma_reports.h5['mapping/element_id'])
    soma_grp.create_dataset('mapping/element_pos',
                            data=soma_reports.h5['mapping/element_pos'])
    soma_grp.create_dataset('mapping/index_pointer',
                            data=soma_reports.h5['mapping/index_pointer'])
    for var_name in soma_reports.variables:
        ds_name = '{}/data'.format(var_name)
        soma_grp.create_dataset(ds_name,
                                data=soma_reports.h5[ds_name][-5000:, :])

    # compartmental report
    soma_reports = CellVarsFile(os.path.join(output_dir, 'full_cell_vars.h5'))
    soma_grp = sample_data.create_group('/compartmental')
    soma_grp.create_dataset('mapping/time', data=[2500.0, 3000.0, 0.1])
    soma_grp.create_dataset('mapping/gids',
                            data=soma_reports.h5['mapping/gids'])  # data=[3])
    soma_grp.create_dataset('mapping/element_id',
                            data=soma_reports.h5['mapping/element_id'])
    soma_grp.create_dataset('mapping/element_pos',
                            data=soma_reports.h5['mapping/element_pos'])
    soma_grp.create_dataset('mapping/index_pointer',
                            data=soma_reports.h5['mapping/index_pointer'])
    for var_name in soma_reports.variables:
        ds_name = '{}/data'.format(var_name)
        soma_grp.create_dataset(ds_name,
                                data=soma_reports.h5[ds_name][-5000:, :])

    # ecp data
    ecp_file = h5py.File(os.path.join(output_dir, 'ecp.h5'), 'r')
    ecp_grp = sample_data.create_group('/ecp')
    ecp_grp.create_dataset('data', data=ecp_file['data'][:, 0])
    ecp_grp.create_dataset('channel_id', data=ecp_file['channel_id'])
Ejemplo n.º 6
0
def test_bionet(input_type='virt',
                conn_type='nsyns',
                capture_output=True,
                tol=1e-05):
    if MPI_rank == 0:
        print(
            'Testing BioNet with {} inputs and {} synaptic connections (nodes: {})'
            .format(input_type, conn_type, MPI_size))

    output_dir = 'output' if capture_output else tempfile.mkdtemp()
    config_base = json.load(open('config_base.json'))
    config_base['manifest']['$OUTPUT_DIR'] = output_dir
    config_base['inputs'] = get_inputs(input_type)
    config_base['manifest']['$NETWORK_DIR'] = os.path.join(
        '$BASE_DIR', get_network_path(conn_type))

    conf = bionet.Config.from_dict(config_base, validate=True)
    conf.build_env()

    net = bionet.BioNetwork.from_config(conf)
    sim = bionet.BioSimulator.from_config(conf, net)
    sim.run()
    barrier()

    if MPI_rank == 0:
        print('Verifying output.')
        expected_file = get_expected_results(input_type, conn_type)

        # Check spikes file
        assert (SpikesFile(os.path.join(
            output_dir, 'spikes.h5')) == SpikesFile(expected_file))

        # soma reports
        soma_report_expected = CellVarsFile(expected_file, h5_root='/soma')
        soma_reports = CellVarsFile(os.path.join(output_dir, 'soma_vars.h5'))
        t_window = soma_report_expected.t_start, soma_report_expected.t_stop
        assert (soma_report_expected.dt == soma_reports.dt)
        assert (soma_report_expected.gids == soma_reports.gids)
        assert (soma_report_expected.variables == soma_reports.variables)
        for gid in soma_report_expected.gids:
            assert (soma_reports.compartment_ids(gid) ==
                    soma_report_expected.compartment_ids(gid)).all()
            for var in soma_report_expected.variables:
                assert (np.allclose(
                    soma_reports.data(gid=gid,
                                      var_name=var,
                                      time_window=t_window),
                    soma_report_expected.data(gid=gid, var_name=var), tol))

        # Compartmental reports
        compart_report_exp = CellVarsFile(expected_file,
                                          h5_root='/compartmental')
        compart_report = CellVarsFile(
            os.path.join(output_dir, 'full_cell_vars.h5'))
        t_window = compart_report_exp.t_start, compart_report_exp.t_stop
        assert (compart_report_exp.dt == compart_report.dt)
        assert (compart_report_exp.variables == compart_report.variables)
        for gid in compart_report_exp.gids:
            assert ((compart_report.compartment_ids(gid) ==
                     compart_report_exp.compartment_ids(gid)).all())
            for var in compart_report_exp.variables:
                assert (np.allclose(
                    compart_report.data(gid,
                                        var_name=var,
                                        time_window=t_window,
                                        compartments='all'),
                    compart_report_exp.data(gid,
                                            var_name=var,
                                            compartments='all'), tol))

        # ecp
        ecp_report = h5py.File(os.path.join(output_dir, 'ecp.h5'), 'r')
        ecp_report_exp = h5py.File(expected_file, 'r')
        ecp_grp_exp = ecp_report_exp['/ecp']
        assert (np.allclose(np.array(ecp_report['/data'][:, 0]),
                            np.array(ecp_grp_exp['data']), tol))
        print('Success!')

    barrier()
    bionet.nrn.quit_execution()
Ejemplo n.º 7
0
numIND      = 10
numHypo     = 10
numINmplus  = 10
numINmminus = 10
numPGN      = 10
numFB       = 10
numIMG      = 10
numMPG      = 10
numEUSmn    = 10
numBladmn   = 10

config_file = "simulation_config.json"
report_name = None
report_name, report_file = _get_cell_report(config_file, report_name)

var_report = CellVarsFile(report_file)
time_steps = var_report.time_trace

# Plot spike raster ---------------------------------------
Blad_gids = np.arange(0,numBladaff)
EUS_gids = Blad_gids + numBladaff
PAG_gids = EUS_gids + numEUSaff
IND_gids = PAG_gids + numPAGaff
Hypo_gids = IND_gids + numIND
INmplus_gids = Hypo_gids + numHypo
INmminus_gids = INmplus_gids + numINmplus
PGN_gids = INmminus_gids + numINmminus
FB_gids = PGN_gids + numPGN
IMG_gids = FB_gids + numFB
MPG_gids = IMG_gids + numIMG
EUSmn_gids = MPG_gids + numMPG