示例#1
0
def teardown_function(function):
    """Executed by py.test after each test in this module

    Since examples might muck with the RC settings, we reset them here.
    """
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
示例#2
0
def pytest_runtest_setup(item):  # noqa: C901
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    if not hasattr(item, 'obj'):
        return

    for mark, option, message in [
            ('example', 'noexamples', "examples not requested"),
            ('slow', 'slow', "slow tests not requested")]:
        if getattr(item.obj, mark, None) and not item.config.getvalue(option):
            pytest.skip(message)

    if getattr(item.obj, 'noassertions', None):
        skipreasons = []
        for fixture_name, option, message in [
                ('analytics', 'analytics', "analytics not requested"),
                ('plt', 'plots', "plots not requested"),
                ('logger', 'logs', "logs not requested")]:
            if fixture_name in item.fixturenames:
                if item.config.getvalue(option):
                    break
                else:
                    skipreasons.append(message)
        else:
            pytest.skip(" and ".join(skipreasons))

    if 'Simulator' in item.fixturenames:
        for test, reason in TestConfig.Simulator.unsupported:
            # We add a '*' before test to eliminate the surprise of needing
            # a '*' before the name of a test function.
            if fnmatch(item.nodeid, '*' + test):
                pytest.xfail(reason)
示例#3
0
    def test_cache_benchmark(self, varying_param, analytics, plt):
        varying = {
            'D': np.asarray(np.linspace(1, 512, 10), dtype=int),
            'N': np.asarray(np.linspace(10, 500, 8), dtype=int),
            'M': np.asarray(np.linspace(750, 2500, 8), dtype=int)
        }[varying_param]
        axis_label = self.param_to_axis_label[varying_param]

        times = [self.time_all(self.get_args(varying_param, v))
                 for v in varying]

        # Restore RC to original settings
        default = RC_DEFAULTS['decoder_cache', 'enabled']
        rc.set("decoder_cache", "enabled", str(default))
        default = RC_DEFAULTS['decoder_cache', 'readonly']
        rc.set("decoder_cache", "readonly", str(default))

        for i, data in enumerate(zip(*times)):
            plt.plot(varying, np.median(data, axis=1), label=self.labels[i])
            analytics.add_data(varying_param, varying, axis_label)
            analytics.add_data(self.keys[i], data)

        plt.xlabel("Number of %s" % axis_label)
        plt.ylabel("Build time (s)")
        plt.legend(loc='best')
示例#4
0
def test_dtype(Simulator, request, seed, bits):
    # Ensure dtype is set back to default after the test, even if it fails
    request.addfinalizer(lambda: rc.set("precision", "bits", "64"))

    float_dtype = np.dtype(getattr(np, "float%s" % bits))
    int_dtype = np.dtype(getattr(np, "int%s" % bits))

    with nengo.Network() as model:
        u = nengo.Node([0.5, -0.4])
        a = nengo.Ensemble(10, 2)
        nengo.Connection(u, a)
        p = nengo.Probe(a)

    rc.set("precision", "bits", bits)
    with Simulator(model) as sim:
        sim.step()

        for k, v in sim.signals.items():
            assert v.dtype in (float_dtype,
                               int_dtype), "Signal '%s' wrong dtype" % k

        objs = (obj for obj in model.all_objects if sim.data[obj] is not None)
        for obj in objs:
            for x in (x for x in sim.data[obj] if isinstance(x, np.ndarray)):
                assert x.dtype == float_dtype, obj

        assert sim.data[p].dtype == float_dtype
示例#5
0
def pytest_runtest_setup(item):
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    item_name = get_item_name(item)

    # join all the lines and then split (preserving quoted strings)
    unsupported = shlex.split(" ".join(
        item.config.getini("nengo_test_unsupported")))
    # group pairs (representing testname + reason)
    unsupported = [unsupported[i:i + 2] for i in range(0, len(unsupported), 2)]

    for test, reason in unsupported:
        # wrap square brackets to interpret them literally
        # (see https://docs.python.org/3/library/fnmatch.html)
        test = "".join("[%s]" % c if c in ('[', ']') else c for c in test)

        # We add a '*' before test to eliminate the surprise of needing
        # a '*' before the name of a test function.
        test = "*" + test

        if fnmatch(item_name, test):
            if TestConfig.run_unsupported:
                item.add_marker(pytest.mark.xfail(reason=reason))
            else:
                pytest.skip(reason)
示例#6
0
文件: conftest.py 项目: nengo/nengo
def pytest_runtest_setup(item):
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    item_name = get_item_name(item)

    # join all the lines and then split (preserving quoted strings)
    unsupported = shlex.split(
        " ".join(item.config.getini("nengo_test_unsupported")))
    # group pairs (representing testname + reason)
    unsupported = [
        unsupported[i:i + 2] for i in range(0, len(unsupported), 2)]

    for test, reason in unsupported:
        # wrap square brackets to interpret them literally
        # (see https://docs.python.org/3/library/fnmatch.html)
        test = "".join("[%s]" % c if c in ('[', ']') else c for c in test)

        # We add a '*' before test to eliminate the surprise of needing
        # a '*' before the name of a test function.
        test = "*" + test

        if fnmatch(item_name, test):
            if TestConfig.run_unsupported:
                item.add_marker(pytest.mark.xfail(reason=reason))
            else:
                pytest.skip(reason)
示例#7
0
文件: ipynb.py 项目: tuchang/nengo
def load_ipython_extension(ipython):
    if IPython.version_info[0] >= 5:
        warnings.warn(
            "Loading the nengo.ipynb notebook extension is no longer "
            "required. Progress bars are automatically activated for IPython "
            "version 5 and later.")
    elif has_ipynb_widgets() and rc.get('progress', 'progress_bar') == 'auto':
        warnings.warn(
            "The nengo.ipynb notebook extension is deprecated. Please upgrade "
            "to IPython version 5 or later.")

        IPythonProgressWidget.load_frontend(ipython)
        rc.set('progress', 'progress_bar', '.'.join(
            (__name__, IPython2ProgressBar.__name__)))
示例#8
0
文件: ipynb.py 项目: nengo/nengo
def load_ipython_extension(ipython):
    if IPython.version_info[0] >= 5:
        warnings.warn(
            "Loading the nengo.ipynb notebook extension is no longer "
            "required. Progress bars are automatically activated for IPython "
            "version 5 and later.")
    elif has_ipynb_widgets() and rc.get('progress', 'progress_bar') == 'auto':
        warnings.warn(
            "The nengo.ipynb notebook extension is deprecated. Please upgrade "
            "to IPython version 5 or later.")

        IPythonProgressWidget.load_frontend(ipython)
        rc.set('progress', 'progress_bar', '.'.join((
            __name__, IPython2ProgressBar.__name__)))
示例#9
0
def pytest_runtest_setup(item):  # noqa: C901
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    if not hasattr(item, 'obj'):
        return  # Occurs for doctests, possibly other weird tests

    conf = item.config
    test_uses_compare = getattr(item.obj, 'compare', None) is not None
    test_uses_sim = 'Simulator' in item.fixturenames
    test_uses_refsim = 'RefSimulator' in item.fixturenames
    tests_frontend = not (test_uses_sim or test_uses_refsim)

    if getattr(item.obj, 'example', None) and not conf.getvalue('noexamples'):
        pytest.skip("examples not requested")
    elif getattr(item.obj, 'slow', None) and not conf.getvalue('slow'):
        pytest.skip("slow tests not requested")
    elif not TestConfig.compare_requested and test_uses_compare:
        pytest.skip("compare tests not requested")
    elif TestConfig.is_skipping_frontend_tests() and tests_frontend:
        pytest.skip("frontend tests not run for alternate backends")
    elif (TestConfig.is_skipping_frontend_tests()
          and test_uses_refsim
          and not TestConfig.is_refsim_overridden()):
        pytest.skip("RefSimulator not overridden")
    elif (TestConfig.is_skipping_frontend_tests()
          and test_uses_sim
          and not TestConfig.is_sim_overridden()):
        pytest.skip("Simulator not overridden")
    elif getattr(item.obj, 'noassertions', None):
        options = []
        for fixture, option in [('analytics', 'analytics'),
                                ('plt', 'plots'),
                                ('logger', 'logs')]:
            if fixture in item.fixturenames and not conf.getvalue(option):
                options.append(option)
        if len(options) > 0:
            pytest.skip("%s not requested" % " and ".join(options))

    if not tests_frontend:
        item_name = get_item_name(item)

        for test, reason in TestConfig.Simulator.unsupported:
            # We add a '*' before test to eliminate the surprise of needing
            # a '*' before the name of a test function.
            if fnmatch(item_name, '*' + test):
                pytest.xfail(reason)
示例#10
0
文件: conftest.py 项目: tuchang/nengo
def pytest_runtest_setup(item):  # noqa: C901
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    if not hasattr(item, 'obj'):
        return  # Occurs for doctests, possibly other weird tests

    conf = item.config
    test_uses_compare = getattr(item.obj, 'compare', None) is not None
    test_uses_sim = 'Simulator' in item.fixturenames
    test_uses_refsim = 'RefSimulator' in item.fixturenames
    tests_frontend = not (test_uses_sim or test_uses_refsim)

    if getattr(item.obj, 'example', None) and not conf.getvalue('noexamples'):
        pytest.skip("examples not requested")
    elif getattr(item.obj, 'slow', None) and not conf.getvalue('slow'):
        pytest.skip("slow tests not requested")
    elif not TestConfig.compare_requested and test_uses_compare:
        pytest.skip("compare tests not requested")
    elif TestConfig.is_skipping_frontend_tests() and tests_frontend:
        pytest.skip("frontend tests not run for alternate backends")
    elif (TestConfig.is_skipping_frontend_tests()
          and test_uses_refsim
          and not TestConfig.is_refsim_overridden()):
        pytest.skip("RefSimulator not overridden")
    elif (TestConfig.is_skipping_frontend_tests()
          and test_uses_sim
          and not TestConfig.is_sim_overridden()):
        pytest.skip("Simulator not overridden")
    elif getattr(item.obj, 'noassertions', None):
        options = []
        for fixture, option in [('analytics', 'analytics'),
                                ('plt', 'plots'),
                                ('logger', 'logs')]:
            if fixture in item.fixturenames and not conf.getvalue(option):
                options.append(option)
        if len(options) > 0:
            pytest.skip("%s not requested" % " and ".join(options))

    if not tests_frontend:
        item_name = get_item_name(item)

        for test, reason in TestConfig.Simulator.unsupported:
            # We add a '*' before test to eliminate the surprise of needing
            # a '*' before the name of a test function.
            if fnmatch(item_name, '*' + test):
                pytest.xfail(reason)
示例#11
0
def pytest_runtest_setup(item):  # noqa: C901
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    if not hasattr(item, 'obj'):
        return  # Occurs for doctests, possibly other weird tests

    test_uses_sim = 'Simulator' in item.fixturenames
    test_uses_refsim = 'RefSimulator' in item.fixturenames
    tests_frontend = not (test_uses_sim or test_uses_refsim)

    if not tests_frontend:
        item_name = get_item_name(item)

        for test, reason in TestConfig.Simulator.unsupported:
            # We add a '*' before test to eliminate the surprise of needing
            # a '*' before the name of a test function.
            if fnmatch(item_name, '*' + test):
                pytest.xfail(reason)
示例#12
0
def pytest_runtest_setup(item):  # noqa: C901
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'False')
    rc.set('exceptions', 'simplified', 'False')

    if not hasattr(item, 'obj'):
        return

    for mark, option, message in [
        ('example', 'noexamples', "examples not requested"),
        ('slow', 'slow', "slow tests not requested")
    ]:
        if getattr(item.obj, mark, None) and not item.config.getvalue(option):
            pytest.skip(message)

    if getattr(item.obj, 'noassertions', None):
        skipreasons = []
        for fixture_name, option, message in [
            ('analytics', 'analytics', "analytics not requested"),
            ('plt', 'plots', "plots not requested"),
            ('logger', 'logs', "logs not requested")
        ]:
            if fixture_name in item.fixturenames:
                if item.config.getvalue(option):
                    break
                else:
                    skipreasons.append(message)
        else:
            pytest.skip(" and ".join(skipreasons))

    if 'Simulator' in item.fixturenames:
        for test, reason in TestConfig.Simulator.unsupported:
            # We add a '*' before test to eliminate the surprise of needing
            # a '*' before the name of a test function.
            if fnmatch(item.nodeid, '*' + test):
                pytest.xfail(reason)
示例#13
0
def test_validation_error(request):
    # Ensure settings are set back to default after the test, even if it fails
    request.addfinalizer(
        lambda: rc.set("exceptions", "simplified",
                       str(RC_DEFAULTS["exceptions"]["simplified"])))

    nengo.rc["exceptions"]["simplified"] = "False"

    with nengo.Network():
        with pytest.raises(ValidationError) as excinfo:
            nengo.Ensemble(n_neurons=0, dimensions=1)

    assert str(excinfo.value) == (
        "Ensemble.n_neurons: Value must be greater than or equal to 1 (got 0)")
    check_tb_entries(
        excinfo.traceback,
        [
            ("test_validation_error",
             "nengo.Ensemble(n_neurons=0, dimensions=1)"),
            ("__call__", "inst.__init__(*args, **kwargs)"),
            ("__init__", "self.n_neurons = n_neurons"),
            ("__setattr__", "super().__setattr__(name, val)"),
            ("__setattr__", "super().__setattr__(name, val)"),
            ("__set__", "self.data[instance] = self.coerce(instance, value)"),
            ("coerce", "return super().coerce(instance, num)"),
            ("coerce", "raise ValidationError(..."),
        ],
    )

    nengo.rc["exceptions"]["simplified"] = "True"

    with pytest.raises(ValidationError) as excinfo:
        nengo.dists.PDF(x=[1, 1], p=[0.1, 0.2])

    assert str(
        excinfo.value) == "PDF.p: PDF must sum to one (sums to 0.300000)"
    check_tb_entries(
        excinfo.traceback,
        [
            ("test_validation_error",
             "nengo.dists.PDF(x=[1, 1], p=[0.1, 0.2])"),
            ("__init__", "raise ValidationError(..."),
        ],
    )
示例#14
0
    def run(self, params):
        self.decision_type = params[0]
        self.drug_type = params[1]
        self.drug = params[2]
        self.trial = params[3]
        self.seed = params[4]
        self.P = params[5]
        self.dt = self.P['dt']
        self.dt_sample = self.P['dt_sample']
        self.t_cue = self.P['t_cue']
        self.t_delay = self.P['t_delay']
        self.drug_effect_neural = self.P['drug_effect_neural']
        self.drug_effect_functional = self.P['drug_effect_functional']
        self.drug_effect_biophysical = self.P['drug_effect_biophysical']
        self.enc_min_cutoff = self.P['enc_min_cutoff']
        self.enc_max_cutoff = self.P['enc_max_cutoff']
        self.sigma_smoothing = self.P['sigma_smoothing']
        self.frac = self.P['frac']
        self.neurons_inputs = self.P['neurons_inputs']
        self.neurons_wm = self.P['neurons_wm']
        self.neurons_decide = self.P['neurons_decide']
        self.time_scale = self.P['time_scale']
        self.cue_scale = self.P['cue_scale']
        self.tau = self.P['tau']
        self.tau_wm = self.P['tau_wm']
        self.noise_wm = self.P['noise_wm']
        self.noise_decision = self.P['noise_decision']
        self.perceived = self.P['perceived']
        self.cues = self.P['cues']
        if self.drug_type == 'biophysical': 
            rc.set("decoder_cache", "enabled", "False") #don't try to remember old decoders
        else:
            rc.set("decoder_cache", "enabled", "True")

        with nengo.Network(seed = self.net_seed) as model:
            wm = nengo.Ensemble(self.neurons_wm, 2, neuron_type = self.neuron_type,seed = self.ens_seed, label = 'Working Memory')
            self.wm_recurrent = nengo.Connection(wm, wm, synapse = self.tau_wm, function = self.wm_recurrent_function, seed = self.con_seed, 
                                                 solver = nengo.solvers.LstsqL2(weights=True))
        with nengo.Simulator(model,dt = self.dt, seed = self.sim_seed) as sim:
            pass
        weights = self.degrade_synaptic_connection(sim, wm)
        
        with nengo.Network(seed = self.net_seed) as model:
            cue = nengo.Node(output = self.cue_function, label = 'Cue')
            time = nengo.Node(output = self.time_function, label = 'Time')
            inputs = nengo.Ensemble(self.neurons_inputs, 2, seed = self.ens_seed, label = 'Input Neurons')
            noise_wm_node = nengo.Node(output = self.noise_bias_function, label = 'Noise injection (WM node)')
            noise_decision_node = nengo.Node(output = self.noise_decision_function, label = 'Noise injection (decision node)')
            wm = nengo.Ensemble(self.neurons_wm, 2, neuron_type = self.neuron_type,seed = self.ens_seed, label = 'Working Memory')
            cor = nengo.Ensemble(1, 1, neuron_type = nengo.Direct(),seed = self.ens_seed, label = 'Accuracy sensor')
            dbs = nengo.Node(output = self.DBS_function,label = 'Deep Brain Stimulation Node', size_out=1)
            if self.decision_type == 'default':
                decision = nengo.Ensemble(self.neurons_decide, 2, seed = self.ens_seed, label = 'Decision Maker')
                
            elif self.decision_type == 'basal_ganglia':
                utilities = nengo.networks.EnsembleArray(self.neurons_inputs, n_ensembles = 2, seed = self.ens_seed, label = 'Utility network')
                BasalGanglia = nengo.networks.BasalGanglia(2, self.neurons_decide)
                decision = nengo.networks.EnsembleArray(self.neurons_decide, n_ensembles = 2, intercepts = Uniform(0.2, 1), encoders = Uniform(1,1), seed = self.ens_seed, label = 'Decision ensemble (Basal Ganglia')
                temp = nengo.Ensemble(self.neurons_decide, 2,neuron_type = self.neuron_type, seed = self.ens_seed)
                bias = nengo.Node([1] * 2, label = 'bias node')
            output = nengo.Ensemble(self.neurons_decide, 1, neuron_type = self.neuron_type, seed = self.ens_seed, label = 'Output')
        
            
            nengo.Connection(cue, inputs[0], synapse = None, seed = self.con_seed)
            nengo.Connection(time, inputs[1], synapse = None, seed = self.con_seed)
            nengo.Connection(inputs, wm, synapse = self.tau_wm, function=self.inputs_function, seed = self.con_seed)         
            self.wm_recurrent = nengo.Connection(wm.neurons, wm.neurons, synapse = self.tau_wm, seed = self.con_seed, transform = weights)
            nengo.Connection(noise_wm_node, wm.neurons, synapse = self.tau_wm, transform = np.ones((self.neurons_wm,1)) * self.tau_wm, seed = self.con_seed)
            
            if self.DBS:
                nengo.Connection(dbs, wm.neurons, synapse = 0, seed = self.con_seed, transform = np.ones((self.neurons_wm,1)))

            if self.decision_type == 'default':
                wm_to_decision = nengo.Connection(wm[0], decision[0], synapse = self.tau, seed = self.con_seed)
                nengo.Connection(noise_decision_node, decision[1], synapse = None, seed = self.con_seed)
                nengo.Connection(decision, output,function = self.decision_function, seed = self.con_seed)
                nengo.Connection(decision, cor, synapse = 0.025,function = self.f_dec, seed = self.con_seed)

                
            elif self.decision_type == 'basal_ganglia':
                wm_to_decision = nengo.Connection(wm[0], utilities.input, synapse = self.tau, function = self.BG_rescale, seed = self.con_seed)
                nengo.Connection(BasalGanglia.output, decision.input, synapse = self.tau, seed = self.con_seed)
                nengo.Connection(noise_decision_node, BasalGanglia.input,synapse = None, seed = self.con_seed) #added external noise?
                nengo.Connection(bias, decision.input, synapse = self.tau, seed = self.con_seed)
                nengo.Connection(decision.input, decision.output, transform=(np.eye(2)-1), synapse = self.tau/2.0, seed = self.con_seed)
                nengo.Connection(decision.output,temp, seed = self.con_seed)
                nengo.Connection(temp,output,function = self.decision_function, seed = self.con_seed, synapse = None)
                nengo.Connection(temp, cor, synapse = 0.2,seed = self.con_seed, function = self.f_dec)
            
            probe_wm = nengo.Probe(wm[0],synapse = 0.024, sample_every = self.dt_sample)
            probe_spikes = nengo.Probe(wm.neurons, 'spikes', sample_every = self.dt_sample)
            probe_output = nengo.Probe(output,synapse = None, sample_every = self.dt_sample)
            p_cor = nengo.Probe(cor, synapse = None, sample_every = self.dt_sample)
            #data_dir = ch_dir()
            #CollapsingGexfConverter().convert(model).write('model.gexf') 

        print('Running trial %s...\n' %(self.trial+1))
        with nengo.Simulator(model,dt = self.dt, seed = self.sim_seed) as sim:
            if self.drug_type == 'biophysical': 
                sim = reset_gain_bias(self.P, model, sim, wm, self.wm_recurrent, wm_to_decision, self.drug)
            sim.run(self.t_cue + self.t_delay)
            xyz = sim.data[probe_spikes]
            abc = np.abs(sim.data[p_cor])
            print('Constructing Dataframes...')
            df_primary = primary_dataframe(self.P, sim, self.drug,self.trial, probe_wm, probe_output, self.day)
            df_firing = firing_dataframe(self.P,sim,self.drug,self.trial, sim.data[wm], probe_spikes, self.day)
            
        return [df_primary, df_firing, abc, xyz]
示例#15
0
def run(params):
	import nengo
	from nengo.dists import Choice,Exponential,Uniform
	from nengo.rc import rc
	import numpy as np
	import pandas as pd
	from helper import MySolver, reset_gain_bias, primary_dataframe, firing_dataframe, get_correct

	decision_type=params[0]
	drug_type=params[1]
	drug = params[2]
	trial = params[3]
	seed = params[4]
	P = params[5]
	dt=P['dt']
	dt_sample=P['dt_sample']
	t_cue=P['t_cue']
	t_delay=P['t_delay']
	drug_effect_neural=P['drug_effect_neural']
	drug_effect_functional=P['drug_effect_functional']
	drug_effect_biophysical=P['drug_effect_biophysical']
	enc_min_cutoff=P['enc_min_cutoff']
	enc_max_cutoff=P['enc_max_cutoff']
	sigma_smoothing=P['sigma_smoothing']
	frac=P['frac']
	neurons_inputs=P['neurons_inputs']
	neurons_wm=P['neurons_wm']
	neurons_decide=P['neurons_decide']
	time_scale=P['time_scale']
	cue_scale=P['cue_scale']
	tau=P['tau']
	tau_wm=P['tau_wm']
	noise_wm=P['noise_wm']
	noise_decision=P['noise_decision']
	perceived=P['perceived']
	cues=P['cues']

	if drug_type == 'biophysical': rc.set("decoder_cache", "enabled", "False") #don't try to remember old decoders
	else: rc.set("decoder_cache", "enabled", "True")

	def cue_function(t):
		if t < t_cue and perceived[trial]!=0:
			return cue_scale * cues[trial]
		else: return 0

	def time_function(t):
		if t > t_cue:
			return time_scale
		else: return 0

	def noise_bias_function(t):
		import numpy as np
		if drug_type=='neural':
			return np.random.normal(drug_effect_neural[drug],noise_wm)
		else:
			return np.random.normal(0.0,noise_wm)

	def noise_decision_function(t):
		import numpy as np
		if decision_type == 'default':
			return np.random.normal(0.0,noise_decision)
		elif decision_type == 'basal_ganglia':
			return np.random.normal(0.0,noise_decision,size=2)

	def inputs_function(x):
		return x * tau_wm

	def wm_recurrent_function(x):
		if drug_type == 'functional':
			return x * drug_effect_functional[drug]
		else:
			return x

	def decision_function(x):
		output=0.0
		if decision_type=='default':
			value=x[0]+x[1]
			if value > 0.0: output = 1.0
			elif value < 0.0: output = -1.0
		elif decision_type=='basal_ganglia':
			if x[0] > x[1]: output = 1.0
			elif x[0] < x[1]: output = -1.0
		return output 

	def BG_rescale(x): #rescales -1 to 1 into 0.3 to 1, makes 2-dimensional
		pos_x = 0.5 * (x + 1)
		rescaled = 0.4 + 0.6 * pos_x, 0.4 + 0.6 * (1 - pos_x)
		return rescaled

	'''model definition'''
	with nengo.Network(seed=seed+trial) as model:

		#Ensembles
		cue = nengo.Node(output=cue_function)
		time = nengo.Node(output=time_function)
		inputs = nengo.Ensemble(neurons_inputs,2)
		noise_wm_node = nengo.Node(output=noise_bias_function)
		noise_decision_node = nengo.Node(output=noise_decision_function)
		wm = nengo.Ensemble(neurons_wm,2)
		if decision_type=='default':
			decision = nengo.Ensemble(neurons_decide,2)
		elif decision_type=='basal_ganglia':
			utilities = nengo.networks.EnsembleArray(neurons_inputs,n_ensembles=2)
			BG = nengo.networks.BasalGanglia(2,neurons_decide)
			decision = nengo.networks.EnsembleArray(neurons_decide,n_ensembles=2,
						intercepts=Uniform(0.2,1),encoders=Uniform(1,1))
			temp = nengo.Ensemble(neurons_decide,2)
			bias = nengo.Node([1]*2)
		output = nengo.Ensemble(neurons_decide,1)

		#Connections
		nengo.Connection(cue,inputs[0],synapse=None)
		nengo.Connection(time,inputs[1],synapse=None)
		nengo.Connection(inputs,wm,synapse=tau_wm,function=inputs_function)
		wm_recurrent=nengo.Connection(wm,wm,synapse=tau_wm,function=wm_recurrent_function)
		nengo.Connection(noise_wm_node,wm.neurons,synapse=tau_wm,transform=np.ones((neurons_wm,1))*tau_wm)
		if decision_type=='default':	
			wm_to_decision=nengo.Connection(wm[0],decision[0],synapse=tau)
			nengo.Connection(noise_decision_node,decision[1],synapse=None)
			nengo.Connection(decision,output,function=decision_function)
		elif decision_type=='basal_ganglia':
			wm_to_decision=nengo.Connection(wm[0],utilities.input,synapse=tau,function=BG_rescale)
			nengo.Connection(utilities.output,BG.input,synapse=None)
			nengo.Connection(BG.output,decision.input,synapse=tau)
			nengo.Connection(noise_decision_node,BG.input,synapse=None) #added external noise?
			nengo.Connection(bias,decision.input,synapse=tau)
			nengo.Connection(decision.input,decision.output,transform=(np.eye(2)-1),synapse=tau/2.0)
			nengo.Connection(decision.output,temp)
			nengo.Connection(temp,output,function=decision_function)

		#Probes
		probe_wm=nengo.Probe(wm[0],synapse=0.01,sample_every=dt_sample)
		probe_spikes=nengo.Probe(wm.neurons, 'spikes', sample_every=dt_sample)
		probe_output=nengo.Probe(output,synapse=None,sample_every=dt_sample)




	'''SIMULATION'''
	print 'Running drug \"%s\", trial %s...' %(drug,trial+1)
	with nengo.Simulator(model,dt=dt) as sim:
		if drug_type == 'biophysical': sim=reset_gain_bias(
				P,model,sim,wm,wm_recurrent,wm_to_decision,drug)
		sim.run(t_cue+t_delay)
		df_primary=primary_dataframe(P,sim,drug,trial,probe_wm,probe_output)
		df_firing=firing_dataframe(P,sim,drug,trial,sim.data[wm],probe_spikes)
	return [df_primary, df_firing]
示例#16
0
class TestCacheBenchmark(object):
    n_trials = 25

    setup = '''
import numpy as np
import nengo
import nengo.cache
from nengo.rc import rc

model = nengo.Network(seed=1)
with model:
    a = nengo.Ensemble({N}, dimensions={D}, n_eval_points={M})
    b = nengo.Ensemble({N}, dimensions={D}, n_eval_points={M})
    conn = nengo.Connection(a, b)
    '''

    without_cache = {
        'rc': 'rc.set("decoder_cache", "enabled", "False")',
        'stmt': 'sim = nengo.Simulator(model)'
    }

    with_cache_miss_ro = {
        'rc': '''
rc.set("decoder_cache", "enabled", "True")
rc.set("decoder_cache", "readonly", "True")
''',
        'stmt': '''
nengo.cache.DecoderCache().invalidate()
sim = nengo.Simulator(model)
'''
    }

    with_cache_miss = {
        'rc': '''
rc.set("decoder_cache", "enabled", "True")
rc.set("decoder_cache", "readonly", "False")
''',
        'stmt': '''
nengo.cache.DecoderCache().invalidate()
sim = nengo.Simulator(model)
'''
    }

    with_cache_hit = {
        'rc': '''
rc.set("decoder_cache", "enabled", "True")
rc.set("decoder_cache", "readonly", "False")
sim = nengo.Simulator(model)
''',
        'stmt': 'sim = nengo.Simulator(model)'
    }

    labels = ["no cache", "cache miss", "cache miss ro", "cache hit"]
    keys = [l.replace(' ', '_') for l in labels]
    param_to_axis_label = {
        'D': "dimensions",
        'N': "neurons",
        'M': "evaluation points"
    }
    defaults = {'D': 1, 'N': 50, 'M': 1000}

    def time_code(self, code, args):
        return timeit.repeat(
            stmt=code['stmt'], setup=self.setup.format(**args) + code['rc'],
            number=1, repeat=self.n_trials)

    def time_all(self, args):
        return (self.time_code(self.without_cache, args),
                self.time_code(self.with_cache_miss, args),
                self.time_code(self.with_cache_miss_ro, args),
                self.time_code(self.with_cache_hit, args))

    def get_args(self, varying_param, value):
        args = dict(self.defaults)  # make a copy
        args[varying_param] = value
        return args

    @pytest.mark.slow
    @pytest.mark.noassertions
    @pytest.mark.parametrize('varying_param', ['D', 'N', 'M'])
    def test_cache_benchmark(self, varying_param, analytics, plt):
        varying = {
            'D': np.asarray(np.linspace(1, 512, 10), dtype=int),
            'N': np.asarray(np.linspace(10, 500, 8), dtype=int),
            'M': np.asarray(np.linspace(750, 2500, 8), dtype=int)
        }[varying_param]
        axis_label = self.param_to_axis_label[varying_param]

        times = [self.time_all(self.get_args(varying_param, v))
                 for v in varying]

        # Restore RC to original settings
        default = RC_DEFAULTS['decoder_cache', 'enabled']
        rc.set("decoder_cache", "enabled", str(default))
        default = RC_DEFAULTS['decoder_cache', 'readonly']
        rc.set("decoder_cache", "readonly", str(default))

        for i, data in enumerate(zip(*times)):
            plt.plot(varying, np.median(data, axis=1), label=self.labels[i])
            analytics.add_data(varying_param, varying, axis_label)
            analytics.add_data(self.keys[i], data)

        plt.xlabel("Number of %s" % axis_label)
        plt.ylabel("Build time (s)")
        plt.legend(loc='best')

    @staticmethod
    def reject_outliers(data):
        med = np.median(data)
        limits = 1.5 * (np.percentile(data, [25, 75]) - med) + med
        return data[np.logical_and(data > limits[0], data < limits[1])]

    @pytest.mark.compare
    @pytest.mark.parametrize('varying_param', ['D', 'N', 'M'])
    def test_compare_cache_benchmark(self, varying_param, analytics_data, plt):
        stats = pytest.importorskip('scipy.stats')

        d1, d2 = analytics_data
        assert np.all(d1[varying_param] == d2[varying_param]), (
            'Cannot compare different parametrizations')
        axis_label = self.param_to_axis_label[varying_param]

        print("Cache, varying {0}:".format(axis_label))
        for label, key in zip(self.labels, self.keys):
            clean_d1 = [self.reject_outliers(d) for d in d1[key]]
            clean_d2 = [self.reject_outliers(d) for d in d2[key]]
            diff = [np.median(b) - np.median(a)
                    for a, b in zip(clean_d1, clean_d2)]

            p_values = np.array([2. * stats.mannwhitneyu(a, b)[1]
                                 for a, b in zip(clean_d1, clean_d2)])
            overall_p = 1. - np.prod(1. - p_values)
            if overall_p < .05:
                print("  {label}: Significant change (p <= {p:.3f}). See plots"
                      " for details.".format(
                          label=label, p=np.ceil(overall_p * 1000.) / 1000.))
            else:
                print("  {label}: No significant change.".format(label=label))

            plt.plot(d1[varying_param], diff, label=label)

        plt.xlabel("Number of %s" % axis_label)
        plt.ylabel("Difference in build time (s)")
        plt.legend(loc='best')
示例#17
0
def load_ipython_extension(ipython):
    if has_ipynb_widgets() and rc.get('progress', 'progress_bar') == 'auto':
        IPythonProgressWidget.load_frontend(ipython)
        rc.set('progress', 'progress_bar', '.'.join(
            (__name__, IPython2ProgressBar.__name__)))
示例#18
0
def load_ipython_extension(ipython):
    if has_ipynb_widgets() and rc.get('progress', 'progress_bar') == 'auto':
        IPythonProgressWidget.load_frontend(ipython)
        rc.set('progress', 'progress_bar', '.'.join((
            __name__, IPython2ProgressBar.__name__)))
示例#19
0
def pytest_configure(config):
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'false')
示例#20
0
文件: conftest.py 项目: epaxon/nengo
def pytest_configure(config):
    rc.reload_rc([])
    rc.set('decoder_cache', 'enabled', 'false')
示例#21
0
文件: conftest.py 项目: CamZHU/nengo
def pytest_configure(config):
    rc.reload_rc([])
    rc.set("decoder_cache", "enabled", "false")
示例#22
0
def pytest_runtest_setup(item):
    rc.reload_rc([])
    rc.set("decoder_cache", "enabled", "False")
    rc.set("exceptions", "simplified", "False")
    rc.set("nengo.Simulator", "fail_fast", "True")