Exemple #1
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.max_T = -1
        self.initialized = True
        self.has_regressors = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot() # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_x = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)

        def _step(x, hid):
            next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)

        self._step = jax.jit(_step)
        return self.step()
Exemple #2
0
    def initialize(self,
                   n=1,
                   m=1,
                   l=32,
                   h=64,
                   optimizer=OGD,
                   loss=mse,
                   lr=0.003):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            l (int): Length of memory for update step purposes.
            h (int): Default value 64. Hidden dimension of RNN.
            optimizer (class): optimizer choice
            loss (class): loss choice
            lr (float): learning rate for update
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.l, self.h = n, m, l, h

        # initialize parameters
        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        W_h = glorot_init(generate_key(), (h, h))
        W_x = glorot_init(generate_key(), (h, n))
        W_out = glorot_init(generate_key(), (m, h))
        b_h = np.zeros(h)
        # self.params = [W_h, W_x, W_out, b_h]
        self.params = {'W_h': W_h, 'W_x': W_x, 'W_out:': W_out, 'b_h': b_h}
        self.hid = np.zeros(h)
        self.x = np.zeros((l, n))
        """ private helper methods"""
        @jax.jit
        def _update_x(self_x, x):
            new_x = np.roll(self_x, -self.n)
            new_x = jax.ops.index_update(new_x, jax.ops.index[-1, :], x)
            return new_x

        @jax.jit
        def _fast_predict(carry, x):
            params, hid = carry  # unroll tuple in carry
            W_h, W_x, W_out, b_h = params.values()
            next_hid = np.tanh(np.dot(W_h, hid) + np.dot(W_x, x) + b_h)
            y = np.dot(W_out, next_hid)
            return (params, next_hid), y

        @jax.jit
        def _predict(params, x):
            _, y = jax.lax.scan(_fast_predict, (params, np.zeros(h)), x)
            return y[-1]

        self.transform = lambda x: float(x) if (self.m == 1) else x
        self._update_x = _update_x
        self._fast_predict = _fast_predict
        self._predict = _predict
        self._store_optimizer(optimizer, self._predict)
Exemple #3
0
def test_random(show=False):
    set_key(5)
    a1 = get_global_key()
    r1 = generate_key()
    set_key(5)
    a2 = get_global_key()
    r2 = generate_key()
    assert str(a1) == str(a2)
    assert str(r1) == str(r2)
    print("test_random passed")
Exemple #4
0
    def initialize(self, n = 1, m = None, p = 3, optimizer = OGD):
        """
        Description: Initializes autoregressive method parameters

        Args:
            p (int): Length of history used for prediction
            optimizer (class): optimizer choice
            loss (class): loss choice
            lr (float): learning rate for update
        """
        self.initialized = True
        self.n = n
        self.p = p

        self.past = np.zeros((p, self.n))

        glorot_init = stax.glorot() # returns a function that initializes weights

        # self.params = glorot_init(generate_key(), (p+1,1))
        self.params = {'phi' : glorot_init(generate_key(), (p+1,1))}

        def _update_past(self_past, x):
            new_past = np.roll(self_past, self.n)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past
        self._update_past = jax.jit(_update_past)

        def _predict(params, x):
            phi = list(params.values())[0]
            x_plus_bias = np.vstack((np.ones((1, self.n)), x))
            return np.dot(x_plus_bias.T, phi).squeeze()
        self._predict = jax.jit(_predict)

        self._store_optimizer(optimizer, self._predict)
Exemple #5
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """

        self.T = 0
        self.max_T = -1
        self.initialized = True
        self.has_regressors = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_hh = glorot_init(generate_key(),
                                (4 * h, h))  # maps h_t to gates
        self.W_xh = glorot_init(generate_key(),
                                (4 * h, n))  # maps x_t to gates
        self.b_h = np.zeros(4 * h)
        self.b_h = jax.ops.index_update(
            self.b_h, jax.ops.index[h:2 * h],
            np.ones(h))  # forget gate biased initialization
        self.W_out = glorot_init(generate_key(), (m, h))  # maps h_t to output
        self.cell = np.zeros(h)  # long-term memory
        self.hid = np.zeros(h)  # short-term memory

        def _step(x, hid, cell):
            sigmoid = lambda x: 1. / (1. + np.exp(
                -x))  # no JAX implementation of sigmoid it seems?
            gate = np.dot(self.W_hh, hid) + np.dot(self.W_xh, x) + self.b_h
            i, f, g, o = np.split(gate,
                                  4)  # order: input, forget, cell, output
            next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, next_cell, y)

        self._step = jax.jit(_step)
        return self.step()
Exemple #6
0
 def step(self):
     """
     Description: Moves the system dynamics one time-step forward.
     Args:
         None
     Returns:
         The next value in the time-series.
     """
     assert self.initialized
     self.T += 1
     return random.normal(generate_key())
Exemple #7
0
 def initialize(self):
     """
     Description: Randomly initialize the hidden dynamics of the system.
     Args:
         None
     Returns:
         None
     """
     self.T = 0
     self.max_T = -1
     self.initialized = True
     return random.normal(generate_key())
Exemple #8
0
 def step(self):
     """
     Description: Takes an input and produces the next output of the RNN.
     Args:
         x (numpy.ndarray): RNN input, an n-dimensional real-valued vector.
     Returns:
         The output of the RNN computed on the past l inputs, including the new x.
     """
     assert self.initialized
     self.T += 1
     x = random.normal(generate_key(), shape=(self.n, ))
     self.hid, self.cell, y = self._step(x, self.hid, self.cell)
     return x, y
    def initialize(self, params={'n': 1, "m": None, "p": 3, "optimizer": OGD}):
        """
		Description: Initializes autoregressive method parameters

		Args:
			p (int): Length of history used for prediction
			optimizer (class): optimizer choice
			loss (class): loss choice
			lr (float): learning rate for update
		"""
        self.initialized = True
        self.n = 1  #params['n']
        self.p = params['p']

        self.past = np.zeros((self.p, self.n))

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights

        # self.params = glorot_init(generate_key(), (p+1,1))
        self.params = {'phi': glorot_init(generate_key(), (self.p, 1))}

        def _update_past(self_past, x):
            new_past = np.roll(self_past, self.n)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past

        self._update_past = jax.jit(_update_past)

        def _predict(params, x):
            phi = list(params.values())[0]
            return np.dot(x.T, phi).squeeze()

        self._predict = jax.jit(_predict)

        self._store_optimizer(params['optimizer'], self._predict)
    def initialize(self, n=1, m=None, p=3, d=1, optimizer=OGD):
        """
		Description: Initializes autoregressive method parameters

		Args:
			n (int): dimension of the data
			p (int): Length of history used for prediction
			d (int): number of difference orders to use. For zero the original autoregressor is used.
			optimizer (class): optimizer choice
		"""
        self.initialized = True
        self.n = n
        self.p = p
        self.d = d

        self.past = np.zeros(
            (p + d,
             self.n))  #store the last d x values to compute the differences

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights

        # self.params = glorot_init(generate_key(), (p+1,1))
        self.params = {'phi': glorot_init(generate_key(), (p + 1, 1))}

        def _update_past(self_past, x):
            new_past = np.roll(self_past, 1)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past

        self._update_past = jax.jit(_update_past)
        '''	unused (easy, but also inefficient version)
		def _computeDifference(x, d):
			if(d < 0):
				return 0.0
			if(d == 0):
				return x[0]
			else:
				return _computeDifference(x, d - 1) - _computeDifference(x[1:], d - 1)
		self._computeDifference = jax.jit(_computeDifference)'''

        def _getDifference(index, d, matrix):
            if (d < 0):
                return np.zeros(matrix[0, 0].shape)
            else:
                return matrix[d, index]

        self._getDifference = jax.jit(_getDifference)

        def _computeDifferenceMatrix(x, d):
            result = np.zeros((d + 1, x.shape[0], x.shape[1]))

            #first row (zeroth difference is the original time series)
            result = jax.ops.index_update(result, jax.ops.index[0, 0:len(x)],
                                          x)

            #fill the next rows
            for k in range(1, d + 1):
                result = jax.ops.index_update(
                    result, jax.ops.index[k, 0:len(x) - k - 1],
                    result[k - 1, 0:len(x) - k - 1] -
                    result[k - 1, 1:len(x) - k])

            return result

        self._computeDifferenceMatrix = jax.jit(_computeDifferenceMatrix)

        def _predict(params, x):
            phi = list(params.values())[0]
            differences = _computeDifferenceMatrix(x, self.d)
            x_plus_bias = np.vstack(
                (np.ones((1, self.n)),
                 np.array([
                     _getDifference(i, self.d, differences)
                     for i in range(0, p)
                 ])))
            return np.dot(x_plus_bias.T, phi).squeeze() + np.sum(
                [_getDifference(0, k, differences) for k in range(self.d)])

        self._predict = jax.jit(_predict)

        def _getUpdateValues(x):
            diffMatrix = _computeDifferenceMatrix(x, self.d)
            differences = np.array([
                _getDifference(i, self.d, diffMatrix)
                for i in range(1, self.p + 1)
            ])
            label = _getDifference(0, self.d, diffMatrix)
            return differences, label

        self._getUpdateValues = jax.jit(_getUpdateValues)

        self._store_optimizer(optimizer, self._predict)
Exemple #11
0
    def initialize(self, n=1, m=1, l=32, h=64, optimizer=OGD):
        """
        Description: Randomly initialize the LSTM.
        Args:
            n (int): Observation/output dimension.
            m (int): Input action dimension.
            l (int): Length of memory for update step purposes.
            h (int): Default value 64. Hidden dimension of LSTM.
            optimizer (class): optimizer choice
            loss (class): loss choice
            lr (float): learning rate for update
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.l, self.h = n, m, l, h

        # initialize parameters
        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        W_hh = glorot_init(generate_key(), (4 * h, h))  # maps h_t to gates
        W_xh = glorot_init(generate_key(), (4 * h, n))  # maps x_t to gates
        W_out = glorot_init(generate_key(), (m, h))  # maps h_t to output
        b_h = np.zeros(4 * h)
        b_h = jax.ops.index_update(
            b_h, jax.ops.index[h:2 * h],
            np.ones(h))  # forget gate biased initialization
        # self.params = [W_hh, W_xh, W_out, b_h]
        self.params = {'W_hh': W_hh, 'W_xh': W_xh, 'W_out': W_out, 'b_h': b_h}
        self.hid = np.zeros(h)
        self.cell = np.zeros(h)
        self.x = np.zeros((l, m))
        """ private helper methods"""

        #@jax.jit
        def _update_x(self_x, x):
            new_x = np.roll(self_x, -self.n)
            new_x = jax.ops.index_update(new_x, jax.ops.index[-1, :], x)
            return new_x

        @jax.jit
        def _fast_predict(carry, x):
            params, hid, cell = carry  # unroll tuple in carry
            W_hh, W_xh, W_out, b_h = params.values()
            sigmoid = lambda x: 1. / (1. + np.exp(
                -x))  # no JAX implementation of sigmoid it seems?
            gate = np.dot(W_hh, hid) + np.dot(W_xh, x) + b_h
            i, f, g, o = np.split(gate,
                                  4)  # order: input, forget, cell, output
            next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(W_out, next_hid)
            return (params, next_hid, next_cell), y

        @jax.jit
        def _predict(params, x):
            _, y = jax.lax.scan(_fast_predict,
                                (params, np.zeros(h), np.zeros(h)), x)
            return y[-1]

        self.transform = lambda x: float(x) if (self.m == 1) else x
        self._update_x = _update_x
        self._fast_predict = _fast_predict
        self._predict = _predict
        self._store_optimizer(optimizer, self._predict)
Exemple #12
0
    def search(self,
               method_id,
               method_params,
               problem_id,
               problem_params,
               loss,
               search_space,
               trials=None,
               smoothing=10,
               min_steps=100,
               verbose=0):
        """
        Description: Search for optimal method parameters
        Args:
            method_id (string): id of method
            method_params (dict): initial method parameters dict (updated by search space)
            problem_id (string): id of problem to try on
            problem_params (dict): problem parameters dict
            loss (function): a function mapping y_pred, y_true -> scalar loss
            search_space (dict): dict mapping parameter names to a finite set of options
            trials (int, None): number of random trials to sample from search space / try all parameters
            smoothing (int): loss computed over smoothing number of steps to decrease variance
            min_steps (int): minimum number of steps that the method gets to run for
            verbose (int): if 1, print progress and current parameters
        """
        self.method_id = method_id
        self.method_params = method_params
        self.problem_id = problem_id
        self.problem_params = problem_params
        self.loss = loss

        # store the order to test parameters
        param_list = list(
            itertools.product(*[v for k, v in search_space.items()]))
        index = np.arange(
            len(param_list)
        )  # np.random.shuffle doesn't work directly on non-JAX objects
        shuffled_index = random.shuffle(generate_key(), index)
        param_order = [param_list[int(i)]
                       for i in shuffled_index]  # shuffle order of elements

        # helper method
        def _update_smoothing(l, val):
            """ update smoothing loss list with new val """
            return jax.ops.index_update(np.roll(l, 1), 0, val)

        self._update_smoothing = jit(_update_smoothing)

        # store optimal params and optimal loss
        optimal_params, optimal_loss = {}, None
        t = 0
        for params in param_order:  # loop over all params in the given order
            t += 1
            curr_params = method_params.copy()
            curr_params.update(
                {k: v
                 for k, v in zip(search_space.keys(), params)})
            loss = self._run_test(curr_params,
                                  smoothing=smoothing,
                                  min_steps=min_steps,
                                  verbose=verbose)
            if not optimal_loss or loss < optimal_loss:
                optimal_params = curr_params
                optimal_loss = loss
            if t == trials:  # break after trials number of attempts, unless trials is None
                break
        return optimal_params, optimal_loss