class serv_ESN: def __init__(self,): weight_scale = 1.0 # .8 weight_inp = 0.2 weight_fb = 10 ** (-3) alpha = 0.99 # .35#.2 fback = False # False inital_washout = 100 # 100 padding_s = 300 units = 28 * 28 indim = 6 outdim = 6 self.esn = ESN(units, indim, outdim, weight_scale, weight_inp, weight_fb, alpha, fback) # self.webapp = webapp local_dir = os.path.dirname(__file__) self.esn.load(local_dir + "/trainied.pickle") self.stepper = self.esn.step_taped() self.outputs = np.zeros((3, outdim)) print "ESN:: init" def serv_close(self,): return True def serv_train(self,): return True def serv_step(self, val_in): state, output, this = self.stepper(val_in, self.outputs, 0.0) output += np.random.random(output.shape) return state, output
def test(sys, weight_scale, weight_inp, weight_fb, alpha, inital_washout, padding_s ): units = 28*28 indim = 6 outdim = 6 esn = ESN( units, indim, outdim, weight_scale,weight_inp,weight_fb, alpha, fback ) esn.load("trainied.pickle") stepper = esn.step_taped() dtsets = read_dataset(sys.argv[1], sys.argv[2]) # import pdb;pdb.set_trace() inputs, outputs, padIdxs, idxs = dtsets[0] plot_output =[] plot_state =[] import time start = time.time() ###########TRAIN # all_states = [] # all_this = [] # all_states, all_this= train(idxs, padIdxs, esn, stepper, inputs, outputs) # M_tonos = np.linalg.pinv(all_states) # # import pdb; pdb.set_trace() # all_this = np.arctanh(all_this) # W_trans = np.dot(M_tonos,all_this) # esn.W_out.set_value(W_trans) # print W_trans ###########END TRAIN print "Time taken ", time.time() - start #########TESTING############# outputs1 = np.zeros(outputs.shape) outputs1[1:] = outputs[:-1] state, output, this = stepper( inputs, outputs, 0.) print output.shape plot_state.extend(state[:,:units]) plot_output.extend(output) #########TESTING############# if int(sys.argv[3]) == 1: f, axarr = plt.subplots(4, sharex=True) for oid,tpt in enumerate(np.array(plot_output).transpose()): try: axarr[0].plot(tpt,label="output"+str(oid)) except: pass axarr[0].set_title('output') # axarr[0].legend() axarr[1].plot(outputs,label="outputs") axarr[2].plot(plot_state,label="state") axarr[2].set_title('state') # axarr[1].legend() axarr[3].plot(inputs,label="inputs") axarr[3].set_title('inputs') # axarr[2].legend() # plt.draw() # plt.figure() # plt.plot(inputs) # # plt.figure() # plt.plot(outputs) plt.show()