示例#1
0
    def _clone_and_simulate(self, net, trial_idx):
        # avoid relative lookups after being forked by joblib
        from hnn_core.network_builder import NetworkBuilder
        from hnn_core.network_builder import _simulate_single_trial

        # XXX this should be built into NetworkBuilder
        # update prng_seedcore params to provide jitter between trials
        for param_key in net.params['prng_*'].keys():
            net.params[param_key] += trial_idx

        neuron_net = NetworkBuilder(net)
        dpl = _simulate_single_trial(neuron_net, trial_idx)

        spikedata = neuron_net.get_data_from_neuron()

        return dpl, spikedata
示例#2
0
def test_add_cell_type():
    """Test adding a new cell type."""
    params = read_params(params_fname)
    net = jones_2009_model(params)
    # instantiate drive events for NetworkBuilder
    net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials'])

    n_total_cells = net._n_cells
    pos = [(0, idx, 0) for idx in range(10)]
    tau1 = 0.6

    new_cell = net.cell_types['L2_basket'].copy()
    net._add_cell_type('new_type', pos=pos, cell_template=new_cell)
    net.cell_types['new_type'].synapses['gabaa']['tau1'] = tau1

    n_new_type = len(net.gid_ranges['new_type'])
    assert n_new_type == len(pos)
    net.add_connection('L2_basket',
                       'new_type',
                       loc='proximal',
                       receptor='gabaa',
                       weight=8e-3,
                       delay=1,
                       lamtha=2)

    network_builder = NetworkBuilder(net)
    assert net._n_cells == n_total_cells + len(pos)
    n_basket = len(net.gid_ranges['L2_basket'])
    n_connections = n_basket * n_new_type
    assert len(network_builder.ncs['L2Basket_new_type_gabaa']) == n_connections
    nc = network_builder.ncs['L2Basket_new_type_gabaa'][0]
    assert nc.syn().tau1 == tau1
示例#3
0
def _clone_and_simulate(net, trial_idx):
    """Run a simulation including building the network

    This is used by both backends. MPIBackend calls this in mpi_child.py, once
    for each trial (blocking), and JoblibBackend calls this for each trial
    (non-blocking)
    """

    # avoid relative lookups after being forked (Joblib)
    from hnn_core.network_builder import NetworkBuilder
    from hnn_core.network_builder import _simulate_single_trial

    neuron_net = NetworkBuilder(net, trial_idx=trial_idx)
    dpl = _simulate_single_trial(neuron_net, trial_idx)

    spikedata = neuron_net.get_data_from_neuron()

    return dpl, spikedata
示例#4
0
def test_gid_assignment():
    """Test that gids are assigned without overlap across ranks"""

    net = jones_2009_model(add_drives_from_params=False)
    weights_ampa = {'L2_basket': 1.0, 'L2_pyramidal': 2.0, 'L5_pyramidal': 3.0}
    syn_delays = {'L2_basket': .1, 'L2_pyramidal': .2, 'L5_pyramidal': .3}

    net.add_bursty_drive('bursty_dist',
                         location='distal',
                         burst_rate=10,
                         weights_ampa=weights_ampa,
                         synaptic_delays=syn_delays,
                         cell_specific=False,
                         n_drive_cells=5)
    net.add_evoked_drive('evoked_prox',
                         mu=1.0,
                         sigma=1.0,
                         numspikes=1,
                         weights_ampa=weights_ampa,
                         location='proximal',
                         synaptic_delays=syn_delays,
                         cell_specific=True,
                         n_drive_cells='n_cells')
    net._instantiate_drives(tstop=20, n_trials=2)

    all_gids = list()
    for type_range in net.gid_ranges.values():
        all_gids.extend(list(type_range))
    all_gids.sort()

    n_hosts = 3
    all_gids_instantiated = list()
    for rank in range(n_hosts):
        net_builder = NetworkBuilder(net)
        net_builder._gid_list = list()
        net_builder._gid_assign(rank=rank, n_hosts=n_hosts)
        all_gids_instantiated.extend(net_builder._gid_list)
    all_gids_instantiated.sort()
    assert all_gids_instantiated == sorted(set(all_gids_instantiated))
    assert all_gids == all_gids_instantiated
示例#5
0
def test_cell():
    """Test cells object."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)

    net = Network(params)
    with NetworkBuilder(net) as neuron_net:
        neuron_net.cells[0].plot_voltage()

    # test that ExpSyn always takes nrn.Segment, not float
    soma_props = {
        "L": 22.1,
        "diam": 23.4,
        "cm": 0.6195,
        "Ra": 200.0,
        "pos": (0., 0., 0.),
        'name': 'test_cell'
    }
    cell = _Cell(gid=1, soma_props=soma_props)
    with pytest.raises(TypeError, match='secloc must be instance of'):
        cell.syn_create(0.5, e=0., tau1=0.5, tau2=5.)
示例#6
0
def test_network():
    """Test network object."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    # add rhythmic inputs (i.e., a type of common input)
    params.update({
        'input_dist_A_weight_L2Pyr_ampa': 5.4e-5,
        'input_dist_A_weight_L5Pyr_ampa': 5.4e-5,
        't0_input_dist': 50,
        'input_prox_A_weight_L2Pyr_ampa': 5.4e-5,
        'input_prox_A_weight_L5Pyr_ampa': 5.4e-5,
        't0_input_prox': 50
    })
    net = Network(deepcopy(params), add_drives_from_params=True)
    network_builder = NetworkBuilder(net)  # needed to populate net.cells

    # Assert that params are conserved across Network initialization
    for p in params:
        assert params[p] == net._params[p]
    assert len(params) == len(net._params)
    print(network_builder)
    print(network_builder.cells[:2])

    # Assert that proper number of gids are created for Network drives
    dns_from_gids = [
        name for name in net.gid_ranges.keys() if name not in net.cellname_list
    ]
    assert len(dns_from_gids) == len(net.external_drives)
    for dn in dns_from_gids:
        assert dn in net.external_drives.keys()
        this_src_gids = set([
            gid for drive_conn in net.external_drives[dn]['conn'].values()
            for gid in drive_conn['src_gids']
        ])  # NB set: globals
        assert len(net.gid_ranges[dn]) == len(this_src_gids)
        assert len(net.external_drives[dn]['events']) == 1  # single trial!

    assert len(net.gid_ranges['bursty1']) == 1
    for drive in net.external_drives.values():
        assert len(drive['events']) == 1  # single trial simulated
        if drive['type'] == 'evoked':
            for kw in ['mu', 'sigma', 'numspikes']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == net.n_cells
            # this also implicitly tests that events are always a list
            assert len(drive['events'][0][0]) == drive['dynamics']['numspikes']
        elif drive['type'] == 'gaussian':
            for kw in ['mu', 'sigma', 'numspikes']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == net.n_cells
        elif drive['type'] == 'poisson':
            for kw in ['tstart', 'tstop', 'rate_constant']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == net.n_cells

        elif drive['type'] == 'bursty':
            for kw in [
                    'distribution', 'tstart', 'tstart_std', 'tstop',
                    'burst_rate', 'burst_std', 'numspikes', 'repeats'
            ]:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == 1
            n_events = (
                drive['dynamics']['numspikes'] *  # 2
                drive['dynamics']['repeats'] *  # 10
                (1 +
                 (drive['dynamics']['tstop'] - drive['dynamics']['tstart'] - 1)
                 // (1000. / drive['dynamics']['burst_rate'])))
            assert len(drive['events'][0][0]) == n_events  # 40

    # make sure the PRNGs are consistent.
    target_times = {
        'evdist1': [66.30498327062551, 61.54362532343694],
        'evprox1': [23.80641637082997, 30.857310915553647],
        'evprox2': [141.76252038319825, 137.73942375578602]
    }
    for drive_name in target_times:
        for idx in [0, -1]:  # first and last
            assert_allclose(net.external_drives[drive_name]['events'][0][idx],
                            target_times[drive_name][idx],
                            rtol=1e-12)

    # check select AMPA weights
    target_weights = {
        'evdist1': {
            'L2_basket': 0.006562,
            'L5_pyramidal': 0.142300
        },
        'evprox1': {
            'L2_basket': 0.08831,
            'L5_pyramidal': 0.00865
        },
        'evprox2': {
            'L2_basket': 0.000003,
            'L5_pyramidal': 0.684013
        }
    }
    for drive_name in target_weights:
        for cellname in target_weights[drive_name]:
            assert_allclose(net.external_drives[drive_name]['conn'][cellname]
                            ['ampa']['A_weight'],
                            target_weights[drive_name][cellname],
                            rtol=1e-12)

    # check select synaptic delays
    target_delays = {
        'evdist1': {
            'L2_basket': 0.1,
            'L5_pyramidal': 0.1
        },
        'evprox1': {
            'L2_basket': 0.1,
            'L5_pyramidal': 1.
        },
        'evprox2': {
            'L2_basket': 0.1,
            'L5_pyramidal': 1.
        }
    }
    for drive_name in target_delays:
        for cellname in target_delays[drive_name]:
            assert_allclose(net.external_drives[drive_name]['conn'][cellname]
                            ['ampa']['A_delay'],
                            target_delays[drive_name][cellname],
                            rtol=1e-12)

    # Assert that an empty CellResponse object is created as an attribute
    assert net.cell_response == CellResponse()
    # array of simulation times is created in Network.__init__, but passed
    # to CellResponse-constructor for storage (Network is agnostic of time)
    with pytest.raises(TypeError,
                       match="'times' is an np.ndarray of simulation times"):
        _ = CellResponse(times=[1, 2, 3])

    # Assert that all external drives are initialized
    n_evoked_sources = net.n_cells * 3
    n_pois_sources = net.n_cells
    n_gaus_sources = net.n_cells
    n_bursty_sources = 2

    # test that expected number of external driving events are created
    assert len(
        network_builder._drive_cells) == (n_evoked_sources + n_pois_sources +
                                          n_gaus_sources + n_bursty_sources)
    assert len(network_builder._gid_list) ==\
        len(network_builder._drive_cells) + net.n_cells
    # first 'evoked drive' comes after real cells and common inputs
    assert network_builder._drive_cells[2].gid ==\
        net.n_cells + n_bursty_sources

    # Assert that netcons are created properly
    # proximal
    assert 'L2Pyr_L2Pyr_nmda' in network_builder.ncs
    n_pyr = len(net.gid_ranges['L2_pyramidal'])
    n_connections = 3 * (n_pyr**2 - n_pyr)  # 3 synapses / cell
    assert len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) == n_connections
    nc = network_builder.ncs['L2Pyr_L2Pyr_nmda'][0]
    assert nc.threshold == params['threshold']

    # create a new connection between cell types
    net = Network(deepcopy(params), add_drives_from_params=True)
    nc_dict = {'A_delay': 1, 'A_weight': 1e-5, 'lamtha': 20, 'threshold': 0.5}
    net._all_to_all_connect('bursty1',
                            'L5_basket',
                            'soma',
                            'gabaa',
                            nc_dict,
                            unique=False)
    network_builder = NetworkBuilder(net)
    assert 'bursty1_L5Basket_gabaa' in network_builder.ncs
    n_conn = len(net.gid_ranges['bursty1']) * len(net.gid_ranges['L5_basket'])
    assert len(network_builder.ncs['bursty1_L5Basket_gabaa']) == n_conn

    # try unique=True
    net = Network(deepcopy(params), add_drives_from_params=True)
    net._all_to_all_connect('extgauss',
                            'L5_basket',
                            'soma',
                            'gabaa',
                            nc_dict,
                            unique=True)
    network_builder = NetworkBuilder(net)
    n_conn = len(net.gid_ranges['L5_basket'])
    assert len(network_builder.ncs['extgauss_L5Basket_gabaa']) == n_conn

    # Test inputs for connectivity API
    net = Network(deepcopy(params), add_drives_from_params=True)
    n_conn = len(network_builder.ncs['L2Basket_L2Pyr_gabaa'])
    kwargs_default = dict(src_gids=[0, 1],
                          target_gids=[35, 36],
                          loc='soma',
                          receptor='gabaa',
                          weight=5e-4,
                          delay=1.0,
                          lamtha=1e9)
    net.add_connection(**kwargs_default)  # smoke test
    network_builder = NetworkBuilder(net)
    assert len(network_builder.ncs['L2Basket_L2Pyr_gabaa']) == n_conn + 4
    nc = network_builder.ncs['L2Basket_L2Pyr_gabaa'][-1]
    assert_allclose(nc.weight[0], kwargs_default['weight'])

    kwargs_good = [('src_gids', 0), ('src_gids', 'L2_pyramidal'),
                   ('src_gids', range(2)), ('target_gids', 35),
                   ('target_gids', range(2)), ('target_gids', 'L2_pyramidal'),
                   ('target_gids', [[35, 36], [37, 38]])]
    for arg, item in kwargs_good:
        kwargs = kwargs_default.copy()
        kwargs[arg] = item
        net.add_connection(**kwargs)

    kwargs_bad = [('src_gids', 0.0), ('src_gids', [0.0]),
                  ('target_gids', 35.0), ('target_gids', [35.0]),
                  ('target_gids', [[35], [36.0]]), ('loc', 1.0),
                  ('receptor', 1.0), ('weight', '1.0'), ('delay', '1.0'),
                  ('lamtha', '1.0')]
    for arg, item in kwargs_bad:
        match = ('must be an instance of')
        with pytest.raises(TypeError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            net.add_connection(**kwargs)

    kwargs_bad = [('src_gids', -1), ('src_gids', [-1]), ('target_gids', -1),
                  ('target_gids', [-1]), ('target_gids', [[35], [-1]]),
                  ('target_gids', [[35]]), ('src_gids', [0, 100]),
                  ('target_gids', [0, 100])]
    for arg, item in kwargs_bad:
        with pytest.raises(AssertionError):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            net.add_connection(**kwargs)

    for arg in ['src_gids', 'target_gids', 'loc', 'receptor']:
        string_arg = 'invalid_string'
        match = f"Invalid value for the '{arg}' parameter"
        with pytest.raises(ValueError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = string_arg
            net.add_connection(**kwargs)

    net.clear_connectivity()
    assert len(net.connectivity) == 0
示例#7
0
def test_network_models():
    """"Test instantiations of the network object"""
    # Make sure critical biophysics for Law model are updated
    net_law = law_2021_model()
    # instantiate drive events for NetworkBuilder
    net_law._instantiate_drives(tstop=net_law._params['tstop'],
                                n_trials=net_law._params['N_trials'])

    for cell_name in ['L5_pyramidal', 'L2_pyramidal']:
        assert net_law.cell_types[cell_name].synapses['gabab']['tau1'] == 45.0
        assert net_law.cell_types[cell_name].synapses['gabab']['tau2'] == 200.0

    # Check add_default_erp()
    net_default = jones_2009_model()
    with pytest.raises(TypeError, match='net must be'):
        add_erp_drives_to_jones_model(net='invalid_input')
    with pytest.raises(TypeError, match='tstart must be'):
        add_erp_drives_to_jones_model(net=net_default, tstart='invalid_input')
    n_conn = len(net_default.connectivity)
    add_erp_drives_to_jones_model(net_default)
    for drive_name in ['evdist1', 'evprox1', 'evprox2']:
        assert drive_name in net_default.external_drives.keys()
    # 15 drive connections are added as follows: evdist1: 3 ampa + 3 nmda,
    # evprox1: 4 ampa, evprox2: 4 ampa, and 1 extra zero-weighted ampa
    # evdist1->L5_basket connection is added to comply with legacy_mode
    assert len(net_default.connectivity) == n_conn + 15

    # Ensure distant dependent calcium gbar
    net_calcium = calcium_model()
    # instantiate drive events for NetworkBuilder
    net_calcium._instantiate_drives(tstop=net_calcium._params['tstop'],
                                    n_trials=net_calcium._params['N_trials'])
    network_builder = NetworkBuilder(net_calcium)
    gid = net_calcium.gid_ranges['L5_pyramidal'][0]
    for section_name, section in \
            network_builder._cells[gid]._nrn_sections.items():
        # Section endpoints where seg.x == 0.0 or 1.0 don't have 'ca' mech
        ca_gbar = [
            seg.__getattribute__('ca').gbar
            for seg in list(section.allseg())[1:-1]
        ]
        na_gbar = [
            seg.__getattribute__('hh2').gnabar
            for seg in list(section.allseg())[1:-1]
        ]
        k_gbar = [
            seg.__getattribute__('hh2').gkbar
            for seg in list(section.allseg())[1:-1]
        ]

        # Ensure positive distance dependent calcium gbar with plateau
        if section_name == 'apical_tuft':
            assert np.all(np.diff(ca_gbar) == 0)
        else:
            assert np.all(np.diff(ca_gbar) > 0)

        # Ensure negative distance dependent sodium gbar with plateau
        if section_name == 'apical_2':
            assert np.all(np.diff(na_gbar[0:3]) < 0)
            assert np.all(np.diff(na_gbar[3:]) == 0)
        elif section_name == 'apical_tuft':
            assert np.all(np.diff(na_gbar) == 0)
        else:
            assert np.all(np.diff(na_gbar) < 0)

        # Ensure negative exponential distance dependent K gbar
        assert np.all(np.diff(k_gbar) < 0)
        assert np.all(np.diff(k_gbar, n=2) > 0)  # positive 2nd derivative
示例#8
0
def test_network():
    """Test network object."""
    params = read_params(params_fname)
    # add rhythmic inputs (i.e., a type of common input)
    params.update({
        'input_dist_A_weight_L2Pyr_ampa': 1.4e-5,
        'input_dist_A_weight_L5Pyr_ampa': 2.4e-5,
        't0_input_dist': 50,
        'input_prox_A_weight_L2Pyr_ampa': 3.4e-5,
        'input_prox_A_weight_L5Pyr_ampa': 4.4e-5,
        't0_input_prox': 50
    })

    net = jones_2009_model(deepcopy(params), add_drives_from_params=True)
    # instantiate drive events for NetworkBuilder
    net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials'])
    network_builder = NetworkBuilder(net)  # needed to instantiate cells

    # Assert that params are conserved across Network initialization
    for p in params:
        assert params[p] == net._params[p]
    assert len(params) == len(net._params)
    print(network_builder)
    print(network_builder._cells[:2])

    # Assert that proper number/types of gids are created for Network drives
    dns_from_gids = [
        name for name in net.gid_ranges.keys() if name not in net.cell_types
    ]
    assert sorted(dns_from_gids) == sorted(net.external_drives.keys())
    for dn in dns_from_gids:
        n_drive_cells = net.external_drives[dn]['n_drive_cells']
        assert len(net.gid_ranges[dn]) == n_drive_cells

    # Check drive dict structure for each external drive
    for drive in net.external_drives.values():
        # Check that connectivity sources correspond to gid_ranges
        conn_idxs = pick_connection(net, src_gids=drive['name'])
        this_src_gids = set([
            gid for conn_idx in conn_idxs
            for gid in net.connectivity[conn_idx]['src_gids']
        ])  # NB set: globals
        assert sorted(this_src_gids) == list(net.gid_ranges[drive['name']])
        # Check type-specific dynamics and events
        n_drive_cells = drive['n_drive_cells']
        assert len(drive['events']) == 1  # single trial simulated
        if drive['type'] == 'evoked':
            for kw in ['mu', 'sigma', 'numspikes']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
            # this also implicitly tests that events are always a list
            assert len(drive['events'][0][0]) == drive['dynamics']['numspikes']
        elif drive['type'] == 'gaussian':
            for kw in ['mu', 'sigma', 'numspikes']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
        elif drive['type'] == 'poisson':
            for kw in ['tstart', 'tstop', 'rate_constant']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
        elif drive['type'] == 'bursty':
            for kw in [
                    'tstart', 'tstart_std', 'tstop', 'burst_rate', 'burst_std',
                    'numspikes'
            ]:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
            n_events = (
                drive['dynamics']['numspikes'] *  # 2
                (1 +
                 (drive['dynamics']['tstop'] - drive['dynamics']['tstart'] - 1)
                 // (1000. / drive['dynamics']['burst_rate'])))
            assert len(drive['events'][0][0]) == n_events  # 4

    # make sure the PRNGs are consistent.
    target_times = {
        'evdist1': [66.30498327062551, 61.54362532343694],
        'evprox1': [23.80641637082997, 30.857310915553647],
        'evprox2': [141.76252038319825, 137.73942375578602]
    }
    for drive_name in target_times:
        for idx in [0, -1]:  # first and last
            assert_allclose(net.external_drives[drive_name]['events'][0][idx],
                            target_times[drive_name][idx],
                            rtol=1e-12)

    # check select AMPA weights
    target_weights = {
        'evdist1': {
            'L2_basket': 0.006562,
            'L5_pyramidal': 0.142300
        },
        'evprox1': {
            'L2_basket': 0.08831,
            'L5_pyramidal': 0.00865
        },
        'evprox2': {
            'L2_basket': 0.000003,
            'L5_pyramidal': 0.684013
        },
        'bursty1': {
            'L2_pyramidal': 0.000034,
            'L5_pyramidal': 0.000044
        },
        'bursty2': {
            'L2_pyramidal': 0.000014,
            'L5_pyramidal': 0.000024
        }
    }
    for drive_name in target_weights:
        for target_type in target_weights[drive_name]:
            conn_idxs = pick_connection(net,
                                        src_gids=drive_name,
                                        target_gids=target_type,
                                        receptor='ampa')
            for conn_idx in conn_idxs:
                drive_conn = net.connectivity[conn_idx]
                assert_allclose(drive_conn['nc_dict']['A_weight'],
                                target_weights[drive_name][target_type],
                                rtol=1e-12)

    # check select synaptic delays
    target_delays = {
        'evdist1': {
            'L2_basket': 0.1,
            'L5_pyramidal': 0.1
        },
        'evprox1': {
            'L2_basket': 0.1,
            'L5_pyramidal': 1.
        },
        'evprox2': {
            'L2_basket': 0.1,
            'L5_pyramidal': 1.
        }
    }
    for drive_name in target_delays:
        for target_type in target_delays[drive_name]:
            conn_idxs = pick_connection(net,
                                        src_gids=drive_name,
                                        target_gids=target_type,
                                        receptor='ampa')
            for conn_idx in conn_idxs:
                drive_conn = net.connectivity[conn_idx]
                assert_allclose(drive_conn['nc_dict']['A_delay'],
                                target_delays[drive_name][target_type],
                                rtol=1e-12)

    # array of simulation times is created in Network.__init__, but passed
    # to CellResponse-constructor for storage (Network is agnostic of time)
    with pytest.raises(TypeError,
                       match="'times' is an np.ndarray of simulation times"):
        _ = CellResponse(times='blah')

    # Assert that all external drives are initialized
    # Assumes legacy mode where cell-specific drives create artificial cells
    # for all network cells regardless of connectivity
    n_evoked_sources = 3 * net._n_cells
    n_pois_sources = net._n_cells
    n_gaus_sources = net._n_cells
    n_bursty_sources = (net.external_drives['bursty1']['n_drive_cells'] +
                        net.external_drives['bursty2']['n_drive_cells'])
    # test that expected number of external driving events are created
    assert len(
        network_builder._drive_cells) == (n_evoked_sources + n_pois_sources +
                                          n_gaus_sources + n_bursty_sources)
    assert len(network_builder._gid_list) ==\
        len(network_builder._drive_cells) + net._n_cells
    # first 'evoked drive' comes after real cells and bursty drive cells
    assert network_builder._drive_cells[n_bursty_sources].gid ==\
        net._n_cells + n_bursty_sources

    # Assert that netcons are created properly
    n_pyr = len(net.gid_ranges['L2_pyramidal'])
    n_basket = len(net.gid_ranges['L2_basket'])

    # Check basket-basket connection where allow_autapses=False
    assert 'L2Pyr_L2Pyr_nmda' in network_builder.ncs
    n_connections = 3 * (n_pyr**2 - n_pyr)  # 3 synapses / cell
    assert len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) == n_connections
    nc = network_builder.ncs['L2Pyr_L2Pyr_nmda'][0]
    assert nc.threshold == params['threshold']

    # Check bursty drives which use cell_specific=False
    assert 'bursty1_L2Pyr_ampa' in network_builder.ncs
    n_bursty1_sources = net.external_drives['bursty1']['n_drive_cells']
    n_connections = n_bursty1_sources * 3 * n_pyr  # 3 synapses / cell
    assert len(network_builder.ncs['bursty1_L2Pyr_ampa']) == n_connections
    nc = network_builder.ncs['bursty1_L2Pyr_ampa'][0]
    assert nc.threshold == params['threshold']

    # Check basket-basket connection where allow_autapses=True
    assert 'L2Basket_L2Basket_gabaa' in network_builder.ncs
    n_connections = n_basket**2  # 1 synapse / cell
    assert len(network_builder.ncs['L2Basket_L2Basket_gabaa']) == n_connections
    nc = network_builder.ncs['L2Basket_L2Basket_gabaa'][0]
    assert nc.threshold == params['threshold']

    # Check evoked drives which use cell_specific=True
    assert 'evdist1_L2Basket_nmda' in network_builder.ncs
    n_connections = n_basket  # 1 synapse / cell
    assert len(network_builder.ncs['evdist1_L2Basket_nmda']) == n_connections
    nc = network_builder.ncs['evdist1_L2Basket_nmda'][0]
    assert nc.threshold == params['threshold']

    # Test inputs for connectivity API
    net = jones_2009_model(deepcopy(params), add_drives_from_params=True)
    # instantiate drive events for NetworkBuilder
    net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials'])
    n_conn = len(network_builder.ncs['L2Basket_L2Pyr_gabaa'])
    kwargs_default = dict(src_gids=[0, 1],
                          target_gids=[35, 36],
                          loc='soma',
                          receptor='gabaa',
                          weight=5e-4,
                          delay=1.0,
                          lamtha=1e9,
                          probability=1.0)
    net.add_connection(**kwargs_default)  # smoke test
    network_builder = NetworkBuilder(net)
    assert len(network_builder.ncs['L2Basket_L2Pyr_gabaa']) == n_conn + 4
    nc = network_builder.ncs['L2Basket_L2Pyr_gabaa'][-1]
    assert_allclose(nc.weight[0], kwargs_default['weight'])

    kwargs_good = [('src_gids', 0), ('src_gids', 'L2_pyramidal'),
                   ('src_gids', range(2)), ('target_gids', 35),
                   ('target_gids', range(2)), ('target_gids', 'L2_pyramidal'),
                   ('target_gids', [[35, 36], [37, 38]]), ('probability', 0.5)]
    for arg, item in kwargs_good:
        kwargs = kwargs_default.copy()
        kwargs[arg] = item
        net.add_connection(**kwargs)

    kwargs_bad = [('src_gids', 0.0), ('src_gids', [0.0]),
                  ('target_gids', 35.0), ('target_gids', [35.0]),
                  ('target_gids', [[35], [36.0]]), ('loc', 1.0),
                  ('receptor', 1.0), ('weight', '1.0'), ('delay', '1.0'),
                  ('lamtha', '1.0'), ('probability', '0.5')]
    for arg, item in kwargs_bad:
        match = ('must be an instance of')
        with pytest.raises(TypeError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            net.add_connection(**kwargs)

    kwargs_bad = [('src_gids', -1), ('src_gids', [-1]), ('target_gids', -1),
                  ('target_gids', [-1]), ('target_gids', [[35], [-1]]),
                  ('target_gids', [[35]]), ('src_gids', [0, 100]),
                  ('target_gids', [0, 100])]
    for arg, item in kwargs_bad:
        with pytest.raises(AssertionError):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            net.add_connection(**kwargs)

    for arg in ['src_gids', 'target_gids', 'loc', 'receptor']:
        string_arg = 'invalid_string'
        match = f"Invalid value for the '{arg}' parameter"
        with pytest.raises(ValueError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = string_arg
            net.add_connection(**kwargs)

    # Check probability=0.5 produces half as many connections as default
    net.add_connection(**kwargs_default)
    kwargs = kwargs_default.copy()
    kwargs['probability'] = 0.5
    net.add_connection(**kwargs)
    n_connections = np.sum(
        [len(t_gids) for t_gids in net.connectivity[-2]['gid_pairs'].values()])
    n_connections_new = np.sum(
        [len(t_gids) for t_gids in net.connectivity[-1]['gid_pairs'].values()])
    assert n_connections_new == np.round(n_connections * 0.5).astype(int)
    assert net.connectivity[-1]['probability'] == 0.5
    with pytest.raises(ValueError, match='probability must be'):
        kwargs = kwargs_default.copy()
        kwargs['probability'] = -1.0
        net.add_connection(**kwargs)

    # Test net.pick_connection()
    kwargs_default = dict(net=net,
                          src_gids=None,
                          target_gids=None,
                          loc=None,
                          receptor=None)

    kwargs_good = [('src_gids', 0), ('src_gids', 'L2_pyramidal'),
                   ('src_gids', range(2)), ('src_gids', None),
                   ('target_gids', 35), ('target_gids', range(2)),
                   ('target_gids', 'L2_pyramidal'), ('target_gids', None),
                   ('loc', 'soma'), ('loc', None), ('receptor', 'gabaa'),
                   ('receptor', None)]
    for arg, item in kwargs_good:
        kwargs = kwargs_default.copy()
        kwargs[arg] = item
        indices = pick_connection(**kwargs)
        for conn_idx in indices:
            if (arg == 'src_gids' or arg == 'target_gids') and \
                    isinstance(item, str):
                assert np.all(
                    np.in1d(net.connectivity[conn_idx][arg],
                            net.gid_ranges[item]))
            elif item is None:
                pass
            else:
                assert np.any(np.in1d([item], net.connectivity[conn_idx][arg]))

    # Check that a given gid isn't present in any connection profile that
    # pick_connection can't identify
    conn_idxs = pick_connection(net, src_gids=0)
    for conn_idx in range(len(net.connectivity)):
        if conn_idx not in conn_idxs:
            assert 0 not in net.connectivity[conn_idx]['src_gids']

    # Check that pick_connection returns empty lists when searching for
    # a drive targetting the wrong location
    conn_idxs = pick_connection(net, src_gids='evdist1', loc='proximal')
    assert len(conn_idxs) == 0
    assert not pick_connection(net, src_gids='evprox1', loc='distal')

    # Check condition where not connections match
    assert pick_connection(net, loc='distal', receptor='gabab') == list()

    kwargs_bad = [('src_gids', 0.0),
                  ('src_gids', [0.0]), ('target_gids', 35.0),
                  ('target_gids', [35.0]), ('target_gids', [35, [36.0]]),
                  ('loc', 1.0), ('receptor', 1.0)]
    for arg, item in kwargs_bad:
        match = ('must be an instance of')
        with pytest.raises(TypeError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            pick_connection(**kwargs)

    kwargs_bad = [('src_gids', -1), ('src_gids', [-1]), ('target_gids', -1),
                  ('target_gids', [-1]), ('src_gids', [35, -1]),
                  ('target_gids', [35, -1])]
    for arg, item in kwargs_bad:
        with pytest.raises(AssertionError):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            pick_connection(**kwargs)

    for arg in ['src_gids', 'target_gids', 'loc', 'receptor']:
        string_arg = 'invalid_string'
        match = f"Invalid value for the '{arg}' parameter"
        with pytest.raises(ValueError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = string_arg
            pick_connection(**kwargs)

    # Test removing connections from net.connectivity
    # Needs to be updated if number of drives change in preceeding tests
    net.clear_connectivity()
    assert len(net.connectivity) == 50
    net.clear_drives()
    assert len(net.connectivity) == 0
示例#9
0
def run_mpi_simulation():
    from mpi4py import MPI

    import pickle
    import codecs
    import os
    import io

    # suppress output to stderr
    stderr_fileno = sys.stderr.fileno()
    null_fd = os.open(os.devnull, os.O_RDWR)
    old_err_fd = os.dup(stderr_fileno)
    os.dup2(null_fd, stderr_fileno)

    # temporarily use a StringIO object to capture stderr
    str_err = io.StringIO()
    sys.stderr = str_err

    from hnn_core import Network
    from hnn_core.network_builder import NetworkBuilder, _simulate_single_trial

    # using template for reading stdin from:
    # https://github.com/cloudpipe/cloudpickle/blob/master/tests/testutils.py

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    # get parameters from stdin
    if rank == 0:
        stream_in = sys.stdin
        # Force the use of bytes streams under Python 3
        if hasattr(sys.stdin, 'buffer'):
            stream_in = sys.stdin.buffer
        input_bytes = _read_all_bytes(stream_in)
        stream_in.close()

        params = pickle.loads(codecs.decode(input_bytes, "base64"))
    else:
        params = None

    params = comm.bcast(params, root=0)
    net = Network(params)
    # XXX store the initial prng_seedcore params to be referenced in each trial
    prng_seedcore_initial = net.params['prng_*'].copy()

    sim_data = []
    for trial_idx in range(params['N_trials']):
        # XXX this should be built into NetworkBuilder
        # update prng_seedcore params to provide jitter between trials
        for param_key in prng_seedcore_initial.keys():
            net.params[param_key] = (prng_seedcore_initial[param_key] +
                                     trial_idx)
        neuron_net = NetworkBuilder(net)
        dpl = _simulate_single_trial(neuron_net, trial_idx)
        if rank == 0:
            spikedata = neuron_net.get_data_from_neuron()
            sim_data.append((dpl, spikedata))

    # send results to stderr
    if rank == 0:
        # send back dpls and spikedata
        pickled_string = pickle.dumps(sim_data)

        # pad data before encoding, always add at least 4 "=" to mark end
        padding = len(pickled_string) % 4
        pickled_string += b"=" * padding
        pickled_string += b"=" * 4

        # encode as base64 before sending to stderr
        repickled_bytes = codecs.encode(pickled_string,
                                        'base64')

        data_iostream = io.BytesIO()

        # Force the use of bytes streams under Python 3
        if hasattr(data_iostream, 'buffer'):
            data_iostream = data_iostream.buffer

        data_iostream.write(repickled_bytes)

    # flush anything in stderr (still points to str_err) to stdout
    sys.stderr.flush()
    sys.stdout.write(sys.stderr.getvalue())

    # restore the old stderr
    os.dup2(old_err_fd, stderr_fileno)
    sys.stderr = open(old_err_fd, 'w')
    os.close(null_fd)

    if rank == 0:
        data_str = data_iostream.getvalue().decode()
        sys.stderr.write(data_str)

    # close the StringIO object
    str_err.close()

    MPI.Finalize()
    return 0
示例#10
0
def test_network():
    """Test network object."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    # add rhythmic inputs (i.e., a type of common input)
    params.update({
        'input_dist_A_weight_L2Pyr_ampa': 5.4e-5,
        'input_dist_A_weight_L5Pyr_ampa': 5.4e-5,
        't0_input_dist': 50,
        'input_prox_A_weight_L2Pyr_ampa': 5.4e-5,
        'input_prox_A_weight_L5Pyr_ampa': 5.4e-5,
        't0_input_prox': 50
    })
    net = Network(deepcopy(params))
    network_builder = NetworkBuilder(net)  # needed to populate net.cells

    # Assert that params are conserved across Network initialization
    for p in params:
        assert params[p] == net.params[p]
    assert len(params) == len(net.params)
    print(network_builder)
    print(network_builder.cells[:2])

    # Assert that proper number of gids are created for Network inputs
    assert len(net.gid_dict['common']) == 2
    assert len(net.gid_dict['extgauss']) == net.n_cells
    assert len(net.gid_dict['extpois']) == net.n_cells
    for ev_input in params['t_ev*']:
        type_key = ev_input[2:-2] + ev_input[-1]
        assert len(net.gid_dict[type_key]) == net.n_cells

    # Assert that an empty Spikes object is created as an attribute
    assert net.spikes == Spikes()

    # Assert that all external feeds are initialized
    n_evoked_sources = 270 * 3
    n_pois_sources = 270
    n_gaus_sources = 270
    n_common_sources = 2
    assert len(
        network_builder._feed_cells) == (n_evoked_sources + n_pois_sources +
                                         n_gaus_sources + n_common_sources)

    # Assert that netcons are created properly
    # proximal
    assert 'L2Pyr_L2Pyr_nmda' in network_builder.ncs
    n_pyr = len(net.gid_dict['L2_pyramidal'])
    n_connections = 3 * (n_pyr**2 - n_pyr)  # 3 synapses / cell
    assert len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) == n_connections
    nc = network_builder.ncs['L2Pyr_L2Pyr_nmda'][0]
    assert nc.threshold == params['threshold']

    # create a new connection between cell types
    nc_dict = {'A_delay': 1, 'A_weight': 1e-5, 'lamtha': 20, 'threshold': 0.5}
    network_builder._connect_celltypes('common',
                                       'L5Basket',
                                       'soma',
                                       'gabaa',
                                       nc_dict,
                                       unique=False)
    assert 'common_L5Basket_gabaa' in network_builder.ncs
    n_conn = len(net.gid_dict['common']) * len(net.gid_dict['L5_basket'])
    assert len(network_builder.ncs['common_L5Basket_gabaa']) == n_conn

    # try unique=True
    network_builder._connect_celltypes('extgauss',
                                       'L5Basket',
                                       'soma',
                                       'gabaa',
                                       nc_dict,
                                       unique=True)
    n_conn = len(net.gid_dict['L5_basket'])
    assert len(network_builder.ncs['extgauss_L5Basket_gabaa']) == n_conn
示例#11
0
def test_network():
    """Test network object."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    # add rhythmic inputs (i.e., a type of common input)
    params.update({
        'input_dist_A_weight_L2Pyr_ampa': 5.4e-5,
        'input_dist_A_weight_L5Pyr_ampa': 5.4e-5,
        't0_input_dist': 50,
        'input_prox_A_weight_L2Pyr_ampa': 5.4e-5,
        'input_prox_A_weight_L5Pyr_ampa': 5.4e-5,
        't0_input_prox': 50
    })
    net = Network(deepcopy(params))
    network_builder = NetworkBuilder(net)  # needed to populate net.cells

    # Assert that params are conserved across Network initialization
    for p in params:
        assert params[p] == net.params[p]
    assert len(params) == len(net.params)
    print(network_builder)
    print(network_builder.cells[:2])

    # Assert that proper number of gids are created for Network inputs
    assert len(net.gid_ranges['common']) == 2
    assert len(net.gid_ranges['extgauss']) == net.n_cells
    assert len(net.gid_ranges['extpois']) == net.n_cells
    for ev_input in params['t_ev*']:
        type_key = ev_input[2:-2] + ev_input[-1]
        assert len(net.gid_ranges[type_key]) == net.n_cells

    # Assert that an empty CellResponse object is created as an attribute
    assert net.cell_response == CellResponse()
    # array of simulation times is created in Network.__init__, but passed
    # to CellResponse-constructor for storage (Network is agnostic of time)
    with pytest.raises(TypeError,
                       match="'times' is an np.ndarray of simulation times"):
        _ = CellResponse(times=[1, 2, 3])

    # Assert that all external feeds are initialized
    n_evoked_sources = net.n_cells * 3
    n_pois_sources = net.n_cells
    n_gaus_sources = net.n_cells
    n_common_sources = 2

    # test that expected number of external driving events are created, and
    # make sure the PRNGs are consistent.
    assert isinstance(net.feed_times, dict)
    # single trial simulated
    assert all(
        len(src_feed_times) == 1
        for src_type, src_feed_times in net.feed_times.items()
        if src_type != 'tonic')
    assert len(net.feed_times['common'][0]) == n_common_sources
    assert len(net.feed_times['common'][0][0]) == 40  # 40 spikes
    assert isinstance(net.feed_times['evprox1'][0][0], list)
    assert len(net.feed_times['evprox1'][0]) == net.n_cells
    assert_allclose(net.feed_times['evprox1'][0][0], [23.80641637082997],
                    rtol=1e-12)

    assert len(
        network_builder._feed_cells) == (n_evoked_sources + n_pois_sources +
                                         n_gaus_sources + n_common_sources)
    assert len(network_builder._gid_list) ==\
        len(network_builder._feed_cells) + net.n_cells
    # first 'evoked feed' comes after real cells and common inputs
    assert network_builder._feed_cells[2].gid == net.n_cells + n_common_sources

    # Assert that netcons are created properly
    # proximal
    assert 'L2Pyr_L2Pyr_nmda' in network_builder.ncs
    n_pyr = len(net.gid_ranges['L2_pyramidal'])
    n_connections = 3 * (n_pyr**2 - n_pyr)  # 3 synapses / cell
    assert len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) == n_connections
    nc = network_builder.ncs['L2Pyr_L2Pyr_nmda'][0]
    assert nc.threshold == params['threshold']

    # create a new connection between cell types
    nc_dict = {'A_delay': 1, 'A_weight': 1e-5, 'lamtha': 20, 'threshold': 0.5}
    network_builder._connect_celltypes('common',
                                       'L5Basket',
                                       'soma',
                                       'gabaa',
                                       nc_dict,
                                       unique=False)
    assert 'common_L5Basket_gabaa' in network_builder.ncs
    n_conn = len(net.gid_ranges['common']) * len(net.gid_ranges['L5_basket'])
    assert len(network_builder.ncs['common_L5Basket_gabaa']) == n_conn

    # try unique=True
    network_builder._connect_celltypes('extgauss',
                                       'L5Basket',
                                       'soma',
                                       'gabaa',
                                       nc_dict,
                                       unique=True)
    n_conn = len(net.gid_ranges['L5_basket'])
    assert len(network_builder.ncs['extgauss_L5Basket_gabaa']) == n_conn
示例#12
0
from hnn_core import read_params, Network
from hnn_core.network_builder import NetworkBuilder

hnn_core_root = op.dirname(hnn_core.__file__)

###############################################################################
# Then we read the parameters file
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

###############################################################################
# Now let's build the network
import matplotlib.pyplot as plt

net = Network(params)
with NetworkBuilder(net) as network_builder:
    network_builder.cells[0].plot_voltage()

    # The cells are stored in the network object as a list
    cells = network_builder.cells
    print(cells[:5])

    # We have different kinds of cells with different cell IDs (gids)
    gids = [0, 35, 135, 170]
    for gid in gids:
        print(cells[gid].name)

    # We can plot the firing pattern of individual cells
    network_builder.cells[0].plot_voltage()
    plt.title('%s (gid=%d)' % (cells[0].name, gid))