Esempio n. 1
0
def _create_replacement_connection(c_in, c_out):
    """Generate a new Connection to replace two through a passthrough Node"""
    assert c_in.post_obj is c_out.pre_obj
    assert c_in.post_obj.output is None

    # determine the filter for the new Connection
    if c_in.synapse is None:
        synapse = c_out.synapse
    elif c_out.synapse is None:
        synapse = c_in.synapse
    else:
        raise NotImplementedError("Cannot merge two filters")
        # Note: the algorithm below is in the right ballpark,
        #  but isn't exactly the same as two low-pass filters
        # filter = c_out.filter + c_in.filter

    function = c_in.function
    if c_out.function is not None:
        raise Exception("Cannot remove a Node with a " "function being computed on it")

    # compute the combined transform
    transform = np.dot(full_transform(c_out), full_transform(c_in))

    # check if the transform is 0 (this happens a lot
    #  with things like identity transforms)
    if np.all(transform == 0):
        return None

    c = nengo.Connection(
        c_in.pre_obj, c_out.post_obj, synapse=synapse, transform=transform, function=function, add_to_container=False
    )
    return c
Esempio n. 2
0
def get_ensemble_sink(model, connection):
    """Get the sink for connections into an Ensemble."""
    ens = model.object_operators[connection.post_obj]

    if (isinstance(connection.pre_obj, nengo.Node)
            and not callable(connection.pre_obj.output)
            and not isinstance(connection.pre_obj.output, Process)
            and connection.pre_obj.output is not None):
        # Connections from constant valued Nodes are optimised out.
        # Build the value that will be added to the direct input for the
        # ensemble.
        val = connection.pre_obj.output[connection.pre_slice]

        if connection.function is not None:
            val = connection.function(val)

        transform = full_transform(connection, slice_pre=False)
        ens.direct_input += np.dot(transform, val)
    else:
        # If this connection has a learning rule
        if connection.learning_rule is not None:
            # If the rule modifies encoders, sink it into learnt input port
            modifies = connection.learning_rule.learning_rule_type.modifies
            if modifies == "encoders":
                return spec(ObjectPort(ens, EnsembleInputPort.learnt))

        # Otherwise we just sink into the Ensemble
        return spec(ObjectPort(ens, InputPort.standard))
def build_from_ensemble_connection(model, conn):
    """Build the parameters object for a connection from an Ensemble."""
    if conn.solver.weights:
        raise NotImplementedError(
            "SpiNNaker does not currently support neuron to neuron connections"
        )

    # Create a random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get the transform
    transform = full_transform(conn, slice_pre=False, allow_scalars=False)

    # Solve for the decoders
    eval_points, decoders, solver_info = connection_b.build_decoders(
        model, conn, rng)

    # Store the parameters in the model
    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)

    # Modify the transform if this is a global inhibition connection
    if (isinstance(conn.post_obj, nengo.ensemble.Neurons)
            and np.all(transform[0, :] == transform[1:, :])):
        transform = np.array([transform[0]])

    return EnsembleTransmissionParameters(decoders, transform)
def build_from_ensemble_connection(model, conn):
    """Build the parameters object for a connection from an Ensemble."""
    if conn.solver.weights:
        raise NotImplementedError(
            "SpiNNaker does not currently support neuron to neuron connections"
        )

    # Create a random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get the transform
    transform = full_transform(conn, slice_pre=False, allow_scalars=False)

    # Solve for the decoders
    eval_points, decoders, solver_info = connection_b.build_decoders(
        model, conn, rng
    )

    # Store the parameters in the model
    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)

    # Modify the transform if this is a global inhibition connection
    if (isinstance(conn.post_obj, nengo.ensemble.Neurons) and
            np.all(transform[0, :] == transform[1:, :])):
        transform = np.array([transform[0]])

    return EnsembleTransmissionParameters(decoders, transform)
Esempio n. 5
0
def get_ensemble_sink(model, connection):
    """Get the sink for connections into an Ensemble."""
    ens = model.object_operators[connection.post_obj]

    if (isinstance(connection.pre_obj, nengo.Node) and
            not callable(connection.pre_obj.output) and
            not isinstance(connection.pre_obj.output, Process) and
            connection.pre_obj.output is not None):
        # Connections from constant valued Nodes are optimised out.
        # Build the value that will be added to the direct input for the
        # ensemble.
        val = connection.pre_obj.output[connection.pre_slice]

        if connection.function is not None:
            val = connection.function(val)

        transform = full_transform(connection, slice_pre=False)
        ens.direct_input += np.dot(transform, val)
    else:
        # If this connection has a learning rule
        if connection.learning_rule is not None:
            # If the rule modifies encoders, sink it into learnt input port
            modifies = connection.learning_rule.learning_rule_type.modifies
            if modifies == "encoders":
                return spec(ObjectPort(ens, EnsembleInputPort.learnt))

        # Otherwise we just sink into the Ensemble
        return spec(ObjectPort(ens, InputPort.standard))
Esempio n. 6
0
def build_from_ensemble_connection(model, conn):
    """Build the parameters object for a connection from an Ensemble."""
    # Create a random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get the transform
    transform = full_transform(conn, slice_pre=False)

    # Use Nengo upstream to build parameters for the solver
    eval_points, activities, targets = connection_b.build_linear_system(
        model, conn, rng
    )

    # Use cached solver
    solver = model.decoder_cache.wrap_solver(conn.solver)
    if conn.solver.weights:
        raise NotImplementedError(
            "SpiNNaker does not currently support neuron to neuron connections"
        )
    else:
        decoders, solver_info = solver(activities, targets, rng=rng)

    # Return the parameters
    return BuiltConnection(
        decoders=decoders,
        eval_points=eval_points,
        transform=transform,
        solver_info=solver_info
    )
Esempio n. 7
0
def build_from_ensemble_connection(model, conn):
    """Build the parameters object for a connection from an Ensemble."""
    if conn.solver.weights:
        raise NotImplementedError(
            "SpiNNaker does not currently support neuron to neuron connections"
        )

    # Create a random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get the transform
    transform = full_transform(conn, slice_pre=False)

    # Solve for the decoders
    eval_points, decoders, solver_info = connection_b.build_decoders(
        model, conn, rng
    )

    # Return the parameters
    return BuiltConnection(
        decoders=decoders,
        eval_points=eval_points,
        transform=transform,
        solver_info=solver_info
    )
Esempio n. 8
0
    def mpi_build_network(model, network):
        """ Build a nengo Network for nengo_mpi.

        This function replaces nengo.builder.build_network.

        For each connection that emenates from a Node, has a non-None
        pre-slice, AND has no function attached to it, we replace it
        with a Connection that is functionally equivalent, but has
        the slicing moved into the transform. This is done because
        in some such cases, the refimpl nengo builder will implement the
        slicing using a python function, which we want to avoid in nengo_mpi.

        Also records which Connections have probes attached to them.

        Parameters
        ----------
        model: MpiModel
            The model to which created components will be added.
        network: nengo.Network
            The network to be built.

        """
        remove_conns = []

        for conn in network.connections:
            replace_connection = (
                isinstance(conn.pre_obj, Node) and conn.pre_slice != slice(None) and conn.function is None
            )

            if replace_connection:
                transform = full_transform(conn)

                with network:
                    Connection(
                        conn.pre_obj,
                        conn.post_obj,
                        synapse=conn.synapse,
                        transform=transform,
                        solver=conn.solver,
                        learning_rule_type=conn.learning_rule_type,
                        eval_points=conn.eval_points,
                        scale_eval_points=conn.scale_eval_points,
                        seed=conn.seed,
                    )

                remove_conns.append(conn)

        if remove_conns:
            network.objects[Connection] = filter(lambda c: c not in remove_conns, network.connections)

            network.connections = network.objects[Connection]

        probed_connections = []
        for probe in network.probes:
            if isinstance(probe.target, Connection):
                probed_connections.append(probe.target)
        model.probed_connections |= set(probed_connections)

        return builder.build_network(model, network)
def build_nrn_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Check pre-conditions
    assert isinstance(conn.pre, nengo.Ensemble)
    assert not isinstance(conn.pre.neuron_type, nengo.neurons.Direct)
    # FIXME assert no rate neurons are used. How to do that?

    # Get input signal
    # FIXME this should probably be
    # model.sig[conn]['in'] = model.sig[conn.pre]["out"]
    # in both cases
    if isinstance(conn.pre, nengo.ensemble.Neurons):
        model.sig[conn]["in"] = model.sig[conn.pre.ensemble]["out"]
    else:
        model.sig[conn]["in"] = model.sig[conn.pre]["out"]

    # Figure out type of connection
    if isinstance(conn.post, nengo.ensemble.Neurons):
        raise NotImplementedError()  # TODO
    elif isinstance(conn.post.neuron_type, Compartmental):
        pass
    else:
        raise AssertionError("This function should only be called if post neurons are " "compartmental.")

    # Solve for weights
    # FIXME just assuming solver is a weight solver, may that break?
    # Default solver should probably also produce sparse solutions for
    # performance reasons
    eval_points, activities, targets = build_linear_system(model, conn, rng=rng)

    # account for transform
    transform = full_transform(conn)
    targets = np.dot(targets, transform.T)

    weights, solver_info = conn.solver(activities, targets, rng=rng, E=model.params[conn.post].scaled_encoders.T)

    # Synapse type
    synapse = conn.synapse
    if is_number(synapse):
        synapse = ExpSyn(synapse)

    # Connect
    # TODO: Why is this adjustment of the weights necessary?
    weights = weights / synapse.tau / 5.0 * 0.1
    connections = [[] for i in range(len(weights))]
    for i, cell in enumerate(ens_to_cells[conn.post]):
        for j, w in enumerate(weights[:, i]):
            if w >= 0.0:
                x = np.random.rand()
                connections[j].append(synapse.create(cell.neuron.apical(x), w * (x + 1)))
            else:
                connections[j].append(synapse.create(cell.neuron.soma(0.5), w))

    # 3. Add operator creating events for synapses if pre neuron fired
    model.add_op(NrnTransmitSpikes(model.sig[conn]["in"], connections))
Esempio n. 10
0
def build_node_transmission_parameters(model, conn):
    """Build transmission parameters for a connection originating at a Node."""
    if conn.pre_obj.output is not None:
        # Connection is not from a passthrough Node
        # Get the full transform, not including the pre_slice
        transform = full_transform(conn, slice_pre=False, allow_scalars=False)
    else:
        # Connection is from a passthrough Node
        # Get the full transform
        transform = full_transform(conn, allow_scalars=False)

    # If the connection is to neurons and the transform is equivalent in every
    # row we treat it as a global inhibition connection and shrink it down to
    # one row.
    if (isinstance(conn.post_obj, nengo.ensemble.Neurons)
            and np.all(transform[0, :] == transform[1:, :])):
        # Reduce the size of the transform
        transform = np.array([transform[0]])

    if conn.pre_obj.output is not None:
        return NodeTransmissionParameters(conn.pre_slice, conn.function,
                                          transform)
    else:
        return PassthroughNodeTransmissionParameters(transform)
Esempio n. 11
0
def build_node_transmission_parameters(model, conn):
    """Build transmission parameters for a connection originating at a Node."""
    if conn.pre_obj.output is not None:
        # Connection is not from a passthrough Node
        # Get the full transform, not including the pre_slice
        transform = full_transform(conn, slice_pre=False, allow_scalars=False)
    else:
        # Connection is from a passthrough Node
        # Get the full transform
        transform = full_transform(conn, allow_scalars=False)

    # If the connection is to neurons and the transform is equivalent in every
    # row we treat it as a global inhibition connection and shrink it down to
    # one row.
    if (isinstance(conn.post_obj, nengo.ensemble.Neurons) and
            np.all(transform[0, :] == transform[1:, :])):
        # Reduce the size of the transform
        transform = np.array([transform[0]])

    if conn.pre_obj.output is not None:
        return NodeTransmissionParameters(conn.pre_slice, conn.function,
                                          transform)
    else:
        return PassthroughNodeTransmissionParameters(transform)
Esempio n. 12
0
def get_factored_weight_matrices_requirements(network):
    (objects, connections) = remove_passthrough_nodes(
        *objs_and_connections(network))

    # For each Ensemble each incoming connection matrix is N_pre x N big
    mem_usage = 0
    for ens in [o for o in objects if isinstance(o, nengo.Ensemble)]:
        out_conns = [c for c in connections if c.pre is ens and
                     isinstance(c.post, nengo.Ensemble)]

        # Outgoing cost is (n_neurons + 1) x out_d where out_d is the number of
        # non-zero rows in the transform matrix
        out_transforms = [full_transform(c, allow_scalars=False) for c in
                          out_conns]
        out_dims = sum(np.sum(np.any(np.abs(t) > 0., axis=1)) for t in
                       out_transforms)
        mem_usage += (ens.n_neurons + 1) * out_dims

        # Incoming cost is just n_neurons x d
        mem_usage += ens.n_neurons * ens.dimensions

    return mem_usage * BYTES_PER_ENC_DECODER  # (4 bytes per value)
Esempio n. 13
0
def build_connection(model, conn):
    """
    Method to build connections into bioensembles.
    Calculates the optimal decoders for this conneciton as though
    the presynaptic ensemble was connecting to a hypothetical LIF ensemble.
    These decoders are used to calculate the synaptic weights
    in init_connection().
    Adds a transmit_spike operator for this connection to the model
    """

    conn_pre = deref_objview(conn.pre)
    conn_post = deref_objview(conn.post)
    rng = np.random.RandomState(model.seeds[conn])

    if isinstance(conn_pre, nengo.Ensemble) and \
            isinstance(conn_pre.neuron_type, BahlNeuron):
        # todo: generalize to custom online solvers
        if not isinstance(conn.solver, NoSolver) and conn.syn_weights is None:
            raise BuildError("Connections from bioneurons must provide a NoSolver or syn_weights"
                            " (got %s from %s to %s)" % (conn.solver, conn_pre, conn_post))

    if (isinstance(conn_post, nengo.Ensemble) and \
            isinstance(conn_post.neuron_type, BahlNeuron)):
        if not isinstance(conn_pre, nengo.Ensemble) or \
                'spikes' not in conn_pre.neuron_type.probeable:
            raise BuildError("May only connect spiking neurons (pre=%s) to "
                             "bioneurons (post=%s)" % (conn_pre, conn_post))
        """
        Given a parcicular connection, labeled by conn.pre,
        Grab the initial decoders
        Generate locations for synapses, then either
        (a) Create synapses with weight equal to
            w_ij=np.dot(d_i,alpha_j*e_j)/n_syn, where
                - d_i is the initial decoder,
                - e_j is the single bioneuron encoder
                - alpha_j is the single bioneuron gain
                - n_syn normalizes total input current for multiple-synapse conns
        (b) Load synaptic weights from a prespecified matrix

        Add synapses with those weights to bioneuron.synapses,
        store this initial synaptic weight matrix in conn.weights = conn.syn_weights
        Finally call neuron.init().
        """
        if conn.syn_locs is None:
            conn.syn_locs = get_synaptic_locations(
                rng,
                conn_pre.n_neurons,
                conn_post.n_neurons,
                conn.n_syn)
        if conn.syn_weights is None:
            use_syn_weights = False
            conn.syn_weights = np.zeros((
                conn_post.n_neurons,
                conn_pre.n_neurons,
                conn.syn_locs.shape[2]))
        else:
            use_syn_weights = True
            conn.syn_weights = copy.copy(conn.syn_weights)
        if conn.learning_node is not None and hasattr(conn.learning_node, 'syn_encoders'):
            # initialize synaptic weights for EncoderNode learned connection
            use_syn_weights = True
            conn.syn_weights = np.array(conn.learning_node.update_weights())

        # Grab decoders from the specified solver (usually nengo.solvers.NoSolver(d))
        transform = full_transform(conn, slice_pre=False)
        eval_points, decoders, solver_info = build_decoders(
                model, conn, rng, transform)

        # normalize the area under the ExpSyn curve to compensate for effect of tau
        times = np.arange(0, 1.0, 0.001)
        k_norm = np.linalg.norm(np.exp((-times/conn.tau_list[0])),1)

        # todo: synaptic gains and encoders
        # print conn, conn_post.gain, conn.post.encoders
        neurons = model.params[conn_post.neurons]  # set in build_bioneurons
        for j, bahl in enumerate(neurons):
            assert isinstance(bahl, Bahl)
            loc = conn.syn_locs[j]
            encoder = conn_post.encoders[j]
            gain = conn_post.gain[j]
            bahl.synapses[conn] = np.empty(
                (loc.shape[0], loc.shape[1]), dtype=object)
            for pre in range(loc.shape[0]):
                for syn in range(loc.shape[1]):
                    if conn.sec == 'apical':
                        section = bahl.cell.apical(loc[pre, syn])
                    elif conn.sec == 'tuft':
                        section = bahl.cell.tuft(loc[pre, syn])
                    elif conn.sec == 'basal':
                        section = bahl.cell.basal(loc[pre, syn])
                    if use_syn_weights:  # syn_weights already specified
                        w_ij = conn.syn_weights[j, pre, syn]
                    else:  # syn_weights should be set by dec_pre and bio encoders/gain
                        w_ij = np.dot(decoders.T[pre], gain * encoder)
                        w_ij = w_ij / conn.n_syn / k_norm
                        conn.syn_weights[j, pre, syn] = w_ij
                    if conn.syn_type == 'ExpSyn':
                        tau = conn.tau_list[0]
                        synapse = ExpSyn(section, w_ij, tau, loc[pre, syn])
                    elif conn.syn_type == 'Exp2Syn':
                        assert len(conn.tau_list) == 2, 'Exp2Syn requires tau_rise, tau_fall'
                        tau1 = conn.tau_list[0]
                        tau2 = conn.tau_list[1]
                        synapse = Exp2Syn(section, w_ij, tau1, tau2, loc[pre, syn])
                    bahl.synapses[conn][pre][syn] = synapse
        neuron.init()

        model.add_op(TransmitSpikes(
            conn, conn_post, conn.learning_node, neurons,
            model.sig[conn_pre]['out'], states=[model.time]))
        model.params[conn] = BuiltConnection(eval_points=eval_points,
                                             solver_info=solver_info,
                                             transform=transform,
                                             weights=conn.syn_weights)

    else:  # normal connection
        return nengo.builder.connection.build_connection(model, conn)
Esempio n. 14
0
    def split_ensemble_array(self, array, n_parts):
        """
        Splits an ensemble array into multiple functionally equivalent ensemble
        arrays, removing old connections and probes and adding new ones. Currently
        will not split ensemble arrays that have neuron output or input nodes, but
        there is no reason this could not be added in the future.

        Parameters
        ----------
        array: nengo.EnsembleArray
            The array to split

        n_parts: int
            Number of arrays to split ``array`` into
        """

        if array.neuron_input is not None or array.neuron_output is not None:
            self.logger.info("Not splitting ensemble array " + array.label + " because it has neuron nodes.")
            return

        if n_parts < 2:
            self.logger.info("Not splitting ensemble array because the " "desired number of parts is < 2.")
            return

        self.logger.info("+" * 80)
        self.logger.info("Splitting ensemble array %s into %d parts.", array.__repr__(), n_parts)

        if not isinstance(array, nengo.networks.EnsembleArray):
            raise ValueError("'array' must be an EnsembleArray")

        inputs, outputs = self.inputs, self.outputs

        n_ensembles = array.n_ensembles
        D = array.dimensions_per_ensemble

        if n_ensembles != len(array.ea_ensembles):
            raise ValueError("Number of ensembles does not match")

        if len(array.all_connections) != n_ensembles * len(array.all_nodes):
            raise ValueError("Number of connections incorrect.")

        # assert no extra connections
        ea_ensemble_set = set(array.ea_ensembles)
        if len(outputs[array.input]) != n_ensembles or (set(c.post for c in outputs[array.input]) != ea_ensemble_set):
            raise ValueError("Extra connections from array input")

        connection_set = set(array.all_connections)

        extra_inputs = set()
        extra_outputs = set()

        for conn in self.top_level_network.all_connections:

            if conn.pre_obj in ea_ensemble_set and conn not in array.all_connections:
                extra_outputs.add(conn)

            if conn.post_obj in ea_ensemble_set and conn not in array.all_connections:
                extra_inputs.add(conn)

        for conn in extra_inputs:
            self.logger.info("\n" + "*" * 20)
            self.logger.info("Extra input connector: %s", conn.pre_obj)
            self.logger.info("Synapse: %s", conn.synapse)
            self.logger.info("Inputs: ")

            for c in inputs[conn.pre_obj]:
                self.logger.info(c)

            self.logger.info("Outputs: ")

            for c in outputs[conn.pre_obj]:
                self.logger.info(c)

        for conn in extra_outputs:

            self.logger.info("\n" + "*" * 20)
            self.logger.info("Extra output connector: %s", conn.post_obj)
            self.logger.info("Synapse: %s", conn.synapse)
            self.logger.info("Inputs: ")

            for c in inputs[conn.post_obj]:
                self.logger.info(c)

            self.logger.info("Outputs: ")

            for c in outputs[conn.post_obj]:
                self.logger.info(c)

        output_nodes = [n for n in array.nodes if n.label[-5:] != "input"]
        assert len(output_nodes) > 0
        # assert len(filter(lambda x: x.label == 'output', output_nodes)) > 0

        for output_node in output_nodes:
            extra_connections = set(c.pre for c in inputs[output_node]) != ea_ensemble_set
            if extra_connections:
                raise ValueError("Extra connections to array output")

        sizes = np.zeros(n_parts, dtype=int)
        j = 0
        for i in range(n_ensembles):
            sizes[j] += 1
            j = (j + 1) % len(sizes)

        indices = np.zeros(len(sizes) + 1, dtype=int)
        indices[1:] = np.cumsum(sizes)

        self.logger.info("*" * 10 + "Fixing input connections")

        # make new input nodes
        with array:
            new_inputs = [
                nengo.Node(size_in=size * D, label="%s%d" % (array.input.label, i)) for i, size in enumerate(sizes)
            ]

        self.node_map[array.input] = new_inputs

        # remove connections involving old input node
        for conn in array.connections[:]:
            if conn.pre_obj is array.input and conn.post in array.ea_ensembles:
                array.connections.remove(conn)

        # make connections from new input nodes to ensembles
        for i, inp in enumerate(new_inputs):
            i0, i1 = indices[i], indices[i + 1]
            for j, ens in enumerate(array.ea_ensembles[i0:i1]):
                with array:
                    nengo.Connection(inp[j * D : (j + 1) * D], ens, synapse=None)

        # make connections into EnsembleArray
        for c_in in inputs[array.input]:

            # remove connection to old node
            self.logger.info("Removing connection from network: %s", c_in)

            pre_outputs = outputs[c_in.pre_obj]

            transform = full_transform(c_in, slice_pre=False, slice_post=True, allow_scalars=False)

            # make connections to new nodes
            for i, inp in enumerate(new_inputs):
                i0, i1 = indices[i], indices[i + 1]
                sub_transform = transform[i0 * D : i1 * D, :]

                if self.preserve_zero_conns or np.any(sub_transform):
                    containing_network = find_object_location(self.top_level_network, c_in)[-1]
                    assert containing_network, "Connection %s is not in network." % c_in

                    with containing_network:
                        new_conn = nengo.Connection(
                            c_in.pre, inp, synapse=c_in.synapse, function=c_in.function, transform=sub_transform
                        )

                    self.logger.info("Added connection: %s", new_conn)

                    inputs[inp].append(new_conn)
                    pre_outputs.append(new_conn)

            assert remove_from_network(self.top_level_network, c_in)
            pre_outputs.remove(c_in)

        # remove old input node
        array.nodes.remove(array.input)
        array.input = None

        self.logger.info("*" * 10 + "Fixing output connections")

        # loop over outputs
        for old_output in output_nodes:

            output_sizes = []
            for ensemble in array.ensembles:
                conn = filter(lambda c: old_output.label in str(c.post), outputs[ensemble])[0]
                output_sizes.append(conn.size_out)

            # make new output nodes
            new_outputs = []
            for i in range(n_parts):
                i0, i1 = indices[i], indices[i + 1]
                i_sizes = output_sizes[i0:i1]
                with array:
                    new_output = nengo.Node(size_in=sum(i_sizes), label="%s_%d" % (old_output.label, i))

                new_outputs.append(new_output)

                i_inds = np.zeros(len(i_sizes) + 1, dtype=int)
                i_inds[1:] = np.cumsum(i_sizes)

                # connect ensembles to new output node
                for j, e in enumerate(array.ea_ensembles[i0:i1]):
                    old_conns = [c for c in array.connections if c.pre is e and c.post_obj is old_output]
                    assert len(old_conns) == 1
                    old_conn = old_conns[0]

                    # remove old connection from ensembles
                    array.connections.remove(old_conn)

                    # add new connection from ensemble
                    j0, j1 = i_inds[j], i_inds[j + 1]
                    with array:
                        nengo.Connection(
                            e,
                            new_output[j0:j1],
                            synapse=old_conn.synapse,
                            function=old_conn.function,
                            transform=old_conn.transform,
                        )

            self.node_map[old_output] = new_outputs

            # connect new outputs to external model
            output_sizes = [n.size_out for n in new_outputs]
            output_inds = np.zeros(len(output_sizes) + 1, dtype=int)
            output_inds[1:] = np.cumsum(output_sizes)

            for c_out in outputs[old_output]:
                assert c_out.function is None

                # remove connection to old node
                self.logger.info("Removing connection from network: %s", c_out)

                transform = full_transform(c_out, slice_pre=True, slice_post=True, allow_scalars=False)

                post_inputs = inputs[c_out.post_obj]

                # add connections to new nodes
                for i, out in enumerate(new_outputs):
                    i0, i1 = output_inds[i], output_inds[i + 1]
                    sub_transform = transform[:, i0:i1]

                    if self.preserve_zero_conns or np.any(sub_transform):
                        containing_network = find_object_location(self.top_level_network, c_out)[-1]
                        assert containing_network, "Connection %s is not in network." % c_out

                        with containing_network:
                            new_conn = nengo.Connection(out, c_out.post, synapse=c_out.synapse, transform=sub_transform)

                        self.logger.info("Added connection: %s", new_conn)

                        outputs[out].append(new_conn)
                        post_inputs.append(new_conn)

                assert remove_from_network(self.top_level_network, c_out)
                post_inputs.remove(c_out)

            # remove old output node
            array.nodes.remove(old_output)
            setattr(array, old_output.label, None)
Esempio n. 15
0
def test_full_transform():
    N = 30

    with nengo.Network():
        neurons3 = nengo.Ensemble(3, dimensions=1).neurons
        ens1 = nengo.Ensemble(N, dimensions=1)
        ens2 = nengo.Ensemble(N, dimensions=2)
        ens3 = nengo.Ensemble(N, dimensions=3)
        node1 = nengo.Node(output=[0])
        node2 = nengo.Node(output=[0, 0])
        node3 = nengo.Node(output=[0, 0, 0])

        # Pre slice with default transform -> 1x3 transform
        conn = nengo.Connection(node3[2], ens1)
        assert np.all(conn.transform == np.array(1))
        assert np.all(full_transform(conn) == np.array([[0, 0, 1]]))

        # Post slice with 1x1 transform -> 1x2 transform
        conn = nengo.Connection(node2[0], ens1, transform=-2)
        assert np.all(conn.transform == np.array(-2))
        assert np.all(full_transform(conn) == np.array([[-2, 0]]))

        # Post slice with 2x1 tranfsorm -> 3x1 transform
        conn = nengo.Connection(node1, ens3[::2], transform=[[1], [2]])
        assert np.all(conn.transform == np.array([[1], [2]]))
        assert np.all(full_transform(conn) == np.array([[1], [0], [2]]))

        # Both slices with 2x1 transform -> 3x2 transform
        conn = nengo.Connection(ens2[-1], neurons3[1:], transform=[[1], [2]])
        assert np.all(conn.transform == np.array([[1], [2]]))
        assert np.all(full_transform(conn) == np.array(
            [[0, 0], [0, 1], [0, 2]]))

        # Full slices that can be optimized away
        conn = nengo.Connection(ens3[:], ens3, transform=2)
        assert np.all(conn.transform == np.array(2))
        assert np.all(full_transform(conn) == np.array(2))

        # Pre slice with 1x1 transform on 2x2 slices -> 2x3 transform
        conn = nengo.Connection(neurons3[:2], ens2, transform=-1)
        assert np.all(conn.transform == np.array(-1))
        assert np.all(full_transform(conn) == np.array(
            [[-1, 0, 0], [0, -1, 0]]))

        # Both slices with 1x1 transform on 2x2 slices -> 3x3 transform
        conn = nengo.Connection(neurons3[1:], neurons3[::2], transform=-1)
        assert np.all(conn.transform == np.array(-1))
        assert np.all(full_transform(conn) == np.array([[0, -1, 0],
                                                       [0, 0, 0],
                                                       [0, 0, -1]]))

        # Both slices with 2x2 transform -> 3x3 transform
        conn = nengo.Connection(node3[[0, 2]], neurons3[1:],
                                transform=[[1, 2], [3, 4]])
        assert np.all(conn.transform == np.array([[1, 2], [3, 4]]))
        assert np.all(full_transform(conn) == np.array([[0, 0, 0],
                                                       [1, 0, 2],
                                                       [3, 0, 4]]))

        # Both slices with 2x3 transform -> 3x3 transform... IN REVERSE!
        conn = nengo.Connection(neurons3[::-1], neurons3[[2, 0]],
                                transform=[[1, 2, 3], [4, 5, 6]])
        assert np.all(conn.transform == np.array([[1, 2, 3], [4, 5, 6]]))
        assert np.all(full_transform(conn) == np.array([[6, 5, 4],
                                                       [0, 0, 0],
                                                       [3, 2, 1]]))

        # Both slices using lists
        conn = nengo.Connection(neurons3[[1, 0, 2]], neurons3[[2, 1]],
                                transform=[[1, 2, 3], [4, 5, 6]])
        assert np.all(conn.transform == np.array([[1, 2, 3], [4, 5, 6]]))
        assert np.all(full_transform(conn) == np.array([[0, 0, 0],
                                                       [5, 4, 6],
                                                       [2, 1, 3]]))

        # using vector
        conn = nengo.Connection(ens3[[1, 0, 2]], ens3[[2, 0, 1]],
                                transform=[1, 2, 3])
        assert np.all(conn.transform == np.array([1, 2, 3]))
        assert np.all(full_transform(conn) == np.array([[2, 0, 0],
                                                       [0, 0, 3],
                                                       [0, 1, 0]]))

        # using vector and lists
        conn = nengo.Connection(ens3[[1, 0, 2]], ens3[[2, 0, 1]],
                                transform=[1, 2, 3])
        assert np.all(conn.transform == np.array([1, 2, 3]))
        assert np.all(full_transform(conn) == np.array([[2, 0, 0],
                                                       [0, 0, 3],
                                                       [0, 1, 0]]))

        # using multi-index lists
        conn = nengo.Connection(ens3, ens2[[0, 1, 0]])
        assert np.all(full_transform(conn) == np.array([[1, 0, 1],
                                                       [0, 1, 0]]))
Esempio n. 16
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero."
                             % (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node) or
            (isinstance(conn.pre_obj, Ensemble) and
             isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice) and
                (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            sig_in, signal = build_pyfunc(
                fn=(lambda x: x[conn.pre_slice]) if conn.function is None else
                   (lambda x: conn.function(x[conn.pre_slice])),
                t_in=False,
                n_in=model.sig[conn]['in'].size,
                n_out=conn.size_mid,
                label=str(conn),
                model=model)
            model.add_op(DotInc(model.sig[conn]['in'],
                                model.sig['common'][1],
                                sig_in,
                                tag="%s input" % conn))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(model, conn)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)
        if conn.solver.weights:
            # account for transform
            targets = np.dot(targets, transform.T)
            transform = np.array(1., dtype=np.float64)

            decoders, solver_info = solver(
                activities, targets, rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = Signal(
            decoders, name="%s.decoders" % conn)
        signal = Signal(np.zeros(signal_size), name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(DotInc(model.sig[conn]['decoders'],
                            model.sig[conn]['in'],
                            signal,
                            tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    if conn.modulatory:
        # Make a new signal, effectively detaching from post
        model.sig[conn]['out'] = Signal(
            np.zeros(model.sig[conn]['out'].size),
            name="%s.mod_output" % conn)
        model.add_op(Reset(model.sig[conn]['out']))

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." % (
                                   conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = Signal(transform,
                                          name="%s.transform" % conn)
    if transform.ndim < 2:
        model.add_op(ElementwiseInc(model.sig[conn]['transform'],
                                    signal,
                                    model.sig[conn]['out'],
                                    tag=str(conn)))
    else:
        model.add_op(DotInc(model.sig[conn]['transform'],
                            signal,
                            model.sig[conn]['out'],
                            tag=str(conn)))

    if conn.learning_rule_type:
        # Forcing update of signal that is modified by learning rules.
        # Learning rules themselves apply DotIncs.

        if isinstance(conn.pre_obj, Neurons):
            modified_signal = model.sig[conn]['transform']
        elif isinstance(conn.pre_obj, Ensemble):
            if conn.solver.weights:
                # TODO: make less hacky.
                # Have to do this because when a weight_solver
                # is provided, then learning rules should operators on
                # "decoders" which is really the weight matrix.
                model.sig[conn]['transform'] = model.sig[conn]['decoders']
                modified_signal = model.sig[conn]['transform']
            else:
                modified_signal = model.sig[conn]['decoders']
        else:
            raise TypeError("Can't apply learning rules to connections of "
                            "this type. pre type: %s, post type: %s"
                            % (type(conn.pre_obj).__name__,
                               type(conn.post_obj).__name__))

        model.add_op(PreserveValue(modified_signal))

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
Esempio n. 17
0
def create_replacement_connection(c_in, c_out):
    """Generate a new Connection to replace two through a passthrough Node"""
    assert c_in.post_obj is c_out.pre_obj
    assert c_in.post_obj.output is None

    # determine the filter for the new Connection
    if c_in.synapse is None:
        synapse = c_out.synapse
    elif c_out.synapse is None:
        synapse = c_in.synapse
    else:
        raise NotImplementedError('Cannot merge two filters')
        # Note: the algorithm below is in the right ballpark,
        #  but isn't exactly the same as two low-pass filters
        # filter = c_out.filter + c_in.filter

    function = c_in.function
    if c_out.function is not None:
        raise Exception('Cannot remove a Node with a '
                        'function being computed on it')

    # compute the combined transform
    transform = np.dot(full_transform(c_out), full_transform(c_in))
    # check if the transform is 0 (this happens a lot
    #  with things like identity transforms)
    if np.all(transform == 0):
        return None

    # Determine the combined keyspace
    if (getattr(c_out, 'keyspace', None) is not None and
            getattr(c_in, 'keyspace', None) is None):
        # If the out keyspace is specified but the IN isn't, then use the out
        # keyspace
        keyspace = getattr(c_out, 'keyspace', None)
    elif (getattr(c_in, 'keyspace', None) is not None and
            getattr(c_out, 'keyspace', None) is None):
        # Vice versa
        keyspace = getattr(c_in, 'keyspace', None)
    elif getattr(c_in, 'keyspace', None) == getattr(c_out, 'keyspace', None):
        # The keyspaces are equivalent
        keyspace = getattr(c_in, 'keyspace', None)
    else:
        # XXX: The incoming and outcoming connections have assigned
        #      keyspaces, this shouldn't occur (often if not at all).
        raise NotImplementedError('Cannot merge two keyspaces.')

    # Determine the type of connection to use
    if c_in.__class__ is c_out.__class__:
        # Types are equivalent, so use this type
        ctype = c_in.__class__
    elif c_in.__class__ is IntermediateConnection:
        # In type is default, use out type
        ctype = c_out.__class__
    elif c_out.__class__ is IntermediateConnection:
        # Out type is default, use in type
        ctype = c_in.__class__
    else:
        raise NotImplementedError("Cannot merge '%s' and '%s' connection "
                                  "types." % (c_in.__class__,
                                              c_out.__class__))

    if ctype is nengo.Connection:
        ctype = IntermediateConnection

    c = ctype(c_in.pre_obj, c_out.post_obj, synapse=synapse,
              transform=transform, function=function, keyspace=keyspace)
    return c
Esempio n. 18
0
def build_connection(model, conn):
    is_ens = isinstance(conn.post_obj, nengo.Ensemble)
    is_neurons = isinstance(conn.post_obj, nengo.ensemble.Neurons)
    is_NEURON = False

    if is_ens:
        if isinstance(conn.post_obj.neuron_type, BioNeuron):
            is_NEURON = True
        else:
            is_NEURON = False
    elif is_neurons:
        if isinstance(conn.post_obj.ensemble.neuron_type, BioNeuron):
            is_NEURON = True
    else:
        is_ens = False
        is_neurons = False
        is_NEURON = False

    if is_NEURON:
        model.sig[conn]['in'] = model.sig[conn.pre_obj]['out']
        assert isinstance(
            conn.pre_obj,
            nengo.Ensemble) and 'spikes' in conn.pre_obj.neuron_type.probeable
        post_obj = conn.post_obj if is_ens else conn.post_obj.ensemble
        pre_obj = conn.pre_obj

        if isinstance(conn.synapse, AMPA):
            taus = "AMPA"
        elif isinstance(conn.synapse, GABA):
            taus = "GABA"
        elif isinstance(conn.synapse, NMDA):
            taus = "NMDA"
        elif isinstance(conn.synapse, LinearSystem):
            taus = -1.0 / np.array(conn.synapse.poles) * 1000  # convert to ms
        else:
            raise "synapse type %s not understood (connection %s)" % (
                conn.synapse, conn)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            conn.rng = np.random.RandomState(model.seeds[conn])
            # conn.e = conn.rng.uniform(-1, 1, size=(pre_obj.n_neurons, post_obj.n_neurons, post_obj.dimensions))
            conn.e = np.zeros(
                (pre_obj.n_neurons, post_obj.n_neurons, post_obj.dimensions))
            conn.locations = conn.rng.uniform(0,
                                              1,
                                              size=(pre_obj.n_neurons,
                                                    post_obj.n_neurons))
            if post_obj.neuron_type.cell_type == "Pyramidal":
                conn.compartments = conn.rng.randint(
                    0, 3,
                    size=(pre_obj.n_neurons,
                          post_obj.n_neurons))  # 0=proximal, 1=distal, 2=basal
            elif post_obj.neuron_type.cell_type == "Interneuron":
                conn.compartments = np.zeros(
                    shape=(pre_obj.n_neurons,
                           post_obj.n_neurons))  # 0=dendrite
            conn.synapses = np.zeros((pre_obj.n_neurons, post_obj.n_neurons),
                                     dtype=list)
            conn.netcons = np.zeros((pre_obj.n_neurons, post_obj.n_neurons),
                                    dtype=list)
            conn.weights = np.zeros((pre_obj.n_neurons, post_obj.n_neurons))
            conn.v_recs = []
            transform = full_transform(conn, slice_pre=False)
            eval_points, d, solver_info = model.build(conn.solver, conn,
                                                      conn.rng, transform)
            conn.d = d.T
            conn.transmitspike = None
        for post in range(post_obj.n_neurons):
            nrn = model.params[post_obj.neurons][post]
            for pre in range(pre_obj.n_neurons):
                conn.weights[pre, post] = np.dot(conn.d[pre], conn.e[pre,
                                                                     post])
                reversal = 0.0 if conn.weights[pre, post] > 0 else -70.0
                if conn.compartments[pre, post] == 0:
                    if post_obj.neuron_type.cell_type == "Pyramidal":
                        loc = nrn.prox(conn.locations[pre, post])
                    elif post_obj.neuron_type.cell_type == "Interneuron":
                        loc = nrn.dendrite(conn.locations[pre, post])
                elif conn.compartments[pre, post] == 1:
                    loc = nrn.dist(conn.locations[pre, post])
                else:
                    loc = nrn.basal(conn.locations[pre, post])
                if type(taus) == str:
                    if taus == "AMPA":
                        syn = neuron.h.ampa(loc)
                    elif taus == "GABA":
                        syn = neuron.h.gaba(loc)
                    elif taus == "NMDA":
                        syn = neuron.h.nmda(loc)
                    else:
                        raise "synapse %s not understood" % taus
                elif len(taus) == 1:
                    syn = neuron.h.ExpSyn(loc)
                    syn.tau = taus[0]
                    syn.e = reversal
                elif len(taus) == 2:
                    #                     syn = neuron.h.Exp2Syn(loc)
                    syn = neuron.h.doubleexp(loc)
                    syn.tauRise = np.min(taus)
                    syn.tauFall = np.max(taus)
                    syn.e = reversal
                conn.synapses[pre, post] = syn
                conn.netcons[pre,
                             post] = neuron.h.NetCon(None, conn.synapses[pre,
                                                                         post])
                conn.netcons[pre, post].weight[0] = np.abs(conn.weights[pre,
                                                                        post])
            conn.v_recs.append(neuron.h.Vector())
            conn.v_recs[post].record(nrn.soma(0.5)._ref_v)
        transmitspike = TransmitSpikes(model.params[post_obj.neurons],
                                       conn.netcons,
                                       model.sig[conn.pre_obj]['out'],
                                       DA=post_obj.neuron_type.DA,
                                       states=[model.time],
                                       dt=model.dt)
        model.add_op(transmitspike)
        conn.transmitspike = transmitspike
        model.params[conn] = BuiltConnection(eval_points=eval_points,
                                             solver_info=solver_info,
                                             transform=transform,
                                             weights=d)

    else:
        c = nengo.builder.connection.build_connection(model, conn)
        model.sig[conn]['weights'].readonly = False
        return c
Esempio n. 19
0
def test_full_transform():
    """Tests ``full_transform`` and its exceptions"""
    N = 30

    with nengo.Network():
        neurons3 = nengo.Ensemble(3, dimensions=1).neurons
        ens1 = nengo.Ensemble(N, dimensions=1)
        ens2 = nengo.Ensemble(N, dimensions=2)
        ens3 = nengo.Ensemble(N, dimensions=3)
        node1 = nengo.Node(output=[0])
        node2 = nengo.Node(output=[0, 0])
        node3 = nengo.Node(output=[0, 0, 0])

        # error for non-Dense transform
        conn = nengo.Connection(
            ens2, ens3, transform=nengo.transforms.Sparse((3, 2), indices=[(0, 0)])
        )
        with pytest.raises(ValidationError, match="can only be applied to Dense"):
            full_transform(conn)

        # Pre slice with default transform -> 1x3 transform
        conn = nengo.Connection(node3[2], ens1)
        assert isinstance(conn.transform, NoTransform)
        assert np.all(full_transform(conn) == np.array([[0, 0, 1]]))

        # Post slice with 1x1 transform -> 1x2 transform
        conn = nengo.Connection(node2[0], ens1, transform=-2)
        assert np.all(conn.transform.init == np.array(-2))
        assert np.all(full_transform(conn) == np.array([[-2, 0]]))

        # Post slice with 2x1 tranfsorm -> 3x1 transform
        conn = nengo.Connection(node1, ens3[::2], transform=[[1], [2]])
        assert np.all(conn.transform.init == np.array([[1], [2]]))
        assert np.all(full_transform(conn) == np.array([[1], [0], [2]]))

        # Both slices with 2x1 transform -> 3x2 transform
        conn = nengo.Connection(ens2[-1], neurons3[1:], transform=[[1], [2]])
        assert np.all(conn.transform.init == np.array([[1], [2]]))
        assert np.all(full_transform(conn) == np.array([[0, 0], [0, 1], [0, 2]]))

        # Full slices that can be optimized away
        conn = nengo.Connection(ens3[:], ens3, transform=2)
        assert np.all(conn.transform.init == np.array(2))
        assert np.all(full_transform(conn) == np.array(2))

        # Pre slice with 1x1 transform on 2x2 slices -> 2x3 transform
        conn = nengo.Connection(neurons3[:2], ens2, transform=-1)
        assert np.all(conn.transform.init == np.array(-1))
        assert np.all(full_transform(conn) == np.array([[-1, 0, 0], [0, -1, 0]]))

        # Both slices with 1x1 transform on 2x2 slices -> 3x3 transform
        conn = nengo.Connection(neurons3[1:], neurons3[::2], transform=-1)
        assert np.all(conn.transform.init == np.array(-1))
        assert np.all(
            full_transform(conn) == np.array([[0, -1, 0], [0, 0, 0], [0, 0, -1]])
        )

        # Both slices with 2x2 transform -> 3x3 transform
        conn = nengo.Connection(node3[[0, 2]], neurons3[1:], transform=[[1, 2], [3, 4]])
        assert np.all(conn.transform.init == np.array([[1, 2], [3, 4]]))
        assert np.all(
            full_transform(conn) == np.array([[0, 0, 0], [1, 0, 2], [3, 0, 4]])
        )

        # Both slices with 2x3 transform -> 3x3 transform... IN REVERSE!
        conn = nengo.Connection(
            neurons3[::-1], neurons3[[2, 0]], transform=[[1, 2, 3], [4, 5, 6]]
        )
        assert np.all(conn.transform.init == np.array([[1, 2, 3], [4, 5, 6]]))
        assert np.all(
            full_transform(conn) == np.array([[6, 5, 4], [0, 0, 0], [3, 2, 1]])
        )

        # Both slices using lists
        conn = nengo.Connection(
            neurons3[[1, 0, 2]], neurons3[[2, 1]], transform=[[1, 2, 3], [4, 5, 6]]
        )
        assert np.all(conn.transform.init == np.array([[1, 2, 3], [4, 5, 6]]))
        assert np.all(
            full_transform(conn) == np.array([[0, 0, 0], [5, 4, 6], [2, 1, 3]])
        )

        # using vector
        conn = nengo.Connection(ens3[[1, 0, 2]], ens3[[2, 0, 1]], transform=[1, 2, 3])
        assert np.all(conn.transform.init == np.array([1, 2, 3]))
        assert np.all(
            full_transform(conn) == np.array([[2, 0, 0], [0, 0, 3], [0, 1, 0]])
        )

        # using vector 1D
        conn = nengo.Connection(ens1, ens1, transform=[5])
        assert full_transform(conn).ndim != 1
        assert np.all(full_transform(conn) == 5)

        # using vector and lists
        conn = nengo.Connection(ens3[[1, 0, 2]], ens3[[2, 0, 1]], transform=[1, 2, 3])
        assert np.all(conn.transform.init == np.array([1, 2, 3]))
        assert np.all(
            full_transform(conn) == np.array([[2, 0, 0], [0, 0, 3], [0, 1, 0]])
        )

        # using multi-index lists
        conn = nengo.Connection(ens3, ens2[[0, 1, 0]])
        assert np.all(full_transform(conn) == np.array([[1, 0, 1], [0, 1, 0]]))
Esempio n. 20
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        print('pre', conn.pre_obj, 'post', conn.post_obj)
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            sig_in, signal = build_pyfunc(
                fn=(lambda x: x[conn.pre_slice]) if conn.function is None else
                (lambda x: conn.function(x[conn.pre_slice])),
                t_in=False,
                n_in=model.sig[conn]['in'].size,
                n_out=conn.size_mid,
                label=str(conn),
                model=model)
            model.add_op(
                DotInc(model.sig[conn]['in'],
                       model.sig['common'][1],
                       sig_in,
                       tag="%s input" % conn))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)
        if conn.solver.weights:
            # account for transform
            targets = np.dot(targets, transform.T)
            transform = np.array(1, dtype=rc.get('precision', 'dtype'))

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = model.Signal(decoders,
                                                   name="%s.decoders" % conn)
        signal = model.Signal(npext.castDecimal(np.zeros(signal_size)),
                              name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    if conn.modulatory:
        # Make a new signal, effectively detaching from post
        model.sig[conn]['out'] = model.Signal(npext.castDecimal(
            np.zeros(model.sig[conn]['out'].size)),
                                              name="%s.mod_output" % conn)
        model.add_op(Reset(model.sig[conn]['out']))

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = model.Signal(transform,
                                                name="%s.transform" % conn)
    print('abcd', model.sig[conn]['out'].value, signal.value)
    if transform.ndim < 2:
        print('line 174', model.sig[conn]['transform'].value)
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    if conn.learning_rule_type:
        # Forcing update of signal that is modified by learning rules.
        # Learning rules themselves apply DotIncs.

        if isinstance(conn.pre_obj, Neurons):
            modified_signal = model.sig[conn]['transform']
        elif isinstance(conn.pre_obj, Ensemble):
            if conn.solver.weights:
                # TODO: make less hacky.
                # Have to do this because when a weight_solver
                # is provided, then learning rules should operators on
                # "decoders" which is really the weight matrix.
                model.sig[conn]['transform'] = model.sig[conn]['decoders']
                modified_signal = model.sig[conn]['transform']
            else:
                modified_signal = model.sig[conn]['decoders']
        else:
            raise TypeError(
                "Can't apply learning rules to connections of "
                "this type. pre type: %s, post type: %s" %
                (type(conn.pre_obj).__name__, type(conn.post_obj).__name__))

        model.add_op(PreserveValue(modified_signal))

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
Esempio n. 21
0
def test_full_transform():
    N = 30

    with nengo.Network():
        neurons3 = nengo.Ensemble(3, dimensions=1).neurons
        ens1 = nengo.Ensemble(N, dimensions=1)
        ens2 = nengo.Ensemble(N, dimensions=2)
        ens3 = nengo.Ensemble(N, dimensions=3)
        node1 = nengo.Node(output=[0])
        node2 = nengo.Node(output=[0, 0])
        node3 = nengo.Node(output=[0, 0, 0])

        # Pre slice with default transform -> 1x3 transform
        conn = nengo.Connection(node3[2], ens1)
        assert np.all(conn.transform.init == np.array(1))
        assert np.all(full_transform(conn) == np.array([[0, 0, 1]]))

        # Post slice with 1x1 transform -> 1x2 transform
        conn = nengo.Connection(node2[0], ens1, transform=-2)
        assert np.all(conn.transform.init == np.array(-2))
        assert np.all(full_transform(conn) == np.array([[-2, 0]]))

        # Post slice with 2x1 tranfsorm -> 3x1 transform
        conn = nengo.Connection(node1, ens3[::2], transform=[[1], [2]])
        assert np.all(conn.transform.init == np.array([[1], [2]]))
        assert np.all(full_transform(conn) == np.array([[1], [0], [2]]))

        # Both slices with 2x1 transform -> 3x2 transform
        conn = nengo.Connection(ens2[-1], neurons3[1:], transform=[[1], [2]])
        assert np.all(conn.transform.init == np.array([[1], [2]]))
        assert np.all(
            full_transform(conn) == np.array([[0, 0], [0, 1], [0, 2]]))

        # Full slices that can be optimized away
        conn = nengo.Connection(ens3[:], ens3, transform=2)
        assert np.all(conn.transform.init == np.array(2))
        assert np.all(full_transform(conn) == np.array(2))

        # Pre slice with 1x1 transform on 2x2 slices -> 2x3 transform
        conn = nengo.Connection(neurons3[:2], ens2, transform=-1)
        assert np.all(conn.transform.init == np.array(-1))
        assert np.all(
            full_transform(conn) == np.array([[-1, 0, 0], [0, -1, 0]]))

        # Both slices with 1x1 transform on 2x2 slices -> 3x3 transform
        conn = nengo.Connection(neurons3[1:], neurons3[::2], transform=-1)
        assert np.all(conn.transform.init == np.array(-1))
        assert np.all(
            full_transform(conn) == np.array([[0, -1, 0], [0, 0, 0],
                                              [0, 0, -1]]))

        # Both slices with 2x2 transform -> 3x3 transform
        conn = nengo.Connection(node3[[0, 2]],
                                neurons3[1:],
                                transform=[[1, 2], [3, 4]])
        assert np.all(conn.transform.init == np.array([[1, 2], [3, 4]]))
        assert np.all(
            full_transform(conn) == np.array([[0, 0, 0], [1, 0, 2], [3, 0, 4]
                                              ]))

        # Both slices with 2x3 transform -> 3x3 transform... IN REVERSE!
        conn = nengo.Connection(neurons3[::-1],
                                neurons3[[2, 0]],
                                transform=[[1, 2, 3], [4, 5, 6]])
        assert np.all(conn.transform.init == np.array([[1, 2, 3], [4, 5, 6]]))
        assert np.all(
            full_transform(conn) == np.array([[6, 5, 4], [0, 0, 0], [3, 2, 1]
                                              ]))

        # Both slices using lists
        conn = nengo.Connection(neurons3[[1, 0, 2]],
                                neurons3[[2, 1]],
                                transform=[[1, 2, 3], [4, 5, 6]])
        assert np.all(conn.transform.init == np.array([[1, 2, 3], [4, 5, 6]]))
        assert np.all(
            full_transform(conn) == np.array([[0, 0, 0], [5, 4, 6], [2, 1, 3]
                                              ]))

        # using vector
        conn = nengo.Connection(ens3[[1, 0, 2]],
                                ens3[[2, 0, 1]],
                                transform=[1, 2, 3])
        assert np.all(conn.transform.init == np.array([1, 2, 3]))
        assert np.all(
            full_transform(conn) == np.array([[2, 0, 0], [0, 0, 3], [0, 1, 0]
                                              ]))

        # using vector 1D
        conn = nengo.Connection(ens1, ens1, transform=[5])
        assert full_transform(conn).ndim != 1
        assert np.all(full_transform(conn) == 5)

        # using vector and lists
        conn = nengo.Connection(ens3[[1, 0, 2]],
                                ens3[[2, 0, 1]],
                                transform=[1, 2, 3])
        assert np.all(conn.transform.init == np.array([1, 2, 3]))
        assert np.all(
            full_transform(conn) == np.array([[2, 0, 0], [0, 0, 3], [0, 1, 0]
                                              ]))

        # using multi-index lists
        conn = nengo.Connection(ens3, ens2[[0, 1, 0]])
        assert np.all(full_transform(conn) == np.array([[1, 0, 1], [0, 1, 0]]))
Esempio n. 22
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            fn = ((lambda x: x[conn.pre_slice]) if conn.function is None else
                  (lambda x: conn.function(x[conn.pre_slice])))
            model.add_op(
                SimPyFunc(output=signal,
                          fn=fn,
                          t_in=False,
                          x=model.sig[conn]['in']))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)

        if conn.solver.weights:
            # include transform in solved weights
            targets = np.dot(targets, transform.T)
            transform = np.array(1., dtype=np.float64)

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = Signal(decoders,
                                             name="%s.decoders" % conn)
        signal = Signal(np.zeros(signal_size), name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = Signal(transform,
                                          name="%s.transform" % conn)
    if transform.ndim < 2:
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    # Build learning rules
    if conn.learning_rule:
        if isinstance(conn.pre_obj, Ensemble):
            model.add_op(PreserveValue(model.sig[conn]['decoders']))
        else:
            model.add_op(PreserveValue(model.sig[conn]['transform']))

        if isinstance(conn.pre_obj, Ensemble) and conn.solver.weights:
            # TODO: make less hacky.
            # Have to do this because when a weight_solver
            # is provided, then learning rules should operate on
            # "decoders" which is really the weight matrix.
            model.sig[conn]['transform'] = model.sig[conn]['decoders']

        rule = conn.learning_rule
        if is_iterable(rule):
            for r in itervalues(rule) if isinstance(rule, dict) else rule:
                model.build(r)
        elif rule is not None:
            model.build(rule)

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
Esempio n. 23
0
def build_nrn_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Check pre-conditions
    assert isinstance(conn.pre, nengo.Ensemble)
    assert not isinstance(conn.pre.neuron_type, nengo.neurons.Direct)
    # FIXME assert no rate neurons are used. How to do that?

    # Get input signal
    # FIXME this should probably be
    # model.sig[conn]['in'] = model.sig[conn.pre]["out"]
    # in both cases
    if isinstance(conn.pre, nengo.ensemble.Neurons):
        model.sig[conn]['in'] = model.sig[conn.pre.ensemble]['out']
    else:
        model.sig[conn]['in'] = model.sig[conn.pre]["out"]

    # Figure out type of connection
    if isinstance(conn.post, nengo.ensemble.Neurons):
        raise NotImplementedError()  # TODO
    elif isinstance(conn.post.neuron_type, Compartmental):
        pass
    else:
        raise AssertionError(
            "This function should only be called if post neurons are "
            "compartmental.")

    # Solve for weights
    # FIXME just assuming solver is a weight solver, may that break?
    # Default solver should probably also produce sparse solutions for
    # performance reasons
    eval_points, activities, targets = build_linear_system(model,
                                                           conn,
                                                           rng=rng)

    # account for transform
    transform = full_transform(conn)
    targets = np.dot(targets, transform.T)

    weights, solver_info = conn.solver(
        activities,
        targets,
        rng=rng,
        E=model.params[conn.post].scaled_encoders.T)

    # Synapse type
    synapse = conn.synapse
    if is_number(synapse):
        synapse = ExpSyn(synapse)

    # Connect
    # TODO: Why is this adjustment of the weights necessary?
    weights = weights / synapse.tau / 5. * .1
    connections = [[] for i in range(len(weights))]
    for i, cell in enumerate(ens_to_cells[conn.post]):
        for j, w in enumerate(weights[:, i]):
            if w >= 0.0:
                x = np.random.rand()
                connections[j].append(
                    synapse.create(cell.neuron.apical(x), w * (x + 1)))
            else:
                connections[j].append(synapse.create(cell.neuron.soma(0.5), w))

    # 3. Add operator creating events for synapses if pre neuron fired
    model.add_op(NrnTransmitSpikes(model.sig[conn]['in'], connections))