예제 #1
0
def test_signal_init_values(RefSimulator):
    """Tests that initial values are not overwritten."""
    zero = Signal([0])
    one = Signal([1])
    five = Signal([5.0])
    zeroarray = Signal([[0], [0], [0]])
    array = Signal([1, 2, 3])

    m = Model(dt=0)
    m.operators += [
        PreserveValue(five),
        PreserveValue(array),
        DotInc(zero, zero, five),
        DotInc(zeroarray, one, array)
    ]

    sim = RefSimulator(None, model=m)
    assert sim.signals[zero][0] == 0
    assert sim.signals[one][0] == 1
    assert sim.signals[five][0] == 5.0
    assert np.all(np.array([1, 2, 3]) == sim.signals[array])
    sim.step()
    assert sim.signals[zero][0] == 0
    assert sim.signals[one][0] == 1
    assert sim.signals[five][0] == 5.0
    assert np.all(np.array([1, 2, 3]) == sim.signals[array])
예제 #2
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    weights = None
    eval_points = None
    solver_info = None
    signal_size = conn.size_out
    post_slice = conn.post_slice

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)

        if conn.function is not None:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(
                SimPyFunc(output=in_signal,
                          fn=conn.function,
                          t_in=False,
                          x=sliced_in))
        else:
            in_signal = sliced_in

    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, decoders, solver_info = build_decoders(model, conn, rng)

        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = conn.post_obj.neurons.size_in
            post_slice = Ellipsis  # don't apply slice later
            weights = decoders.T
        else:
            weights = multiply(conn.transform, decoders.T)
    else:
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Add operator for applying weights
    if weights is None:
        weights = np.array(conn.transform)

    if isinstance(conn.post_obj, Neurons):
        gain = model.params[conn.post_obj.ensemble].gain[post_slice]
        weights = multiply(gain, weights)

    if conn.learning_rule is not None and weights.ndim < 2:
        raise ValueError("Learning connection must have full transform matrix")

    model.sig[conn]['weights'] = Signal(weights,
                                        name="%s.weights" % conn,
                                        readonly=True)
    signal = Signal(np.zeros(signal_size), name="%s.weighted" % conn)
    model.add_op(Reset(signal))
    op = ElementwiseInc if weights.ndim < 2 else DotInc
    model.add_op(
        op(model.sig[conn]['weights'],
           in_signal,
           signal,
           tag="%s.weights_elementwiseinc" % conn))

    # Add operator for filtering
    if conn.synapse is not None:
        signal = model.build(conn.synapse, signal)

    # Copy to the proper slice
    model.add_op(
        SlicedCopy(signal,
                   model.sig[conn]['out'],
                   b_slice=post_slice,
                   inc=True,
                   tag="%s.gain" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        model.sig[conn]['weights'].readonly = False
        model.add_op(PreserveValue(model.sig[conn]['weights']))

        rule = conn.learning_rule
        if is_iterable(rule):
            for r in itervalues(rule) if isinstance(rule, dict) else rule:
                model.build(r)
        elif rule is not None:
            model.build(rule)

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         weights=weights)
예제 #3
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            fn = ((lambda x: x[conn.pre_slice]) if conn.function is None else
                  (lambda x: conn.function(x[conn.pre_slice])))
            model.add_op(
                SimPyFunc(output=signal,
                          fn=fn,
                          t_in=False,
                          x=model.sig[conn]['in']))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)

        if conn.solver.weights:
            # include transform in solved weights
            targets = np.dot(targets, transform.T)
            transform = np.array(1., dtype=np.float64)

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = Signal(decoders,
                                             name="%s.decoders" % conn)
        signal = Signal(np.zeros(signal_size), name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = Signal(transform,
                                          name="%s.transform" % conn)
    if transform.ndim < 2:
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    # Build learning rules
    if conn.learning_rule:
        if isinstance(conn.pre_obj, Ensemble):
            model.add_op(PreserveValue(model.sig[conn]['decoders']))
        else:
            model.add_op(PreserveValue(model.sig[conn]['transform']))

        if isinstance(conn.pre_obj, Ensemble) and conn.solver.weights:
            # TODO: make less hacky.
            # Have to do this because when a weight_solver
            # is provided, then learning rules should operate on
            # "decoders" which is really the weight matrix.
            model.sig[conn]['transform'] = model.sig[conn]['decoders']

        rule = conn.learning_rule
        if is_iterable(rule):
            for r in itervalues(rule) if isinstance(rule, dict) else rule:
                model.build(r)
        elif rule is not None:
            model.build(rule)

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
예제 #4
0
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero." %
                (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    weights = None
    eval_points = None
    solver_info = None
    signal_size = conn.size_out
    post_slice = conn.post_slice

    # Sample transform if given a distribution
    transform = (conn.transform.sample(
        conn.size_out, conn.size_mid, rng=rng) if isinstance(
            conn.transform, Distribution) else np.array(conn.transform))

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        weights = transform
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is not None:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
        else:
            in_signal = sliced_in
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, weights, solver_info = build_decoders(
            model, conn, rng, transform)
        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = conn.post_obj.neurons.size_in
            post_slice = Ellipsis  # don't apply slice later
    else:
        weights = transform
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    if isinstance(conn.post_obj, Neurons):
        weights = multiply(
            model.params[conn.post_obj.ensemble].gain[post_slice], weights)

    # Add operator for applying weights
    model.sig[conn]['weights'] = Signal(weights,
                                        name="%s.weights" % conn,
                                        readonly=True)
    signal = Signal(np.zeros(signal_size), name="%s.weighted" % conn)
    model.add_op(Reset(signal))
    op = ElementwiseInc if weights.ndim < 2 else DotInc
    model.add_op(
        op(model.sig[conn]['weights'],
           in_signal,
           signal,
           tag="%s.weights_elementwiseinc" % conn))

    # Add operator for filtering
    if conn.synapse is not None:
        signal = model.build(conn.synapse, signal)

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]['weighted'] = signal

    # Copy to the proper slice
    model.add_op(
        SlicedCopy(signal,
                   model.sig[conn]['out'],
                   dst_slice=post_slice,
                   inc=True,
                   tag="%s.gain" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in itervalues(rule) if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if 'encoders' in targets:
            encoder_sig = model.sig[conn.post_obj]['encoders']
            if not any(
                    isinstance(op, PreserveValue) and op.dst is encoder_sig
                    for op in model.operators):
                encoder_sig.readonly = False
                model.add_op(PreserveValue(encoder_sig))
        if 'decoders' in targets or 'weights' in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]['weights'].readonly = False
            model.add_op(PreserveValue(model.sig[conn]['weights']))

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         transform=transform,
                                         weights=weights)
예제 #5
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        print('pre', conn.pre_obj, 'post', conn.post_obj)
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            sig_in, signal = build_pyfunc(
                fn=(lambda x: x[conn.pre_slice]) if conn.function is None else
                (lambda x: conn.function(x[conn.pre_slice])),
                t_in=False,
                n_in=model.sig[conn]['in'].size,
                n_out=conn.size_mid,
                label=str(conn),
                model=model)
            model.add_op(
                DotInc(model.sig[conn]['in'],
                       model.sig['common'][1],
                       sig_in,
                       tag="%s input" % conn))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)
        if conn.solver.weights:
            # account for transform
            targets = np.dot(targets, transform.T)
            transform = np.array(1, dtype=rc.get('precision', 'dtype'))

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = model.Signal(decoders,
                                                   name="%s.decoders" % conn)
        signal = model.Signal(npext.castDecimal(np.zeros(signal_size)),
                              name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    if conn.modulatory:
        # Make a new signal, effectively detaching from post
        model.sig[conn]['out'] = model.Signal(npext.castDecimal(
            np.zeros(model.sig[conn]['out'].size)),
                                              name="%s.mod_output" % conn)
        model.add_op(Reset(model.sig[conn]['out']))

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = model.Signal(transform,
                                                name="%s.transform" % conn)
    print('abcd', model.sig[conn]['out'].value, signal.value)
    if transform.ndim < 2:
        print('line 174', model.sig[conn]['transform'].value)
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    if conn.learning_rule_type:
        # Forcing update of signal that is modified by learning rules.
        # Learning rules themselves apply DotIncs.

        if isinstance(conn.pre_obj, Neurons):
            modified_signal = model.sig[conn]['transform']
        elif isinstance(conn.pre_obj, Ensemble):
            if conn.solver.weights:
                # TODO: make less hacky.
                # Have to do this because when a weight_solver
                # is provided, then learning rules should operators on
                # "decoders" which is really the weight matrix.
                model.sig[conn]['transform'] = model.sig[conn]['decoders']
                modified_signal = model.sig[conn]['transform']
            else:
                modified_signal = model.sig[conn]['decoders']
        else:
            raise TypeError(
                "Can't apply learning rules to connections of "
                "this type. pre type: %s, post type: %s" %
                (type(conn.pre_obj).__name__, type(conn.post_obj).__name__))

        model.add_op(PreserveValue(modified_signal))

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)