コード例 #1
0
 def test_simple_rnn(self):
     np.random.seed(12082518)
     x = np.random.rand(128, 8, 32)
     #
     X = K.placeholder(shape=(None, 8, 32))
     X1 = K.placeholder(shape=(None, 8, 32))
     X2 = K.placeholder(shape=(None, 8, 32))
     X3 = K.placeholder(shape=(None, 8, 33))
     f = N.RNN(32, activation=K.relu, input_mode='skip')
     #
     y = f(X, mask=K.ones(shape=(128, 8)))
     graph = K.ComputationGraph(y)
     self.assertEqual(len(graph.inputs), 1)
     f1 = K.function([X], y)
     x1 = f1(x)
     # ====== different placeholder ====== #
     y = f(X1)
     f2 = K.function([X1], y)
     x2 = f1(x)
     self.assertEqual(np.sum(x1[0] == x2[0]), np.prod(x1[0].shape))
     # ====== pickle load ====== #
     f = cPickle.loads(cPickle.dumps(f))
     y = f(X2)
     f2 = K.function([X2], y)
     x3 = f2(x)
     self.assertEqual(np.sum(x2[0] == x3[0]), np.prod(x2[0].shape))
     # ====== other input shape ====== #
     error_happen = False
     try:
         y = f(X3)
         f3 = K.function([X3], y)
         x3 = f3(np.random.rand(128, 8, 33))
     except (ValueError, Exception):
         error_happen = True
     self.assertTrue(error_happen)
コード例 #2
0
ファイル: compare_test.py プロジェクト: liqin123/odin
 def odin_net3():
     "RNN"
     W = [random(28, 32), random(32, 32), random(32), random_bin(12, 28)]
     f = N.Sequence([
         N.Dense(num_units=32, W_init=W[0], b_init=W[2],
             activation=K.linear),
         N.RNN(num_units=32, activation=K.relu,
             W_init=W[1])
     ])
     return X1, f(X1, hid_init=zeros(1, 32))