def get_network(probability=1.0):
    net = jones_2009_model(params, add_drives_from_params=True)
    net.clear_connectivity()

    # Pyramidal cell connections
    location, receptor = 'distal', 'ampa'
    weight, delay, lamtha = 1.0, 1.0, 70
    src = 'L5_pyramidal'
    for target in ['L5_pyramidal', 'L2_basket']:
        net.add_connection(src,
                           target,
                           location,
                           receptor,
                           delay,
                           weight,
                           lamtha,
                           probability=probability)

    # Basket cell connections
    location, receptor = 'soma', 'gabaa'
    weight, delay, lamtha = 1.0, 1.0, 70
    src = 'L2_basket'
    for target in ['L5_pyramidal', 'L2_basket']:
        net.add_connection(src,
                           target,
                           location,
                           receptor,
                           delay,
                           weight,
                           lamtha,
                           probability=probability)
    return net
Exemple #2
0
def test_add_cell_type():
    """Test adding a new cell type."""
    params = read_params(params_fname)
    net = jones_2009_model(params)
    # instantiate drive events for NetworkBuilder
    net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials'])

    n_total_cells = net._n_cells
    pos = [(0, idx, 0) for idx in range(10)]
    tau1 = 0.6

    new_cell = net.cell_types['L2_basket'].copy()
    net._add_cell_type('new_type', pos=pos, cell_template=new_cell)
    net.cell_types['new_type'].synapses['gabaa']['tau1'] = tau1

    n_new_type = len(net.gid_ranges['new_type'])
    assert n_new_type == len(pos)
    net.add_connection('L2_basket',
                       'new_type',
                       loc='proximal',
                       receptor='gabaa',
                       weight=8e-3,
                       delay=1,
                       lamtha=2)

    network_builder = NetworkBuilder(net)
    assert net._n_cells == n_total_cells + len(pos)
    n_basket = len(net.gid_ranges['L2_basket'])
    n_connections = n_basket * n_new_type
    assert len(network_builder.ncs['L2Basket_new_type_gabaa']) == n_connections
    nc = network_builder.ncs['L2Basket_new_type_gabaa'][0]
    assert nc.syn().tau1 == tau1
def test_extracellular_api():
    """Test extracellular recording API."""
    net = jones_2009_model(deepcopy(params), add_drives_from_params=True)

    # Test LFP electrodes
    electrode_pos = (1, 2, 3)
    net.add_electrode_array('el1', electrode_pos)
    electrode_pos = [(1, 2, 3), (-1, -2, -3)]
    net.add_electrode_array('arr1', electrode_pos)
    assert len(net.rec_arrays) == 2
    assert len(net.rec_arrays['arr1'].positions) == 2

    # ensure unique names
    pytest.raises(ValueError, net.add_electrode_array, 'arr1', [(6, 6, 800)])
    # all remaining input arguments checked by ExtracellularArray

    rec_arr = ExtracellularArray(electrode_pos)
    with pytest.raises(AttributeError, match="can't set attribute"):
        rec_arr.times = [1, 2, 3]
    with pytest.raises(AttributeError, match="can't set attribute"):
        rec_arr.voltages = [1, 2, 3]
    with pytest.raises(TypeError, match="trial index must be int"):
        _ = rec_arr['0']
    with pytest.raises(IndexError, match="the data contain"):
        _ = rec_arr[42]

    # positions are 3-tuples
    bad_positions = [[(1, 2), (1, 2, 3)], [42, (1, 2, 3)]]
    for bogus_pos in bad_positions:
        pytest.raises((ValueError, TypeError), ExtracellularArray, bogus_pos)

    good_positions = [(1, 2, 3), (100, 200, 300)]
    for cond in ['0.3', [0.3], -1]:  # conductivity is positive float
        pytest.raises((TypeError, AssertionError), ExtracellularArray,
                      good_positions, conductivity=cond)
    for meth in ['foo', 0.3]:  # method is 'psa' or 'lsa' (or None for test)
        pytest.raises((TypeError, AssertionError, ValueError),
                      ExtracellularArray, good_positions, method=meth)
    for mind in ['foo', -1, None]:  # minimum distance to segment boundary
        pytest.raises((TypeError, AssertionError), ExtracellularArray,
                      good_positions, min_distance=mind)

    pytest.raises(ValueError, ExtracellularArray,  # more chans than voltages
                  good_positions, times=[1], voltages=[[[42]]])
    pytest.raises(ValueError, ExtracellularArray,  # less times than voltages
                  good_positions, times=[1], voltages=[[[42, 42], [84, 84]]])

    rec_arr = ExtracellularArray(good_positions,
                                 times=[0, 0.1, 0.21, 0.3],  # uneven sampling
                                 voltages=[[[0, 0, 0, 0], [0, 0, 0, 0]]])
    with pytest.raises(RuntimeError, match="Extracellular sampling times"):
        _ = rec_arr.sfreq
    rec_arr._reset()
    assert len(rec_arr.times) == len(rec_arr.voltages) == 0
    assert rec_arr.sfreq is None
    rec_arr = ExtracellularArray(good_positions,
                                 times=[0], voltages=[[[0], [0]]])
    with pytest.raises(RuntimeError, match="Sampling rate is not defined"):
        _ = rec_arr.sfreq
Exemple #4
0
def test_network_visualization():
    """Test network visualisations."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    params.update({'N_pyr_x': 3,
                   'N_pyr_y': 3})
    net = jones_2009_model(params)
    plot_cells(net)
    ax = net.cell_types['L2_pyramidal'].plot_morphology()
    assert len(ax.lines) == 8

    conn_idx = 0
    plot_connectivity_matrix(net, conn_idx, show=False)
    with pytest.raises(TypeError, match='net must be an instance of'):
        plot_connectivity_matrix('blah', conn_idx, show_weight=False)

    with pytest.raises(TypeError, match='conn_idx must be an instance of'):
        plot_connectivity_matrix(net, 'blah', show_weight=False)

    with pytest.raises(TypeError, match='show_weight must be an instance of'):
        plot_connectivity_matrix(net, conn_idx, show_weight='blah')

    src_gid = 5
    plot_cell_connectivity(net, conn_idx, src_gid, show=False)
    with pytest.raises(TypeError, match='net must be an instance of'):
        plot_cell_connectivity('blah', conn_idx, src_gid=src_gid)

    with pytest.raises(TypeError, match='conn_idx must be an instance of'):
        plot_cell_connectivity(net, 'blah', src_gid)

    with pytest.raises(TypeError, match='src_gid must be an instance of'):
        plot_cell_connectivity(net, conn_idx, src_gid='blah')

    with pytest.raises(ValueError, match='src_gid -1 not a valid cell ID'):
        plot_cell_connectivity(net, conn_idx, src_gid=-1)

    # test interactive clicking updates the position of src_cell in plot
    del net.connectivity[-1]
    conn_idx = 15
    net.add_connection(net.gid_ranges['L2_pyramidal'][::2],
                       'L5_basket', 'soma',
                       'ampa', 0.00025, 1.0, lamtha=3.0,
                       probability=0.8)
    fig = plot_cell_connectivity(net, conn_idx)
    ax_src, ax_target, _ = fig.axes

    pos = net.pos_dict['L2_pyramidal'][2]
    _fake_click(fig, ax_src, [pos[0], pos[1]])
    pos_in_plot = ax_target.collections[2].get_offsets().data[0]
    assert_allclose(pos[:2], pos_in_plot)
def test_transmembrane_currents():
    """Test that net transmembrane current is zero at all times."""
    params.update({'N_pyr_x': 3,
                   'N_pyr_y': 3,
                   't_evprox_1': 5,
                   't_evdist_1': 10,
                   't_evprox_2': 20,
                   'N_trials': 1})
    net = jones_2009_model(params, add_drives_from_params=True)
    electrode_pos = (0, 0, 0)  # irrelevant where electrode is
    # all transfer resistances set to unity
    net.add_electrode_array('net_Im', electrode_pos, method=None)
    _ = simulate_dipole(net, tstop=40.)
    assert_allclose(net.rec_arrays['net_Im'].voltages, 0,
                    rtol=1e-10, atol=1e-10)
Exemple #6
0
def test_dipole_visualization():
    """Test dipole visualisations."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    params.update({'N_pyr_x': 3,
                   'N_pyr_y': 3})
    net = jones_2009_model(params)
    weights_ampa_p = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
    syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}

    net.add_bursty_drive(
        'beta_prox', tstart=0., burst_rate=25, burst_std=5,
        numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
        weights_ampa=weights_ampa_p, synaptic_delays=syn_delays_p,
        event_seed=14)

    dpls = simulate_dipole(net, tstop=100., n_trials=2)
    fig = dpls[0].plot()  # plot the first dipole alone
    axes = fig.get_axes()[0]
    dpls[0].copy().smooth(window_len=10).plot(ax=axes)  # add smoothed versions
    dpls[0].copy().savgol_filter(h_freq=30).plot(ax=axes)  # on top

    # test decimation options
    plot_dipole(dpls[0], decim=2)
    for dec in [-1, [2, 2.]]:
        with pytest.raises(ValueError,
                           match='each decimation factor must be a positive'):
            plot_dipole(dpls[0], decim=dec)

    # test plotting multiple dipoles as overlay
    fig = plot_dipole(dpls)

    # multiple TFRs get averaged
    fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3)

    with pytest.raises(RuntimeError,
                       match="All dipoles must be scaled equally!"):
        plot_dipole([dpls[0].copy().scale(10), dpls[1].copy().scale(20)])
    with pytest.raises(RuntimeError,
                       match="All dipoles must be scaled equally!"):
        plot_psd([dpls[0].copy().scale(10), dpls[1].copy().scale(20)])
    with pytest.raises(RuntimeError,
                       match="All dipoles must be sampled equally!"):
        dpl_sfreq = dpls[0].copy()
        dpl_sfreq.sfreq /= 10
        plot_psd([dpls[0], dpl_sfreq])
Exemple #7
0
    def _run_hnn_core_fixture(backend=None, n_procs=None, n_jobs=1,
                              reduced=False, record_vsoma=False,
                              record_isoma=False, postproc=False,
                              electrode_array=None):
        hnn_core_root = op.dirname(hnn_core.__file__)

        # default params
        params_fname = op.join(hnn_core_root, 'param', 'default.json')
        params = read_params(params_fname)

        tstop = 170.
        if reduced:
            params.update({'N_pyr_x': 3,
                           'N_pyr_y': 3,
                           't_evprox_1': 5,
                           't_evdist_1': 10,
                           't_evprox_2': 20,
                           'N_trials': 2})
            tstop = 40.
        net = jones_2009_model(params, add_drives_from_params=True)
        if electrode_array is not None:
            for name, positions in electrode_array.items():
                net.add_electrode_array(name, positions)

        if backend == 'mpi':
            with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'):
                dpls = simulate_dipole(net, record_vsoma=record_isoma,
                                       record_isoma=record_vsoma,
                                       postproc=postproc, tstop=tstop)
        elif backend == 'joblib':
            with JoblibBackend(n_jobs=n_jobs):
                dpls = simulate_dipole(net, record_vsoma=record_isoma,
                                       record_isoma=record_vsoma,
                                       postproc=postproc, tstop=tstop)
        else:
            dpls = simulate_dipole(net, record_vsoma=record_isoma,
                                   record_isoma=record_vsoma,
                                   postproc=postproc, tstop=tstop)

        # check that the network object is picklable after the simulation
        pickle.dumps(net)

        # number of trials simulated
        for drive in net.external_drives.values():
            assert len(drive['events']) == params['N_trials']

        return dpls, net
def test_rec_array_calculation():
    """Test LFP calculation."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    params.update({'N_pyr_x': 3,
                   'N_pyr_y': 3,
                   't_evprox_1': 7,
                   't_evdist_1': 17})
    net = jones_2009_model(params, add_drives_from_params=True)

    # one electrode inside, one above the active elements of the network
    electrode_pos = [(1.5, 1.5, 1000), (1.5, 1.5, 3000)]
    net.add_electrode_array('arr1', electrode_pos)
    _ = simulate_dipole(net, tstop=5, n_trials=1)

    # test accessing simulated voltages
    assert (len(net.rec_arrays['arr1']) ==
            len(net.rec_arrays['arr1'].voltages) == 1)  # n_trials
    assert len(net.rec_arrays['arr1'].voltages[0]) == 2  # n_contacts
    assert (len(net.rec_arrays['arr1'].voltages[0][0]) ==
            len(net.rec_arrays['arr1'].times))
    # ensure copy drops data (but retains electrode position information etc.)
    net_copy = net.copy()
    assert isinstance(net_copy.rec_arrays['arr1'], ExtracellularArray)
    assert len(net_copy.rec_arrays['arr1'].voltages) == 0

    assert isinstance(net.rec_arrays['arr1'].voltages, np.ndarray)
    assert isinstance(net.rec_arrays['arr1'].times, np.ndarray)

    # using the same electrode positions, but a different method: LSA
    net.add_electrode_array('arr2', electrode_pos, method='lsa')

    # make sure no sinister segfaults are triggered when running mult. trials
    n_trials = 5  # NB 5 trials!
    _ = simulate_dipole(net, tstop=5, n_trials=n_trials)

    # simulate_dipole is run twice above, first 1 then 5 trials.
    # Make sure that previous results are discarded on each run
    assert len(net.rec_arrays['arr1']._data) == n_trials

    for trial_idx in range(n_trials):
        # LSA and PSA should agree far away (second electrode)
        assert_allclose(net.rec_arrays['arr1']._data[trial_idx][1],
                        net.rec_arrays['arr2']._data[trial_idx][1],
                        rtol=1e-3, atol=1e-3)
Exemple #9
0
    def test_run_mpibackend_oversubscribed(self, run_hnn_core_fixture):
        """Test running MPIBackend with oversubscribed number of procs"""
        hnn_core_root = op.dirname(hnn_core.__file__)
        params_fname = op.join(hnn_core_root, 'param', 'default.json')
        params = read_params(params_fname)
        params.update({
            'N_pyr_x': 3,
            'N_pyr_y': 3,
            't_evprox_1': 5,
            't_evdist_1': 10,
            't_evprox_2': 20,
            'N_trials': 2
        })
        net = jones_2009_model(params, add_drives_from_params=True)

        oversubscribed = round(cpu_count() * 1.5)
        with MPIBackend(n_procs=oversubscribed) as backend:
            assert backend.n_procs == oversubscribed
            simulate_dipole(net, tstop=40)
Exemple #10
0
def test_network_cell_positions():
    """"Test manipulation of cell positions in the network object"""

    net = jones_2009_model()
    assert np.isclose(net._inplane_distance, 1.)  # default
    assert np.isclose(net._layer_separation, 1307.4)  # default

    # change both from their default values
    net.set_cell_positions(inplane_distance=2.)
    assert np.isclose(net._layer_separation, 1307.4)  # still the default
    net.set_cell_positions(layer_separation=1000.)
    assert np.isclose(net._inplane_distance, 2.)  # mustn't change

    # check that in-plane distance is now 2. for the default 10 x 10 grid
    assert np.allclose(  # x-coordinate jumps every 10th gid
        np.diff(np.array(net.pos_dict['L5_pyramidal'])[9::10, 0], axis=0), 2.)
    assert np.allclose(  # test first 10 y-coordinates
        np.diff(np.array(net.pos_dict['L5_pyramidal'])[:9, 1], axis=0), 2.)

    # check that layer separation has changed (L5 is zero) tp 1000.
    assert np.isclose(net.pos_dict['L2_pyramidal'][0][2], 1000.)

    with pytest.raises(ValueError, match='In-plane distance must be positive'):
        net.set_cell_positions(inplane_distance=0.)
    with pytest.raises(ValueError, match='Layer separation must be positive'):
        net.set_cell_positions(layer_separation=0.)

    # Check that the origin of the drive cells matches the new 'origin'
    # when set_cell_positions is called after adding drives.
    # As the network dimensions increase, so does the center-of-mass of the
    # grid points, which is where all hnn drives should be located. The lamtha-
    # dependent weights and delays of the drives are calculated with respect to
    # this origin.
    add_erp_drives_to_jones_model(net)
    net.set_cell_positions(inplane_distance=20.)
    for drive_name, drive in net.external_drives.items():
        assert len(net.pos_dict[drive_name]) == drive['n_drive_cells']
        # just test the 0th index, assume all others then fine too
        for idx in range(3):  # x,y,z coords
            assert (net.pos_dict[drive_name][0][idx] == net.pos_dict['origin']
                    [idx])
Exemple #11
0
def test_dipole_simulation():
    """Test data produced from simulate_dipole() call."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)
    params.update({
        'N_pyr_x': 3,
        'N_pyr_y': 3,
        'dipole_smooth_win': 5,
        't_evprox_1': 5,
        't_evdist_1': 10,
        't_evprox_2': 20
    })
    net = jones_2009_model(params, add_drives_from_params=True)
    with pytest.raises(ValueError, match="Invalid number of simulations: 0"):
        simulate_dipole(net, tstop=25., n_trials=0)
    with pytest.raises(TypeError, match="record_vsoma must be bool, got int"):
        simulate_dipole(net, tstop=25., n_trials=1, record_vsoma=0)
    with pytest.raises(TypeError, match="record_isoma must be bool, got int"):
        simulate_dipole(net,
                        tstop=25.,
                        n_trials=1,
                        record_vsoma=False,
                        record_isoma=0)

    # test Network.copy() returns 'bare' network after simulating
    dpl = simulate_dipole(net, tstop=25., n_trials=1)[0]
    net_copy = net.copy()
    assert len(net_copy.external_drives['evprox1']['events']) == 0

    # test that Dipole.copy() returns the expected exact copy
    assert_allclose(dpl.data['agg'], dpl.copy().data['agg'])

    with pytest.warns(UserWarning, match='No connections'):
        net = Network(params)
        # warning triggered on simulate_dipole()
        simulate_dipole(net, tstop=0.1, n_trials=1)

        # Smoke test for raster plot with no spikes
        net.cell_response.plot_spikes_raster()
Exemple #12
0
def test_gid_assignment():
    """Test that gids are assigned without overlap across ranks"""

    net = jones_2009_model(add_drives_from_params=False)
    weights_ampa = {'L2_basket': 1.0, 'L2_pyramidal': 2.0, 'L5_pyramidal': 3.0}
    syn_delays = {'L2_basket': .1, 'L2_pyramidal': .2, 'L5_pyramidal': .3}

    net.add_bursty_drive('bursty_dist',
                         location='distal',
                         burst_rate=10,
                         weights_ampa=weights_ampa,
                         synaptic_delays=syn_delays,
                         cell_specific=False,
                         n_drive_cells=5)
    net.add_evoked_drive('evoked_prox',
                         mu=1.0,
                         sigma=1.0,
                         numspikes=1,
                         weights_ampa=weights_ampa,
                         location='proximal',
                         synaptic_delays=syn_delays,
                         cell_specific=True,
                         n_drive_cells='n_cells')
    net._instantiate_drives(tstop=20, n_trials=2)

    all_gids = list()
    for type_range in net.gid_ranges.values():
        all_gids.extend(list(type_range))
    all_gids.sort()

    n_hosts = 3
    all_gids_instantiated = list()
    for rank in range(n_hosts):
        net_builder = NetworkBuilder(net)
        net_builder._gid_list = list()
        net_builder._gid_assign(rank=rank, n_hosts=n_hosts)
        all_gids_instantiated.extend(net_builder._gid_list)
    all_gids_instantiated.sort()
    assert all_gids_instantiated == sorted(set(all_gids_instantiated))
    assert all_gids == all_gids_instantiated
Exemple #13
0
    def test_terminate_mpibackend(self, run_hnn_core_fixture):
        """Test terminating MPIBackend from thread"""
        hnn_core_root = op.dirname(hnn_core.__file__)
        params_fname = op.join(hnn_core_root, 'param', 'default.json')
        params = read_params(params_fname)
        params.update({
            'N_pyr_x': 3,
            'N_pyr_y': 3,
            't_evprox_1': 5,
            't_evdist_1': 10,
            't_evprox_2': 20,
            'N_trials': 2
        })
        net = jones_2009_model(params, add_drives_from_params=True)

        with MPIBackend() as backend:
            event = Event()
            # start background thread that will kill all MPIBackends
            # until event.set()
            kill_t = Thread(target=_terminate_mpibackend,
                            args=(event, backend))
            # make thread a daemon in case we throw an exception
            # and don't run event.set() so that py.test will
            # not hang before exiting
            kill_t.daemon = True
            kill_t.start()

            with pytest.warns(UserWarning) as record:
                with pytest.raises(
                        RuntimeError,
                        match="MPI simulation failed. Return code: 1"):
                    simulate_dipole(net, tstop=40)

            event.set()
        expected_string = "Child process failed unexpectedly"
        assert expected_string in record[0].message.args[0]
Exemple #14
0
def test_network():
    """Test network object."""
    params = read_params(params_fname)
    # add rhythmic inputs (i.e., a type of common input)
    params.update({
        'input_dist_A_weight_L2Pyr_ampa': 1.4e-5,
        'input_dist_A_weight_L5Pyr_ampa': 2.4e-5,
        't0_input_dist': 50,
        'input_prox_A_weight_L2Pyr_ampa': 3.4e-5,
        'input_prox_A_weight_L5Pyr_ampa': 4.4e-5,
        't0_input_prox': 50
    })

    net = jones_2009_model(deepcopy(params), add_drives_from_params=True)
    # instantiate drive events for NetworkBuilder
    net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials'])
    network_builder = NetworkBuilder(net)  # needed to instantiate cells

    # Assert that params are conserved across Network initialization
    for p in params:
        assert params[p] == net._params[p]
    assert len(params) == len(net._params)
    print(network_builder)
    print(network_builder._cells[:2])

    # Assert that proper number/types of gids are created for Network drives
    dns_from_gids = [
        name for name in net.gid_ranges.keys() if name not in net.cell_types
    ]
    assert sorted(dns_from_gids) == sorted(net.external_drives.keys())
    for dn in dns_from_gids:
        n_drive_cells = net.external_drives[dn]['n_drive_cells']
        assert len(net.gid_ranges[dn]) == n_drive_cells

    # Check drive dict structure for each external drive
    for drive in net.external_drives.values():
        # Check that connectivity sources correspond to gid_ranges
        conn_idxs = pick_connection(net, src_gids=drive['name'])
        this_src_gids = set([
            gid for conn_idx in conn_idxs
            for gid in net.connectivity[conn_idx]['src_gids']
        ])  # NB set: globals
        assert sorted(this_src_gids) == list(net.gid_ranges[drive['name']])
        # Check type-specific dynamics and events
        n_drive_cells = drive['n_drive_cells']
        assert len(drive['events']) == 1  # single trial simulated
        if drive['type'] == 'evoked':
            for kw in ['mu', 'sigma', 'numspikes']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
            # this also implicitly tests that events are always a list
            assert len(drive['events'][0][0]) == drive['dynamics']['numspikes']
        elif drive['type'] == 'gaussian':
            for kw in ['mu', 'sigma', 'numspikes']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
        elif drive['type'] == 'poisson':
            for kw in ['tstart', 'tstop', 'rate_constant']:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
        elif drive['type'] == 'bursty':
            for kw in [
                    'tstart', 'tstart_std', 'tstop', 'burst_rate', 'burst_std',
                    'numspikes'
            ]:
                assert kw in drive['dynamics'].keys()
            assert len(drive['events'][0]) == n_drive_cells
            n_events = (
                drive['dynamics']['numspikes'] *  # 2
                (1 +
                 (drive['dynamics']['tstop'] - drive['dynamics']['tstart'] - 1)
                 // (1000. / drive['dynamics']['burst_rate'])))
            assert len(drive['events'][0][0]) == n_events  # 4

    # make sure the PRNGs are consistent.
    target_times = {
        'evdist1': [66.30498327062551, 61.54362532343694],
        'evprox1': [23.80641637082997, 30.857310915553647],
        'evprox2': [141.76252038319825, 137.73942375578602]
    }
    for drive_name in target_times:
        for idx in [0, -1]:  # first and last
            assert_allclose(net.external_drives[drive_name]['events'][0][idx],
                            target_times[drive_name][idx],
                            rtol=1e-12)

    # check select AMPA weights
    target_weights = {
        'evdist1': {
            'L2_basket': 0.006562,
            'L5_pyramidal': 0.142300
        },
        'evprox1': {
            'L2_basket': 0.08831,
            'L5_pyramidal': 0.00865
        },
        'evprox2': {
            'L2_basket': 0.000003,
            'L5_pyramidal': 0.684013
        },
        'bursty1': {
            'L2_pyramidal': 0.000034,
            'L5_pyramidal': 0.000044
        },
        'bursty2': {
            'L2_pyramidal': 0.000014,
            'L5_pyramidal': 0.000024
        }
    }
    for drive_name in target_weights:
        for target_type in target_weights[drive_name]:
            conn_idxs = pick_connection(net,
                                        src_gids=drive_name,
                                        target_gids=target_type,
                                        receptor='ampa')
            for conn_idx in conn_idxs:
                drive_conn = net.connectivity[conn_idx]
                assert_allclose(drive_conn['nc_dict']['A_weight'],
                                target_weights[drive_name][target_type],
                                rtol=1e-12)

    # check select synaptic delays
    target_delays = {
        'evdist1': {
            'L2_basket': 0.1,
            'L5_pyramidal': 0.1
        },
        'evprox1': {
            'L2_basket': 0.1,
            'L5_pyramidal': 1.
        },
        'evprox2': {
            'L2_basket': 0.1,
            'L5_pyramidal': 1.
        }
    }
    for drive_name in target_delays:
        for target_type in target_delays[drive_name]:
            conn_idxs = pick_connection(net,
                                        src_gids=drive_name,
                                        target_gids=target_type,
                                        receptor='ampa')
            for conn_idx in conn_idxs:
                drive_conn = net.connectivity[conn_idx]
                assert_allclose(drive_conn['nc_dict']['A_delay'],
                                target_delays[drive_name][target_type],
                                rtol=1e-12)

    # array of simulation times is created in Network.__init__, but passed
    # to CellResponse-constructor for storage (Network is agnostic of time)
    with pytest.raises(TypeError,
                       match="'times' is an np.ndarray of simulation times"):
        _ = CellResponse(times='blah')

    # Assert that all external drives are initialized
    # Assumes legacy mode where cell-specific drives create artificial cells
    # for all network cells regardless of connectivity
    n_evoked_sources = 3 * net._n_cells
    n_pois_sources = net._n_cells
    n_gaus_sources = net._n_cells
    n_bursty_sources = (net.external_drives['bursty1']['n_drive_cells'] +
                        net.external_drives['bursty2']['n_drive_cells'])
    # test that expected number of external driving events are created
    assert len(
        network_builder._drive_cells) == (n_evoked_sources + n_pois_sources +
                                          n_gaus_sources + n_bursty_sources)
    assert len(network_builder._gid_list) ==\
        len(network_builder._drive_cells) + net._n_cells
    # first 'evoked drive' comes after real cells and bursty drive cells
    assert network_builder._drive_cells[n_bursty_sources].gid ==\
        net._n_cells + n_bursty_sources

    # Assert that netcons are created properly
    n_pyr = len(net.gid_ranges['L2_pyramidal'])
    n_basket = len(net.gid_ranges['L2_basket'])

    # Check basket-basket connection where allow_autapses=False
    assert 'L2Pyr_L2Pyr_nmda' in network_builder.ncs
    n_connections = 3 * (n_pyr**2 - n_pyr)  # 3 synapses / cell
    assert len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) == n_connections
    nc = network_builder.ncs['L2Pyr_L2Pyr_nmda'][0]
    assert nc.threshold == params['threshold']

    # Check bursty drives which use cell_specific=False
    assert 'bursty1_L2Pyr_ampa' in network_builder.ncs
    n_bursty1_sources = net.external_drives['bursty1']['n_drive_cells']
    n_connections = n_bursty1_sources * 3 * n_pyr  # 3 synapses / cell
    assert len(network_builder.ncs['bursty1_L2Pyr_ampa']) == n_connections
    nc = network_builder.ncs['bursty1_L2Pyr_ampa'][0]
    assert nc.threshold == params['threshold']

    # Check basket-basket connection where allow_autapses=True
    assert 'L2Basket_L2Basket_gabaa' in network_builder.ncs
    n_connections = n_basket**2  # 1 synapse / cell
    assert len(network_builder.ncs['L2Basket_L2Basket_gabaa']) == n_connections
    nc = network_builder.ncs['L2Basket_L2Basket_gabaa'][0]
    assert nc.threshold == params['threshold']

    # Check evoked drives which use cell_specific=True
    assert 'evdist1_L2Basket_nmda' in network_builder.ncs
    n_connections = n_basket  # 1 synapse / cell
    assert len(network_builder.ncs['evdist1_L2Basket_nmda']) == n_connections
    nc = network_builder.ncs['evdist1_L2Basket_nmda'][0]
    assert nc.threshold == params['threshold']

    # Test inputs for connectivity API
    net = jones_2009_model(deepcopy(params), add_drives_from_params=True)
    # instantiate drive events for NetworkBuilder
    net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials'])
    n_conn = len(network_builder.ncs['L2Basket_L2Pyr_gabaa'])
    kwargs_default = dict(src_gids=[0, 1],
                          target_gids=[35, 36],
                          loc='soma',
                          receptor='gabaa',
                          weight=5e-4,
                          delay=1.0,
                          lamtha=1e9,
                          probability=1.0)
    net.add_connection(**kwargs_default)  # smoke test
    network_builder = NetworkBuilder(net)
    assert len(network_builder.ncs['L2Basket_L2Pyr_gabaa']) == n_conn + 4
    nc = network_builder.ncs['L2Basket_L2Pyr_gabaa'][-1]
    assert_allclose(nc.weight[0], kwargs_default['weight'])

    kwargs_good = [('src_gids', 0), ('src_gids', 'L2_pyramidal'),
                   ('src_gids', range(2)), ('target_gids', 35),
                   ('target_gids', range(2)), ('target_gids', 'L2_pyramidal'),
                   ('target_gids', [[35, 36], [37, 38]]), ('probability', 0.5)]
    for arg, item in kwargs_good:
        kwargs = kwargs_default.copy()
        kwargs[arg] = item
        net.add_connection(**kwargs)

    kwargs_bad = [('src_gids', 0.0), ('src_gids', [0.0]),
                  ('target_gids', 35.0), ('target_gids', [35.0]),
                  ('target_gids', [[35], [36.0]]), ('loc', 1.0),
                  ('receptor', 1.0), ('weight', '1.0'), ('delay', '1.0'),
                  ('lamtha', '1.0'), ('probability', '0.5')]
    for arg, item in kwargs_bad:
        match = ('must be an instance of')
        with pytest.raises(TypeError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            net.add_connection(**kwargs)

    kwargs_bad = [('src_gids', -1), ('src_gids', [-1]), ('target_gids', -1),
                  ('target_gids', [-1]), ('target_gids', [[35], [-1]]),
                  ('target_gids', [[35]]), ('src_gids', [0, 100]),
                  ('target_gids', [0, 100])]
    for arg, item in kwargs_bad:
        with pytest.raises(AssertionError):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            net.add_connection(**kwargs)

    for arg in ['src_gids', 'target_gids', 'loc', 'receptor']:
        string_arg = 'invalid_string'
        match = f"Invalid value for the '{arg}' parameter"
        with pytest.raises(ValueError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = string_arg
            net.add_connection(**kwargs)

    # Check probability=0.5 produces half as many connections as default
    net.add_connection(**kwargs_default)
    kwargs = kwargs_default.copy()
    kwargs['probability'] = 0.5
    net.add_connection(**kwargs)
    n_connections = np.sum(
        [len(t_gids) for t_gids in net.connectivity[-2]['gid_pairs'].values()])
    n_connections_new = np.sum(
        [len(t_gids) for t_gids in net.connectivity[-1]['gid_pairs'].values()])
    assert n_connections_new == np.round(n_connections * 0.5).astype(int)
    assert net.connectivity[-1]['probability'] == 0.5
    with pytest.raises(ValueError, match='probability must be'):
        kwargs = kwargs_default.copy()
        kwargs['probability'] = -1.0
        net.add_connection(**kwargs)

    # Test net.pick_connection()
    kwargs_default = dict(net=net,
                          src_gids=None,
                          target_gids=None,
                          loc=None,
                          receptor=None)

    kwargs_good = [('src_gids', 0), ('src_gids', 'L2_pyramidal'),
                   ('src_gids', range(2)), ('src_gids', None),
                   ('target_gids', 35), ('target_gids', range(2)),
                   ('target_gids', 'L2_pyramidal'), ('target_gids', None),
                   ('loc', 'soma'), ('loc', None), ('receptor', 'gabaa'),
                   ('receptor', None)]
    for arg, item in kwargs_good:
        kwargs = kwargs_default.copy()
        kwargs[arg] = item
        indices = pick_connection(**kwargs)
        for conn_idx in indices:
            if (arg == 'src_gids' or arg == 'target_gids') and \
                    isinstance(item, str):
                assert np.all(
                    np.in1d(net.connectivity[conn_idx][arg],
                            net.gid_ranges[item]))
            elif item is None:
                pass
            else:
                assert np.any(np.in1d([item], net.connectivity[conn_idx][arg]))

    # Check that a given gid isn't present in any connection profile that
    # pick_connection can't identify
    conn_idxs = pick_connection(net, src_gids=0)
    for conn_idx in range(len(net.connectivity)):
        if conn_idx not in conn_idxs:
            assert 0 not in net.connectivity[conn_idx]['src_gids']

    # Check that pick_connection returns empty lists when searching for
    # a drive targetting the wrong location
    conn_idxs = pick_connection(net, src_gids='evdist1', loc='proximal')
    assert len(conn_idxs) == 0
    assert not pick_connection(net, src_gids='evprox1', loc='distal')

    # Check condition where not connections match
    assert pick_connection(net, loc='distal', receptor='gabab') == list()

    kwargs_bad = [('src_gids', 0.0),
                  ('src_gids', [0.0]), ('target_gids', 35.0),
                  ('target_gids', [35.0]), ('target_gids', [35, [36.0]]),
                  ('loc', 1.0), ('receptor', 1.0)]
    for arg, item in kwargs_bad:
        match = ('must be an instance of')
        with pytest.raises(TypeError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            pick_connection(**kwargs)

    kwargs_bad = [('src_gids', -1), ('src_gids', [-1]), ('target_gids', -1),
                  ('target_gids', [-1]), ('src_gids', [35, -1]),
                  ('target_gids', [35, -1])]
    for arg, item in kwargs_bad:
        with pytest.raises(AssertionError):
            kwargs = kwargs_default.copy()
            kwargs[arg] = item
            pick_connection(**kwargs)

    for arg in ['src_gids', 'target_gids', 'loc', 'receptor']:
        string_arg = 'invalid_string'
        match = f"Invalid value for the '{arg}' parameter"
        with pytest.raises(ValueError, match=match):
            kwargs = kwargs_default.copy()
            kwargs[arg] = string_arg
            pick_connection(**kwargs)

    # Test removing connections from net.connectivity
    # Needs to be updated if number of drives change in preceeding tests
    net.clear_connectivity()
    assert len(net.connectivity) == 50
    net.clear_drives()
    assert len(net.connectivity) == 0
Exemple #15
0
import os.path as op
import tempfile

import matplotlib.pyplot as plt

###############################################################################
# Let us import hnn_core

import hnn_core
from hnn_core import simulate_dipole, jones_2009_model
from hnn_core.viz import plot_dipole

###############################################################################
# Let us first create our default network and visualize the cells
# inside it.
net = jones_2009_model()
net.plot_cells()
net.cell_types['L5_pyramidal'].plot_morphology()

###############################################################################
# The network of cells is now defined, to which we add external drives as
# required. Weights are prescribed separately for AMPA and NMDA receptors
# (receptors that are not used can be omitted or set to zero). The possible
# drive types include the following (click on the links for documentation):
#
# - :meth:`hnn_core.Network.add_evoked_drive`
# - :meth:`hnn_core.Network.add_poisson_drive`
# - :meth:`hnn_core.Network.add_bursty_drive`

###############################################################################
# First, we add a distal evoked drive
Exemple #16
0
from hnn_core.viz import plot_dipole

###############################################################################
# We begin by instantiating the network model from Law et al. 2021 [1]_.
net = law_2021_model()

###############################################################################
# The Law 2021 model is based on the network model described in
# Jones et al. 2009 [2]_ with several important modifications. One of the most
# significant changes is substantially increasing the rise and fall time
# constants of GABAb-conductances on L2 and L5 pyramidal. Another important
# change is the removal of calcium channels from basal dendrites and soma of
# L5 pyramidal cells specifically.
# We can inspect these properties with the ``net.cell_types`` attribute which
# contains information on the biophysics and geometry of each cell.
net_jones = jones_2009_model()

jones_rise = net_jones.cell_types['L5_pyramidal'].synapses['gabab']['tau1']
law_rise = net.cell_types['L5_pyramidal'].synapses['gabab']['tau1']
print(f'GABAb Rise (ms): {jones_rise} -> {law_rise}')

jones_fall = net_jones.cell_types['L5_pyramidal'].synapses['gabab']['tau2']
law_fall = net.cell_types['L5_pyramidal'].synapses['gabab']['tau2']
print(f'GABAb Fall (ms): {jones_fall} -> {law_fall}\n')

print('Apical Dendrite Channels:')
print(net.cell_types['L5_pyramidal'].sections['apical_1'].mechs.keys())
print("\nBasal Dendrite Channels ('ca' missing):")
print(net.cell_types['L5_pyramidal'].sections['basal_1'].mechs.keys())

###############################################################################
Exemple #17
0
hnn_core_root = op.dirname(hnn_core.__file__)

###############################################################################
# Then we read the parameters file
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

###############################################################################
# To explore how to modify network connectivity, we will start with simulating
# the evoked response from the
# :ref:`evoked example <sphx_glr_auto_examples_plot_simulate_evoked.py>`, and
# explore how it changes with new connections. We first instantiate the
# network. (Note: Setting ``add_drives_from_params=True`` loads a set of
# predefined drives without the drives API shown previously).
net_erp = jones_2009_model(params, add_drives_from_params=True)

###############################################################################
# Instantiating the network comes with a predefined set of connections that
# reflect the canonical neocortical microcircuit. ``net.connectivity``
# is a list of dictionaries which detail every cell-cell, and drive-cell
# connection. The weights of these connections can be visualized with
# :func:`~hnn_core.viz.plot_connectivity_weights` as well as
# :func:`~hnn_core.viz.plot_cell_connectivity`. We can search for specific
# connections using ``pick_connection`` which returns the indices
# of ``net.connectivity`` that match the provided parameters.
from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity
from hnn_core.network import pick_connection

print(len(net_erp.connectivity))
Exemple #18
0
from hnn_core import (read_params, read_spikes, jones_2009_model,
                      simulate_dipole)

hnn_core_root = op.dirname(hnn_core.__file__)

###############################################################################
# Then we read the parameters file
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

###############################################################################
# Now let's build the network. We have used the same weights as in the
# :ref:`evoked example <sphx_glr_auto_examples_plot_simulate_evoked.py>`.
import matplotlib.pyplot as plt

net = jones_2009_model(params)

###############################################################################
# ``net`` does not have any driving inputs and only defines the local network
# connectivity. Let us go ahead and first add a distal evoked drive.
# We need to define the AMPA and NMDA weights for the connections. An
# "evoked drive" defines inputs that are normally distributed with a certain
# mean and standard deviation.

weights_ampa_d1 = {
    'L2_basket': 0.006562,
    'L2_pyramidal': 7e-6,
    'L5_pyramidal': 0.142300
}
weights_nmda_d1 = {
    'L2_basket': 0.019482,
Exemple #19
0
import os.path as op
import tempfile

###############################################################################
# Let us import ``hnn_core``.

import hnn_core
from hnn_core import read_spikes, jones_2009_model, simulate_dipole

###############################################################################
# Now let's build the network. We have used the same weights as in the
# :ref:`evoked example <sphx_glr_auto_examples_plot_simulate_evoked.py>`.
import matplotlib.pyplot as plt

net = jones_2009_model()

###############################################################################
# ``net`` does not have any driving inputs and only defines the local network
# connectivity. Let us go ahead and first add a distal evoked drive.
# We need to define the AMPA and NMDA weights for the connections. An
# "evoked drive" defines inputs that are normally distributed with a certain
# mean and standard deviation.

weights_ampa_d1 = {
    'L2_basket': 0.006562,
    'L2_pyramidal': 7e-6,
    'L5_pyramidal': 0.142300
}
weights_nmda_d1 = {
    'L2_basket': 0.019482,
# Read the base parameters from a file
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

###############################################################################
# Let's first simulate the dipole with some initial parameters. The parameter
# definitions also contain the drives. Even though we could add drives
# explicitly through our API
# (see :ref:`sphx_glr_auto_examples_workflows_plot_simulate_evoked.py`),
# for conciseness,
# we add them automatically from the parameter files

scale_factor = 3000.
smooth_window_len = 30.
tstop = exp_dpl.times[-1]
net = jones_2009_model(params=params, add_drives_from_params=True)
with MPIBackend(n_procs=n_procs):
    print("Running simulation with initial parameters")
    initial_dpl = simulate_dipole(net, tstop=tstop, n_trials=1)[0]
    initial_dpl = initial_dpl.scale(scale_factor).smooth(smooth_window_len)

###############################################################################
# Now we start the optimization!

from hnn_core.optimization import optimize_evoked

with MPIBackend(n_procs=n_procs):
    params_optim = optimize_evoked(params,
                                   exp_dpl,
                                   initial_dpl,
                                   scale_factor=scale_factor,
Exemple #21
0
def test_network_models():
    """"Test instantiations of the network object"""
    # Make sure critical biophysics for Law model are updated
    net_law = law_2021_model()
    # instantiate drive events for NetworkBuilder
    net_law._instantiate_drives(tstop=net_law._params['tstop'],
                                n_trials=net_law._params['N_trials'])

    for cell_name in ['L5_pyramidal', 'L2_pyramidal']:
        assert net_law.cell_types[cell_name].synapses['gabab']['tau1'] == 45.0
        assert net_law.cell_types[cell_name].synapses['gabab']['tau2'] == 200.0

    # Check add_default_erp()
    net_default = jones_2009_model()
    with pytest.raises(TypeError, match='net must be'):
        add_erp_drives_to_jones_model(net='invalid_input')
    with pytest.raises(TypeError, match='tstart must be'):
        add_erp_drives_to_jones_model(net=net_default, tstart='invalid_input')
    n_conn = len(net_default.connectivity)
    add_erp_drives_to_jones_model(net_default)
    for drive_name in ['evdist1', 'evprox1', 'evprox2']:
        assert drive_name in net_default.external_drives.keys()
    # 15 drive connections are added as follows: evdist1: 3 ampa + 3 nmda,
    # evprox1: 4 ampa, evprox2: 4 ampa, and 1 extra zero-weighted ampa
    # evdist1->L5_basket connection is added to comply with legacy_mode
    assert len(net_default.connectivity) == n_conn + 15

    # Ensure distant dependent calcium gbar
    net_calcium = calcium_model()
    # instantiate drive events for NetworkBuilder
    net_calcium._instantiate_drives(tstop=net_calcium._params['tstop'],
                                    n_trials=net_calcium._params['N_trials'])
    network_builder = NetworkBuilder(net_calcium)
    gid = net_calcium.gid_ranges['L5_pyramidal'][0]
    for section_name, section in \
            network_builder._cells[gid]._nrn_sections.items():
        # Section endpoints where seg.x == 0.0 or 1.0 don't have 'ca' mech
        ca_gbar = [
            seg.__getattribute__('ca').gbar
            for seg in list(section.allseg())[1:-1]
        ]
        na_gbar = [
            seg.__getattribute__('hh2').gnabar
            for seg in list(section.allseg())[1:-1]
        ]
        k_gbar = [
            seg.__getattribute__('hh2').gkbar
            for seg in list(section.allseg())[1:-1]
        ]

        # Ensure positive distance dependent calcium gbar with plateau
        if section_name == 'apical_tuft':
            assert np.all(np.diff(ca_gbar) == 0)
        else:
            assert np.all(np.diff(ca_gbar) > 0)

        # Ensure negative distance dependent sodium gbar with plateau
        if section_name == 'apical_2':
            assert np.all(np.diff(na_gbar[0:3]) < 0)
            assert np.all(np.diff(na_gbar[3:]) == 0)
        elif section_name == 'apical_tuft':
            assert np.all(np.diff(na_gbar) == 0)
        else:
            assert np.all(np.diff(na_gbar) < 0)

        # Ensure negative exponential distance dependent K gbar
        assert np.all(np.diff(k_gbar) < 0)
        assert np.all(np.diff(k_gbar, n=2) > 0)  # positive 2nd derivative
plt.ylabel('Current Dipole (nAm)')
plt.xlim((0, 170))
plt.axhline(0, c='k', ls=':')
plt.show()

###############################################################################
# Now, let us try to simulate the same with ``hnn-core``. We read in the
# network parameters from ``N20.json`` and instantiate the network.

import hnn_core
from hnn_core import simulate_dipole, jones_2009_model
from hnn_core import average_dipoles, JoblibBackend

hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'N20.json')
net = jones_2009_model(params_fname)

###############################################################################
# To simulate the source of the median nerve evoked response, we add a
# sequence of synchronous evoked drives: 1 proximal, 2 distal, and 1 final
# proximal drive. In order to understand the physiological implications of
# proximal and distal drive as well as the general process used to articulate
# a sequence of exogenous drive for simulating evoked responses, see the
# `HNN ERP tutorial`_. Note that setting ``n_drive_cells=1`` and
# ``cell_specific=False`` creates a drive with synchronous input across cells
# in the network.

# Early proximal drive
weights_ampa_p = {
    'L2_basket': 0.0036,
    'L2_pyramidal': 0.0039,
Exemple #23
0
def test_optimize_evoked():
    """Test running the full routine in a reduced network."""
    hnn_core_root = op.dirname(hnn_core.__file__)
    params_fname = op.join(hnn_core_root, 'param', 'default.json')
    params = read_params(params_fname)

    tstop = 10.
    n_trials = 1

    # simulate a dipole to establish ground-truth drive parameters
    mu_orig = 6.
    params.update({
        'N_pyr_x': 3,
        'N_pyr_y': 3,
        't_evprox_1': mu_orig,
        'sigma_t_evprox_1': 2.
    })
    net_orig = jones_2009_model(params, add_drives_from_params=True)
    del net_orig.external_drives['evprox2']
    del net_orig.external_drives['evdist1']
    dpl_orig = simulate_dipole(net_orig, tstop=tstop, n_trials=n_trials)[0]

    # simulate a dipole with a time-shifted drive
    mu_offset = 4.
    params.update({
        'N_pyr_x': 3,
        'N_pyr_y': 3,
        't_evprox_1': mu_offset,
        'sigma_t_evprox_1': 2.
    })
    net_offset = jones_2009_model(params, add_drives_from_params=True)
    del net_offset.external_drives['evprox2']
    del net_offset.external_drives['evdist1']
    dpl_offset = simulate_dipole(net_offset, tstop=tstop, n_trials=n_trials)[0]
    # get drive params from the pre-optimization Network instance
    _, _, drive_static_params_orig = _get_drive_params(net_offset, ['evprox1'])

    with pytest.raises(ValueError,
                       match='The current Network instance lacks '
                       'any evoked drives'):
        net_empty = net_offset.copy()
        del net_empty.external_drives['evprox1']
        net_opt = optimize_evoked(net_empty,
                                  tstop=tstop,
                                  n_trials=n_trials,
                                  target_dpl=dpl_orig,
                                  initial_dpl=dpl_offset)

    net_opt = optimize_evoked(net_offset,
                              tstop=tstop,
                              n_trials=n_trials,
                              target_dpl=dpl_orig,
                              initial_dpl=dpl_offset,
                              timing_range_multiplier=3.,
                              sigma_range_multiplier=50.,
                              synweight_range_multiplier=500.,
                              maxiter=10)

    # the names of drives should be preserved during optimization
    assert net_offset.external_drives.keys() == net_opt.external_drives.keys()

    drive_dynamics_opt, drive_syn_weights_opt, drive_static_params_opt = \
        _get_drive_params(net_opt, ['evprox1'])

    # ensure that params corresponding to only one evoked drive are discovered
    assert (len(drive_dynamics_opt) == len(drive_syn_weights_opt) ==
            len(drive_static_params_opt) == 1)

    # static drive params should remain constant
    assert drive_static_params_opt == drive_static_params_orig