def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.max_T = -1 self.initialized = True self.has_regressors = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot() # returns a function that initializes weights self.W_h = glorot_init(generate_key(), (h, h)) self.W_x = glorot_init(generate_key(), (h, n)) self.W_out = glorot_init(generate_key(), (m, h)) self.b_h = np.zeros(h) self.hid = np.zeros(h) def _step(x, hid): next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h) y = np.dot(self.W_out, next_hid) return (next_hid, y) self._step = jax.jit(_step) return self.step()
def initialize(self, n=1, m=1, l=32, h=64, optimizer=OGD, loss=mse, lr=0.003): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. l (int): Length of memory for update step purposes. h (int): Default value 64. Hidden dimension of RNN. optimizer (class): optimizer choice loss (class): loss choice lr (float): learning rate for update """ self.T = 0 self.initialized = True self.n, self.m, self.l, self.h = n, m, l, h # initialize parameters glorot_init = stax.glorot( ) # returns a function that initializes weights W_h = glorot_init(generate_key(), (h, h)) W_x = glorot_init(generate_key(), (h, n)) W_out = glorot_init(generate_key(), (m, h)) b_h = np.zeros(h) # self.params = [W_h, W_x, W_out, b_h] self.params = {'W_h': W_h, 'W_x': W_x, 'W_out:': W_out, 'b_h': b_h} self.hid = np.zeros(h) self.x = np.zeros((l, n)) """ private helper methods""" @jax.jit def _update_x(self_x, x): new_x = np.roll(self_x, -self.n) new_x = jax.ops.index_update(new_x, jax.ops.index[-1, :], x) return new_x @jax.jit def _fast_predict(carry, x): params, hid = carry # unroll tuple in carry W_h, W_x, W_out, b_h = params.values() next_hid = np.tanh(np.dot(W_h, hid) + np.dot(W_x, x) + b_h) y = np.dot(W_out, next_hid) return (params, next_hid), y @jax.jit def _predict(params, x): _, y = jax.lax.scan(_fast_predict, (params, np.zeros(h)), x) return y[-1] self.transform = lambda x: float(x) if (self.m == 1) else x self._update_x = _update_x self._fast_predict = _fast_predict self._predict = _predict self._store_optimizer(optimizer, self._predict)
def test_random(show=False): set_key(5) a1 = get_global_key() r1 = generate_key() set_key(5) a2 = get_global_key() r2 = generate_key() assert str(a1) == str(a2) assert str(r1) == str(r2) print("test_random passed")
def initialize(self, n = 1, m = None, p = 3, optimizer = OGD): """ Description: Initializes autoregressive method parameters Args: p (int): Length of history used for prediction optimizer (class): optimizer choice loss (class): loss choice lr (float): learning rate for update """ self.initialized = True self.n = n self.p = p self.past = np.zeros((p, self.n)) glorot_init = stax.glorot() # returns a function that initializes weights # self.params = glorot_init(generate_key(), (p+1,1)) self.params = {'phi' : glorot_init(generate_key(), (p+1,1))} def _update_past(self_past, x): new_past = np.roll(self_past, self.n) new_past = jax.ops.index_update(new_past, 0, x) return new_past self._update_past = jax.jit(_update_past) def _predict(params, x): phi = list(params.values())[0] x_plus_bias = np.vstack((np.ones((1, self.n)), x)) return np.dot(x_plus_bias.T, phi).squeeze() self._predict = jax.jit(_predict) self._store_optimizer(optimizer, self._predict)
def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.max_T = -1 self.initialized = True self.has_regressors = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot( ) # returns a function that initializes weights self.W_hh = glorot_init(generate_key(), (4 * h, h)) # maps h_t to gates self.W_xh = glorot_init(generate_key(), (4 * h, n)) # maps x_t to gates self.b_h = np.zeros(4 * h) self.b_h = jax.ops.index_update( self.b_h, jax.ops.index[h:2 * h], np.ones(h)) # forget gate biased initialization self.W_out = glorot_init(generate_key(), (m, h)) # maps h_t to output self.cell = np.zeros(h) # long-term memory self.hid = np.zeros(h) # short-term memory def _step(x, hid, cell): sigmoid = lambda x: 1. / (1. + np.exp( -x)) # no JAX implementation of sigmoid it seems? gate = np.dot(self.W_hh, hid) + np.dot(self.W_xh, x) + self.b_h i, f, g, o = np.split(gate, 4) # order: input, forget, cell, output next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g) next_hid = sigmoid(o) * np.tanh(next_cell) y = np.dot(self.W_out, next_hid) return (next_hid, next_cell, y) self._step = jax.jit(_step) return self.step()
def step(self): """ Description: Moves the system dynamics one time-step forward. Args: None Returns: The next value in the time-series. """ assert self.initialized self.T += 1 return random.normal(generate_key())
def initialize(self): """ Description: Randomly initialize the hidden dynamics of the system. Args: None Returns: None """ self.T = 0 self.max_T = -1 self.initialized = True return random.normal(generate_key())
def step(self): """ Description: Takes an input and produces the next output of the RNN. Args: x (numpy.ndarray): RNN input, an n-dimensional real-valued vector. Returns: The output of the RNN computed on the past l inputs, including the new x. """ assert self.initialized self.T += 1 x = random.normal(generate_key(), shape=(self.n, )) self.hid, self.cell, y = self._step(x, self.hid, self.cell) return x, y
def initialize(self, params={'n': 1, "m": None, "p": 3, "optimizer": OGD}): """ Description: Initializes autoregressive method parameters Args: p (int): Length of history used for prediction optimizer (class): optimizer choice loss (class): loss choice lr (float): learning rate for update """ self.initialized = True self.n = 1 #params['n'] self.p = params['p'] self.past = np.zeros((self.p, self.n)) glorot_init = stax.glorot( ) # returns a function that initializes weights # self.params = glorot_init(generate_key(), (p+1,1)) self.params = {'phi': glorot_init(generate_key(), (self.p, 1))} def _update_past(self_past, x): new_past = np.roll(self_past, self.n) new_past = jax.ops.index_update(new_past, 0, x) return new_past self._update_past = jax.jit(_update_past) def _predict(params, x): phi = list(params.values())[0] return np.dot(x.T, phi).squeeze() self._predict = jax.jit(_predict) self._store_optimizer(params['optimizer'], self._predict)
def initialize(self, n=1, m=None, p=3, d=1, optimizer=OGD): """ Description: Initializes autoregressive method parameters Args: n (int): dimension of the data p (int): Length of history used for prediction d (int): number of difference orders to use. For zero the original autoregressor is used. optimizer (class): optimizer choice """ self.initialized = True self.n = n self.p = p self.d = d self.past = np.zeros( (p + d, self.n)) #store the last d x values to compute the differences glorot_init = stax.glorot( ) # returns a function that initializes weights # self.params = glorot_init(generate_key(), (p+1,1)) self.params = {'phi': glorot_init(generate_key(), (p + 1, 1))} def _update_past(self_past, x): new_past = np.roll(self_past, 1) new_past = jax.ops.index_update(new_past, 0, x) return new_past self._update_past = jax.jit(_update_past) ''' unused (easy, but also inefficient version) def _computeDifference(x, d): if(d < 0): return 0.0 if(d == 0): return x[0] else: return _computeDifference(x, d - 1) - _computeDifference(x[1:], d - 1) self._computeDifference = jax.jit(_computeDifference)''' def _getDifference(index, d, matrix): if (d < 0): return np.zeros(matrix[0, 0].shape) else: return matrix[d, index] self._getDifference = jax.jit(_getDifference) def _computeDifferenceMatrix(x, d): result = np.zeros((d + 1, x.shape[0], x.shape[1])) #first row (zeroth difference is the original time series) result = jax.ops.index_update(result, jax.ops.index[0, 0:len(x)], x) #fill the next rows for k in range(1, d + 1): result = jax.ops.index_update( result, jax.ops.index[k, 0:len(x) - k - 1], result[k - 1, 0:len(x) - k - 1] - result[k - 1, 1:len(x) - k]) return result self._computeDifferenceMatrix = jax.jit(_computeDifferenceMatrix) def _predict(params, x): phi = list(params.values())[0] differences = _computeDifferenceMatrix(x, self.d) x_plus_bias = np.vstack( (np.ones((1, self.n)), np.array([ _getDifference(i, self.d, differences) for i in range(0, p) ]))) return np.dot(x_plus_bias.T, phi).squeeze() + np.sum( [_getDifference(0, k, differences) for k in range(self.d)]) self._predict = jax.jit(_predict) def _getUpdateValues(x): diffMatrix = _computeDifferenceMatrix(x, self.d) differences = np.array([ _getDifference(i, self.d, diffMatrix) for i in range(1, self.p + 1) ]) label = _getDifference(0, self.d, diffMatrix) return differences, label self._getUpdateValues = jax.jit(_getUpdateValues) self._store_optimizer(optimizer, self._predict)
def initialize(self, n=1, m=1, l=32, h=64, optimizer=OGD): """ Description: Randomly initialize the LSTM. Args: n (int): Observation/output dimension. m (int): Input action dimension. l (int): Length of memory for update step purposes. h (int): Default value 64. Hidden dimension of LSTM. optimizer (class): optimizer choice loss (class): loss choice lr (float): learning rate for update """ self.T = 0 self.initialized = True self.n, self.m, self.l, self.h = n, m, l, h # initialize parameters glorot_init = stax.glorot( ) # returns a function that initializes weights W_hh = glorot_init(generate_key(), (4 * h, h)) # maps h_t to gates W_xh = glorot_init(generate_key(), (4 * h, n)) # maps x_t to gates W_out = glorot_init(generate_key(), (m, h)) # maps h_t to output b_h = np.zeros(4 * h) b_h = jax.ops.index_update( b_h, jax.ops.index[h:2 * h], np.ones(h)) # forget gate biased initialization # self.params = [W_hh, W_xh, W_out, b_h] self.params = {'W_hh': W_hh, 'W_xh': W_xh, 'W_out': W_out, 'b_h': b_h} self.hid = np.zeros(h) self.cell = np.zeros(h) self.x = np.zeros((l, m)) """ private helper methods""" #@jax.jit def _update_x(self_x, x): new_x = np.roll(self_x, -self.n) new_x = jax.ops.index_update(new_x, jax.ops.index[-1, :], x) return new_x @jax.jit def _fast_predict(carry, x): params, hid, cell = carry # unroll tuple in carry W_hh, W_xh, W_out, b_h = params.values() sigmoid = lambda x: 1. / (1. + np.exp( -x)) # no JAX implementation of sigmoid it seems? gate = np.dot(W_hh, hid) + np.dot(W_xh, x) + b_h i, f, g, o = np.split(gate, 4) # order: input, forget, cell, output next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g) next_hid = sigmoid(o) * np.tanh(next_cell) y = np.dot(W_out, next_hid) return (params, next_hid, next_cell), y @jax.jit def _predict(params, x): _, y = jax.lax.scan(_fast_predict, (params, np.zeros(h), np.zeros(h)), x) return y[-1] self.transform = lambda x: float(x) if (self.m == 1) else x self._update_x = _update_x self._fast_predict = _fast_predict self._predict = _predict self._store_optimizer(optimizer, self._predict)
def search(self, method_id, method_params, problem_id, problem_params, loss, search_space, trials=None, smoothing=10, min_steps=100, verbose=0): """ Description: Search for optimal method parameters Args: method_id (string): id of method method_params (dict): initial method parameters dict (updated by search space) problem_id (string): id of problem to try on problem_params (dict): problem parameters dict loss (function): a function mapping y_pred, y_true -> scalar loss search_space (dict): dict mapping parameter names to a finite set of options trials (int, None): number of random trials to sample from search space / try all parameters smoothing (int): loss computed over smoothing number of steps to decrease variance min_steps (int): minimum number of steps that the method gets to run for verbose (int): if 1, print progress and current parameters """ self.method_id = method_id self.method_params = method_params self.problem_id = problem_id self.problem_params = problem_params self.loss = loss # store the order to test parameters param_list = list( itertools.product(*[v for k, v in search_space.items()])) index = np.arange( len(param_list) ) # np.random.shuffle doesn't work directly on non-JAX objects shuffled_index = random.shuffle(generate_key(), index) param_order = [param_list[int(i)] for i in shuffled_index] # shuffle order of elements # helper method def _update_smoothing(l, val): """ update smoothing loss list with new val """ return jax.ops.index_update(np.roll(l, 1), 0, val) self._update_smoothing = jit(_update_smoothing) # store optimal params and optimal loss optimal_params, optimal_loss = {}, None t = 0 for params in param_order: # loop over all params in the given order t += 1 curr_params = method_params.copy() curr_params.update( {k: v for k, v in zip(search_space.keys(), params)}) loss = self._run_test(curr_params, smoothing=smoothing, min_steps=min_steps, verbose=verbose) if not optimal_loss or loss < optimal_loss: optimal_params = curr_params optimal_loss = loss if t == trials: # break after trials number of attempts, unless trials is None break return optimal_params, optimal_loss