def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous): cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) return seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, num_decoder_symbols, embedding_size=2, feed_previous=feed_previous)
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous): cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) return seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, num_decoder_symbols, embedding_size=2, feed_previous=feed_previous)
def testEmbeddingTiedRNNSeq2Seq(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): enc_inp = [ constant_op.constant( 1, dtypes.int32, shape=[2]) for i in range(2) ] dec_inp = [ constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 5), res[0].shape) res = sess.run([mem]) self.assertEqual((2, 2), res[0].c.shape) self.assertEqual((2, 2), res[0].h.shape) # Test when num_decoder_symbols is provided, the size of decoder output # is num_decoder_symbols. with variable_scope.variable_scope("decoder_symbols_seq2seq"): dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, num_symbols=5, num_decoder_symbols=3, embedding_size=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 3), res[0].shape) # Test externally provided output projection. w = variable_scope.get_variable("proj_w", [2, 5]) b = variable_scope.get_variable("proj_b", [5]) with variable_scope.variable_scope("proj_seq2seq"): dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2, output_projection=(w, b)) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 2), res[0].shape) # Test that previous-feeding model ignores inputs after the first. dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3 with variable_scope.variable_scope("other"): d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, cell, num_symbols=5, embedding_size=2, feed_previous=constant_op.constant(True)) sess.run([variables.global_variables_initializer()]) variable_scope.get_variable_scope().reuse_variables() d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2, feed_previous=True) d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, cell, num_symbols=5, embedding_size=2, feed_previous=True) res1 = sess.run(d1) res2 = sess.run(d2) res3 = sess.run(d3) self.assertAllClose(res1, res2) self.assertAllClose(res1, res3)
def testEmbeddingTiedRNNSeq2Seq(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): enc_inp = [ constant_op.constant( 1, dtypes.int32, shape=[2]) for i in range(2) ] dec_inp = [ constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] cell = functools.partial( core_rnn_cell_impl.BasicLSTMCell, 2, state_is_tuple=True) dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 5), res[0].shape) res = sess.run([mem]) self.assertEqual((2, 2), res[0].c.shape) self.assertEqual((2, 2), res[0].h.shape) # Test when num_decoder_symbols is provided, the size of decoder output # is num_decoder_symbols. with variable_scope.variable_scope("decoder_symbols_seq2seq"): dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell(), num_symbols=5, num_decoder_symbols=3, embedding_size=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 3), res[0].shape) # Test externally provided output projection. w = variable_scope.get_variable("proj_w", [2, 5]) b = variable_scope.get_variable("proj_b", [5]) with variable_scope.variable_scope("proj_seq2seq"): dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2, output_projection=(w, b)) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 2), res[0].shape) # Test that previous-feeding model ignores inputs after the first. dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3 with variable_scope.variable_scope("other"): d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, cell(), num_symbols=5, embedding_size=2, feed_previous=constant_op.constant(True)) sess.run([variables.global_variables_initializer()]) variable_scope.get_variable_scope().reuse_variables() d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2, feed_previous=True) d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, cell(), num_symbols=5, embedding_size=2, feed_previous=True) res1 = sess.run(d1) res2 = sess.run(d2) res3 = sess.run(d3) self.assertAllClose(res1, res2) self.assertAllClose(res1, res3)