Example #1
0
  def test_rnn_with_cells(self):
    gru_cell1 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), 'gru1')
    gru_cell2 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), 'gru2')

    with tf.variable_scope('gru3') as vscope:
      gru_cell3 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), vscope)

    lstm_cell = tdl.ScopedLayer(
        tf.contrib.rnn.BasicLSTMCell(num_units=16), 'lstm')

    gru1 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(gru_cell1))

    gru2 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(gru_cell2, initial_state=tf.ones(16)))

    gru3 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(gru_cell3, initial_state=tdb.FromTensor(tf.ones(16))))

    lstm = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(lstm_cell))

    with self.test_session():
      gru1.eval('abcde')
      gru2.eval('ABCDE')
      gru3.eval('vghj')
      lstm.eval('123abc')
Example #2
0
 def test_embedding_raises(self):
     self.assertRaises(ValueError, tdl.Embedding, 2, 2, np.zeros([3, 3]))
     six.assertRaisesRegex(self, TypeError,
                           'Embeddings take scalar inputs.',
                           tdl.Embedding(2, 2),
                           tf.constant([[0, 0]], 'int32'))
     six.assertRaisesRegex(self, TypeError,
                           'Embeddings take integer inputs.',
                           tdl.Embedding(2, 2), tf.constant([0], 'float32'))
Example #3
0
  def test_lstm_cell(self):
    # Test to make sure examples from the documentation compile.
    td = tdb
    num_hidden = 32

    # Create an LSTM cell, the hard way.
    lstm_cell = td.Composition()
    with lstm_cell.scope():
      in_state = td.Identity().reads(lstm_cell.input[1])
      bx = td.Concat().reads(lstm_cell.input[0], in_state[1])
      bi = td.Function(tdl.FC(num_hidden, tf.nn.sigmoid)).reads(bx)
      bf = td.Function(tdl.FC(num_hidden, tf.nn.sigmoid)).reads(bx)
      bo = td.Function(tdl.FC(num_hidden, tf.nn.sigmoid)).reads(bx)
      bg = td.Function(tdl.FC(num_hidden, tf.nn.tanh)).reads(bx)
      bc = td.Function(lambda c, i, f, g: c*f + i*g).reads(
          in_state[0], bi, bf, bg)
      by = td.Function(lambda c, o: tf.tanh(c) * o).reads(bc, bo)
      out_state = td.Identity().reads(bc, by)
      lstm_cell.output.reads(by, out_state)

    str_lstm = (td.InputTransform(lambda s: [ord(c) for c in s]) >>
                td.Map(td.Scalar('int32') >>
                       td.Function(tdl.Embedding(128, 16))) >>
                td.RNN(lstm_cell,
                       initial_state=td.AllOf(tf.zeros(32), tf.zeros(32))))

    with self.test_session():
      str_lstm.eval('The quick brown fox.')
Example #4
0
 def test_embedding(self):
     weights = np.array([[1, 2], [3, 4]], dtype='float32')
     embedding = tdl.Embedding(2, 2, initializer=weights)
     with self.test_session() as sess:
         embeddings = [embedding(tf.constant([x])) for x in [0, 1, 7, -5]]
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([[[1, 2]], [[3, 4]], [[3, 4]], [[3, 4]]],
                             sess.run(embeddings))
Example #5
0
 def test_embedding_initializer(self):
     embedding = tdl.Embedding(2,
                               2,
                               initializer=tf.constant_initializer(1.0))
     with self.test_session() as sess:
         embeddings = [embedding(tf.constant([x])) for x in [0, 1, 7]]
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([[[1, 1]], [[1, 1]], [[1, 1]]],
                             sess.run(embeddings))
Example #6
0
 def test_embedding_nomod(self):
     weights = np.array([[1, 2], [3, 4]], dtype='float32')
     embedding = tdl.Embedding(2, 2, initializer=weights, mod_inputs=False)
     with self.test_session() as sess:
         embeddings = [embedding(tf.constant([x])) for x in [0, 1]]
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([[[1, 2]], [[3, 4]]], sess.run(embeddings))
         self.assertRaises(
             Exception,  # API doesn't specify what tf.gather() throws
             sess.run,
             embedding(tf.constant([2])))
Example #7
0
 def test_record_doc_example(self):
   # Test to make sure examples from the documentation compile.
   example_datum = {'id': 8,
                    'name': 'Joe Smith',
                    'location': (2.5, 7.0)}
   num_ids = 16
   embed_len = 16
   td = tdb
   char_rnn = (td.InputTransform(lambda s: [ord(c) for c in s]) >>
               td.Map(td.Scalar('int32') >>
                      td.Function(tdl.Embedding(128, 16))) >>
               td.Fold(td.Concat() >> td.Function(tdl.FC(32)),
                       td.FromTensor(tf.zeros(32))))
   r = (td.Record([('id', (td.Scalar('int32') >>
                           td.Function(tdl.Embedding(num_ids, embed_len)))),
                   ('name', char_rnn),
                   ('location', td.Vector(2))])
        >> td.Concat() >> td.Function(tdl.FC(256)))
   with self.test_session():
     r.eval(example_datum)
Example #8
0
  def test_hierarchical_rnn(self):
    char_cell = tdl.ScopedLayer(
        tf.contrib.rnn.BasicLSTMCell(num_units=16), 'char_cell')
    word_cell = tdl.ScopedLayer(
        tf.contrib.rnn.BasicLSTMCell(num_units=32), 'word_cell')

    char_lstm = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
                 tdb.Map(tdb.Scalar('int32') >>
                         tdb.Function(tdl.Embedding(128, 8))) >>
                 tdb.RNN(char_cell))
    word_lstm = (tdb.Map(char_lstm >> tdb.GetItem(1) >> tdb.Concat()) >>
                 tdb.RNN(word_cell))

    with self.test_session():
      word_lstm.eval(['the', 'cat', 'sat', 'on', 'a', 'mat'])