Esempio n. 1
0
 def testSequentialEncoder(self, transition_layer_fn):
   inputs = tf.zeros([3, 5, 10])
   encoder = encoders.SequentialEncoder(
       [DenseEncoder(1, 20), DenseEncoder(3, 20)],
       transition_layer_fn=transition_layer_fn)
   outputs, states, _ = encoder(inputs)
   self.assertEqual(len(states), 4)
   outputs = self.evaluate(outputs)
   self.assertAllEqual(outputs.shape, [3, 5, 20])
Esempio n. 2
0
 def testSequentialEncoder(self):
   sequence_length = [17, 21, 20]
   inputs = _build_dummy_sequences(sequence_length)
   encoder = encoders.SequentialEncoder([
       encoders.UnidirectionalRNNEncoder(1, 20),
       encoders.PyramidalRNNEncoder(3, 10, reduction_factor=2)])
   _, _, encoded_length = encoder.encode(
       inputs, sequence_length=sequence_length)
   with self.test_session() as sess:
     sess.run(tf.global_variables_initializer())
     encoded_length = sess.run(encoded_length)
     self.assertAllEqual([4, 5, 5], encoded_length)
Esempio n. 3
0
 def testSequentialEncoder(self, transition_layer_fn):
   inputs = tf.zeros([3, 5, 10])
   encoder = encoders.SequentialEncoder(
       [DenseEncoder(1, 20), DenseEncoder(3, 20)],
       transition_layer_fn=transition_layer_fn)
   outputs, states, _ = encoder.encode(inputs)
   self.assertEqual(len(states), 4)
   if not compat.is_tf2():
     with self.test_session() as sess:
       sess.run(tf.global_variables_initializer())
   outputs = self.evaluate(outputs)
   self.assertAllEqual(outputs.shape, [3, 5, 20])
Esempio n. 4
0
 def testSequentialEncoder(self):
     sequence_length = [17, 21, 20]
     inputs = _build_dummy_sequences(sequence_length)
     encoder = encoders.SequentialEncoder([
         encoders.UnidirectionalRNNEncoder(1, 20),
         encoders.PyramidalRNNEncoder(3, 10, reduction_factor=2)
     ])
     _, state, encoded_length = encoder.encode(
         inputs, sequence_length=sequence_length)
     self.assertEqual(4, len(state))
     for s in state:
         self.assertIsInstance(s, tf.contrib.rnn.LSTMStateTuple)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         encoded_length = sess.run(encoded_length)
         self.assertAllEqual([4, 5, 5], encoded_length)
Esempio n. 5
0
 def testSequentialEncoderWithTooManyTransitionLayers(self):
     with self.assertRaises(ValueError):
         _ = encoders.SequentialEncoder(
             [DenseEncoder(1, 20), DenseEncoder(3, 20)],
             transition_layer_fn=[tf.identity, tf.identity],
         )