def build_ensemble(model, ens): """Builds an `.Ensemble` object into a model. A brief summary of what happens in the ensemble build process, in order: 1. Generate evaluation points and encoders. 2. Normalize encoders to unit length. 3. Determine bias and gain. 4. Create neuron input signal 5. Add operator for injecting bias. 6. Call build function for neuron type. 7. Scale encoders by gain and radius. 8. Add operators for multiplying decoded input signal by encoders and incrementing the result in the neuron input signal. 9. Call build function for injected noise. Some of these steps may be altered or omitted depending on the parameters of the ensemble, in particular the neuron type. For example, most steps are omitted for the `.Direct` neuron type. Parameters ---------- model : Model The model to build into. ens : Ensemble The ensemble to build. Notes ----- Sets ``model.params[ens]`` to a `.BuiltEnsemble` instance. """ # Create random number generator rng = np.random.RandomState(model.seeds[ens]) eval_points = gen_eval_points(ens, ens.eval_points, rng=rng) # Set up signal model.sig[ens]['in'] = Signal(np.zeros(ens.dimensions), name="%s.signal" % ens) model.add_op(Reset(model.sig[ens]['in'])) # Set up encoders if isinstance(ens.neuron_type, Direct): encoders = np.identity(ens.dimensions) elif isinstance(ens.encoders, Distribution): encoders = get_samples(ens.encoders, ens.n_neurons, ens.dimensions, rng=rng) else: encoders = npext.array(ens.encoders, min_dims=2, dtype=np.float64) if ens.normalize_encoders: encoders /= npext.norm(encoders, axis=1, keepdims=True) # Build the neurons gain, bias, max_rates, intercepts = get_gain_bias(ens, rng) if isinstance(ens.neuron_type, Direct): model.sig[ens.neurons]['in'] = Signal(np.zeros(ens.dimensions), name='%s.neuron_in' % ens) model.sig[ens.neurons]['out'] = model.sig[ens.neurons]['in'] model.add_op(Reset(model.sig[ens.neurons]['in'])) else: model.sig[ens.neurons]['in'] = Signal(np.zeros(ens.n_neurons), name="%s.neuron_in" % ens) model.sig[ens.neurons]['out'] = Signal(np.zeros(ens.n_neurons), name="%s.neuron_out" % ens) model.sig[ens.neurons]['bias'] = Signal(bias, name="%s.bias" % ens, readonly=True) model.add_op( Copy(model.sig[ens.neurons]['bias'], model.sig[ens.neurons]['in'])) # This adds the neuron's operator and sets other signals model.build(ens.neuron_type, ens.neurons) # Scale the encoders if isinstance(ens.neuron_type, Direct): scaled_encoders = encoders else: scaled_encoders = encoders * (gain / ens.radius)[:, np.newaxis] model.sig[ens]['encoders'] = Signal(scaled_encoders, name="%s.scaled_encoders" % ens, readonly=True) # Inject noise if specified if ens.noise is not None: model.build(ens.noise, sig_out=model.sig[ens.neurons]['in'], inc=True) # Create output signal, using built Neurons model.add_op( DotInc(model.sig[ens]['encoders'], model.sig[ens]['in'], model.sig[ens.neurons]['in'], tag="%s encoding" % ens)) # Output is neural output model.sig[ens]['out'] = model.sig[ens.neurons]['out'] model.params[ens] = BuiltEnsemble(eval_points=eval_points, encoders=encoders, intercepts=intercepts, max_rates=max_rates, scaled_encoders=scaled_encoders, gain=gain, bias=bias)
def build_connection(model, conn): # noqa: C901 """Builds a `.Connection` object into a model. A brief summary of what happens in the connection build process, in order: 1. Solve for decoders. 2. Combine transform matrix with decoders to get weights. 3. Add operators for computing the function or multiplying neural activity by weights. 4. Call build function for the synapse. 5. Call build function for the learning rule. 6. Add operator for applying learning rule delta to weights. Some of these steps may be altered or omitted depending on the parameters of the connection, in particular the pre and post types. Parameters ---------- model : Model The model to build into. conn : Connection The connection to build. Notes ----- Sets ``model.params[conn]`` to a `.BuiltConnection` instance. """ # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj key = "out" if is_pre else "in" if target not in model.sig: raise BuildError( f"Building {conn}: the '{'pre' if is_pre else 'post'}' object {target} " "is not in the model, or has a size of zero.") signal = model.sig[target].get(key, None) if signal is None or signal.size == 0: raise BuildError( f"Building {conn}: the '{'pre' if is_pre else 'post'}' object {target} " f"has a '{key}' size of zero.") return signal model.sig[conn]["in"] = get_prepost_signal(is_pre=True) model.sig[conn]["out"] = get_prepost_signal(is_pre=False) decoders = None encoders = None eval_points = None solver_info = None post_slice = conn.post_slice # Figure out the signal going across this connection in_signal = model.sig[conn]["in"] if isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct)): # Node or Decoded connection in directmode sliced_in = slice_signal(model, in_signal, conn.pre_slice) if conn.function is None: in_signal = sliced_in elif isinstance(conn.function, np.ndarray): raise BuildError("Cannot use function points in direct connection") else: in_signal = Signal(shape=conn.size_mid, name=f"{conn}.func") model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in)) elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, decoders, solver_info = model.build( conn.solver, conn, rng) if isinstance(conn.post_obj, Ensemble) and conn.solver.weights: model.sig[conn]["out"] = model.sig[conn.post_obj.neurons]["in"] encoders = model.params[conn.post_obj].scaled_encoders.T encoders = encoders[conn.post_slice] # post slice already applied to encoders (either here or in # `build_decoders`), so don't apply later post_slice = None else: in_signal = slice_signal(model, in_signal, conn.pre_slice) # Build transform if conn.solver.weights and not conn.solver.compositional: # special case for non-compositional weight solvers, where # the solver is solving for the full weight matrix. so we don't # need to combine decoders/transform/encoders. weighted, weights = model.build(Dense(decoders.shape, init=decoders), in_signal, rng=rng) else: weighted, weights = model.build(conn.transform, in_signal, decoders=decoders, encoders=encoders, rng=rng) model.sig[conn]["weights"] = weights # Build synapse if conn.synapse is not None: weighted = model.build(conn.synapse, weighted, mode="update") # Store the weighted-filtered output in case we want to probe it model.sig[conn]["weighted"] = weighted if isinstance(conn.post_obj, Neurons): # Apply neuron gains (we don't need to do this if we're connecting to # an Ensemble, because the gains are rolled into the encoders) gains = Signal( model.params[conn.post_obj.ensemble].gain[post_slice], name=f"{conn}.gains", ) if is_integer(post_slice) or isinstance(post_slice, slice): sliced_out = model.sig[conn]["out"][post_slice] else: # advanced indexing not supported on Signals, so we need to set up an # intermediate signal and use a Copy op to perform the indexing sliced_out = Signal(shape=gains.shape, name=f"{conn}.sliced_out") model.add_op(Reset(sliced_out)) model.add_op( Copy(sliced_out, model.sig[conn]["out"], dst_slice=post_slice, inc=True)) model.add_op( ElementwiseInc(gains, weighted, sliced_out, tag=f"{conn}.gains_elementwiseinc")) else: # Copy to the proper slice model.add_op( Copy( weighted, model.sig[conn]["out"], dst_slice=post_slice, inc=True, tag=f"{conn}", )) # Build learning rules if conn.learning_rule is not None: # TODO: provide a general way for transforms to expose learnable params if not isinstance(conn.transform, (Dense, NoTransform)): raise NotImplementedError( f"Learning on connections with {type(conn.transform).__name__} " "transforms is not supported") rule = conn.learning_rule rule = [rule] if not is_iterable(rule) else rule targets = [] for r in rule.values() if isinstance(rule, dict) else rule: model.build(r) targets.append(r.modifies) if "encoders" in targets: encoder_sig = model.sig[conn.post_obj]["encoders"] encoder_sig.readonly = False if "decoders" in targets or "weights" in targets: if weights.ndim < 2: raise BuildError( "'transform' must be a 2-dimensional array for learning") model.sig[conn]["weights"].readonly = False model.params[conn] = BuiltConnection( eval_points=eval_points, solver_info=solver_info, transform=conn.transform, weights=getattr(weights, "initial_value", None), )
def build_connection(model, conn): """Builds a `.Connection` object into a model. A brief summary of what happens in the connection build process, in order: 1. Solve for decoders. 2. Combine transform matrix with decoders to get weights. 3. Add operators for computing the function or multiplying neural activity by weights. 4. Call build function for the synapse. 5. Call build function for the learning rule. 6. Add operator for applying learning rule delta to weights. Some of these steps may be altered or omitted depending on the parameters of the connection, in particular the pre and post types. Parameters ---------- model : Model The model to build into. conn : Connection The connection to build. Notes ----- Sets ``model.params[conn]`` to a `.BuiltConnection` instance. """ # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj key = 'out' if is_pre else 'in' if target not in model.sig: raise BuildError("Building %s: the %r object %s is not in the " "model, or has a size of zero." % (conn, 'pre' if is_pre else 'post', target)) if key not in model.sig[target]: raise BuildError( "Building %s: the %r object %s has a %r size of zero." % (conn, 'pre' if is_pre else 'post', target, key)) return model.sig[target][key] model.sig[conn]['in'] = get_prepost_signal(is_pre=True) model.sig[conn]['out'] = get_prepost_signal(is_pre=False) decoders = None encoders = None eval_points = None solver_info = None post_slice = conn.post_slice # Figure out the signal going across this connection in_signal = model.sig[conn]['in'] if (isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct))): # Node or Decoded connection in directmode sliced_in = slice_signal(model, in_signal, conn.pre_slice) if conn.function is None: in_signal = sliced_in elif isinstance(conn.function, np.ndarray): raise BuildError("Cannot use function points in direct connection") else: in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn) model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in)) elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, decoders, solver_info = model.build( conn.solver, conn, rng) if conn.solver.weights: model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in'] if isinstance(conn.post_obj, Ensemble): encoders = model.params[conn.post_obj].scaled_encoders.T encoders = encoders[conn.post_slice] # post slice already applied to encoders, don't apply later post_slice = None else: in_signal = slice_signal(model, in_signal, conn.pre_slice) # Build transform weighted, weights = model.build(conn.transform, in_signal, decoders=decoders, encoders=encoders, rng=rng) model.sig[conn]["weights"] = weights # Build synapse if conn.synapse is not None: weighted = model.build(conn.synapse, weighted) # Store the weighted-filtered output in case we want to probe it model.sig[conn]['weighted'] = weighted if isinstance(conn.post_obj, Neurons): # Apply neuron gains (we don't need to do this if we're connecting to # an Ensemble, because the gains are rolled into the encoders) gains = Signal(model.params[conn.post_obj.ensemble].gain[post_slice], name="%s.gains" % conn) model.add_op(ElementwiseInc( gains, weighted, model.sig[conn]['out'][post_slice], tag="%s.gains_elementwiseinc" % conn)) else: # Copy to the proper slice model.add_op(Copy( weighted, model.sig[conn]['out'], dst_slice=post_slice, inc=True, tag="%s" % conn)) # Build learning rules if conn.learning_rule is not None: # TODO: provide a general way for transforms to expose learnable params if isinstance(conn.transform, Convolution): raise NotImplementedError( "Learning on convolutional connections is not supported") rule = conn.learning_rule rule = [rule] if not is_iterable(rule) else rule targets = [] for r in itervalues(rule) if isinstance(rule, dict) else rule: model.build(r) targets.append(r.modifies) if 'encoders' in targets: encoder_sig = model.sig[conn.post_obj]['encoders'] encoder_sig.readonly = False if 'decoders' in targets or 'weights' in targets: if weights.ndim < 2: raise BuildError( "'transform' must be a 2-dimensional array for learning") model.sig[conn]['weights'].readonly = False model.params[conn] = BuiltConnection(eval_points=eval_points, solver_info=solver_info, transform=conn.transform, weights=weights.initial_value)
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()]) assert mergeable(DummyOp(sets=[DummySignal()]), [DummyOp(sets=[DummySignal()])]) # check matching dtypes assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]), [DummyOp(sets=[DummySignal(dtype=np.float64)])]) # shape mismatch assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]), [DummyOp(sets=[DummySignal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]), [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]), [DummyOp(sets=[DummySignal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=True)]) assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(DummySignal(), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())]) assert mergeable( ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())]) assert not mergeable( ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(), DummySignal())]) # simpyfunc (t input must match) time = DummySignal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, DummySignal()), [SimPyFunc(None, None, None, DummySignal())]) assert not mergeable(SimPyFunc(None, None, DummySignal(), None), [SimPyFunc(None, None, None, DummySignal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIF(), DummySignal(), DummySignal())]) assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIFRate(), DummySignal(), DummySignal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(Izhikevich(), DummySignal(), DummySignal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal()), [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.float32)]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(3,))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 1))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"), [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")]) # check matching TF_PROCESS_IMPL # note: we only have one item in TF_PROCESS_IMPL at the moment, so no # such thing as a mismatch assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Lowpass(0), None, None, DummySignal())]) # check custom vs non custom assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # simtensornode a = SimTensorNode(None, DummySignal(), None, DummySignal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) assert not mergeable(a, [b])
def build_pes(model, pes, rule): """ Builds a `nengo.PES` object into a Nengo model. Overrides the standard Nengo PES builder in order to avoid slicing on axes > 0 (not currently supported in NengoDL). Parameters ---------- model : Model The model to build into. pes : PES Learning rule type to build. rule : LearningRule The learning rule object corresponding to the neuron type. Notes ----- Does not modify ``model.params[]`` and can therefore be called more than once with the same `nengo.PES` instance. """ conn = rule.connection # Create input error signal error = Signal(shape=(rule.size_in, ), name="PES:error") model.add_op(Reset(error)) model.sig[rule]["in"] = error # error connection will attach here acts = build_or_passthrough(model, pes.pre_synapse, model.sig[conn.pre_obj]["out"]) if not conn.is_decoded: # multiply error by post encoders to get a per-neuron error post = get_post_ens(conn) encoders = model.sig[post]["encoders"] if conn.post_obj is not conn.post: # in order to avoid slicing encoders along an axis > 0, we pad # `error` out to the full base dimensionality and then do the # dotinc with the full encoder matrix padded_error = Signal(shape=(encoders.shape[1], )) model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice)) else: padded_error = error # error = dot(encoders, error) local_error = Signal(shape=(post.n_neurons, )) model.add_op(Reset(local_error)) model.add_op( DotInc(encoders, padded_error, local_error, tag="PES:encode")) else: local_error = error model.operators.append( SimPES(acts, local_error, model.sig[rule]["delta"], pes.learning_rate)) # expose these for probes model.sig[rule]["error"] = error model.sig[rule]["activities"] = acts
def build_test_rule(model, test_rule, rule): error = Signal(np.zeros(rule.connection.size_in)) model.add_op(Reset(error)) model.sig[rule]['in'] = error[:rule.size_in] model.add_op(Copy(error, model.sig[rule]['delta']))
def test_remove_constant_copies(): # check that Copy with no inputs gets turned into Reset x = DummySignal() operators = [Copy(DummySignal(), x)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is x assert new_operators[0].value == 0 # check that Copy with Node input doesn't get changed x = DummySignal(label="<Node lorem ipsum") operators = [Copy(x, DummySignal())] new_operators = remove_constant_copies(operators) assert new_operators == operators # check that Copy with trainable input doesn't get changed x = DummySignal() x.trainable = True operators = [Copy(x, DummySignal())] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with updated input doesn't get changed x = DummySignal() operators = [Copy(x, DummySignal()), DummyOp(updates=[x])] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with inc'd input doesn't get changed x = DummySignal() operators = [Copy(x, DummySignal()), DummyOp(incs=[x])] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with set input doesn't get changed x = DummySignal() operators = [Copy(x, DummySignal()), DummyOp(sets=[x])] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with read input/output does get changed x = DummySignal() y = DummySignal() operators = [Copy(x, y), DummyOp(reads=[x]), DummyOp(reads=[y])] new_operators = remove_constant_copies(operators) assert len(new_operators) == 3 assert new_operators[1:] == operators[1:] assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is y assert new_operators[0].value == 0 # check Copy with Reset input does get changed x = DummySignal() y = DummySignal() operators = [Copy(x, y), Reset(x, 2)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is y assert new_operators[0].value == 2 # check that slicing is respected x = DummySignal() y = Signal(initial_value=[0, 0]) operators = [Copy(x, y, dst_slice=slice(1, 2)), Reset(x, 2)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst.shape == (1, ) assert new_operators[0].dst.is_view assert new_operators[0].dst.elemoffset == 1 assert new_operators[0].dst.base is y assert new_operators[0].value == 2 # check that CopyInc gets turned into ResetInc x = DummySignal() y = DummySignal() operators = [Copy(x, y, inc=True), Reset(x, 2)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.ResetInc) assert new_operators[0].dst is y assert new_operators[0].value == 2 assert len(new_operators[0].incs) == 1 assert len(new_operators[0].sets) == 0
def test_remove_zero_incs(): # check that zero inputs get removed (for A or X) operators = [ DotInc(DummySignal(), DummySignal(initial_value=1), DummySignal()) ] new_operators = remove_zero_incs(operators) assert new_operators == [] operators = [ DotInc(DummySignal(initial_value=1), DummySignal(), DummySignal()) ] new_operators = remove_zero_incs(operators) assert new_operators == [] # check that zero inputs (copy) get removed operators = [Copy(DummySignal(), DummySignal(), DummySignal(), inc=True)] new_operators = remove_zero_incs(operators) assert new_operators == [] # check that node inputs don't get removed x = DummySignal(label="<Node lorem ipsum") operators = [DotInc(DummySignal(initial_value=1), x, DummySignal())] new_operators = remove_zero_incs(operators) assert new_operators == operators # check that zero inputs + trainable don't get removed x = DummySignal() x.trainable = True operators = [DotInc(DummySignal(initial_value=1), x, DummySignal())] new_operators = remove_zero_incs(operators) assert new_operators == operators # check that updated input doesn't get removed x = DummySignal() operators = [ DotInc(DummySignal(initial_value=1), x, DummySignal()), DummyOp(updates=[x]) ] new_operators = remove_zero_incs(operators) assert new_operators == operators # check that inc'd input doesn't get removed x = DummySignal() operators = [ DotInc(DummySignal(initial_value=1), x, DummySignal()), DummyOp(incs=[x]) ] new_operators = remove_zero_incs(operators) assert new_operators == operators # check that set'd input doesn't get removed x = DummySignal() operators = [ DotInc(DummySignal(initial_value=1), x, DummySignal()), DummyOp(sets=[x]) ] new_operators = remove_zero_incs(operators) assert new_operators == operators # check that Reset(0) input does get removed x = DummySignal() operators = [ DotInc(DummySignal(initial_value=1), x, DummySignal()), Reset(x) ] new_operators = remove_zero_incs(operators) assert new_operators == operators[1:] # check that Reset(1) input does not get removed x = DummySignal() operators = [ DotInc(DummySignal(initial_value=1), x, DummySignal()), Reset(x, 1) ] new_operators = remove_zero_incs(operators) assert new_operators == operators # check that set's get turned into a reset x = DummySignal() operators = [Copy(DummySignal(), x)] new_operators = remove_zero_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is x assert new_operators[0].value == 0
def build_connection(model, conn): """Builds a `.Connection` object into a model. A brief summary of what happens in the connection build process, in order: 1. Solve for decoders. 2. Combine transform matrix with decoders to get weights. 3. Add operators for computing the function or multiplying neural activity by weights. 4. Call build function for the synapse. 5. Call build function for the learning rule. 6. Add operator for applying learning rule delta to weights. Some of these steps may be altered or omitted depending on the parameters of the connection, in particular the pre and post types. Parameters ---------- model : Model The model to build into. conn : Connection The connection to build. Notes ----- Sets ``model.params[conn]`` to a `.BuiltConnection` instance. """ # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj key = 'out' if is_pre else 'in' if target not in model.sig: raise BuildError("Building %s: the %r object %s is not in the " "model, or has a size of zero." % (conn, 'pre' if is_pre else 'post', target)) if key not in model.sig[target]: raise BuildError( "Building %s: the %r object %s has a %r size of zero." % (conn, 'pre' if is_pre else 'post', target, key)) return model.sig[target][key] model.sig[conn]['in'] = get_prepost_signal(is_pre=True) model.sig[conn]['out'] = get_prepost_signal(is_pre=False) weights = None eval_points = None solver_info = None signal_size = conn.size_out post_slice = conn.post_slice # Sample transform if given a distribution transform = get_samples(conn.transform, conn.size_out, d=conn.size_mid, rng=rng) # Figure out the signal going across this connection in_signal = model.sig[conn]['in'] if (isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct))): # Node or Decoded connection in directmode weights = transform sliced_in = slice_signal(model, in_signal, conn.pre_slice) if conn.function is None: in_signal = sliced_in elif isinstance(conn.function, np.ndarray): raise BuildError("Cannot use function points in direct connection") else: in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn) model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in)) elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, weights, solver_info = build_decoders( model, conn, rng, transform) if conn.solver.weights: model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in'] signal_size = conn.post_obj.neurons.size_in post_slice = None # don't apply slice later else: weights = transform in_signal = slice_signal(model, in_signal, conn.pre_slice) if isinstance(conn.post_obj, Neurons): weights = multiply( model.params[conn.post_obj.ensemble].gain[post_slice], weights) # Add operator for applying weights model.sig[conn]['weights'] = Signal(weights, name="%s.weights" % conn, readonly=True) signal = Signal(np.zeros(signal_size), name="%s.weighted" % conn) model.add_op(Reset(signal)) op = ElementwiseInc if weights.ndim < 2 else DotInc model.add_op( op(model.sig[conn]['weights'], in_signal, signal, tag="%s.weights_elementwiseinc" % conn)) # Add operator for filtering if conn.synapse is not None: signal = model.build(conn.synapse, signal) # Store the weighted-filtered output in case we want to probe it model.sig[conn]['weighted'] = signal # Copy to the proper slice model.add_op( Copy(signal, model.sig[conn]['out'], dst_slice=post_slice, inc=True, tag="%s.gain" % conn)) # Build learning rules if conn.learning_rule is not None: rule = conn.learning_rule rule = [rule] if not is_iterable(rule) else rule targets = [] for r in itervalues(rule) if isinstance(rule, dict) else rule: model.build(r) targets.append(r.modifies) if 'encoders' in targets: encoder_sig = model.sig[conn.post_obj]['encoders'] encoder_sig.readonly = False if 'decoders' in targets or 'weights' in targets: if weights.ndim < 2: raise BuildError( "'transform' must be a 2-dimensional array for learning") model.sig[conn]['weights'].readonly = False model.params[conn] = BuiltConnection(eval_points=eval_points, solver_info=solver_info, transform=transform, weights=weights)
def build_mpes(model, mpes, rule): conn = rule.connection # Create input error signal error = Signal(shape=(rule.size_in, ), name="PES:error") model.add_op(Reset(error)) model.sig[rule]["in"] = error # error connection will attach here acts = build_or_passthrough(model, mpes.pre_synapse, model.sig[conn.pre_obj]["out"]) post = get_post_ens(conn) encoders = model.sig[post]["encoders"] out_size = encoders.shape[0] in_size = acts.shape[0] from scipy.stats import truncnorm def get_truncated_normal(mean, sd, low, upp): try: return truncnorm( (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd ) \ .rvs( out_size * in_size ) \ .reshape( (out_size, in_size) ) except ZeroDivisionError: return np.full((out_size, in_size), mean) np.random.seed(mpes.seed) r_min_noisy = get_truncated_normal(mpes.r_min, mpes.r_min * mpes.noise_percentage[0], 0, np.inf) np.random.seed(mpes.seed) r_max_noisy = get_truncated_normal(mpes.r_max, mpes.r_max * mpes.noise_percentage[1], np.max(r_min_noisy), np.inf) np.random.seed(mpes.seed) exponent_noisy = np.random.normal( mpes.exponent, np.abs(mpes.exponent) * mpes.noise_percentage[2], (out_size, in_size)) np.random.seed(mpes.seed) pos_mem_initial = np.random.normal(1e8, 1e8 * mpes.noise_percentage[3], (out_size, in_size)) np.random.seed(mpes.seed + 1) neg_mem_initial = np.random.normal(1e8, 1e8 * mpes.noise_percentage[3], (out_size, in_size)) pos_memristors = Signal(shape=(out_size, in_size), name="mPES:pos_memristors", initial_value=pos_mem_initial) neg_memristors = Signal(shape=(out_size, in_size), name="mPES:neg_memristors", initial_value=neg_mem_initial) model.sig[conn]["pos_memristors"] = pos_memristors model.sig[conn]["neg_memristors"] = neg_memristors if conn.post_obj is not conn.post: # in order to avoid slicing encoders along an axis > 0, we pad # `error` out to the full base dimensionality and then do the # dotinc with the full encoder matrix # comes into effect when slicing post connection padded_error = Signal(shape=(encoders.shape[1], )) model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice)) else: padded_error = error # error = dot(encoders, error) local_error = Signal(shape=(post.n_neurons, )) model.add_op(Reset(local_error)) model.add_op(DotInc(encoders, padded_error, local_error, tag="PES:encode")) model.operators.append( SimmPES(acts, local_error, mpes.learning_rate, model.sig[conn]["pos_memristors"], model.sig[conn]["neg_memristors"], model.sig[conn]["weights"], mpes.noise_percentage, mpes.gain, r_min_noisy, r_max_noisy, exponent_noisy)) # expose these for probes model.sig[rule]["error"] = error model.sig[rule]["activities"] = acts model.sig[rule]["pos_memristors"] = pos_memristors model.sig[rule]["neg_memristors"] = neg_memristors