Beispiel #1
0
    def add_default_output_vector(self, output_vector, output_name='output',
                                  n_neurons=50, min_activation_value=0.5):
        # Default configuration to use for the ensembles
        default_ens_config = nengo.Config(nengo.Ensemble)
        default_ens_config[nengo.Ensemble].radius = 1
        default_ens_config[nengo.Ensemble].intercepts = \
            ClippedExpDist(self.exp_scale, 0.0, 1.0)
        default_ens_config[nengo.Ensemble].encoders = Choice([[1]])
        default_ens_config[nengo.Ensemble].eval_points = Uniform(0.0, 1.0)
        default_ens_config[nengo.Ensemble].n_eval_points = self.n_eval_points

        with nested(self, default_ens_config):
            default_vector_ens = nengo.Ensemble(n_neurons, 1,
                                                label=('Default %s vector' %
                                                       output_name))

            nengo.Connection(self.bias_node, default_vector_ens,
                             synapse=None)

            if self.thresh_ens is not None:
                c = nengo.Connection(self.thresh_ens.output,
                                     default_vector_ens,
                                     transform=(-(1.0 / min_activation_value) *
                                                np.ones((1, self.num_items))))
            else:
                c = nengo.Connection(self.elem_output, default_vector_ens,
                                     transform=(-(1.0 / min_activation_value) *
                                                np.ones((1, self.num_items))))

            self.default_output_utility = default_vector_ens
            self.default_output_thresholded_utility = default_vector_ens

            # Add the output connection to the output connection list
            self.default_vector_inhibit_conns.append(c)

            # Make new output class attribute and connect to it
            output = getattr(self, output_name)
            nengo.Connection(default_vector_ens, output,
                             transform=np.matrix(output_vector).T,
                             synapse=None)

            if self.inhibit is not None:
                nengo.Connection(self.inhibit, default_vector_ens,
                                 transform=-1.0, synapse=None)
Beispiel #2
0
def Product(n_neurons, dimensions, input_magnitude=1, config=None, net=None):
    """Computes the element-wise product of two equally sized vectors.

    The network used to calculate the product is described in
    `Precise multiplications with the NEF
    <http://nbviewer.ipython.org/github/ctn-archive/technical-reports/blob/master/Precise-multiplications-with-the-NEF.ipynb#An-alternative-network>`_.

    A simpler version of this network can be found in the `Multiplication
    example <http://pythonhosted.org/nengo/examples/multiplication.html>`_.
    """
    if net is None:
        net = nengo.Network(label="Product")

    if config is None:
        config = nengo.Config(nengo.Ensemble)

    with nested(net, config):
        net.A = nengo.Node(size_in=dimensions, label="A")
        net.B = nengo.Node(size_in=dimensions, label="B")
        net.output = nengo.Node(size_in=dimensions, label="output")

        net.sq1 = EnsembleArray(max(1, n_neurons // 2),
                                n_ensembles=dimensions,
                                ens_dimensions=1,
                                radius=input_magnitude * np.sqrt(2))
        net.sq2 = EnsembleArray(max(1, n_neurons // 2),
                                n_ensembles=dimensions,
                                ens_dimensions=1,
                                radius=input_magnitude * np.sqrt(2))

        tr = 1. / np.sqrt(2.)
        nengo.Connection(net.A, net.sq1.input, transform=tr, synapse=None)
        nengo.Connection(net.B, net.sq1.input, transform=tr, synapse=None)
        nengo.Connection(net.A, net.sq2.input, transform=tr, synapse=None)
        nengo.Connection(net.B, net.sq2.input, transform=-tr, synapse=None)

        sq1_out = net.sq1.add_output('square', np.square)
        nengo.Connection(sq1_out, net.output, transform=.5)
        sq2_out = net.sq2.add_output('square', np.square)
        nengo.Connection(sq2_out, net.output, transform=-.5)

    return net
Beispiel #3
0
def Product(n_neurons, dimensions, input_magnitude=1, config=None, net=None):
    """Computes the element-wise product of two equally sized vectors."""
    if net is None:
        net = nengo.Network(label="Product")

    if config is None:
        config = nengo.Config(nengo.Ensemble)
        config[nengo.Ensemble].encoders = Choice(
            [[1, 1], [1, -1], [-1, 1], [-1, -1]])

    with nested(net, config):
        net.A = nengo.Node(size_in=dimensions, label="A")
        net.B = nengo.Node(size_in=dimensions, label="B")
        net.output = nengo.Node(size_in=dimensions, label="output")

        net.product = EnsembleArray(n_neurons, n_ensembles=dimensions,
                                    ens_dimensions=2,
                                    radius=input_magnitude * np.sqrt(2))
        nengo.Connection(net.A, net.product.input[::2], synapse=None)
        nengo.Connection(net.B, net.product.input[1::2], synapse=None)
        net.output = net.product.add_output('product', lambda x: x[0] * x[1])
    return net
Beispiel #4
0
def Product(n_neurons, dimensions, input_magnitude=1, config=None, net=None):
    """Computes the element-wise product of two equally sized vectors."""
    if net is None:
        net = nengo.Network(label="Product")

    if config is None:
        config = nengo.Config(nengo.Ensemble)
        config[nengo.Ensemble].encoders = Choice(
            [[1, 1], [1, -1], [-1, 1], [-1, -1]])

    with nested(net, config):
        net.A = nengo.Node(size_in=dimensions, label="A")
        net.B = nengo.Node(size_in=dimensions, label="B")
        net.output = nengo.Node(size_in=dimensions, label="output")

        net.product = EnsembleArray(n_neurons, n_ensembles=dimensions,
                                    ens_dimensions=2,
                                    radius=input_magnitude * np.sqrt(2))
        nengo.Connection(net.A, net.product.input[::2], synapse=None)
        nengo.Connection(net.B, net.product.input[1::2], synapse=None)
        net.output = net.product.add_output('product', lambda x: x[0] * x[1])

    return net
Beispiel #5
0
    def __init__(self, input_vectors, output_vectors=None,  # noqa: C901
                 n_neurons=50, threshold=0.3, input_scales=1.0,
                 inhibitable=False, inhibit_scale=1.5, label=None, seed=None,
                 add_to_container=None):
        super(AssociativeMemory, self).__init__(label, seed, add_to_container)

        # Label prefix for all the ensemble labels
        label_prefix = "" if label is None else label + "_"

        # If output vocabulary is not specified, use input vector list
        # (i.e autoassociative memory)
        if output_vectors is None:
            output_vectors = input_vectors

        # Handle different vector list types
        if is_iterable(input_vectors):
            input_vectors = np.matrix(input_vectors)

        if is_iterable(output_vectors):
            output_vectors = np.matrix(output_vectors)

        # Fail if number of input items and number of output items don't
        # match
        if input_vectors.shape[0] != output_vectors.shape[0]:
            raise ValueError(
                'Number of input vectors does not match number of output '
                'vectors. %d != %d'
                % (input_vectors.shape[0], output_vectors.shape[0]))

        # Handle possible different threshold / input_scale values for each
        # element in the associative memory
        if not is_iterable(threshold):
            threshold = np.array([threshold] * input_vectors.shape[0])
        else:
            threshold = np.array(threshold)
        if threshold.shape[0] != input_vectors.shape[0]:
            raise ValueError(
                'Number of threshold values do not match number of input'
                'vectors. Got: %d, expected %d.' %
                (threshold.shape[0], input_vectors.shape[0]))

        # Handle scaling of each input vector
        if not is_iterable(input_scales):
            input_scale = np.matrix([input_scales] * input_vectors.shape[0])
        else:
            input_scale = np.matrix(input_scale)
        if input_scale.shape[1] != input_vectors.shape[0]:
            raise ValueError(
                'Number of input_scale values do not match number of input'
                'vectors. Got: %d, expected %d.' %
                (input_scale.shape[1], input_vectors.shape[0]))

        # Input and output nodes
        N = input_vectors.shape[0]
        self.num_items = N

        # Scaling factor for exponential distribution and filtered step
        # function
        self.exp_scale = 0.15
        filt_scale = 15
        self.filt_step_func = \
            lambda x: filtered_step(x, 0.0, scale=filt_scale)

        # Evaluation points parameters
        self.n_eval_points = 5000

        # Default configuration to use for the ensembles
        am_ens_config = nengo.Config(nengo.Ensemble)
        am_ens_config[nengo.Ensemble].radius = 1
        am_ens_config[nengo.Ensemble].intercepts = \
            ClippedExpDist(self.exp_scale, 0.0, 1.0)
        am_ens_config[nengo.Ensemble].encoders = Choice([[1]])
        am_ens_config[nengo.Ensemble].eval_points = Uniform(0.0, 1.0)
        am_ens_config[nengo.Ensemble].n_eval_points = self.n_eval_points

        # Store output connections (need to redo them if output thresholding
        # is added by the user - see add_threshold_to_output)
        self.out_conns = []

        # Store the inhibitory connections from elem_output to the default
        # vector ensembles (need to redo them if output thresholding
        # is added by the user - see add_threshold_to_output)
        self.default_vector_inhibit_conns = []

        # Flag to indicate if the am network is using thresholded outputs
        self.thresh_ens = None

        # Flag to indicate if the am network is configured with wta
        self._using_wta = False

        # Create the associative memory network
        with nested(self, am_ens_config):
            self.bias_node = nengo.Node(output=1)

            self.input = nengo.Node(size_in=input_vectors.shape[1],
                                    label="input")
            self.output = nengo.Node(size_in=output_vectors.shape[1],
                                     label="output")

            self.elem_input = nengo.Node(size_in=N, label="element input")
            self.elem_output = nengo.Node(size_in=N, label="element output")

            nengo.Connection(self.input, self.elem_input, synapse=None,
                             transform=np.multiply(input_vectors,
                                                   input_scale.T))

            # Make each ensemble
            self.am_ensembles = []
            for i in range(N):
                # Create ensemble
                e = nengo.Ensemble(n_neurons, 1, label=label_prefix + str(i))
                self.am_ensembles.append(e)

                # Connect input and output nodes
                nengo.Connection(self.bias_node, e, transform=-threshold[i],
                                 synapse=None)
                nengo.Connection(self.elem_input[i], e, synapse=None)
                nengo.Connection(e, self.elem_output[i],
                                 function=self.filt_step_func, synapse=None)

            # Configure associative memory to be inhibitable
            if inhibitable:
                # Input node for inhibitory gating signal (if enabled)
                self.inhibit = nengo.Node(size_in=1, label="inhibit")
                nengo.Connection(self.inhibit, self.elem_input,
                                 transform=-np.ones((N, 1)) * inhibit_scale,
                                 synapse=None)
                # Note: We can use decoded connection here because all the
                # encoding vectors are [1]
            else:
                self.inhibit = None

            # Configure utilities output
            self.utilities = self.elem_output

            c = nengo.Connection(self.elem_output, self.output,
                                 transform=output_vectors.T, synapse=None)

            # Add the output connection to the output connection list
            self.out_conns.append(c)
Beispiel #6
0
    def add_threshold_to_outputs(self, n_neurons=50, inhibit_scale=10):
        if self.thresh_ens is not None:
            warnings.warn('AssociativeMemory network is already configured ' +
                          'with thresholded outputs. Additional ' +
                          'add_threshold_to_output function calls are ' +
                          'ignored.')
            return

        # Default configuration to use for the ensembles
        thresh_ens_config = nengo.Config(nengo.Ensemble)
        thresh_ens_config[nengo.Ensemble].radius = 1
        thresh_ens_config[nengo.Ensemble].intercepts = Uniform(0.5, 1.0)
        thresh_ens_config[nengo.Ensemble].encoders = Choice([[1]])
        thresh_ens_config[nengo.Ensemble].eval_points = Uniform(0.75, 1.1)
        thresh_ens_config[nengo.Ensemble].n_eval_points = self.n_eval_points

        with nested(self, thresh_ens_config):
            self.thresh_bias = EnsembleArray(n_neurons, self.num_items,
                                             label='thresh_bias')
            self.thresh_ens = EnsembleArray(n_neurons, self.num_items,
                                            label='thresh_ens')

            nengo.Connection(self.bias_node, self.thresh_bias.input,
                             transform=np.ones((self.num_items, 1)),
                             synapse=None)
            nengo.Connection(self.bias_node, self.thresh_ens.input,
                             transform=np.ones((self.num_items, 1)),
                             synapse=None)
            nengo.Connection(self.elem_output, self.thresh_bias.input,
                             transform=-inhibit_scale)
            nengo.Connection(self.thresh_bias.output, self.thresh_ens.input,
                             transform=-inhibit_scale)

            self.thresholded_utilities = self.thresh_ens.output

            # Reroute the thresh_ens output to default vector ensembles,
            # and remove the original connections
            conn_list = []
            for conn in self.default_vector_inhibit_conns:
                c = nengo.Connection(self.thresh_ens.output, conn.post,
                                     transform=conn.transform,
                                     synapse=conn.synapse)
                self.connections.remove(conn)
                conn_list.append(c)
            self.default_vector_inhibit_conns = conn_list

            # Reroute the thresh_ens output to the output nodes, and remove the
            # original connections
            conn_list = []
            for conn in self.out_conns:
                c = nengo.Connection(self.thresh_ens.output, conn.post,
                                     transform=conn.transform,
                                     synapse=conn.synapse)
                self.connections.remove(conn)
                conn_list.append(c)
            self.out_conns = conn_list

            # Make inhibitory connection if inhibit option is set
            if self.inhibit is not None:
                for e in self.thresh_ens.ensembles:
                    nengo.Connection(self.inhibit, e,
                                     transform=-1.5, synapse=None)
Beispiel #7
0
    def __init__(self, input_vectors, output_vectors=None,  # noqa: C901
                 n_neurons=50, threshold=0.3, input_scales=1.0,
                 inhibitable=False,
                 label=None, seed=None, add_to_container=None):
        super(AssociativeMemory, self).__init__(label, seed, add_to_container)

        # --- Put arguments in canonical form
        if output_vectors is None:
            # If output vocabulary is not specified, use input vector list
            # (i.e autoassociative memory)
            output_vectors = input_vectors
        if is_iterable(input_vectors):
            input_vectors = np.array(input_vectors, ndmin=2)
        if is_iterable(output_vectors):
            output_vectors = np.array(output_vectors, ndmin=2)

        if input_vectors.shape[0] == 0:
            raise ValueError('Number of input vectors cannot be 0.')
        elif input_vectors.shape[0] != output_vectors.shape[0]:
            # Fail if number of input items and number of output items don't
            # match
            raise ValueError(
                'Number of input vectors does not match number of output '
                'vectors. %d != %d'
                % (input_vectors.shape[0], output_vectors.shape[0]))

        # Handle possible different threshold / input_scale values for each
        # element in the associative memory
        if not is_iterable(threshold):
            threshold = threshold * np.ones(input_vectors.shape[0])
        else:
            threshold = np.array(threshold)

        # --- Check preconditions
        self.n_items = input_vectors.shape[0]
        if self.n_items != output_vectors.shape[0]:
            raise ValueError(
                "Number of input vectors (%d) does not match number of output "
                "vectors (%d)" % (self.n_items, output_vectors.shape[0]))
        if threshold.shape[0] != self.n_items:
            raise ValueError(
                "Number of threshold values (%d) does not match number of "
                "input vectors (%d)." % (threshold.shape[0], self.n_items))

        # --- Set parameters
        self.out_conns = []  # Used in `add_threshold_to_output`
        # Used in `add_threshold_to_output`
        self.default_vector_inhibit_conns = []
        self.thresh_ens = None  # Will hold thresholded outputs
        self.is_wta = False
        self._inhib_scale = 1.5

        # -- Create the core network
        with nested(self, self.am_ens_config):
            self.bias_node = nengo.Node(output=1)
            self.elem_input = nengo.Node(
                size_in=self.n_items, label="element input")
            self.elem_output = nengo.Node(
                size_in=self.n_items, label="element output")
            self.utilities = self.elem_output

            self.am_ensembles = []
            label_prefix = "" if label is None else label + "_"
            filt_scale = 15
            filt_step_func = lambda x: filtered_step(x, 0.0, scale=filt_scale)
            for i in range(self.n_items):
                e = nengo.Ensemble(n_neurons, 1, label=label_prefix + str(i))
                self.am_ensembles.append(e)

                # Connect input and output nodes
                nengo.Connection(self.bias_node, e, transform=-threshold[i])
                nengo.Connection(self.elem_input[i], e)
                nengo.Connection(
                    e, self.elem_output[i], function=filt_step_func)

            if inhibitable:
                # Input node for inhibitory gating signal (if enabled)
                self.inhibit = nengo.Node(size_in=1, label="inhibit")
                nengo.Connection(self.inhibit, self.elem_input,
                                 transform=-np.ones((self.n_items, 1))
                                 * self._inhib_scale)
                # Note: We can use a decoded connection here because all the
                # am_ensembles have [1] encoders
            else:
                self.inhibit = None
        self.add_input_mapping("input", input_vectors, input_scales)
        self.add_output_mapping("output", output_vectors)
Beispiel #8
0
def BasalGanglia(dimensions, n_neurons_per_ensemble=100, output_weight=-3,
                 input_bias=0.0, ampa_config=None, gaba_config=None, net=None):
    """Winner takes all; outputs 0 at max dimension, negative elsewhere."""

    if net is None:
        net = nengo.Network("Basal Ganglia")

    ampa_config, override_ampa = config_with_default_synapse(
        ampa_config, nengo.Lowpass(0.002))
    gaba_config, override_gaba = config_with_default_synapse(
        gaba_config, nengo.Lowpass(0.008))

    # Affects all ensembles / connections in the BG
    # unless they've been overridden on `net.config`
    config = nengo.Config(nengo.Ensemble, nengo.Connection)
    config[nengo.Ensemble].radius = 1.5
    config[nengo.Ensemble].encoders = Choice([[1]])
    try:
        # Best, if we have SciPy
        config[nengo.Connection].solver = NnlsL2nz()
    except ImportError:
        # Warn if we can't use the better decoder solver.
        warnings.warn("SciPy is not installed, so BasalGanglia will "
                      "use the default decoder solver. Installing SciPy "
                      "may improve BasalGanglia performance.")

    ea_params = {'n_neurons': n_neurons_per_ensemble,
                 'n_ensembles': dimensions}

    with nested(config, net):
        net.strD1 = EnsembleArray(label="Striatal D1 neurons",
                                  intercepts=Uniform(Weights.e, 1),
                                  **ea_params)
        net.strD2 = EnsembleArray(label="Striatal D2 neurons",
                                  intercepts=Uniform(Weights.e, 1),
                                  **ea_params)
        net.stn = EnsembleArray(label="Subthalamic nucleus",
                                intercepts=Uniform(Weights.ep, 1),
                                **ea_params)
        net.gpi = EnsembleArray(label="Globus pallidus internus",
                                intercepts=Uniform(Weights.eg, 1),
                                **ea_params)
        net.gpe = EnsembleArray(label="Globus pallidus externus",
                                intercepts=Uniform(Weights.ee, 1),
                                **ea_params)

        net.input = nengo.Node(label="input", size_in=dimensions)
        net.output = nengo.Node(label="output", size_in=dimensions)

        # add bias input (BG performs best in the range 0.5--1.5)
        if abs(input_bias) > 0.0:
            net.bias_input = nengo.Node(np.ones(dimensions) * input_bias)
            nengo.Connection(net.bias_input, net.input)

        # spread the input to StrD1, StrD2, and STN
        nengo.Connection(net.input, net.strD1.input, synapse=None,
                         transform=Weights.ws * (1 + Weights.lg))
        nengo.Connection(net.input, net.strD2.input, synapse=None,
                         transform=Weights.ws * (1 - Weights.le))
        nengo.Connection(net.input, net.stn.input, synapse=None,
                         transform=Weights.wt)

        # connect the striatum to the GPi and GPe (inhibitory)
        strD1_output = net.strD1.add_output('func_str', Weights.str_func)
        strD2_output = net.strD2.add_output('func_str', Weights.str_func)
        with gaba_config:
            nengo.Connection(strD1_output, net.gpi.input,
                             transform=-Weights.wm)
            nengo.Connection(strD2_output, net.gpe.input,
                             transform=-Weights.wm)

        # connect the STN to GPi and GPe (broad and excitatory)
        tr = Weights.wp * np.ones((dimensions, dimensions))
        stn_output = net.stn.add_output('func_stn', Weights.stn_func)
        with ampa_config:
            nengo.Connection(stn_output, net.gpi.input, transform=tr)
            nengo.Connection(stn_output, net.gpe.input, transform=tr)

        # connect the GPe to GPi and STN (inhibitory)
        gpe_output = net.gpe.add_output('func_gpe', Weights.gpe_func)
        with gaba_config:
            nengo.Connection(gpe_output, net.gpi.input, transform=-Weights.we)
            nengo.Connection(gpe_output, net.stn.input, transform=-Weights.wg)

        # connect GPi to output (inhibitory)
        gpi_output = net.gpi.add_output('func_gpi', Weights.gpi_func)
        nengo.Connection(gpi_output, net.output, synapse=None,
                         transform=output_weight)

    # Return ampa_config and gaba_config to previous states, if changed
    if override_ampa:
        del ampa_config[nengo.Connection].synapse
    if override_gaba:
        del gaba_config[nengo.Connection].synapse

    return net
def BasalGanglia(dimensions,
                 n_neurons_per_ensemble=100,
                 output_weight=-3,
                 input_bias=0.0,
                 ampa_config=None,
                 gaba_config=None,
                 net=None):
    """Winner takes all; outputs 0 at max dimension, negative elsewhere."""

    if net is None:
        net = nengo.Network("Basal Ganglia")

    ampa_config, override_ampa = config_with_default_synapse(
        ampa_config, nengo.Lowpass(0.002))
    gaba_config, override_gaba = config_with_default_synapse(
        gaba_config, nengo.Lowpass(0.008))

    # Affects all ensembles / connections in the BG
    # unless they've been overridden on `net.config`
    config = nengo.Config(nengo.Ensemble, nengo.Connection)
    config[nengo.Ensemble].radius = 1.5
    config[nengo.Ensemble].encoders = Choice([[1]])
    try:
        # Best, if we have SciPy
        config[nengo.Connection].solver = NnlsL2nz()
    except ImportError:
        # Warn if we can't use the better decoder solver.
        warnings.warn("SciPy is not installed, so BasalGanglia will "
                      "use the default decoder solver. Installing SciPy "
                      "may improve BasalGanglia performance.")

    ea_params = {
        'n_neurons': n_neurons_per_ensemble,
        'n_ensembles': dimensions
    }

    with nested(config, net):
        net.strD1 = EnsembleArray(label="Striatal D1 neurons",
                                  intercepts=Uniform(Weights.e, 1),
                                  **ea_params)
        net.strD2 = EnsembleArray(label="Striatal D2 neurons",
                                  intercepts=Uniform(Weights.e, 1),
                                  **ea_params)
        net.stn = EnsembleArray(label="Subthalamic nucleus",
                                intercepts=Uniform(Weights.ep, 1),
                                **ea_params)
        net.gpi = EnsembleArray(label="Globus pallidus internus",
                                intercepts=Uniform(Weights.eg, 1),
                                **ea_params)
        net.gpe = EnsembleArray(label="Globus pallidus externus",
                                intercepts=Uniform(Weights.ee, 1),
                                **ea_params)

        net.input = nengo.Node(label="input", size_in=dimensions)
        net.output = nengo.Node(label="output", size_in=dimensions)

        # add bias input (BG performs best in the range 0.5--1.5)
        if abs(input_bias) > 0.0:
            net.bias_input = nengo.Node(np.ones(dimensions) * input_bias,
                                        label="basal ganglia bias")
            nengo.Connection(net.bias_input, net.input)

        # spread the input to StrD1, StrD2, and STN
        nengo.Connection(net.input,
                         net.strD1.input,
                         synapse=None,
                         transform=Weights.ws * (1 + Weights.lg))
        nengo.Connection(net.input,
                         net.strD2.input,
                         synapse=None,
                         transform=Weights.ws * (1 - Weights.le))
        nengo.Connection(net.input,
                         net.stn.input,
                         synapse=None,
                         transform=Weights.wt)

        # connect the striatum to the GPi and GPe (inhibitory)
        strD1_output = net.strD1.add_output('func_str', Weights.str_func)
        strD2_output = net.strD2.add_output('func_str', Weights.str_func)
        with gaba_config:
            nengo.Connection(strD1_output,
                             net.gpi.input,
                             transform=-Weights.wm)
            nengo.Connection(strD2_output,
                             net.gpe.input,
                             transform=-Weights.wm)

        # connect the STN to GPi and GPe (broad and excitatory)
        tr = Weights.wp * np.ones((dimensions, dimensions))
        stn_output = net.stn.add_output('func_stn', Weights.stn_func)
        with ampa_config:
            nengo.Connection(stn_output, net.gpi.input, transform=tr)
            nengo.Connection(stn_output, net.gpe.input, transform=tr)

        # connect the GPe to GPi and STN (inhibitory)
        gpe_output = net.gpe.add_output('func_gpe', Weights.gpe_func)
        with gaba_config:
            nengo.Connection(gpe_output, net.gpi.input, transform=-Weights.we)
            nengo.Connection(gpe_output, net.stn.input, transform=-Weights.wg)

        # connect GPi to output (inhibitory)
        gpi_output = net.gpi.add_output('func_gpi', Weights.gpi_func)
        nengo.Connection(gpi_output,
                         net.output,
                         synapse=None,
                         transform=output_weight)

    # Return ampa_config and gaba_config to previous states, if changed
    if override_ampa:
        del ampa_config[nengo.Connection].synapse
    if override_gaba:
        del gaba_config[nengo.Connection].synapse

    return net