def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.initialized = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot( ) # returns a function that initializes weights self.W_hh = glorot_init(generate_key(), (4 * h, h)) # maps h_t to gates self.W_xh = glorot_init(generate_key(), (4 * h, n)) # maps x_t to gates self.b_h = np.zeros(4 * h) jax.ops.index_update(self.b_h, jax.ops.index[h:2 * h], np.ones(h)) # forget gate biased initialization self.W_out = glorot_init(generate_key(), (m, h)) # maps h_t to output self.cell = np.zeros(h) # long-term memory self.hid = np.zeros(h) # short-term memory self.sigmoid = lambda x: 1. / (1. + np.exp( -x)) # no JAX implementation of sigmoid it seems? return np.dot(self.W_out, self.hid)
def test_random(): set_key(5) a1 = get_global_key() r1 = generate_key() set_key(5) a2 = get_global_key() r2 = generate_key() assert str(a1) == str(a2) assert str(r1) == str(r2) print("test_random passed")
def test_lstm(steps=100, show_plot=False, verbose=False): T = steps n, m, l, h = 5, 3, 5, 10 problem = LSTM_Output() problem.initialize(n, m, l, h) assert problem.T == 0 test_output = [] for t in range(T): if verbose and (t+1) * 10 % T == 0: print("{} timesteps".format(t+1)) u = random.normal(generate_key(), shape=(n,)) test_output.append(problem.step(u)) info = problem.hidden() if verbose: print(info) assert problem.T == T if show_plot: plt.plot(test_output) plt.title("lstm") plt.show(block=False) plt.pause(1) plt.close() print("test_lstm passed") return
def test_lds(steps=100, show_plot=False): T = steps n, m, d = 5, 3, 10 problem = LDS() problem.initialize(n, m, d) assert problem.T == 0 test_output = [] for t in range(T): u = random.normal(generate_key(), shape=(n, )) test_output.append(problem.step(u)) info = problem.hidden() if verbose: print(info) assert problem.T == T if show_plot: plt.plot(test_output) plt.title("lds") plt.show(block=False) plt.pause(1) plt.close() print("test_lds passed") return
def step(self): """ Description: Moves the system dynamics one time-step forward. Args: None Returns: The next value in the time-series. """ assert self.initialized self.T += 1 return random.normal(generate_key())
def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.initialized = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot() # returns a function that initializes weights self.W_h = glorot_init(generate_key(), (h, h)) self.W_x = glorot_init(generate_key(), (h, n)) self.W_out = glorot_init(generate_key(), (m, h)) self.b_h = np.zeros(h) self.hid = np.zeros(h) return np.dot(self.W_out, self.hid)