def test_rnn_with_cells(self): gru_cell1 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), 'gru1') gru_cell2 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), 'gru2') with tf.variable_scope('gru3') as vscope: gru_cell3 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), vscope) lstm_cell = tdl.ScopedLayer( tf.contrib.rnn.BasicLSTMCell(num_units=16), 'lstm') gru1 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >> tdb.Map(tdb.Scalar('int32') >> tdb.Function(tdl.Embedding(128, 8))) >> tdb.RNN(gru_cell1)) gru2 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >> tdb.Map(tdb.Scalar('int32') >> tdb.Function(tdl.Embedding(128, 8))) >> tdb.RNN(gru_cell2, initial_state=tf.ones(16))) gru3 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >> tdb.Map(tdb.Scalar('int32') >> tdb.Function(tdl.Embedding(128, 8))) >> tdb.RNN(gru_cell3, initial_state=tdb.FromTensor(tf.ones(16)))) lstm = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >> tdb.Map(tdb.Scalar('int32') >> tdb.Function(tdl.Embedding(128, 8))) >> tdb.RNN(lstm_cell)) with self.test_session(): gru1.eval('abcde') gru2.eval('ABCDE') gru3.eval('vghj') lstm.eval('123abc')
def test_embedding_raises(self): self.assertRaises(ValueError, tdl.Embedding, 2, 2, np.zeros([3, 3])) six.assertRaisesRegex(self, TypeError, 'Embeddings take scalar inputs.', tdl.Embedding(2, 2), tf.constant([[0, 0]], 'int32')) six.assertRaisesRegex(self, TypeError, 'Embeddings take integer inputs.', tdl.Embedding(2, 2), tf.constant([0], 'float32'))
def test_lstm_cell(self): # Test to make sure examples from the documentation compile. td = tdb num_hidden = 32 # Create an LSTM cell, the hard way. lstm_cell = td.Composition() with lstm_cell.scope(): in_state = td.Identity().reads(lstm_cell.input[1]) bx = td.Concat().reads(lstm_cell.input[0], in_state[1]) bi = td.Function(tdl.FC(num_hidden, tf.nn.sigmoid)).reads(bx) bf = td.Function(tdl.FC(num_hidden, tf.nn.sigmoid)).reads(bx) bo = td.Function(tdl.FC(num_hidden, tf.nn.sigmoid)).reads(bx) bg = td.Function(tdl.FC(num_hidden, tf.nn.tanh)).reads(bx) bc = td.Function(lambda c, i, f, g: c*f + i*g).reads( in_state[0], bi, bf, bg) by = td.Function(lambda c, o: tf.tanh(c) * o).reads(bc, bo) out_state = td.Identity().reads(bc, by) lstm_cell.output.reads(by, out_state) str_lstm = (td.InputTransform(lambda s: [ord(c) for c in s]) >> td.Map(td.Scalar('int32') >> td.Function(tdl.Embedding(128, 16))) >> td.RNN(lstm_cell, initial_state=td.AllOf(tf.zeros(32), tf.zeros(32)))) with self.test_session(): str_lstm.eval('The quick brown fox.')
def test_embedding(self): weights = np.array([[1, 2], [3, 4]], dtype='float32') embedding = tdl.Embedding(2, 2, initializer=weights) with self.test_session() as sess: embeddings = [embedding(tf.constant([x])) for x in [0, 1, 7, -5]] sess.run(tf.global_variables_initializer()) self.assertAllEqual([[[1, 2]], [[3, 4]], [[3, 4]], [[3, 4]]], sess.run(embeddings))
def test_embedding_initializer(self): embedding = tdl.Embedding(2, 2, initializer=tf.constant_initializer(1.0)) with self.test_session() as sess: embeddings = [embedding(tf.constant([x])) for x in [0, 1, 7]] sess.run(tf.global_variables_initializer()) self.assertAllEqual([[[1, 1]], [[1, 1]], [[1, 1]]], sess.run(embeddings))
def test_embedding_nomod(self): weights = np.array([[1, 2], [3, 4]], dtype='float32') embedding = tdl.Embedding(2, 2, initializer=weights, mod_inputs=False) with self.test_session() as sess: embeddings = [embedding(tf.constant([x])) for x in [0, 1]] sess.run(tf.global_variables_initializer()) self.assertAllEqual([[[1, 2]], [[3, 4]]], sess.run(embeddings)) self.assertRaises( Exception, # API doesn't specify what tf.gather() throws sess.run, embedding(tf.constant([2])))
def test_record_doc_example(self): # Test to make sure examples from the documentation compile. example_datum = {'id': 8, 'name': 'Joe Smith', 'location': (2.5, 7.0)} num_ids = 16 embed_len = 16 td = tdb char_rnn = (td.InputTransform(lambda s: [ord(c) for c in s]) >> td.Map(td.Scalar('int32') >> td.Function(tdl.Embedding(128, 16))) >> td.Fold(td.Concat() >> td.Function(tdl.FC(32)), td.FromTensor(tf.zeros(32)))) r = (td.Record([('id', (td.Scalar('int32') >> td.Function(tdl.Embedding(num_ids, embed_len)))), ('name', char_rnn), ('location', td.Vector(2))]) >> td.Concat() >> td.Function(tdl.FC(256))) with self.test_session(): r.eval(example_datum)
def test_hierarchical_rnn(self): char_cell = tdl.ScopedLayer( tf.contrib.rnn.BasicLSTMCell(num_units=16), 'char_cell') word_cell = tdl.ScopedLayer( tf.contrib.rnn.BasicLSTMCell(num_units=32), 'word_cell') char_lstm = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >> tdb.Map(tdb.Scalar('int32') >> tdb.Function(tdl.Embedding(128, 8))) >> tdb.RNN(char_cell)) word_lstm = (tdb.Map(char_lstm >> tdb.GetItem(1) >> tdb.Concat()) >> tdb.RNN(word_cell)) with self.test_session(): word_lstm.eval(['the', 'cat', 'sat', 'on', 'a', 'mat'])