def save_data(sim_type, output_dir): save_file = 'expected/sim_output_{}.h5'.format(sim_type) sample_data = h5py.File(save_file, 'w') # spikes data input_spikes = SpikesFile(os.path.join(output_dir, 'spikes.h5')) spikes_df = input_spikes.to_dataframe() sample_data.create_dataset('/spikes/gids', data=np.array(spikes_df['gids'])) sample_data.create_dataset('/spikes/timestamps', data=np.array(spikes_df['timestamps'])) sample_data['/spikes/gids'].attrs['sorting'] = 'time' # soma data soma_reports = CellVarsFile(os.path.join(output_dir, 'soma_vars.h5')) soma_grp = sample_data.create_group('/soma') soma_grp.create_dataset('mapping/time', data=[2500.0, 3000.0, 0.1]) soma_grp.create_dataset('mapping/gids', data=[3]) for var_name in soma_reports.variables: soma_grp.create_dataset(var_name, data=soma_reports.data(var_name, gid=3, time_window=(2500.0, 3000.0))) # compartmental report ecp_file = h5py.File(os.path.join(output_dir, 'ecp.h5'), 'r') ecp_grp = sample_data.create_group('/ecp') ecp_grp.create_dataset('data', data=ecp_file['data'][:, 0]) ecp_grp.create_dataset('channel_id', data=ecp_file['channel_id'])
def plot_report(config_file=None, report_file=None, report_name=None, variables=None, gids=None): if report_file is None: report_name, report_file = _get_cell_report(config_file, report_name) var_report = CellVarsFile(report_file) variables = listify( variables) if variables is not None else var_report.variables gids = listify(gids) if gids is not None else var_report.gids time_steps = var_report.time_trace def __units_str(var): units = var_report.units(var) if units == CellVarsFile.UNITS_UNKNOWN: units = missing_units.get(var, '') return '({})'.format(units) if units else '' n_plots = len(variables) if n_plots > 1: # If more than one variale to plot do so in different subplots f, axarr = plt.subplots(n_plots, 1) for i, var in enumerate(variables): for gid in gids: axarr[i].plot(time_steps, var_report.data(gid=gid, var_name=var), label='gid {}'.format(gid)) axarr[i].legend() axarr[i].set_ylabel('{} {}'.format(var, __units_str(var))) if i < n_plots - 1: axarr[i].set_xticklabels([]) axarr[i].set_xlabel('time (ms)') elif n_plots == 1: # For plotting a single variable plt.figure() for gid in gids: plt.plot(time_steps, var_report.data(gid=0, var_name=variables[0]), label='gid {}'.format(gid)) plt.ylabel('{} {}'.format(variables[0], __units_str(variables[0]))) plt.xlabel('time (ms)') else: return plt.show() #for gid in gids: # plt.plot(times, var_report.data(gid=0, var_name='v'), label='gid {}'.format(gid)) '''
def get_variable_report(config_file=None, report_file=None, report_name=None, variable=None, gid=None): """Returns variable reports for specified gids Function will return the report for a specific cell's variable. """ if report_file is None: report_name, report_file = _get_cell_report(config_file, report_name) var_report = CellVarsFile(report_file) time_steps = var_report.time_trace return var_report.data(gid=gid, var_name=variable), time_steps
def plot_vars(file_names, cell_var='v', gid_list=[], t_min=None, t_max=None): """Plots variable traces for a SONATA h5 file. If multiple spike files are specified will do a side-by-side comparsion for each gid. :param file_names: list of cell_var file names :param cell_var: cell variable to plot trace :param gid_list: used to set what gid/subplots to show (if empty list just plot all possible gids) """ # convert to list if single spike file passed in file_names = [file_names] if not isinstance(file_names, (tuple, list)) else file_names assert(len(file_names) > 0) # Use bmtk to parse the cell-var files cell_var_files = [] for fn in file_names: try: cell_var_files.append(CellVarsFile(fn, h5_root="/report/internal")) except KeyError: cell_var_files.append(CellVarsFile(fn)) # get first spike file and properties cvf_base = cell_var_files[0] xlim = [t_min or cvf_base.time_trace[0], t_max or cvf_base.time_trace[-1]] # Use the same x-axis across all subplots gid_list = cvf_base.gids if not gid_list else gid_list # if gid_list is None just get all gids in first file n_cells = len(cvf_base.gids) fig, ax = plt.subplots(n_cells, 1, figsize=(10, 10)) for subplot, gid in enumerate(gid_list): for i, cvf in enumerate(cell_var_files): # plot all traces ax[subplot].plot(cvf.time_trace, cvf.data(gid, cell_var), label=file_names[i]) ax[subplot].yaxis.set_label_position("right") ax[subplot].set_ylabel('gid {}'.format(gid), fontsize='xx-small') ax[subplot].set_xlim(xlim) if subplot + 1 < n_cells: # remove x-axis labels on all but the last plot ax[subplot].set_xticklabels([]) else: # Use the last plot to get the legend handles, labels = ax[subplot].get_legend_handles_labels() fig.legend(handles, labels, loc='upper right') plt.show()
def save_data(sim_type, conn_type, output_dir): """Saves the expected results""" from bmtk import __version__ as bmtk_version from neuron import h import platform save_file = 'expected/sim_output_{}.h5'.format(sim_type) sample_data = h5py.File(save_file, 'w') root_group = sample_data['/'] root_group.attrs['bmtk'] = bmtk_version root_group.attrs['date'] = str(datetime.datetime.now()) root_group.attrs['python'] = '{}.{}'.format(*sys.version_info[0:2]) root_group.attrs['NEURON'] = h.nrnversion() root_group.attrs['system'] = platform.system() root_group.attrs['arch'] = platform.machine() # spikes data input_spikes = SpikesFile(os.path.join(output_dir, 'spikes.h5')) spikes_df = input_spikes.to_dataframe() sample_data.create_dataset('/spikes/gids', data=np.array(spikes_df['gids'])) sample_data.create_dataset('/spikes/timestamps', data=np.array(spikes_df['timestamps'])) sample_data['/spikes/gids'].attrs['sorting'] = 'time' # soma data soma_reports = CellVarsFile(os.path.join(output_dir, 'soma_vars.h5')) soma_grp = sample_data.create_group('/soma') soma_grp.create_dataset('mapping/time', data=[2500.0, 3000.0, 0.1]) soma_grp.create_dataset('mapping/gids', data=soma_reports.h5['mapping/gids']) # data=[3]) soma_grp.create_dataset('mapping/element_id', data=soma_reports.h5['mapping/element_id']) soma_grp.create_dataset('mapping/element_pos', data=soma_reports.h5['mapping/element_pos']) soma_grp.create_dataset('mapping/index_pointer', data=soma_reports.h5['mapping/index_pointer']) for var_name in soma_reports.variables: ds_name = '{}/data'.format(var_name) soma_grp.create_dataset(ds_name, data=soma_reports.h5[ds_name][-5000:, :]) # compartmental report soma_reports = CellVarsFile(os.path.join(output_dir, 'full_cell_vars.h5')) soma_grp = sample_data.create_group('/compartmental') soma_grp.create_dataset('mapping/time', data=[2500.0, 3000.0, 0.1]) soma_grp.create_dataset('mapping/gids', data=soma_reports.h5['mapping/gids']) # data=[3]) soma_grp.create_dataset('mapping/element_id', data=soma_reports.h5['mapping/element_id']) soma_grp.create_dataset('mapping/element_pos', data=soma_reports.h5['mapping/element_pos']) soma_grp.create_dataset('mapping/index_pointer', data=soma_reports.h5['mapping/index_pointer']) for var_name in soma_reports.variables: ds_name = '{}/data'.format(var_name) soma_grp.create_dataset(ds_name, data=soma_reports.h5[ds_name][-5000:, :]) # ecp data ecp_file = h5py.File(os.path.join(output_dir, 'ecp.h5'), 'r') ecp_grp = sample_data.create_group('/ecp') ecp_grp.create_dataset('data', data=ecp_file['data'][:, 0]) ecp_grp.create_dataset('channel_id', data=ecp_file['channel_id'])
def test_bionet(input_type='virt', conn_type='nsyns', capture_output=True, tol=1e-05): if MPI_rank == 0: print( 'Testing BioNet with {} inputs and {} synaptic connections (nodes: {})' .format(input_type, conn_type, MPI_size)) output_dir = 'output' if capture_output else tempfile.mkdtemp() config_base = json.load(open('config_base.json')) config_base['manifest']['$OUTPUT_DIR'] = output_dir config_base['inputs'] = get_inputs(input_type) config_base['manifest']['$NETWORK_DIR'] = os.path.join( '$BASE_DIR', get_network_path(conn_type)) conf = bionet.Config.from_dict(config_base, validate=True) conf.build_env() net = bionet.BioNetwork.from_config(conf) sim = bionet.BioSimulator.from_config(conf, net) sim.run() barrier() if MPI_rank == 0: print('Verifying output.') expected_file = get_expected_results(input_type, conn_type) # Check spikes file assert (SpikesFile(os.path.join( output_dir, 'spikes.h5')) == SpikesFile(expected_file)) # soma reports soma_report_expected = CellVarsFile(expected_file, h5_root='/soma') soma_reports = CellVarsFile(os.path.join(output_dir, 'soma_vars.h5')) t_window = soma_report_expected.t_start, soma_report_expected.t_stop assert (soma_report_expected.dt == soma_reports.dt) assert (soma_report_expected.gids == soma_reports.gids) assert (soma_report_expected.variables == soma_reports.variables) for gid in soma_report_expected.gids: assert (soma_reports.compartment_ids(gid) == soma_report_expected.compartment_ids(gid)).all() for var in soma_report_expected.variables: assert (np.allclose( soma_reports.data(gid=gid, var_name=var, time_window=t_window), soma_report_expected.data(gid=gid, var_name=var), tol)) # Compartmental reports compart_report_exp = CellVarsFile(expected_file, h5_root='/compartmental') compart_report = CellVarsFile( os.path.join(output_dir, 'full_cell_vars.h5')) t_window = compart_report_exp.t_start, compart_report_exp.t_stop assert (compart_report_exp.dt == compart_report.dt) assert (compart_report_exp.variables == compart_report.variables) for gid in compart_report_exp.gids: assert ((compart_report.compartment_ids(gid) == compart_report_exp.compartment_ids(gid)).all()) for var in compart_report_exp.variables: assert (np.allclose( compart_report.data(gid, var_name=var, time_window=t_window, compartments='all'), compart_report_exp.data(gid, var_name=var, compartments='all'), tol)) # ecp ecp_report = h5py.File(os.path.join(output_dir, 'ecp.h5'), 'r') ecp_report_exp = h5py.File(expected_file, 'r') ecp_grp_exp = ecp_report_exp['/ecp'] assert (np.allclose(np.array(ecp_report['/data'][:, 0]), np.array(ecp_grp_exp['data']), tol)) print('Success!') barrier() bionet.nrn.quit_execution()
numIND = 10 numHypo = 10 numINmplus = 10 numINmminus = 10 numPGN = 10 numFB = 10 numIMG = 10 numMPG = 10 numEUSmn = 10 numBladmn = 10 config_file = "simulation_config.json" report_name = None report_name, report_file = _get_cell_report(config_file, report_name) var_report = CellVarsFile(report_file) time_steps = var_report.time_trace # Plot spike raster --------------------------------------- Blad_gids = np.arange(0,numBladaff) EUS_gids = Blad_gids + numBladaff PAG_gids = EUS_gids + numEUSaff IND_gids = PAG_gids + numPAGaff Hypo_gids = IND_gids + numIND INmplus_gids = Hypo_gids + numHypo INmminus_gids = INmplus_gids + numINmplus PGN_gids = INmminus_gids + numINmminus FB_gids = PGN_gids + numPGN IMG_gids = FB_gids + numFB MPG_gids = IMG_gids + numIMG EUSmn_gids = MPG_gids + numMPG