コード例 #1
0
    def initialize(self, n, m, h=64):
        """
        Description:
            Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """

        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_hh = glorot_init(generate_key(),
                                (4 * h, h))  # maps h_t to gates
        self.W_xh = glorot_init(generate_key(),
                                (4 * h, n))  # maps x_t to gates
        self.b_h = np.zeros(4 * h)
        jax.ops.index_update(self.b_h, jax.ops.index[h:2 * h],
                             np.ones(h))  # forget gate biased initialization
        self.W_out = glorot_init(generate_key(), (m, h))  # maps h_t to output
        self.cell = np.zeros(h)  # long-term memory
        self.hid = np.zeros(h)  # short-term memory
        self.sigmoid = lambda x: 1. / (1. + np.exp(
            -x))  # no JAX implementation of sigmoid it seems?
        return np.dot(self.W_out, self.hid)
コード例 #2
0
ファイル: test_random.py プロジェクト: paula-gradu/ctsb
def test_random():
    set_key(5)
    a1 = get_global_key()
    r1 = generate_key()
    set_key(5)
    a2 = get_global_key()
    r2 = generate_key()
    assert str(a1) == str(a2)
    assert str(r1) == str(r2)
    print("test_random passed")
コード例 #3
0
ファイル: test_lstm_output.py プロジェクト: paula-gradu/ctsb
def test_lstm(steps=100, show_plot=False, verbose=False):
    T = steps
    n, m, l, h = 5, 3, 5, 10
    problem = LSTM_Output()
    problem.initialize(n, m, l, h)
    assert problem.T == 0

    test_output = []
    for t in range(T):
        if verbose and (t+1) * 10 % T == 0:
            print("{} timesteps".format(t+1))
        u = random.normal(generate_key(), shape=(n,))
        test_output.append(problem.step(u))

    info = problem.hidden()
    if verbose:
        print(info)
        
    assert problem.T == T
    if show_plot:
        plt.plot(test_output)
        plt.title("lstm")
        plt.show(block=False)
        plt.pause(1)
        plt.close()
    print("test_lstm passed")
    return
コード例 #4
0
ファイル: test_lds.py プロジェクト: paula-gradu/ctsb
def test_lds(steps=100, show_plot=False):
    T = steps
    n, m, d = 5, 3, 10
    problem = LDS()
    problem.initialize(n, m, d)
    assert problem.T == 0

    test_output = []
    for t in range(T):
        u = random.normal(generate_key(), shape=(n, ))
        test_output.append(problem.step(u))

    info = problem.hidden()
    if verbose:
        print(info)

    assert problem.T == T
    if show_plot:
        plt.plot(test_output)
        plt.title("lds")
        plt.show(block=False)
        plt.pause(1)
        plt.close()
    print("test_lds passed")
    return
コード例 #5
0
 def step(self):
     """
     Description:
         Moves the system dynamics one time-step forward.
     Args:
         None
     Returns:
         The next value in the time-series.
     """
     assert self.initialized
     self.T += 1
     return random.normal(generate_key())
コード例 #6
0
ファイル: rnn_output.py プロジェクト: paula-gradu/ctsb
    def initialize(self, n, m, h=64):
        """
        Description:
            Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot() # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_x = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)
        return np.dot(self.W_out, self.hid)