Example #1
0
 def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
   cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
   return seq2seq_lib.embedding_tied_rnn_seq2seq(
       enc_inp,
       dec_inp,
       cell,
       num_decoder_symbols,
       embedding_size=2,
       feed_previous=feed_previous)
Example #2
0
 def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
   cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
   return seq2seq_lib.embedding_tied_rnn_seq2seq(
       enc_inp,
       dec_inp,
       cell,
       num_decoder_symbols,
       embedding_size=2,
       feed_previous=feed_previous)
Example #3
0
  def testEmbeddingTiedRNNSeq2Seq(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        enc_inp = [
            constant_op.constant(
                1, dtypes.int32, shape=[2]) for i in range(2)
        ]
        dec_inp = [
            constant_op.constant(
                i, dtypes.int32, shape=[2]) for i in range(3)
        ]
        cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
        dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
            enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 5), res[0].shape)

        res = sess.run([mem])
        self.assertEqual((2, 2), res[0].c.shape)
        self.assertEqual((2, 2), res[0].h.shape)

        # Test when num_decoder_symbols is provided, the size of decoder output
        # is num_decoder_symbols.
        with variable_scope.variable_scope("decoder_symbols_seq2seq"):
          dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
              enc_inp,
              dec_inp,
              cell,
              num_symbols=5,
              num_decoder_symbols=3,
              embedding_size=2)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 3), res[0].shape)

        # Test externally provided output projection.
        w = variable_scope.get_variable("proj_w", [2, 5])
        b = variable_scope.get_variable("proj_b", [5])
        with variable_scope.variable_scope("proj_seq2seq"):
          dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
              enc_inp,
              dec_inp,
              cell,
              num_symbols=5,
              embedding_size=2,
              output_projection=(w, b))
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 2), res[0].shape)

        # Test that previous-feeding model ignores inputs after the first.
        dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3
        with variable_scope.variable_scope("other"):
          d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
              enc_inp,
              dec_inp2,
              cell,
              num_symbols=5,
              embedding_size=2,
              feed_previous=constant_op.constant(True))
        sess.run([variables.global_variables_initializer()])
        variable_scope.get_variable_scope().reuse_variables()
        d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
            enc_inp,
            dec_inp,
            cell,
            num_symbols=5,
            embedding_size=2,
            feed_previous=True)
        d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
            enc_inp,
            dec_inp2,
            cell,
            num_symbols=5,
            embedding_size=2,
            feed_previous=True)
        res1 = sess.run(d1)
        res2 = sess.run(d2)
        res3 = sess.run(d3)
        self.assertAllClose(res1, res2)
        self.assertAllClose(res1, res3)
Example #4
0
  def testEmbeddingTiedRNNSeq2Seq(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        enc_inp = [
            constant_op.constant(
                1, dtypes.int32, shape=[2]) for i in range(2)
        ]
        dec_inp = [
            constant_op.constant(
                i, dtypes.int32, shape=[2]) for i in range(3)
        ]
        cell = functools.partial(
            core_rnn_cell_impl.BasicLSTMCell,
            2, state_is_tuple=True)
        dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
            enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 5), res[0].shape)

        res = sess.run([mem])
        self.assertEqual((2, 2), res[0].c.shape)
        self.assertEqual((2, 2), res[0].h.shape)

        # Test when num_decoder_symbols is provided, the size of decoder output
        # is num_decoder_symbols.
        with variable_scope.variable_scope("decoder_symbols_seq2seq"):
          dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
              enc_inp,
              dec_inp,
              cell(),
              num_symbols=5,
              num_decoder_symbols=3,
              embedding_size=2)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 3), res[0].shape)

        # Test externally provided output projection.
        w = variable_scope.get_variable("proj_w", [2, 5])
        b = variable_scope.get_variable("proj_b", [5])
        with variable_scope.variable_scope("proj_seq2seq"):
          dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
              enc_inp,
              dec_inp,
              cell(),
              num_symbols=5,
              embedding_size=2,
              output_projection=(w, b))
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 2), res[0].shape)

        # Test that previous-feeding model ignores inputs after the first.
        dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3
        with variable_scope.variable_scope("other"):
          d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
              enc_inp,
              dec_inp2,
              cell(),
              num_symbols=5,
              embedding_size=2,
              feed_previous=constant_op.constant(True))
        sess.run([variables.global_variables_initializer()])
        variable_scope.get_variable_scope().reuse_variables()
        d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
            enc_inp,
            dec_inp,
            cell(),
            num_symbols=5,
            embedding_size=2,
            feed_previous=True)
        d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
            enc_inp,
            dec_inp2,
            cell(),
            num_symbols=5,
            embedding_size=2,
            feed_previous=True)
        res1 = sess.run(d1)
        res2 = sess.run(d2)
        res3 = sess.run(d3)
        self.assertAllClose(res1, res2)
        self.assertAllClose(res1, res3)