def __init__(self, n_pc, input_node, n_neuron_out, lr): ''' Initialize actor net as a nengo network object PARAMS: n_pc - number of place cells n_neuron_in - number of neurons in Ensemble encoding input n_neuron_out - number of neurons in Ensemble encoding output ''' with nengo.Network() as net: net.output = nengo.Ensemble(n_neurons=n_neuron_out, dimensions=8, radius=np.sqrt(8)) net.conn = nengo.Connection( input_node, net.output, synapse=0.01, function=lambda x: [0] * 8, solver=nengo.solvers.LstsqL2(weights=True), learning_rule_type=Learning.TDL(learning_rate=lr)) self.net = net
BACKEND = 'CPU' dt = 0.001 duration = 10 discount = 0.9995 env = TestEnvActor(dt=dt, trial_length=40, reset=1000) with nengo.Network() as net: envnode = nengo.Node(lambda t, v: env.step(v), size_in=1, size_out=3) in_ens = nengo.Ensemble(n_neurons=1000, radius=2, dimensions=1) # encodes position actor = nengo.Ensemble(n_neurons=1000, radius=1, dimensions=1) critic = CriticNet(in_ens, n_neuron_out=1000, lr=1e-5) error = ErrorNode(discount=discount) # seems like a reasonable value to have a reward gradient over the entire episode switch = Switch(state=1, switch_off=False, switchtime=duration/2) # needed for compatibility with error implementation nengo.Connection(envnode[0], in_ens) conn = nengo.Connection(in_ens, actor, function=lambda x: [0], solver=nengo.solvers.LstsqL2(weights=True), learning_rule_type=Learning.TDL(learning_rate=1e-8)) nengo.Connection(actor, envnode) # error node connections # reward = input[0] value = input[1] switch = input[2] state = input[3] reset = input[4].astype(int) nengo.Connection(envnode[1], error.net.errornode[0], synapse=0.01) # reward connection nengo.Connection(critic.net.output, error.net.errornode[1], synapse=0.01) # value prediction nengo.Connection(switch.net.switch, error.net.errornode[2], synapse=0.01) # learning switch nengo.Connection(error.net.errornode[1], error.net.errornode[3], synapse=0.01) # feed value into next step nengo.Connection(envnode[2], error.net.errornode[4], synapse=0.01) # propagate reset signal # error to critic nengo.Connection(error.net.errornode[0], critic.net.conn.learning_rule, transform=-1) nengo.Connection(error.net.errornode[0], conn.learning_rule) # Probes