def make(learn_type='hPES', nperd=20, learn_rate=5e-5, supervision_ratio=0.5, oja=False): net = nef.Network('Learn Digits') # creates a network in Nengo input = net.add(Input('input')) # create the input node pre = net.make('pre', IND * nperd, IND, radius=4.5) post = net.make('post', OUTD * nperd, OUTD, radius=2.5) # Create error population # Established learning connection between input and output populations if learn_type == 'PES': pes.make(net, preName='pre', postName='post', errName='error', N_err=OUTD * nperd, rate=learn_rate, oja=oja) elif learn_type == 'hPES': hpes.make(net, preName='pre', postName='post', errName='error', N_err=OUTD * nperd, rate=learn_rate, supervisionRatio=supervision_ratio) # Connect parts of network net.connect(input.getOrigin('label'), 'error') net.connect(input.getOrigin('input'), pre) net.connect('post', 'error', weight=-1) # Create a gate for turning learning on and off gating.make(net,name='Gate', gated='error', neurons=40, pstc=0.01) # Have the 'learning switch position' file drive the gate net.connect(input.getOrigin('learnswitch'), 'Gate') return net
def make(self): import nef import nef.templates.learned_termination as pes import nef.templates.hpes_termination as hpes import nef.templates.gate as gating if self.net is not None: return self.net random.seed(self.seed) net = nef.Network('Learn Network', seed=random.randrange(0x7fffffff)) net.make('pre', self.nperd * self.in_d, self.in_d) net.make('post', self.nperd * self.out_d, self.out_d) net.make_fourier_input('input', dimensions=self.in_d, base=0.25, high=40) net.connect('input', 'pre') if self.learn_type == 'PES': pes.make(net, preName='pre', postName='post', errName='error', N_err=self.nperd * self.out_d, rate=self.learn_rate, oja=self.oja) elif self.learn_type == 'hPES': hpes.make(net, preName='pre', postName='post', errName='error', N_err=self.nperd * self.out_d, rate=self.learn_rate, supervisionRatio=self.supervision_ratio) elif self.learn_type == 'control': net.connect('pre', 'post', func=self.func, origin_name='pre_00') net.make('error', 1, self.out_d, mode='direct') # Unused net.connect('pre', 'error', func=self.func) net.connect('post', 'error', weight=-1) start = 'test' if self.testtype == 'full' else 'train' net.make_input('switch', LearnBuilder.get_learning_times( self.train, self.test, start)) gating.make(net, name='Gate', gated='error', neurons=50, pstc=0.01) net.connect('switch', 'Gate') # Calculate actual error net.make('actual', 1, self.in_d, mode='direct') net.connect('input', 'actual') net.make('actual error', 1, self.out_d, mode='direct') net.connect('actual', 'actual error', func=self.func) net.connect('post', 'actual error', weight=-1) self.net = net return self.net