def test_mark_signals(): with nengo.Network() as net: ens0 = nengo.Ensemble(10, 1, neuron_type=nengo.LIF()) ens1 = nengo.Ensemble(20, 1, neuron_type=nengo.Direct()) ens2 = nengo.Ensemble(30, 1) conn0 = nengo.Connection(ens0, ens1) conn1 = nengo.Connection(ens0, ens1, learning_rule_type=nengo.PES()) conn2 = nengo.Connection(ens0, ens2, learning_rule_type=nengo.Voja()) nengo.Probe(ens2) model = nengo.builder.Model() model.build(net) tg = tensor_graph.TensorGraph(model, None, None, 1, None, utils.NullProgressBar(), None) tg.mark_signals() assert model.sig[ens0]["encoders"].trainable assert model.sig[ens1]["encoders"].trainable assert not model.sig[ens2]["encoders"].trainable assert model.sig[ens0.neurons]["bias"].trainable assert model.sig[ens2.neurons]["bias"].trainable assert model.sig[conn0]["weights"].trainable assert not model.sig[conn1]["weights"].trainable assert model.sig[conn2]["weights"].trainable trainables = ( model.sig[ens0]["encoders"], model.sig[ens1]["encoders"], model.sig[ens0.neurons]["bias"], model.sig[ens2.neurons]["bias"], model.sig[conn0]["weights"], model.sig[conn2]["weights"], ) for op in model.operators: for sig in op.all_signals: if sig in trainables: assert sig.trainable else: assert not sig.trainable
def test_planner_config(config_planner): with nengo.Network() as net: if config_planner is not None: net.config.configures(nengo.Network) if config_planner: net.config[nengo.Network].set_param( "planner", nengo.params.Parameter( "planner", graph_optimizer.noop_planner)) model = nengo.builder.Model() model.build(net) sig = nengo.builder.signal.Signal([1]) sig2 = nengo.builder.signal.Signal([1]) sig3 = nengo.builder.signal.Signal([1]) model.add_op(nengo.builder.operator.DotInc(sig, sig2, sig3)) model.add_op(nengo.builder.operator.DotInc(sig, sig2, sig3)) tg = tensor_graph.TensorGraph(model, None, None, tf.float32, 1, None, utils.NullProgressBar()) assert len(tg.plan) == (2 if config_planner else 1)
def test_mark_signals_config(): with nengo.Network() as net: config.configure_settings(trainable=None) net.config[nengo.Ensemble].trainable = False with nengo.Network(): # check that object in subnetwork inherits config from parent ens0 = nengo.Ensemble(10, 1, label="ens0") # check that ens.neurons can be set independent of ens net.config[ens0.neurons].trainable = True with nengo.Network(): with nengo.Network(): # check that subnetworks can override parent configs config.configure_settings(trainable=True) ens1 = nengo.Ensemble(10, 1, label="ens1") with nengo.Network(): # check that subnetworks inherit the trainable settings # from parent networks ens3 = nengo.Ensemble(10, 1, label="ens3") # check that instances can be set independent of class ens2 = nengo.Ensemble(10, 1, label="ens2") net.config[ens2].trainable = True model = nengo.builder.Model() model.build(net) progress = utils.NullProgressBar() tg = tensor_graph.TensorGraph(model, None, None, 1, None, progress, None) tg.mark_signals() assert not model.sig[ens0]["encoders"].trainable assert model.sig[ens0.neurons]["bias"].trainable assert model.sig[ens1]["encoders"].trainable assert model.sig[ens2]["encoders"].trainable assert model.sig[ens3]["encoders"].trainable # check that learning rule connections can be manually set to True with nengo.Network() as net: config.configure_settings(trainable=None) a = nengo.Ensemble(10, 1) b = nengo.Ensemble(10, 1) conn0 = nengo.Connection(a, b, learning_rule_type=nengo.PES()) net.config[conn0].trainable = True model = nengo.builder.Model() model.build(net) tg = tensor_graph.TensorGraph(model, None, None, 1, None, progress, None) with pytest.warns(UserWarning): tg.mark_signals() assert model.sig[conn0]["weights"].trainable with nengo.Network() as net: config.configure_settings(trainable=None) a = nengo.Node([0]) ens = nengo.Ensemble(10, 1) nengo.Connection(a, ens, learning_rule_type=nengo.Voja()) net.config[nengo.Ensemble].trainable = True model = nengo.builder.Model() model.build(net) tg = tensor_graph.TensorGraph(model, None, None, 1, None, progress, None) with pytest.warns(UserWarning): tg.mark_signals() assert model.sig[ens]["encoders"].trainable # check that models with no toplevel work sig = nengo.builder.signal.Signal([0]) op = nengo.builder.operator.Reset(sig, 1) model = nengo.builder.Model() model.add_op(op) tg = tensor_graph.TensorGraph(model, None, None, 1, None, progress, None) with pytest.warns(UserWarning): tg.mark_signals() assert not sig.trainable