コード例 #1
0
ファイル: blocks_test.py プロジェクト: yalechang/fold
  def test_rnn_with_cells(self):
    gru_cell1 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), 'gru1')
    gru_cell2 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), 'gru2')

    with tf.variable_scope('gru3') as vscope:
      gru_cell3 = tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=16), vscope)

    lstm_cell = tdl.ScopedLayer(
        tf.contrib.rnn.BasicLSTMCell(num_units=16), 'lstm')

    gru1 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(gru_cell1))

    gru2 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(gru_cell2, initial_state=tf.ones(16)))

    gru3 = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(gru_cell3, initial_state=tdb.FromTensor(tf.ones(16))))

    lstm = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
            tdb.Map(tdb.Scalar('int32') >>
                    tdb.Function(tdl.Embedding(128, 8))) >>
            tdb.RNN(lstm_cell))

    with self.test_session():
      gru1.eval('abcde')
      gru2.eval('ABCDE')
      gru3.eval('vghj')
      lstm.eval('123abc')
コード例 #2
0
ファイル: blocks_test.py プロジェクト: yalechang/fold
  def test_hierarchical_rnn(self):
    char_cell = tdl.ScopedLayer(
        tf.contrib.rnn.BasicLSTMCell(num_units=16), 'char_cell')
    word_cell = tdl.ScopedLayer(
        tf.contrib.rnn.BasicLSTMCell(num_units=32), 'word_cell')

    char_lstm = (tdb.InputTransform(lambda s: [ord(c) for c in s]) >>
                 tdb.Map(tdb.Scalar('int32') >>
                         tdb.Function(tdl.Embedding(128, 8))) >>
                 tdb.RNN(char_cell))
    word_lstm = (tdb.Map(char_lstm >> tdb.GetItem(1) >> tdb.Concat()) >>
                 tdb.RNN(word_cell))

    with self.test_session():
      word_lstm.eval(['the', 'cat', 'sat', 'on', 'a', 'mat'])
コード例 #3
0
 def test_compiler_input_tensor(self):
     input_tensor = tf.Variable(['foobar', 'baz'],
                                dtype=tf.string,
                                name='input_variable')
     init_op = tf.global_variables_initializer()
     root_block = tdb.InputTransform(len) >> tdb.Scalar()
     compiler = tdc.Compiler()
     compiler.compile(root_block)
     compiler.init_loom(max_depth=1, input_tensor=input_tensor)
     output_tensor, = compiler.output_tensors
     with self.test_session() as sess:
         sess.run(init_op)
         results = sess.run(output_tensor)
         self.assertEqual(len(results), 2)
         self.assertEqual(results[0], 6.)
         self.assertEqual(results[1], 3.)
         sess.run(input_tensor.assign(['foo', 'blah']))
         results = sess.run(output_tensor)
         self.assertEqual(len(results), 2)
         self.assertEqual(results[0], 3.)
         self.assertEqual(results[1], 4.)
コード例 #4
0
ファイル: blocks_test.py プロジェクト: yalechang/fold
 def test_eval_py_object(self):
   block = tdb.InputTransform(str)
   self.assertBuildsConst('42', block, 42)
コード例 #5
0
ファイル: blocks_test.py プロジェクト: yalechang/fold
 def test_input_transform_const(self):
   block = tdb.Map(tdb.InputTransform(lambda x: 1 + ord(x) - ord('a')) >>
                   tdb.Scalar('int32'))
   self.assertBuildsConst([1, 2, 3, 4], block, 'abcd')
コード例 #6
0
ファイル: blocks_test.py プロジェクト: yalechang/fold
 def test_input_transform(self):
   block = tdb.Map(tdb.InputTransform(lambda x: 1 + ord(x) - ord('a')) >>
                   tdb.Scalar('int32') >> tdb.Function(tf.negative))
   self.assertBuilds([-1, -2, -3, -4], block, 'abcd')
コード例 #7
0
ファイル: blocks_test.py プロジェクト: yalechang/fold
  def test_repr(self):
    goldens = {
        tdb.Tensor([]): '<td.Tensor dtype=\'float32\' shape=()>',
        tdb.Tensor([1, 2], 'int32', name='foo'):
        '<td.Tensor \'foo\' dtype=\'int32\' shape=(1, 2)>',

        tdb.Scalar('int64'): '<td.Scalar dtype=\'int64\'>',

        tdb.Vector(42): '<td.Vector dtype=\'float32\' size=42>',

        tdb.FromTensor(tf.zeros(3)): '<td.FromTensor \'zeros:0\'>',

        tdb.Function(tf.negative,
                     name='foo'): '<td.Function \'foo\' tf_fn=\'negative\'>',

        tdb.Identity(): '<td.Identity>',
        tdb.Identity('foo'): '<td.Identity \'foo\'>',

        tdb.InputTransform(ord): '<td.InputTransform py_fn=\'ord\'>',

        tdb.SerializedMessageToTree('foo'):
        '<td.SerializedMessageToTree \'foo\' '
        'py_fn=\'serialized_message_to_tree\'>',

        tdb.GetItem(3, 'mu'): '<td.GetItem \'mu\' key=3>',

        tdb.Length(): '<td.Length dtype=\'float32\'>',

        tdb.Slice(stop=2): '<td.Slice key=slice(None, 2, None)>',
        tdb.Slice(stop=2, name='x'):
        '<td.Slice \'x\' key=slice(None, 2, None)>',

        tdb.ForwardDeclaration(name='foo')():
        '<td.ForwardDeclaration() \'foo\'>',

        tdb.Composition(name='x').input: '<td.Composition.input \'x\'>',
        tdb.Composition(name='x').output: '<td.Composition.output \'x\'>',
        tdb.Composition(name='x'): '<td.Composition \'x\'>',

        tdb.Pipe(): '<td.Pipe>',
        tdb.Pipe(tdb.Scalar(), tdb.Identity()): '<td.Pipe>',

        tdb.Record({}, name='x'): '<td.Record \'x\' ordered=False>',
        tdb.Record((), name='x'): '<td.Record \'x\' ordered=True>',

        tdb.AllOf(): '<td.AllOf>',
        tdb.AllOf(tdb.Identity()): '<td.AllOf>',
        tdb.AllOf(tdb.Identity(), tdb.Identity()): '<td.AllOf>',

        tdb.AllOf(name='x'): '<td.AllOf \'x\'>',
        tdb.AllOf(tdb.Identity(), name='x'): '<td.AllOf \'x\'>',
        tdb.AllOf(tdb.Identity(), tdb.Identity(), name='x'): '<td.AllOf \'x\'>',

        tdb.Map(tdb.Scalar(), name='x'):
        '<td.Map \'x\' element_block=<td.Scalar dtype=\'float32\'>>',

        tdb.Fold(tdb.Function(tf.add), tf.ones([]), name='x'):
        '<td.Fold \'x\' combine_block=<td.Function tf_fn=\'add\'> '
        'start_block=<td.FromTensor \'ones:0\'>>',

        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8))):
        '<td.RNN>',
        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8)), name='x'):
        '<td.RNN \'x\'>',
        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8)),
                initial_state=tf.ones(8)):
        '<td.RNN>',
        tdb.RNN(tdl.ScopedLayer(tf.contrib.rnn.GRUCell(num_units=8)),
                initial_state=tf.ones(8), name='x'):
        '<td.RNN \'x\'>',

        tdb.Reduce(tdb.Function(tf.add), name='x'):
        '<td.Reduce \'x\' combine_block=<td.Function tf_fn=\'add\'>>',

        tdb.Sum(name='foo'):
        '<td.Sum \'foo\' combine_block=<td.Function tf_fn=\'add\'>>',

        tdb.Min(name='foo'):
        '<td.Min \'foo\' combine_block=<td.Function tf_fn=\'minimum\'>>',

        tdb.Max(name='foo'):
        '<td.Max \'foo\' combine_block=<td.Function tf_fn=\'maximum\'>>',

        tdb.Mean(name='foo'): '<td.Mean \'foo\'>',

        tdb.OneOf(ord, (tdb.Scalar(), tdb.Scalar()), name='x'):
        '<td.OneOf \'x\'>',

        tdb.Optional(tdb.Scalar(), name='foo'):
        '<td.Optional \'foo\' some_case_block=<td.Scalar dtype=\'float32\'>>',

        tdb.Concat(1, True, 'x'):
        '<td.Concat \'x\' concat_dim=1 flatten=True>',

        tdb.Broadcast(name='x'): '<td.Broadcast \'x\'>',

        tdb.Zip(name='x'): '<td.Zip \'x\'>',

        tdb.NGrams(n=42, name='x'): '<td.NGrams \'x\' n=42>',

        tdb.OneHot(2, 3, name='x'):
        '<td.OneHot \'x\' dtype=\'float32\' start=2 stop=3>',
        tdb.OneHot(3): '<td.OneHot dtype=\'float32\' start=0 stop=3>',

        tdb.OneHotFromList(['a', 'b']): '<td.OneHotFromList>',
        tdb.OneHotFromList(['a', 'b'], name='foo'):
        '<td.OneHotFromList \'foo\'>',

        tdb.Nth(name='x'): '<td.Nth \'x\'>',

        tdb.Zeros([], 'x'): '<td.Zeros \'x\'>',

        tdb.Void(): '<td.Void>',
        tdb.Void('foo'): '<td.Void \'foo\'>',

        tdm.Metric('foo'): '<td.Metric \'foo\'>'}
    for block, expected_repr in sorted(six.iteritems(goldens),
                                       key=lambda kv: kv[1]):
      self.assertEqual(repr(block), expected_repr)