Beispiel #1
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.max_T = -1
        self.initialized = True
        self.has_regressors = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot() # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_x = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)

        def _step(x, hid):
            next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)

        self._step = jax.jit(_step)
        return self.step()
Beispiel #2
0
    def initialize(self, n, m, h=64):
        """
        Description:
            Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """

        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_hh = glorot_init(generate_key(),
                                (4 * h, h))  # maps h_t to gates
        self.W_xh = glorot_init(generate_key(),
                                (4 * h, n))  # maps x_t to gates
        self.b_h = np.zeros(4 * h)
        jax.ops.index_update(self.b_h, jax.ops.index[h:2 * h],
                             np.ones(h))  # forget gate biased initialization
        self.W_out = glorot_init(generate_key(), (m, h))  # maps h_t to output
        self.cell = np.zeros(h)  # long-term memory
        self.hid = np.zeros(h)  # short-term memory
        self.sigmoid = lambda x: 1. / (1. + np.exp(
            -x))  # no JAX implementation of sigmoid it seems?
        return np.dot(self.W_out, self.hid)
Beispiel #3
0
def MaskedDense(mask, bias=True, W_init=glorot(), b_init=randn()):
    """
    As in jax.experimental.stax, each layer constructor function returns
    an (init_fun, apply_fun) pair, where `init_fun` takes an rng key and
    an input shape and returns an (output_shape, params) pair, and
    `apply_fun` takes params, inputs, and an rng key and applies the layer.

    :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer.
    :param bool bias: whether to include bias term.
    :param array W_init: initialization method for the weights.
    :param array b_init: initialization method for the bias terms.
    :return: a (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        W = W_init(k1, mask.shape)
        if bias:
            b = b_init(k2, mask.shape[-1:])
            params = (W, b)
        else:
            params = W
        return input_shape[:-1] + mask.shape[-1:], params

    def apply_fun(params, inputs, **kwargs):
        if bias:
            W, b = params
            return np.dot(inputs, W * mask) + b
        else:
            W = params
            return np.dot(inputs, W * mask)

    return init_fun, apply_fun
Beispiel #4
0
    def initialize(cls,
                   rng,
                   in_spec,
                   dim_out,
                   kernel_init=stax.glorot(),
                   bias_init=stax.zeros):
        """Initializes Dense Layer.

    Args:
      rng: Random key.
      in_spec: Input Spec.
      dim_out: Output dimensions.
      kernel_init: Kernel initialization function.
      bias_init: Bias initialization function.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
        if rng is None:
            raise ValueError('Need valid RNG to instantiate Dense layer.')
        dim_in = in_spec.shape[-1]
        k1, k2 = random.split(rng)
        params = DenseParams(
            base.create_parameter(k1, (dim_in, dim_out), init=kernel_init),
            base.create_parameter(k2, (dim_out, ), init=bias_init))
        return base.LayerParams(params)
Beispiel #5
0
    def new(input_size, output_size, hidden_layers, key):

        _randn_fn = randn()

        def vector_init(shape):
            if isinstance(shape, int):
                shape = (shape, )
            nonlocal key
            key, rng = random.split(key)
            return _randn_fn(rng, shape)

        _glorot_fn = glorot()

        def matrix_init(shape):
            nonlocal key
            key, rng = random.split(key)
            return _glorot_fn(rng, shape)

        input_state = vector_init(input_size)
        hidden_states = []
        for size in hidden_layers:
            hidden_states.append(vector_init(size))
        output_states = vector_init(output_size)
        states = [input_state, *hidden_states, output_states]

        # weights
        fwd_weights, bwd_weights = [], []
        for prev, post in zip(states[:-1], states[1:]):
            fwd_weights.append(matrix_init((prev.shape[0], post.shape[0])))
            bwd_weights.append(matrix_init((post.shape[0], prev.shape[0])))

        return LayeredNet(states, [*fwd_weights, *bwd_weights])
Beispiel #6
0
    def initialize(self, n = 1, m = None, p = 3, optimizer = OGD):
        """
        Description: Initializes autoregressive method parameters

        Args:
            p (int): Length of history used for prediction
            optimizer (class): optimizer choice
            loss (class): loss choice
            lr (float): learning rate for update
        """
        self.initialized = True
        self.n = n
        self.p = p

        self.past = np.zeros((p, self.n))

        glorot_init = stax.glorot() # returns a function that initializes weights

        # self.params = glorot_init(generate_key(), (p+1,1))
        self.params = {'phi' : glorot_init(generate_key(), (p+1,1))}

        def _update_past(self_past, x):
            new_past = np.roll(self_past, self.n)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past
        self._update_past = jax.jit(_update_past)

        def _predict(params, x):
            phi = list(params.values())[0]
            x_plus_bias = np.vstack((np.ones((1, self.n)), x))
            return np.dot(x_plus_bias.T, phi).squeeze()
        self._predict = jax.jit(_predict)

        self._store_optimizer(optimizer, self._predict)
Beispiel #7
0
    def initialize(self,
                   n=1,
                   m=1,
                   l=32,
                   h=64,
                   optimizer=OGD,
                   loss=mse,
                   lr=0.003):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            l (int): Length of memory for update step purposes.
            h (int): Default value 64. Hidden dimension of RNN.
            optimizer (class): optimizer choice
            loss (class): loss choice
            lr (float): learning rate for update
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.l, self.h = n, m, l, h

        # initialize parameters
        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        W_h = glorot_init(generate_key(), (h, h))
        W_x = glorot_init(generate_key(), (h, n))
        W_out = glorot_init(generate_key(), (m, h))
        b_h = np.zeros(h)
        # self.params = [W_h, W_x, W_out, b_h]
        self.params = {'W_h': W_h, 'W_x': W_x, 'W_out:': W_out, 'b_h': b_h}
        self.hid = np.zeros(h)
        self.x = np.zeros((l, n))
        """ private helper methods"""
        @jax.jit
        def _update_x(self_x, x):
            new_x = np.roll(self_x, -self.n)
            new_x = jax.ops.index_update(new_x, jax.ops.index[-1, :], x)
            return new_x

        @jax.jit
        def _fast_predict(carry, x):
            params, hid = carry  # unroll tuple in carry
            W_h, W_x, W_out, b_h = params.values()
            next_hid = np.tanh(np.dot(W_h, hid) + np.dot(W_x, x) + b_h)
            y = np.dot(W_out, next_hid)
            return (params, next_hid), y

        @jax.jit
        def _predict(params, x):
            _, y = jax.lax.scan(_fast_predict, (params, np.zeros(h)), x)
            return y[-1]

        self.transform = lambda x: float(x) if (self.m == 1) else x
        self._update_x = _update_x
        self._fast_predict = _fast_predict
        self._predict = _predict
        self._store_optimizer(optimizer, self._predict)
Beispiel #8
0
def Dense(name, out_dim, W_init=stax.glorot(), b_init=stax.randn()):
    """Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, example_input):
        input_shape = example_input.shape
        k1, k2 = random.split(rng)
        W, b = W_init(k1, (out_dim, input_shape[-1])), b_init(k2, (out_dim, ))
        return W, b

    def apply_fun(params, inputs):
        W, b = params
        return np.dot(W, inputs) + b

    return core.Layer(name, init_fun, apply_fun).bind
Beispiel #9
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """

        self.T = 0
        self.max_T = -1
        self.initialized = True
        self.has_regressors = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_hh = glorot_init(generate_key(),
                                (4 * h, h))  # maps h_t to gates
        self.W_xh = glorot_init(generate_key(),
                                (4 * h, n))  # maps x_t to gates
        self.b_h = np.zeros(4 * h)
        self.b_h = jax.ops.index_update(
            self.b_h, jax.ops.index[h:2 * h],
            np.ones(h))  # forget gate biased initialization
        self.W_out = glorot_init(generate_key(), (m, h))  # maps h_t to output
        self.cell = np.zeros(h)  # long-term memory
        self.hid = np.zeros(h)  # short-term memory

        def _step(x, hid, cell):
            sigmoid = lambda x: 1. / (1. + np.exp(
                -x))  # no JAX implementation of sigmoid it seems?
            gate = np.dot(self.W_hh, hid) + np.dot(self.W_xh, x) + self.b_h
            i, f, g, o = np.split(gate,
                                  4)  # order: input, forget, cell, output
            next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, next_cell, y)

        self._step = jax.jit(_step)
        return self.step()
Beispiel #10
0
    def initialize(self, n, m, h=64):
        """
        Description:
            Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot() # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_x = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)
        return np.dot(self.W_out, self.hid)
Beispiel #11
0
def conv_info(in_shape,
              out_chan,
              filter_shape,
              strides=None,
              padding='VALID',
              kernel_init=None,
              bias_init=stax.randn(1e-6),
              transpose=False):
    """Returns parameters and output shape information given input shapes."""
    # Essentially the `stax` implementation
    if len(in_shape) != 3:
        raise ValueError('Need to `jax.vmap` in order to batch')
    in_shape = (1, ) + in_shape
    lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS
    one = (1, ) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'),
                                             rhs_spec.index('I'))
    filter_shape_iter = iter(filter_shape)
    kernel_shape = tuple([
        out_chan if c == 'O' else
        in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter)
        for c in rhs_spec
    ])
    if transpose:
        out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape,
                                                   strides, padding,
                                                   DIMENSION_NUMBERS)
    else:
        out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape,
                                                 strides, padding,
                                                 DIMENSION_NUMBERS)
    bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
    bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
    out_shape = out_shape[1:]
    shapes = (out_shape, kernel_shape, bias_shape)
    inits = (kernel_init, bias_init)
    return shapes, inits, (strides, padding, one)
Beispiel #12
0
def MaskedDense(out_dim, mask, W_init=glorot(), b_init=randn()):
    """
    As in jax.experimental.stax, each layer constructor function returns
    an (init_fun, apply_fun) pair, where `init_fun` takes an rng key and
    an input shape and returns an (output_shape, params) pair, and
    `apply_fun` takes params, inputs, and an rng key and applies the layer.

    :param int out_dim: Number of output dimensions.
    :param array mask: Mask applied to the weights of the layer.
    :param array W_init: initialization method for the weights.
    :param array b_init: initialization method for the bias terms.
    :return: a (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = random.split(rng)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return np.dot(inputs, W * mask) + b

    return init_fun, apply_fun
    def initialize(self, params={'n': 1, "m": None, "p": 3, "optimizer": OGD}):
        """
		Description: Initializes autoregressive method parameters

		Args:
			p (int): Length of history used for prediction
			optimizer (class): optimizer choice
			loss (class): loss choice
			lr (float): learning rate for update
		"""
        self.initialized = True
        self.n = 1  #params['n']
        self.p = params['p']

        self.past = np.zeros((self.p, self.n))

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights

        # self.params = glorot_init(generate_key(), (p+1,1))
        self.params = {'phi': glorot_init(generate_key(), (self.p, 1))}

        def _update_past(self_past, x):
            new_past = np.roll(self_past, self.n)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past

        self._update_past = jax.jit(_update_past)

        def _predict(params, x):
            phi = list(params.values())[0]
            return np.dot(x.T, phi).squeeze()

        self._predict = jax.jit(_predict)

        self._store_optimizer(params['optimizer'], self._predict)
Beispiel #14
0
def create_parameter(rng, spec, init=stax.glorot()):
    return init(rng, spec)
Beispiel #15
0
 def testGlorotInitShape(self, shape):
     key = random.PRNGKey(0)
     out = stax.glorot()(key, shape)
     self.assertEqual(out.shape, shape)
    def initialize(self, n=1, m=None, p=3, d=1, optimizer=OGD):
        """
		Description: Initializes autoregressive method parameters

		Args:
			n (int): dimension of the data
			p (int): Length of history used for prediction
			d (int): number of difference orders to use. For zero the original autoregressor is used.
			optimizer (class): optimizer choice
		"""
        self.initialized = True
        self.n = n
        self.p = p
        self.d = d

        self.past = np.zeros(
            (p + d,
             self.n))  #store the last d x values to compute the differences

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights

        # self.params = glorot_init(generate_key(), (p+1,1))
        self.params = {'phi': glorot_init(generate_key(), (p + 1, 1))}

        def _update_past(self_past, x):
            new_past = np.roll(self_past, 1)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past

        self._update_past = jax.jit(_update_past)
        '''	unused (easy, but also inefficient version)
		def _computeDifference(x, d):
			if(d < 0):
				return 0.0
			if(d == 0):
				return x[0]
			else:
				return _computeDifference(x, d - 1) - _computeDifference(x[1:], d - 1)
		self._computeDifference = jax.jit(_computeDifference)'''

        def _getDifference(index, d, matrix):
            if (d < 0):
                return np.zeros(matrix[0, 0].shape)
            else:
                return matrix[d, index]

        self._getDifference = jax.jit(_getDifference)

        def _computeDifferenceMatrix(x, d):
            result = np.zeros((d + 1, x.shape[0], x.shape[1]))

            #first row (zeroth difference is the original time series)
            result = jax.ops.index_update(result, jax.ops.index[0, 0:len(x)],
                                          x)

            #fill the next rows
            for k in range(1, d + 1):
                result = jax.ops.index_update(
                    result, jax.ops.index[k, 0:len(x) - k - 1],
                    result[k - 1, 0:len(x) - k - 1] -
                    result[k - 1, 1:len(x) - k])

            return result

        self._computeDifferenceMatrix = jax.jit(_computeDifferenceMatrix)

        def _predict(params, x):
            phi = list(params.values())[0]
            differences = _computeDifferenceMatrix(x, self.d)
            x_plus_bias = np.vstack(
                (np.ones((1, self.n)),
                 np.array([
                     _getDifference(i, self.d, differences)
                     for i in range(0, p)
                 ])))
            return np.dot(x_plus_bias.T, phi).squeeze() + np.sum(
                [_getDifference(0, k, differences) for k in range(self.d)])

        self._predict = jax.jit(_predict)

        def _getUpdateValues(x):
            diffMatrix = _computeDifferenceMatrix(x, self.d)
            differences = np.array([
                _getDifference(i, self.d, diffMatrix)
                for i in range(1, self.p + 1)
            ])
            label = _getDifference(0, self.d, diffMatrix)
            return differences, label

        self._getUpdateValues = jax.jit(_getUpdateValues)

        self._store_optimizer(optimizer, self._predict)
Beispiel #17
0
    def initialize(self, u_dim, y_dim, hid_dim=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            u_dim (int): Input dimension.
            y_dim (int): Observation/output dimension.
            hid_dim (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """

        self.T = 0
        self.initialized = True
        # self.n, self.m, self.h = n, m, h

        self.u_dim = u_dim  # input dimension
        self.y_dim = y_dim  # output dimension
        self.hid_dim = hid_dim  # hidden state dimension
        self.cell_dim = hid_dim  # observable state dimension

        # self.m = self.y_dim # state dimension
        # self.n = self.u_dim # input dimension
        self.rollout_controller = None
        self.target = jax.random.uniform(generate_key(),
                                         shape=(self.y_dim, ),
                                         minval=-1,
                                         maxval=1)

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_hh = glorot_init(
            generate_key(),
            (4 * self.hid_dim, self.hid_dim))  # maps h_t to gates
        self.W_uh = glorot_init(
            generate_key(),
            (4 * self.hid_dim, self.u_dim))  # maps x_t to gates
        self.b_h = np.zeros(4 * self.hid_dim)
        self.b_h = jax.ops.index_update(
            self.b_h, jax.ops.index[self.hid_dim:2 * self.hid_dim],
            np.ones(self.hid_dim))  # forget gate biased initialization
        self.W_out = glorot_init(
            generate_key(), (self.y_dim, self.hid_dim))  # maps h_t to output
        # self.cell = np.zeros(self.hid_dim) # long-term memory
        # self.hid = np.zeros(self.hid_dim) # short-term memory
        self.hid_cell = np.hstack(
            (np.zeros(self.hid_dim), np.zeros(self.hid_dim)))
        '''
        def _step(x, hid, cell):
            sigmoid = lambda x: 1. / (1. + np.exp(-x)) # no JAX implementation of sigmoid it seems?
            gate = np.dot(self.W_hh, hid) + np.dot(self.W_uh, x) + self.b_h 
            i, f, g, o = np.split(gate, 4) # order: input, forget, cell, output
            next_cell =  sigmoid(f) * cell + sigmoid(i) * np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, next_cell, y)'''
        def _dynamics(hid_cell_state, u):
            hid = hid_cell_state[:self.hid_dim]
            cell = hid_cell_state[self.hid_dim:]

            sigmoid = lambda u: 1. / (1. + np.exp(-u))
            gate = np.dot(self.W_hh, hid) + np.dot(self.W_uh, u) + self.b_h
            i, f, g, o = np.split(gate,
                                  4)  # order: input, forget, cell, output
            next_cell = sigmoid(f) * cell + sigmoid(i) + np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(self.W_out, next_hid)
            return (np.hstack((next_hid, next_cell)), y)

        self._dynamics = jax.jit(
            _dynamics
        )  # MUST store as self._dynamics for default rollout implementation to work
        # C_x, C_u = (np.diag(np.array([0.2, 0.05, 1.0, 0.05])), np.diag(np.array([0.05])))
        # self._loss = jax.jit(lambda x, u: x.T @ C_x @ x + u.T @ C_u @ u) # MUST store as self._loss
        self._loss = lambda x, u: (self.target - self._dynamics(x, u))**2

        # stack the jacobians of environment dynamics gradient
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        self._dynamics_jacobian = jax.jit(
            lambda x, u: np.hstack(jacobian(x, u)))

        # stack the gradients of environment loss
        loss_grad = jax.grad(self._loss, argnums=(0, 1))
        self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u)))

        # block the hessian of environment loss
        block_hessian = lambda A: np.vstack(
            [np.hstack([A[0][0], A[0][1]]),
             np.hstack([A[1][0], A[1][1]])])
        hessian = jax.hessian(self._loss, argnums=(0, 1))
        self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u)))

        def _rollout(act, dyn, x_0, T):
            def f(x, i):
                u = act(x)
                x_next = dyn(x, u)
                return x_next, np.hstack((x, u))

            _, trajectory = jax.lax.scan(f, x_0, np.arange(T))
            return trajectory

        self._rollout = jax.jit(_rollout, static_argnums=(0, 1, 3))

        # self._step = jax.jit(_step)
        # return np.dot(self.W_out, self.hid)
        return np.dot(self.W_out, self.hid_cell[:self.hid_dim])
Beispiel #18
0
 def testGlorotInitShape(self, shape):
     out = stax.glorot()(shape)
     self.assertEqual(out.shape, shape)
Beispiel #19
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_u = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)

        self.rollout_controller = None
        self.target = jax.random.uniform(generate_key(),
                                         shape=(self.m, ),
                                         minval=-1,
                                         maxval=1)
        '''
        def _step(x, hid):
            next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)'''
        def _dynamics(hid, u):
            next_hid = np.tanh(
                np.dot(self.W_h, hid) + np.dot(self.W_u, u) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)

        # self._step = jax.jit(_step)
        self._dynamics = jax.jit(_dynamics)

        self._loss = lambda x, u: (self.target - self._dynamics(x, u))**2

        # stack the jacobians of environment dynamics gradient
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        self._dynamics_jacobian = jax.jit(
            lambda x, u: np.hstack(jacobian(x, u)))

        # stack the gradients of environment loss
        loss_grad = jax.grad(self._loss, argnums=(0, 1))
        self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u)))

        # block the hessian of environment loss
        block_hessian = lambda A: np.vstack(
            [np.hstack([A[0][0], A[0][1]]),
             np.hstack([A[1][0], A[1][1]])])
        hessian = jax.hessian(self._loss, argnums=(0, 1))
        self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u)))

        def _rollout(act, dyn, x_0, T):
            def f(x, i):
                u = act(x)
                x_next = dyn(x, u)
                return x_next, np.hstack((x, u))

            _, trajectory = jax.lax.scan(f, x_0, np.arange(T))
            return trajectory

        self._rollout = jax.jit(_rollout, static_argnums=(0, 1, 3))
        return np.dot(self.W_out, self.hid)
Beispiel #20
0
    def initialize(self, n=1, m=1, l=32, h=64, optimizer=OGD):
        """
        Description: Randomly initialize the LSTM.
        Args:
            n (int): Observation/output dimension.
            m (int): Input action dimension.
            l (int): Length of memory for update step purposes.
            h (int): Default value 64. Hidden dimension of LSTM.
            optimizer (class): optimizer choice
            loss (class): loss choice
            lr (float): learning rate for update
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.l, self.h = n, m, l, h

        # initialize parameters
        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        W_hh = glorot_init(generate_key(), (4 * h, h))  # maps h_t to gates
        W_xh = glorot_init(generate_key(), (4 * h, n))  # maps x_t to gates
        W_out = glorot_init(generate_key(), (m, h))  # maps h_t to output
        b_h = np.zeros(4 * h)
        b_h = jax.ops.index_update(
            b_h, jax.ops.index[h:2 * h],
            np.ones(h))  # forget gate biased initialization
        # self.params = [W_hh, W_xh, W_out, b_h]
        self.params = {'W_hh': W_hh, 'W_xh': W_xh, 'W_out': W_out, 'b_h': b_h}
        self.hid = np.zeros(h)
        self.cell = np.zeros(h)
        self.x = np.zeros((l, m))
        """ private helper methods"""

        #@jax.jit
        def _update_x(self_x, x):
            new_x = np.roll(self_x, -self.n)
            new_x = jax.ops.index_update(new_x, jax.ops.index[-1, :], x)
            return new_x

        @jax.jit
        def _fast_predict(carry, x):
            params, hid, cell = carry  # unroll tuple in carry
            W_hh, W_xh, W_out, b_h = params.values()
            sigmoid = lambda x: 1. / (1. + np.exp(
                -x))  # no JAX implementation of sigmoid it seems?
            gate = np.dot(W_hh, hid) + np.dot(W_xh, x) + b_h
            i, f, g, o = np.split(gate,
                                  4)  # order: input, forget, cell, output
            next_cell = sigmoid(f) * cell + sigmoid(i) * np.tanh(g)
            next_hid = sigmoid(o) * np.tanh(next_cell)
            y = np.dot(W_out, next_hid)
            return (params, next_hid, next_cell), y

        @jax.jit
        def _predict(params, x):
            _, y = jax.lax.scan(_fast_predict,
                                (params, np.zeros(h), np.zeros(h)), x)
            return y[-1]

        self.transform = lambda x: float(x) if (self.m == 1) else x
        self._update_x = _update_x
        self._fast_predict = _fast_predict
        self._predict = _predict
        self._store_optimizer(optimizer, self._predict)
Beispiel #21
0
 def __init__(self, out_dim, W_init=stax.glorot(), b_init=stax.randn()):
     super(Dense, self).__init__()
     self.out_dim = out_dim
     self.W_init = W_init
     self.b_init = b_init