def build_bcm(model, bcm, rule): conn = rule.connection pre = (conn.pre_obj if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.ensemble) post = (conn.post_obj if isinstance(conn.post_obj, Ensemble) else conn.post_obj.ensemble) transform = model.sig[conn]['transform'] pre_activities = model.sig[pre.neurons]['out'] post_activities = model.sig[post.neurons]['out'] pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau) post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau) theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau) delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)), name='BCM: Delta') model.add_op(SimBCM(pre_filtered, post_filtered, theta, delta, learning_rate=bcm.learning_rate)) model.add_op(ElementwiseInc( model.sig['common'][1], delta, transform, tag="BCM: Inc Transform")) # expose these for probes model.sig[rule]['theta'] = theta model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_bcm(model, bcm, rule): conn = rule.connection pre = (conn.pre_obj if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.ensemble) post = (conn.post_obj if isinstance(conn.post_obj, Ensemble) else conn.post_obj.ensemble) transform = model.sig[conn]['transform'] pre_activities = model.sig[pre.neurons]['out'] post_activities = model.sig[post.neurons]['out'] pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau) post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau) theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau) delta = model.Signal(npext.castDecimal(np.zeros((post.n_neurons, pre.n_neurons))), name='BCM: Delta') model.add_op(SimBCM(pre_filtered, post_filtered, theta, delta, learning_rate=bcm.learning_rate)) model.add_op(ElementwiseInc( model.sig['common'][1], delta, transform, tag="BCM: Inc Transform")) # expose these for probes model.sig[rule]['theta'] = theta model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_oja(model, oja, rule): conn = rule.connection pre = (conn.pre_obj if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.ensemble) post = (conn.post_obj if isinstance(conn.post_obj, Ensemble) else conn.post_obj.ensemble) transform = model.sig[conn]['transform'] pre_activities = model.sig[pre.neurons]['out'] post_activities = model.sig[post.neurons]['out'] pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau) post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau) delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)), name='Oja: Delta') model.add_op( SimOja(pre_filtered, post_filtered, transform, delta, learning_rate=oja.learning_rate, beta=oja.beta)) model.add_op( ElementwiseInc(model.sig['common'][1], delta, transform, tag="Oja: Inc Transform")) # expose these for probes model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_oja(model, oja, rule): conn = rule.connection pre_activities = model.sig[get_pre_ens(conn).neurons]['out'] post_activities = model.sig[get_post_ens(conn).neurons]['out'] pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau) post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau) model.add_op(SimOja(pre_filtered, post_filtered, model.sig[conn]['weights'], model.sig[rule]['delta'], learning_rate=oja.learning_rate, beta=oja.beta)) # expose these for probes model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_oja(model, oja, rule): conn = rule.connection pre_activities = model.sig[get_pre_ens(conn).neurons]['out'] post_activities = model.sig[get_post_ens(conn).neurons]['out'] pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau) post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau) model.add_op( SimOja(pre_filtered, post_filtered, model.sig[conn]['transform'], model.sig[rule]['delta'], learning_rate=oja.learning_rate, beta=oja.beta)) # expose these for probes model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_bcm(model, bcm, rule): conn = rule.connection pre_activities = model.sig[get_pre_ens(conn).neurons]['out'] pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau) post_activities = model.sig[get_post_ens(conn).neurons]['out'] post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau) theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau) model.add_op(SimBCM(pre_filtered, post_filtered, theta, model.sig[rule]['delta'], learning_rate=bcm.learning_rate)) # expose these for probes model.sig[rule]['theta'] = theta model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_pes(model, pes, rule): conn = rule.connection # Create input error signal error = Signal(np.zeros(rule.size_in), name="PES:error") model.add_op(Reset(error)) model.sig[rule]['in'] = error # error connection will attach here acts = filtered_signal(model, pes, model.sig[conn.pre_obj]['out'], pes.pre_tau) acts_view = acts.reshape((1, acts.size)) # Compute the correction, i.e. the scaled negative error correction = Signal(np.zeros(error.shape), name="PES:correction") local_error = correction.reshape((error.size, 1)) model.add_op(Reset(correction)) # correction = -learning_rate * (dt / n_neurons) * error n_neurons = (conn.pre_obj.n_neurons if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.size_in) lr_sig = Signal(-pes.learning_rate * model.dt / n_neurons, name="PES:learning_rate") model.add_op(DotInc(lr_sig, error, correction, tag="PES:correct")) if conn.solver.weights or (isinstance(conn.pre_obj, Neurons) and isinstance(conn.post_obj, Neurons)): post = get_post_ens(conn) transform = model.sig[conn]['transform'] encoders = model.sig[post]['encoders'] # encoded = dot(encoders, correction) encoded = Signal(np.zeros(transform.shape[0]), name="PES:encoded") model.add_op(Reset(encoded)) model.add_op(DotInc(encoders, correction, encoded, tag="PES:encode")) local_error = encoded.reshape((encoded.size, 1)) elif not isinstance(conn.pre_obj, (Ensemble, Neurons)): raise ValueError("'pre' object '%s' not suitable for PES learning" % (conn.pre_obj)) # delta = local_error * activities model.add_op(Reset(model.sig[rule]['delta'])) model.add_op( ElementwiseInc(local_error, acts_view, model.sig[rule]['delta'], tag="PES:Inc Delta")) # expose these for probes model.sig[rule]['error'] = error model.sig[rule]['correction'] = correction model.sig[rule]['activities'] = acts model.params[rule] = None # no build-time info to return
def build_bcm(model, bcm, rule): conn = rule.connection pre_activities = model.sig[get_pre_ens(conn).neurons]['out'] pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau) post_activities = model.sig[get_post_ens(conn).neurons]['out'] post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau) theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau) model.add_op( SimBCM(pre_filtered, post_filtered, theta, model.sig[rule]['delta'], learning_rate=bcm.learning_rate)) # expose these for probes model.sig[rule]['theta'] = theta model.sig[rule]['pre_filtered'] = pre_filtered model.sig[rule]['post_filtered'] = post_filtered model.params[rule] = None # no build-time info to return
def build_pes(model, pes, rule): conn = rule.connection # Create input error signal error = Signal(np.zeros(rule.size_in), name="PES:error") model.add_op(Reset(error)) model.sig[rule]['in'] = error # error connection will attach here acts = filtered_signal( model, pes, model.sig[conn.pre_obj]['out'], pes.pre_tau) acts_view = acts.reshape((1, acts.size)) # Compute the correction, i.e. the scaled negative error correction = Signal(np.zeros(error.shape), name="PES:correction") local_error = correction.reshape((error.size, 1)) model.add_op(Reset(correction)) # correction = -learning_rate * (dt / n_neurons) * error n_neurons = (conn.pre_obj.n_neurons if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.size_in) lr_sig = Signal(-pes.learning_rate * model.dt / n_neurons, name="PES:learning_rate") model.add_op(DotInc(lr_sig, error, correction, tag="PES:correct")) if conn.solver.weights or ( isinstance(conn.pre_obj, Neurons) and isinstance(conn.post_obj, Neurons)): post = get_post_ens(conn) weights = model.sig[conn]['weights'] encoders = model.sig[post]['encoders'] # encoded = dot(encoders, correction) encoded = Signal(np.zeros(weights.shape[0]), name="PES:encoded") model.add_op(Reset(encoded)) model.add_op(DotInc(encoders, correction, encoded, tag="PES:encode")) local_error = encoded.reshape((encoded.size, 1)) elif not isinstance(conn.pre_obj, (Ensemble, Neurons)): raise ValueError("'pre' object '%s' not suitable for PES learning" % (conn.pre_obj)) # delta = local_error * activities model.add_op(Reset(model.sig[rule]['delta'])) model.add_op(ElementwiseInc( local_error, acts_view, model.sig[rule]['delta'], tag="PES:Inc Delta")) # expose these for probes model.sig[rule]['error'] = error model.sig[rule]['correction'] = correction model.sig[rule]['activities'] = acts model.params[rule] = None # no build-time info to return
def synapse_probe(model, key, probe): try: sig = model.sig[probe.obj][key] except IndexError: raise ValueError("Attribute '%s' is not probable on %s." % (key, probe.obj)) if probe.slice is not None: sig = sig[probe.slice] if probe.synapse is None: model.sig[probe]['in'] = sig else: model.sig[probe]['in'] = filtered_signal( model, probe, sig, probe.synapse)
def synapse_probe(model, key, probe): try: sig = model.sig[probe.obj][key] except IndexError: raise ValueError("Attribute '%s' is not probable on %s." % (key, probe.obj)) if isinstance(probe.slice, slice): sig = sig[probe.slice] else: raise NotImplementedError("Indexing slices not implemented") if probe.synapse is None: model.sig[probe]['in'] = sig else: model.sig[probe]['in'] = filtered_signal( model, probe, sig, probe.synapse)
def build_connection(model, conn): # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj key = 'out' if is_pre else 'in' if target not in model.sig: raise ValueError("Building %s: the '%s' object %s " "is not in the model, or has a size of zero." % (conn, 'pre' if is_pre else 'post', target)) if key not in model.sig[target]: raise ValueError("Error building %s: the '%s' object %s " "has a '%s' size of zero." % (conn, 'pre' if is_pre else 'post', target, key)) return model.sig[target][key] model.sig[conn]['in'] = get_prepost_signal(is_pre=True) model.sig[conn]['out'] = get_prepost_signal(is_pre=False) decoders = None eval_points = None solver_info = None transform = full_transform(conn, slice_pre=False) # Figure out the signal going across this connection if (isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct))): # Node or Decoded connection in directmode if (conn.function is None and isinstance(conn.pre_slice, slice) and (conn.pre_slice.step is None or conn.pre_slice.step == 1)): signal = model.sig[conn]['in'][conn.pre_slice] else: signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn) fn = ((lambda x: x[conn.pre_slice]) if conn.function is None else (lambda x: conn.function(x[conn.pre_slice]))) model.add_op( SimPyFunc(output=signal, fn=fn, t_in=False, x=model.sig[conn]['in'])) elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, activities, targets = build_linear_system( model, conn, rng) # Use cached solver, if configured solver = model.decoder_cache.wrap_solver(conn.solver) if conn.solver.weights: # include transform in solved weights targets = np.dot(targets, transform.T) transform = np.array(1., dtype=np.float64) decoders, solver_info = solver( activities, targets, rng=rng, E=model.params[conn.post_obj].scaled_encoders.T) model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in'] signal_size = model.sig[conn]['out'].size else: decoders, solver_info = solver(activities, targets, rng=rng) signal_size = conn.size_mid # Add operator for decoders decoders = decoders.T model.sig[conn]['decoders'] = Signal(decoders, name="%s.decoders" % conn) signal = Signal(np.zeros(signal_size), name=str(conn)) model.add_op(Reset(signal)) model.add_op( DotInc(model.sig[conn]['decoders'], model.sig[conn]['in'], signal, tag="%s decoding" % conn)) else: # Direct connection signal = model.sig[conn]['in'] # Add operator for filtering if conn.synapse is not None: signal = filtered_signal(model, conn, signal, conn.synapse) # Add operator for transform if isinstance(conn.post_obj, Neurons): if not model.has_built(conn.post_obj.ensemble): # Since it hasn't been built, it wasn't added to the Network, # which is most likely because the Neurons weren't associated # with an Ensemble. raise RuntimeError("Connection '%s' refers to Neurons '%s' " "that are not a part of any Ensemble." % (conn, conn.post_obj)) if conn.post_slice != slice(None): raise NotImplementedError( "Post-slices on connections to neurons are not implemented") gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice] if transform.ndim < 2: transform = transform * gain else: transform *= gain[:, np.newaxis] model.sig[conn]['transform'] = Signal(transform, name="%s.transform" % conn) if transform.ndim < 2: model.add_op( ElementwiseInc(model.sig[conn]['transform'], signal, model.sig[conn]['out'], tag=str(conn))) else: model.add_op( DotInc(model.sig[conn]['transform'], signal, model.sig[conn]['out'], tag=str(conn))) # Build learning rules if conn.learning_rule: if isinstance(conn.pre_obj, Ensemble): model.add_op(PreserveValue(model.sig[conn]['decoders'])) else: model.add_op(PreserveValue(model.sig[conn]['transform'])) if isinstance(conn.pre_obj, Ensemble) and conn.solver.weights: # TODO: make less hacky. # Have to do this because when a weight_solver # is provided, then learning rules should operate on # "decoders" which is really the weight matrix. model.sig[conn]['transform'] = model.sig[conn]['decoders'] rule = conn.learning_rule if is_iterable(rule): for r in itervalues(rule) if isinstance(rule, dict) else rule: model.build(r) elif rule is not None: model.build(rule) model.params[conn] = BuiltConnection(decoders=decoders, eval_points=eval_points, transform=transform, solver_info=solver_info)
def build_connection(model, conn): # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj key = 'out' if is_pre else 'in' if target not in model.sig: raise ValueError("Building %s: the '%s' object %s " "is not in the model, or has a size of zero." % (conn, 'pre' if is_pre else 'post', target)) if key not in model.sig[target]: raise ValueError("Error building %s: the '%s' object %s " "has a '%s' size of zero." % (conn, 'pre' if is_pre else 'post', target, key)) return model.sig[target][key] model.sig[conn]['in'] = get_prepost_signal(is_pre=True) model.sig[conn]['out'] = get_prepost_signal(is_pre=False) weights = None eval_points = None solver_info = None signal_size = conn.size_out post_slice = conn.post_slice # Figure out the signal going across this connection in_signal = model.sig[conn]['in'] if (isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct))): # Node or Decoded connection in directmode sliced_in = slice_signal(model, in_signal, conn.pre_slice) if conn.function is not None: in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn) model.add_op(SimPyFunc( output=in_signal, fn=conn.function, t_in=False, x=sliced_in)) else: in_signal = sliced_in elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, decoders, solver_info = build_decoders(model, conn, rng) if conn.solver.weights: model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in'] signal_size = conn.post_obj.neurons.size_in post_slice = Ellipsis # don't apply slice later weights = decoders.T else: weights = multiply(conn.transform, decoders.T) else: in_signal = slice_signal(model, in_signal, conn.pre_slice) # Add operator for applying weights if weights is None: weights = np.array(conn.transform) if isinstance(conn.post_obj, Neurons): gain = model.params[conn.post_obj.ensemble].gain[post_slice] weights = multiply(gain, weights) if conn.learning_rule is not None and weights.ndim < 2: raise ValueError("Learning connection must have full transform matrix") model.sig[conn]['weights'] = Signal(weights, name="%s.weights") signal = Signal(np.zeros(signal_size), name="%s.weighted") model.add_op(Reset(signal)) op = ElementwiseInc if weights.ndim < 2 else DotInc model.add_op(op(model.sig[conn]['weights'], in_signal, signal, tag="%s.weights_elementwiseinc" % conn)) # Add operator for filtering if conn.synapse is not None: signal = filtered_signal(model, conn, signal, conn.synapse) # Copy to the proper slice model.add_op(SlicedCopy( signal, model.sig[conn]['out'], b_slice=post_slice, inc=True, tag="%s.gain" % conn)) # Build learning rules if conn.learning_rule is not None: model.add_op(PreserveValue(model.sig[conn]['weights'])) rule = conn.learning_rule if is_iterable(rule): for r in itervalues(rule) if isinstance(rule, dict) else rule: model.build(r) elif rule is not None: model.build(rule) model.params[conn] = BuiltConnection(eval_points=eval_points, solver_info=solver_info, weights=weights)
def build_connection(model, conn): # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj print('pre', conn.pre_obj, 'post', conn.post_obj) key = 'out' if is_pre else 'in' if target not in model.sig: raise ValueError("Building %s: the '%s' object %s " "is not in the model, or has a size of zero." % (conn, 'pre' if is_pre else 'post', target)) if key not in model.sig[target]: raise ValueError("Error building %s: the '%s' object %s " "has a '%s' size of zero." % (conn, 'pre' if is_pre else 'post', target, key)) return model.sig[target][key] model.sig[conn]['in'] = get_prepost_signal(is_pre=True) model.sig[conn]['out'] = get_prepost_signal(is_pre=False) decoders = None eval_points = None solver_info = None transform = full_transform(conn, slice_pre=False) # Figure out the signal going across this connection if (isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct))): # Node or Decoded connection in directmode if (conn.function is None and isinstance(conn.pre_slice, slice) and (conn.pre_slice.step is None or conn.pre_slice.step == 1)): signal = model.sig[conn]['in'][conn.pre_slice] else: sig_in, signal = build_pyfunc( fn=(lambda x: x[conn.pre_slice]) if conn.function is None else (lambda x: conn.function(x[conn.pre_slice])), t_in=False, n_in=model.sig[conn]['in'].size, n_out=conn.size_mid, label=str(conn), model=model) model.add_op( DotInc(model.sig[conn]['in'], model.sig['common'][1], sig_in, tag="%s input" % conn)) elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, activities, targets = build_linear_system( model, conn, rng) # Use cached solver, if configured solver = model.decoder_cache.wrap_solver(conn.solver) if conn.solver.weights: # account for transform targets = np.dot(targets, transform.T) transform = np.array(1, dtype=rc.get('precision', 'dtype')) decoders, solver_info = solver( activities, targets, rng=rng, E=model.params[conn.post_obj].scaled_encoders.T) model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in'] signal_size = model.sig[conn]['out'].size else: decoders, solver_info = solver(activities, targets, rng=rng) signal_size = conn.size_mid # Add operator for decoders decoders = decoders.T model.sig[conn]['decoders'] = model.Signal(decoders, name="%s.decoders" % conn) signal = model.Signal(npext.castDecimal(np.zeros(signal_size)), name=str(conn)) model.add_op(Reset(signal)) model.add_op( DotInc(model.sig[conn]['decoders'], model.sig[conn]['in'], signal, tag="%s decoding" % conn)) else: # Direct connection signal = model.sig[conn]['in'] # Add operator for filtering if conn.synapse is not None: signal = filtered_signal(model, conn, signal, conn.synapse) if conn.modulatory: # Make a new signal, effectively detaching from post model.sig[conn]['out'] = model.Signal(npext.castDecimal( np.zeros(model.sig[conn]['out'].size)), name="%s.mod_output" % conn) model.add_op(Reset(model.sig[conn]['out'])) # Add operator for transform if isinstance(conn.post_obj, Neurons): if not model.has_built(conn.post_obj.ensemble): # Since it hasn't been built, it wasn't added to the Network, # which is most likely because the Neurons weren't associated # with an Ensemble. raise RuntimeError("Connection '%s' refers to Neurons '%s' " "that are not a part of any Ensemble." % (conn, conn.post_obj)) if conn.post_slice != slice(None): raise NotImplementedError( "Post-slices on connections to neurons are not implemented") gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice] if transform.ndim < 2: transform = transform * gain else: transform *= gain[:, np.newaxis] model.sig[conn]['transform'] = model.Signal(transform, name="%s.transform" % conn) print('abcd', model.sig[conn]['out'].value, signal.value) if transform.ndim < 2: print('line 174', model.sig[conn]['transform'].value) model.add_op( ElementwiseInc(model.sig[conn]['transform'], signal, model.sig[conn]['out'], tag=str(conn))) else: model.add_op( DotInc(model.sig[conn]['transform'], signal, model.sig[conn]['out'], tag=str(conn))) if conn.learning_rule_type: # Forcing update of signal that is modified by learning rules. # Learning rules themselves apply DotIncs. if isinstance(conn.pre_obj, Neurons): modified_signal = model.sig[conn]['transform'] elif isinstance(conn.pre_obj, Ensemble): if conn.solver.weights: # TODO: make less hacky. # Have to do this because when a weight_solver # is provided, then learning rules should operators on # "decoders" which is really the weight matrix. model.sig[conn]['transform'] = model.sig[conn]['decoders'] modified_signal = model.sig[conn]['transform'] else: modified_signal = model.sig[conn]['decoders'] else: raise TypeError( "Can't apply learning rules to connections of " "this type. pre type: %s, post type: %s" % (type(conn.pre_obj).__name__, type(conn.post_obj).__name__)) model.add_op(PreserveValue(modified_signal)) model.params[conn] = BuiltConnection(decoders=decoders, eval_points=eval_points, transform=transform, solver_info=solver_info)
def build_connection(model, conn): # Create random number generator rng = np.random.RandomState(model.seeds[conn]) # Get input and output connections from pre and post def get_prepost_signal(is_pre): target = conn.pre_obj if is_pre else conn.post_obj key = 'out' if is_pre else 'in' if target not in model.sig: raise ValueError("Building %s: the '%s' object %s " "is not in the model, or has a size of zero." % (conn, 'pre' if is_pre else 'post', target)) if key not in model.sig[target]: raise ValueError("Error building %s: the '%s' object %s " "has a '%s' size of zero." % (conn, 'pre' if is_pre else 'post', target, key)) return model.sig[target][key] model.sig[conn]['in'] = get_prepost_signal(is_pre=True) model.sig[conn]['out'] = get_prepost_signal(is_pre=False) decoders = None eval_points = None solver_info = None transform = full_transform(conn, slice_pre=False) # Figure out the signal going across this connection if (isinstance(conn.pre_obj, Node) or (isinstance(conn.pre_obj, Ensemble) and isinstance(conn.pre_obj.neuron_type, Direct))): # Node or Decoded connection in directmode if (conn.function is None and isinstance(conn.pre_slice, slice) and (conn.pre_slice.step is None or conn.pre_slice.step == 1)): signal = model.sig[conn]['in'][conn.pre_slice] else: sig_in, signal = build_pyfunc( fn=(lambda x: x[conn.pre_slice]) if conn.function is None else (lambda x: conn.function(x[conn.pre_slice])), t_in=False, n_in=model.sig[conn]['in'].size, n_out=conn.size_mid, label=str(conn), model=model) model.add_op(DotInc(model.sig[conn]['in'], model.sig['common'][1], sig_in, tag="%s input" % conn)) elif isinstance(conn.pre_obj, Ensemble): # Normal decoded connection eval_points, activities, targets = build_linear_system(model, conn) # Use cached solver, if configured solver = model.decoder_cache.wrap_solver(conn.solver) if conn.solver.weights: # account for transform targets = np.dot(targets, transform.T) transform = np.array(1., dtype=np.float64) decoders, solver_info = solver( activities, targets, rng=rng, E=model.params[conn.post_obj].scaled_encoders.T) model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in'] signal_size = model.sig[conn]['out'].size else: decoders, solver_info = solver(activities, targets, rng=rng) signal_size = conn.size_mid # Add operator for decoders decoders = decoders.T model.sig[conn]['decoders'] = Signal( decoders, name="%s.decoders" % conn) signal = Signal(np.zeros(signal_size), name=str(conn)) model.add_op(Reset(signal)) model.add_op(DotInc(model.sig[conn]['decoders'], model.sig[conn]['in'], signal, tag="%s decoding" % conn)) else: # Direct connection signal = model.sig[conn]['in'] # Add operator for filtering if conn.synapse is not None: signal = filtered_signal(model, conn, signal, conn.synapse) if conn.modulatory: # Make a new signal, effectively detaching from post model.sig[conn]['out'] = Signal( np.zeros(model.sig[conn]['out'].size), name="%s.mod_output" % conn) model.add_op(Reset(model.sig[conn]['out'])) # Add operator for transform if isinstance(conn.post_obj, Neurons): if not model.has_built(conn.post_obj.ensemble): # Since it hasn't been built, it wasn't added to the Network, # which is most likely because the Neurons weren't associated # with an Ensemble. raise RuntimeError("Connection '%s' refers to Neurons '%s' " "that are not a part of any Ensemble." % ( conn, conn.post_obj)) if conn.post_slice != slice(None): raise NotImplementedError( "Post-slices on connections to neurons are not implemented") gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice] if transform.ndim < 2: transform = transform * gain else: transform *= gain[:, np.newaxis] model.sig[conn]['transform'] = Signal(transform, name="%s.transform" % conn) if transform.ndim < 2: model.add_op(ElementwiseInc(model.sig[conn]['transform'], signal, model.sig[conn]['out'], tag=str(conn))) else: model.add_op(DotInc(model.sig[conn]['transform'], signal, model.sig[conn]['out'], tag=str(conn))) if conn.learning_rule_type: # Forcing update of signal that is modified by learning rules. # Learning rules themselves apply DotIncs. if isinstance(conn.pre_obj, Neurons): modified_signal = model.sig[conn]['transform'] elif isinstance(conn.pre_obj, Ensemble): if conn.solver.weights: # TODO: make less hacky. # Have to do this because when a weight_solver # is provided, then learning rules should operators on # "decoders" which is really the weight matrix. model.sig[conn]['transform'] = model.sig[conn]['decoders'] modified_signal = model.sig[conn]['transform'] else: modified_signal = model.sig[conn]['decoders'] else: raise TypeError("Can't apply learning rules to connections of " "this type. pre type: %s, post type: %s" % (type(conn.pre_obj).__name__, type(conn.post_obj).__name__)) model.add_op(PreserveValue(modified_signal)) model.params[conn] = BuiltConnection(decoders=decoders, eval_points=eval_points, transform=transform, solver_info=solver_info)