Exemplo n.º 1
0
def test_create_dot():
    """Constructing a .dot file for a model"""

    model = nengo.Network()
    with model:
        D = 3
        input = nengo.Node([1] * D, label='input')
        a = nengo.networks.EnsembleArray(50, D, label='a')
        b = nengo.networks.EnsembleArray(50, D, label='b')
        output = nengo.Node(None, size_in=D, label='output')

        nengo.Connection(input, a.input, synapse=0.01)
        nengo.Connection(a.output, b.input, synapse=0.01)
        nengo.Connection(b.output, b.input, synapse=0.01, transform=0.9)
        nengo.Connection(a.output,
                         a.input,
                         synapse=0.01,
                         transform=np.ones((D, D)))
        nengo.Connection(b.output, output, synapse=0.01)

    dot = generate_graphviz(*objs_and_connections(model))
    assert len(dot.splitlines()) == 31
    # not sure what else to check here

    dot = generate_graphviz(*remove_passthrough_nodes(
        *objs_and_connections(model)))
    assert len(dot.splitlines()) == 27
Exemplo n.º 2
0
def test_create_dot():
    """Constructing a .dot file for a model"""

    model = nengo.Model()
    with model:
        D = 3
        input = nengo.Node([1]*D, label='input')
        a = nengo.networks.EnsembleArray(50, D, label='a')
        b = nengo.networks.EnsembleArray(50, D, label='b')
        output = nengo.Node(None, size_in=D, label='output')

        nengo.Connection(input, a.input, filter=0.01)
        nengo.Connection(a.output, b.input, filter=0.01)
        nengo.Connection(b.output, b.input, filter=0.01, transform=0.9)
        nengo.Connection(a.output, a.input, filter=0.01,
                         transform=np.ones((D, D)))
        nengo.Connection(b.output, output, filter=0.01)

    dot = generate_graphviz(model.objs.values(), model.connections.values())
    assert len(dot.splitlines()) == 31
    # not sure what else to check here

    objs, conns = remove_passthrough_nodes(model.objs.values(),
                                           model.connections.values())

    dot = generate_graphviz(objs, conns)
    assert len(dot.splitlines()) == 27
def test_create_host_network_nested():
    model = nengo.Network()
    with model:
        m2 = nengo.Network()
        with m2:
            pn0 = nengo.Node(None, size_in=1, label='PassNode')
            a = nengo.Ensemble(1, 1)
            n1 = nengo.Node(lambda t, v: v, size_in=1, label='n1')

            nengo.Connection(pn0, a, synapse=None)
            nengo.Connection(a, n1)

        n2 = nengo.Node(np.sin, label='input')
        n3 = nengo.Node(lambda t, v: v, size_in=1, size_out=1, label='output')

        nengo.Connection(n2, pn0)
        nn = nengo.Connection(n1, n3)

    mock_io = mock.Mock()
    (objs, conns) = remove_passthrough_nodes(*objs_and_connections(model))
    host_network = nodes.create_host_network(
        [n for n in objs if isinstance(n, nengo.Node)], conns, mock_io)

    # Should be 5 nodes
    # n1, n2, n3, Input for n1, Output for n2
    assert(len(host_network.nodes) == 5)
    assert(pn0 not in host_network.nodes)
    assert(n1 in host_network.nodes)
    assert(n2 in host_network.nodes)
    assert(n3 in host_network.nodes)

    assert(len(host_network.connections) == 3)
    assert(nn in host_network.connections)
Exemplo n.º 4
0
def test_remove_passthrough():
    """Test scanning through a model and removing Nodes with output=None"""

    model = nengo.Network()
    with model:
        D = 3
        input = nengo.Node([1]*D, label='input')
        a = nengo.networks.EnsembleArray(50, D, label='a')
        b = nengo.networks.EnsembleArray(50, D, label='b')

        def printout(t, x):
            print(t, x)
        output = nengo.Node(printout, size_in=D, label='output')

        nengo.Connection(input, a.input, synapse=0.01)
        nengo.Connection(a.output, b.input, synapse=0.01)
        nengo.Connection(b.output, b.input, synapse=0.01, transform=0.9)
        nengo.Connection(a.output, a.input, synapse=0.01,
                         transform=np.ones((D, D)))
        nengo.Connection(b.output, output, synapse=0.01)

    objs, conns = remove_passthrough_nodes(*objs_and_connections(model))

    assert len(objs) == 8
    assert len(conns) == 21
Exemplo n.º 5
0
def test_remove_passthrough():
    """Test scanning through a model and removing Nodes with output=None"""

    model = nengo.Network()
    with model:
        D = 3
        input = nengo.Node([1] * D, label="input")
        a = nengo.networks.EnsembleArray(50, D, label="a")
        b = nengo.networks.EnsembleArray(50, D, label="b")

        def printout(t, x):
            logging.info("%s, %s", t, x)

        output = nengo.Node(printout, size_in=D, label="output")

        nengo.Connection(input, a.input, synapse=0.01)
        nengo.Connection(a.output, b.input, synapse=0.01)
        nengo.Connection(b.output, b.input, synapse=0.01, transform=0.9)
        nengo.Connection(a.output, a.input, synapse=0.01, transform=np.ones((D, D)))
        nengo.Connection(b.output, output, synapse=0.01)

    objs, conns = remove_passthrough_nodes(*objs_and_connections(model))

    assert len(objs) == 8
    assert len(conns) == 21
Exemplo n.º 6
0
def test_passthrough_errors():
    """Test errors removing Nodes with output=None"""

    model = nengo.Network()
    with model:
        a = nengo.Ensemble(10, 1)
        b = nengo.Ensemble(10, 1)
        node = nengo.Node(None, size_in=1)
        nengo.Connection(a, node, synapse=0.01)
        nengo.Connection(node, b, synapse=0.01)
    with pytest.raises(NotImplementedError):
        remove_passthrough_nodes(*objs_and_connections(model))

    model = nengo.Network()
    with model:
        node = nengo.Node(None, size_in=1)
        nengo.Connection(node, node, synapse=0.01)
    with pytest.raises(Exception):
        remove_passthrough_nodes(*objs_and_connections(model))
Exemplo n.º 7
0
def test_passthrough_errors():
    """Test errors removing Nodes with output=None"""

    model = nengo.Network()
    with model:
        a = nengo.Ensemble(10, 1)
        b = nengo.Ensemble(10, 1)
        node = nengo.Node(None, size_in=1)
        nengo.Connection(a, node, synapse=0.01)
        nengo.Connection(node, b, synapse=0.01)
    with pytest.raises(Unconvertible):
        remove_passthrough_nodes(*objs_and_connections(model))

    model = nengo.Network()
    with model:
        node = nengo.Node(None, size_in=1)
        nengo.Connection(node, node, synapse=0.01)
    with pytest.raises(Unconvertible):
        remove_passthrough_nodes(*objs_and_connections(model))
Exemplo n.º 8
0
def test_passthrough_errors():
    """Test errors removing Nodes with output=None"""

    model = nengo.Model()
    with model:
        a = nengo.Ensemble(10, 1)
        b = nengo.Ensemble(10, 1)
        node = nengo.Node(None, size_in=1)
        nengo.Connection(a, node, filter=0.01)
        nengo.Connection(node, b, filter=0.01)
    with pytest.raises(NotImplementedError):
        remove_passthrough_nodes(model.objs, model.connections)

    model = nengo.Model()
    with model:
        node = nengo.Node(None, size_in=1)
        nengo.Connection(node, node, filter=0.01)
    with pytest.raises(Exception):
        remove_passthrough_nodes(model.objs, model.connections)
Exemplo n.º 9
0
def test_obj_conn_diagram():
    """Constructing a .dot file for a set of objects and connections."""
    model = make_net()
    objs = model.all_nodes + model.all_ensembles
    conns = model.all_connections

    dot = obj_conn_diagram(objs, conns)
    assert len(dot.splitlines()) == 31

    dot = obj_conn_diagram(*remove_passthrough_nodes(objs, conns))
    assert len(dot.splitlines()) == 27
Exemplo n.º 10
0
def test_obj_conn_diagram():
    """Constructing a .dot file for a set of objects and connections."""
    model = make_net()
    objs = model.all_nodes + model.all_ensembles
    conns = model.all_connections

    dot = obj_conn_diagram(objs, conns)
    assert len(dot.splitlines()) == 31

    dot = obj_conn_diagram(*remove_passthrough_nodes(objs, conns))
    assert len(dot.splitlines()) == 27
Exemplo n.º 11
0
def test_create_host_network():
    """Test creating a network to simulate on the host.  All I/O connections
    will have been replaced with new Nodes which handle communication with the
    IO system.
    """
    class TestNode(nengo.Node):
        def spinnaker_build(self, builder):
            pass

    model = nengo.Network()
    with model:
        a = nengo.Ensemble(1, 1, label="A")
        b = nengo.Node(lambda t, v: v, size_in=1, size_out=1, label="B")
        c = nengo.Node(lambda t, v: v**2, size_in=1, size_out=1, label="C")
        d = nengo.Ensemble(1, 1, label="D")
        n = TestNode(output=lambda t, v: v, size_in=1, size_out=1, label="N")
        o = nengo.Node(lambda t, v: v, size_in=1, size_out=1, label="Orphan")
        e = nengo.Ensemble(1, 1, label="E")

        a_b = nengo.Connection(a, b)
        b_c = nengo.Connection(b, c)
        c_d = nengo.Connection(c, d)

        a_n = nengo.Connection(a, n)
        b_n = nengo.Connection(b, n)
        n_d = nengo.Connection(n, d)

        d_e = nengo.Connection(d, e)

    mock_io = mock.Mock()
    (objs, conns) = remove_passthrough_nodes(*objs_and_connections(model))
    host_network = nodes.create_host_network(
        [n for n in objs if isinstance(n, nengo.Node)], conns, mock_io)

    assert(len(host_network.ensembles) == 0)
    assert(len(host_network.nodes) == 5)  # b, c, (a->b), (c->d), (b->n)
    assert(len(host_network.connections) == 4) # (a)->b, b->c, c->(d), b->(n)

    assert(b in host_network.nodes)
    assert(c in host_network.nodes)

    assert(a_n not in host_network.connections)
    assert(b_n not in host_network.connections)
    assert(n_d not in host_network.connections)
    assert(d_e not in host_network.connections)
    assert(n not in host_network.nodes)
    assert(o not in host_network.nodes)

    for c_ in host_network.connections:
        if c_.post == b: assert(c_.pre.output.node == b)
        if c_.pre == b: assert(c_.post == c or c_.post.output.node == b)
        if c_.pre == c: assert(c_.post.output.node == c)
Exemplo n.º 12
0
def get_weight_matrices_requirements(network):
    # The weight matrix requirements for a core are the N_a x N_b_c matrix for
    # each preceding connection.
    (objects, connections) = remove_passthrough_nodes(
        *objs_and_connections(network))

    # For each Ensemble each incoming connection matrix is N_pre x N
    mem_usage = 0
    for ens in [o for o in objects if isinstance(o, nengo.Ensemble)]:
        # Get all incoming connections and hence all unique Ensemble sources
        sources = set([c.pre for c in connections if c.post is ens and
                       isinstance(c.pre, nengo.Ensemble)])
        pre_neurons = sum(s.n_neurons for s in sources)
        mem_usage += pre_neurons * ens.n_neurons

    return mem_usage * BYTES_PER_SYNAPSE
Exemplo n.º 13
0
def test_remove_passthrough_bg():
    """Test scanning through a model and removing Nodes with output=None"""

    model = nengo.Network()
    with model:
        D = 3
        input = nengo.Node([1]*D, label='input')

        def printout(t, x):
            print(t, x)
        output = nengo.Node(printout, size_in=D, label='output')
        bg = nengo.networks.BasalGanglia(D, 20)
        nengo.Connection(input, bg.input, synapse=0.01)
        nengo.Connection(bg.output, output, synapse=0.01)

    objs, conns = remove_passthrough_nodes(*objs_and_connections(model))

    assert len(objs) == 17
    assert len(conns) == 42
def test_remove_passthrough_bg():
    """Test scanning through a model and removing Nodes with output=None"""

    model = nengo.Network()
    with model:
        D = 3
        input = nengo.Node([1]*D, label='input')

        def printout(t, x):
            print(t, x)
        output = nengo.Node(printout, size_in=D, label='output')
        bg = nengo.networks.BasalGanglia(D, 20, label='BG')
        nengo.Connection(input, bg.input, synapse=0.01)
        nengo.Connection(bg.output, output, synapse=0.01)

    objs, conns = remove_passthrough_nodes(*objs_and_connections(model))

    assert len(objs) == 17
    assert len(conns) == 42
Exemplo n.º 15
0
def get_factored_weight_matrices_requirements(network):
    (objects, connections) = remove_passthrough_nodes(
        *objs_and_connections(network))

    # For each Ensemble each incoming connection matrix is N_pre x N big
    mem_usage = 0
    for ens in [o for o in objects if isinstance(o, nengo.Ensemble)]:
        out_conns = [c for c in connections if c.pre is ens and
                     isinstance(c.post, nengo.Ensemble)]

        # Outgoing cost is (n_neurons + 1) x out_d where out_d is the number of
        # non-zero rows in the transform matrix
        out_transforms = [full_transform(c, allow_scalars=False) for c in
                          out_conns]
        out_dims = sum(np.sum(np.any(np.abs(t) > 0., axis=1)) for t in
                       out_transforms)
        mem_usage += (ens.n_neurons + 1) * out_dims

        # Incoming cost is just n_neurons x d
        mem_usage += ens.n_neurons * ens.dimensions

    return mem_usage * BYTES_PER_ENC_DECODER  # (4 bytes per value)