示例#1
0
        def seq2seq_f(lstm_inputs, decoder_inputs, seq_length, do_decode):

            num_hidden = attn_num_layers * attn_num_hidden
            lstm_fw_cell = rnn_cell_impl.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)
            # Backward direction cell
            lstm_bw_cell = rnn_cell_impl.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)

            pre_encoder_inputs, output_state_fw, output_state_bw = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs,
                initial_state_fw=None, initial_state_bw=None,
                dtype=tf.float32, sequence_length=None, scope=None)

            encoder_inputs = [e*f for e,f in zip(pre_encoder_inputs,encoder_masks[:seq_length])]
            top_states = [array_ops.reshape(e, [-1, 1, num_hidden*2])
                    for e in encoder_inputs]
            attention_states = array_ops.concat(top_states, 1)
            initial_state = tf.concat(axis=1, values=[output_state_fw, output_state_bw])
            outputs, _, attention_weights_history = embedding_attention_decoder(
                    decoder_inputs, initial_state, attention_states, cell,
                    num_symbols=target_vocab_size, 
                    embedding_size=target_embedding_size,
                    num_heads=1,
                    output_size=target_vocab_size, 
                    output_projection=None,
                    feed_previous=do_decode,
                    initial_state_attention=False,
                    attn_num_hidden = attn_num_hidden)
            return outputs, attention_weights_history
示例#2
0
 def testBasicLSTMCell(self):
     with self.test_session() as sess:
         with variable_scope.variable_scope(
                 "root", initializer=init_ops.constant_initializer(0.5)):
             x = array_ops.zeros([1, 2])
             m = array_ops.zeros([1, 8])
             cell = rnn_cell_impl.MultiRNNCell([
                 rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
                 for _ in range(2)
             ],
                                               state_is_tuple=False)
             g, out_m = cell(x, m)
             expected_variable_names = [
                 "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
                 rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
                 "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
                 rnn_cell_impl._BIAS_VARIABLE_NAME,
                 "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
                 rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
                 "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
                 rnn_cell_impl._BIAS_VARIABLE_NAME
             ]
             self.assertEqual(expected_variable_names,
                              [v.name for v in cell.trainable_variables])
             self.assertFalse(cell.non_trainable_variables)
             sess.run([variables_lib.global_variables_initializer()])
             res = sess.run([g, out_m], {
                 x.name: np.array([[1., 1.]]),
                 m.name: 0.1 * np.ones([1, 8])
             })
             self.assertEqual(len(res), 2)
             variables = variables_lib.global_variables()
             self.assertEqual(expected_variable_names,
                              [v.name for v in variables])
             # The numbers in results were not calculated, this is just a smoke test.
             self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
             expected_mem = np.array([[
                 0.68967271, 0.68967271, 0.44848421, 0.44848421, 0.39897051,
                 0.39897051, 0.24024698, 0.24024698
             ]])
             self.assertAllClose(res[1], expected_mem)
         with variable_scope.variable_scope(
                 "other", initializer=init_ops.constant_initializer(0.5)):
             x = array_ops.zeros(
                 [1, 3])  # Test BasicLSTMCell with input_size != num_units.
             m = array_ops.zeros([1, 4])
             g, out_m = rnn_cell_impl.BasicLSTMCell(2,
                                                    state_is_tuple=False)(x,
                                                                          m)
             sess.run([variables_lib.global_variables_initializer()])
             res = sess.run(
                 [g, out_m], {
                     x.name: np.array([[1., 1., 1.]]),
                     m.name: 0.1 * np.ones([1, 4])
                 })
             self.assertEqual(len(res), 2)
示例#3
0
文件: main.py 项目: wolfhu/DARNN
def RNN(encoder_input, decoder_input, weights, biases,
        encoder_attention_states, n_input_encoder, n_steps_encoder,
        n_hidden_encoder, n_input_decoder, n_steps_decoder, n_hidden_decoder):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Prepare data for encoder
    # Permuting batch_size and n_steps
    encoder_input = tf.transpose(encoder_input, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    encoder_input = tf.reshape(encoder_input, [-1, n_input_encoder])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    encoder_input = tf.split(encoder_input, n_steps_encoder, 0)

    # Prepare data for decoder
    # Permuting batch_size and n_steps
    decoder_input = tf.transpose(decoder_input, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    decoder_input = tf.reshape(decoder_input, [-1, n_input_decoder])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    decoder_input = tf.split(decoder_input, n_steps_decoder, 0)

    # Encoder.
    with tf.variable_scope('encoder') as scope:
        encoder_cell = rnn_cell.BasicLSTMCell(n_hidden_encoder,
                                              forget_bias=1.0)
        encoder_outputs, encoder_state, attn_weights = attention_encoder.attention_encoder(
            encoder_input, encoder_attention_states, encoder_cell)

    # First calculate a concatenation of encoder outputs to put attention on.
    top_states = [
        tf.reshape(e, [-1, 1, encoder_cell.output_size])
        for e in encoder_outputs
    ]
    attention_states = tf.concat(top_states, 1)

    with tf.variable_scope('decoder') as scope:
        decoder_cell = rnn_cell.BasicLSTMCell(n_hidden_decoder,
                                              forget_bias=1.0)
        outputs, states = seq2seq.attention_decoder(decoder_input,
                                                    encoder_state,
                                                    attention_states,
                                                    decoder_cell)

    return tf.matmul(outputs[-1],
                     weights['out1']) + biases['out1'], attn_weights
示例#4
0
 def testBasicLSTMCellWithStateTuple(self):
     with self.test_session() as sess:
         with variable_scope.variable_scope(
                 "root", initializer=init_ops.constant_initializer(0.5)):
             x = array_ops.zeros([1, 2])
             m0 = array_ops.zeros([1, 4])
             m1 = array_ops.zeros([1, 4])
             cell = rnn_cell_impl.MultiRNNCell([
                 rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
                 for _ in range(2)
             ],
                                               state_is_tuple=True)
             g, (out_m0, out_m1) = cell(x, (m0, m1))
             sess.run([variables_lib.global_variables_initializer()])
             res = sess.run(
                 [g, out_m0, out_m1], {
                     x.name: np.array([[1., 1.]]),
                     m0.name: 0.1 * np.ones([1, 4]),
                     m1.name: 0.1 * np.ones([1, 4])
                 })
             self.assertEqual(len(res), 3)
             # The numbers in results were not calculated, this is just a smoke test.
             # Note, however, these values should match the original
             # version having state_is_tuple=False.
             self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
             expected_mem0 = np.array(
                 [[0.68967271, 0.68967271, 0.44848421, 0.44848421]])
             expected_mem1 = np.array(
                 [[0.39897051, 0.39897051, 0.24024698, 0.24024698]])
             self.assertAllClose(res[1], expected_mem0)
             self.assertAllClose(res[2], expected_mem1)
示例#5
0
  def testBasicLSTMCellStateTupleType(self):
    with self.test_session():
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        x = array_ops.zeros([1, 2])
        m0 = (array_ops.zeros([1, 2]),) * 2
        m1 = (array_ops.zeros([1, 2]),) * 2
        cell = rnn_cell_impl.MultiRNNCell(
            [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
            state_is_tuple=True)
        self.assertTrue(isinstance(cell.state_size, tuple))
        self.assertTrue(
            isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple))
        self.assertTrue(
            isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple))

        # Pass in regular tuples
        _, (out_m0, out_m1) = cell(x, (m0, m1))
        self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
        self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))

        # Pass in LSTMStateTuples
        variable_scope.get_variable_scope().reuse_variables()
        zero_state = cell.zero_state(1, dtypes.float32)
        self.assertTrue(isinstance(zero_state, tuple))
        self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple))
        self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple))
        _, (out_m0, out_m1) = cell(x, zero_state)
        self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
        self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
示例#6
0
  def testRNNCellSerialization(self):
    for cell in [
        rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
        rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
        rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
        rnn_cell_impl.GRUCell(32, dtype=dtypes.float32)
    ]:
      with self.cached_session():
        x = keras.Input((None, 5))
        layer = keras.layers.RNN(cell)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(optimizer="rmsprop", loss="mse")

        # Test basic case serialization.
        x_np = np.random.random((6, 5, 5))
        y_np = model.predict(x_np)
        weights = model.get_weights()
        config = layer.get_config()
        # The custom_objects is important here since rnn_cell_impl is
        # not visible as a Keras layer, and also has a name conflict with
        # keras.LSTMCell and GRUCell.
        layer = keras.layers.RNN.from_config(
            config,
            custom_objects={
                "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
                "GRUCell": rnn_cell_impl.GRUCell,
                "LSTMCell": rnn_cell_impl.LSTMCell,
                "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
            })
        y = layer(x)
        model = keras.models.Model(x, y)
        model.set_weights(weights)
        y_np_2 = model.predict(x_np)
        self.assertAllClose(y_np, y_np_2, atol=1e-4)
示例#7
0
def _build_multi_lstm_cell(num_units,
                           num_layers,
                           train_test_predict,
                           keep_prob=1.0):
    cell = rnn_cell_impl.BasicLSTMCell(
        num_units, reuse=not (train_test_predict == 'train'))
    if train_test_predict == 'train' and keep_prob < 1.0:
        cell = rnn_cell_impl.DropoutWrapper(cell, output_keep_prob=keep_prob)
    cells = [cell for _ in range(num_layers)]
    return rnn_cell_impl.MultiRNNCell(cells)
示例#8
0
 def testEmbeddingWrapperWithDynamicRnn(self):
     with self.test_session() as sess:
         with variable_scope.variable_scope("root"):
             inputs = ops.convert_to_tensor([[[0], [0]]],
                                            dtype=dtypes.int64)
             input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
             embedding_cell = contrib_rnn.EmbeddingWrapper(
                 rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True),
                 embedding_classes=1,
                 embedding_size=2)
             outputs, _ = rnn.dynamic_rnn(cell=embedding_cell,
                                          inputs=inputs,
                                          sequence_length=input_lengths,
                                          dtype=dtypes.float32)
             sess.run([variables_lib.global_variables_initializer()])
             # This will fail if output's dtype is inferred from input's.
             sess.run(outputs)
示例#9
0
 def testBasicLSTMCellStateSizeError(self):
   """Tests that state_size must be num_units * 2."""
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       num_units = 2
       state_size = num_units * 3  # state_size must be num_units * 2
       batch_size = 3
       input_size = 4
       x = array_ops.zeros([batch_size, input_size])
       m = array_ops.zeros([batch_size, state_size])
       with self.assertRaises(ValueError):
         g, out_m = rnn_cell_impl.BasicLSTMCell(
             num_units, state_is_tuple=False)(x, m)
         sess.run([variables_lib.global_variables_initializer()])
         sess.run([g, out_m],
                  {x.name: 1 * np.ones([batch_size, input_size]),
                   m.name: 0.1 * np.ones([batch_size, state_size])})
示例#10
0
  def testBasicLSTMCellInterchangeWithLSTMCell(self):
    with self.test_session(graph=ops_lib.Graph()) as sess:
      basic_cell = rnn_cell_impl.BasicLSTMCell(1)
      basic_cell(array_ops.ones([1, 1]),
                 state=basic_cell.zero_state(batch_size=1,
                                             dtype=dtypes.float32))
      self.evaluate([v.initializer for v in basic_cell.variables])
      self.evaluate(basic_cell._bias.assign([10.] * 4))
      save = saver.Saver()
      prefix = os.path.join(self.get_temp_dir(), "ckpt")
      save_path = save.save(sess, prefix)

    with self.test_session(graph=ops_lib.Graph()) as sess:
      lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
      lstm_cell(array_ops.ones([1, 1]),
                state=lstm_cell.zero_state(batch_size=1,
                                           dtype=dtypes.float32))
      self.evaluate([v.initializer for v in lstm_cell.variables])
      save = saver.Saver()
      save.restore(sess, save_path)
      self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
示例#11
0
 def testBasicLSTMCell(self):
     for dtype in [dtypes.float16, dtypes.float32]:
         np_dtype = dtype.as_numpy_dtype
         with self.test_session(graph=ops.Graph()) as sess:
             with variable_scope.variable_scope(
                     "root",
                     initializer=init_ops.constant_initializer(0.5)):
                 x = array_ops.zeros([1, 2], dtype=dtype)
                 m = array_ops.zeros([1, 8], dtype=dtype)
                 cell = rnn_cell_impl.MultiRNNCell([
                     rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
                     for _ in range(2)
                 ],
                                                   state_is_tuple=False)
                 self.assertEqual(cell.dtype, None)
                 g, out_m = cell(x, m)
                 # Layer infers the input type.
                 self.assertEqual(cell.dtype, dtype.name)
                 expected_variable_names = [
                     "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
                     rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
                     "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
                     rnn_cell_impl._BIAS_VARIABLE_NAME,
                     "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
                     rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
                     "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
                     rnn_cell_impl._BIAS_VARIABLE_NAME
                 ]
                 self.assertEqual(
                     expected_variable_names,
                     [v.name for v in cell.trainable_variables])
                 self.assertFalse(cell.non_trainable_variables)
                 sess.run([variables_lib.global_variables_initializer()])
                 res = sess.run(
                     [g, out_m], {
                         x.name: np.array([[1., 1.]]),
                         m.name: 0.1 * np.ones([1, 8])
                     })
                 self.assertEqual(len(res), 2)
                 variables = variables_lib.global_variables()
                 self.assertEqual(expected_variable_names,
                                  [v.name for v in variables])
                 # The numbers in results were not calculated, this is just a
                 # smoke test.
                 self.assertAllClose(
                     res[0], np.array([[0.240, 0.240]], dtype=np_dtype),
                     1e-2)
                 expected_mem = np.array([[
                     0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240
                 ]],
                                         dtype=np_dtype)
                 self.assertAllClose(res[1], expected_mem, 1e-2)
             with variable_scope.variable_scope(
                     "other",
                     initializer=init_ops.constant_initializer(0.5)):
                 # Test BasicLSTMCell with input_size != num_units.
                 x = array_ops.zeros([1, 3], dtype=dtype)
                 m = array_ops.zeros([1, 4], dtype=dtype)
                 g, out_m = rnn_cell_impl.BasicLSTMCell(
                     2, state_is_tuple=False)(x, m)
                 sess.run([variables_lib.global_variables_initializer()])
                 res = sess.run(
                     [g, out_m], {
                         x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
                         m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)
                     })
                 self.assertEqual(len(res), 2)
# -*- coding: utf-8 -*-
"""
@author: Daniel
@contact: [email protected]
@file: tf_zero_state_learn.py
@time: 2017/7/20 14:18
"""
import tensorflow as tf
from tensorflow.python.ops import rnn_cell_impl

batch_size = tf.placeholder(tf.int32, 5)
cell = rnn_cell_impl.BasicLSTMCell(128)
initial_state = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
示例#13
0
def TMCAModel(input_x,
              input_pre_y,
              data_info,
              n_output_class,
              n_output_embed,
              n_steps_encoder,
              n_steps_decoder,
              n_hidden_encoder,
              n_hidden_decoder,
              ones_matrix,
              frag_mode=[True, True]):
    """
    Complete TMCA model.
    
    :param input_x: Encoder input. Tensors.
        n_features x [n_batch_size x n_steps_encoder x n_feature_dim]
    :param input_pre_y: Decoder input. Tensor.
        [n_batch_size x n_steps_decoder x 1]
    :param data_info:  Data information.
        A list of dicts like
        [{"name": "user", "format": "embed", "num": 22209, "dim": 60, "active": true,
        "default": false}]
    :param n_output_class: The number of POIs. Int.
    :param n_output_embed: The dimensions of output POI's hidden vector. Int
    :param n_steps_encoder: The number of encoder steps. Int.
    :param n_steps_decoder: The number of decoder steps. Int.
        In our paper, n_steps_decoder = n_steps_encoder + 1.
    :param n_hidden_encoder: The dimensions of encoder LSTM cell states. Int.
    :param n_hidden_decoder: The dimensions of decoder LSTM cell states. Int.
    :param ones_matrix: Auxiliary vector. Tensor.
        [n_batch_size x 1]
    :param frag_mode: [BOOL, BOOL]
    
    :return e_y_pred: The hidden vector for next POI (prediction)
        [n_batch_size x n_output_embed]
    """

    # ==================== Rich Context Incorporation ========================
    initializer = tf.random_normal_initializer(0, 0.01)

    encoder_input, embeddings_set = get_encoder_input(input_x, data_info,
                                                      n_steps_encoder,
                                                      ones_matrix, initializer)

    decoder_input, w_y_prev = get_decoder_input(input_pre_y, n_steps_decoder,
                                                n_output_class, n_output_embed,
                                                initializer)
    embeddings_set.append(w_y_prev)

    # ====================== Encoder and Decoder ==============================
    if not (frag_mode[0] and frag_mode[1]):

        batch_size = tf.shape(encoder_input[0][0])[0]
        initial_state_size = tf.stack([batch_size, n_hidden_encoder])
        initial_state = [
            tf.zeros(initial_state_size, dtype=tf.float32) for _ in range(2)
        ]

    # ============================= Encoder ===================================
    with tf.variable_scope('Encoder'):
        encoder_cell = rnn_cell.BasicLSTMCell(n_hidden_encoder,
                                              forget_bias=0.0)

        if frag_mode[0]:
            # Multi-context attention
            print("We will use multi-level context attention for encoder.")
            encoder_attention_states = [
                tf.transpose(inp, [1, 2, 0]) for inp in encoder_input
            ]
            encoder_outputs, encoder_state, context_attns, feature_attns = \
                AttentionEncoder.attention_encoder(encoder_input, encoder_attention_states, encoder_cell)

        else:
            # Not use multi-context attention
            # Mean attention for encoder input
            print("We will use mean attention for encoder.")
            n_feature = len(encoder_input)
            alpha_f = 1 / n_feature
            x_index = 0
            for feature_idx, feature in enumerate(data_info):
                if feature["active"]:
                    if feature["format"] != "float":
                        alpha_c = 1 / feature["dim"] * alpha_f
                        encoder_input[x_index] = [
                            inp_ * alpha_c for inp_ in encoder_input[x_index]
                        ]
                    x_index += 1
            encoder_input = [
                tf.concat([inp[si] for inp in encoder_input], 1)
                for si in range(n_steps_encoder)
            ]

            # encoder
            encoder_outputs = []
            state_ = initial_state
            for step_idx, inp_ in enumerate(encoder_input):
                if step_idx > 0:
                    tf.get_variable_scope().reuse_variables()
                outp_, state_ = encoder_cell(inp_, state_)
                encoder_outputs.append(outp_)

    # ============================= Decoder ======================================
    with tf.variable_scope('Decoder'):

        decoder_cell = rnn_cell.BasicLSTMCell(n_hidden_decoder,
                                              forget_bias=0.0)

        if frag_mode[1]:

            # Temporal attention
            decoder_attention_states = tf.concat([
                tf.reshape(h_, [-1, 1, encoder_cell.output_size])
                for h_ in encoder_outputs
            ], 1)
            print("We will use temporal attention for decoder.")
            decoder_outputs, decoder_state, temporal_attns = \
                AttentionDecoder.attention_decoder(decoder_input, decoder_attention_states, decoder_cell,
                                                   output_size=n_output_embed)

        else:

            print("We will use mean attention for decoder.")
            # Not use temporal attention
            # Mean attention for decoder input
            H = tf.concat([
                tf.reshape(h_, [-1, n_hidden_encoder, 1])
                for h_ in encoder_outputs
            ], 2)
            h_tilde = tf.reduce_mean(H, 2)

            # decoder
            decoder_outputs = []
            state_ = initial_state
            for step_idx, inp_ in enumerate(decoder_input):
                if step_idx > 0:
                    tf.get_variable_scope().reuse_variables()
                inp = tf.concat([inp_, h_tilde], 1)
                outp_, state_ = decoder_cell(inp, state_)
                decoder_outputs.append(outp_)

    e_y_pred = decoder_outputs[-1]  # shape: (batch_size, embed_dim)
    # ===========================================================================

    return e_y_pred
示例#14
0
    def __init__(self, encoder_masks, encoder_inputs_tensor, 
            decoder_inputs,
            target_weights,
            target_vocab_size, 
            buckets,
            target_embedding_size,
            attn_num_layers,
            attn_num_hidden,
            forward_only,
            use_gru):
        """Create the model.

        Args:
          source_vocab_size: size of the source vocabulary.
          target_vocab_size: size of the target vocabulary.
          buckets: a list of pairs (I, O), where I specifies maximum input length
            that will be processed in that bucket, and O specifies maximum output
            length. Training instances that have inputs longer than I or outputs
            longer than O will be pushed to the next bucket and padded accordingly.
            We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
          size: number of units in each layer of the model.
          num_layers: number of layers in the model.
          max_gradient_norm: gradients will be clipped to maximally this norm.
          learning_rate: learning rate to start with.
          learning_rate_decay_factor: decay learning rate by this much when needed.
          use_lstm: if true, we use LSTM cells instead of GRU cells.
          num_samples: number of samples for sampled softmax.
          forward_only: if set, we do not construct the backward pass in the model.
        """
        self.encoder_inputs_tensor = encoder_inputs_tensor
        self.decoder_inputs = decoder_inputs
        self.target_weights = target_weights
        self.target_vocab_size = target_vocab_size
        self.buckets = buckets
        self.encoder_masks = encoder_masks

        # Create the internal multi-layer cell for our RNN
        single_cell = rnn_cell_impl.BasicLSTMCell(attn_num_hidden, forget_bias=0.0, state_is_tuple=False)
        if use_gru:
            print("using GRU CELL in decoder")
            single_cell = rnn_cell_impl.GRUCell(attn_num_hidden)
        cell = single_cell

        if attn_num_layers > 1:
            cell = rnn_cell_impl.MultiRNNCell([single_cell] * attn_num_layers, state_is_tuple=False)

        # The seq2seq function: we use embedding for the input and attention.
        def seq2seq_f(lstm_inputs, decoder_inputs, seq_length, do_decode):

            num_hidden = attn_num_layers * attn_num_hidden
            lstm_fw_cell = rnn_cell_impl.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)
            # Backward direction cell
            lstm_bw_cell = rnn_cell_impl.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)

            pre_encoder_inputs, output_state_fw, output_state_bw = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs,
                initial_state_fw=None, initial_state_bw=None,
                dtype=tf.float32, sequence_length=None, scope=None)

            encoder_inputs = [e*f for e,f in zip(pre_encoder_inputs,encoder_masks[:seq_length])]
            top_states = [array_ops.reshape(e, [-1, 1, num_hidden*2])
                    for e in encoder_inputs]
            attention_states = array_ops.concat(top_states, 1)
            initial_state = tf.concat(axis=1, values=[output_state_fw, output_state_bw])
            outputs, _, attention_weights_history = embedding_attention_decoder(
                    decoder_inputs, initial_state, attention_states, cell,
                    num_symbols=target_vocab_size, 
                    embedding_size=target_embedding_size,
                    num_heads=1,
                    output_size=target_vocab_size, 
                    output_projection=None,
                    feed_previous=do_decode,
                    initial_state_attention=False,
                    attn_num_hidden = attn_num_hidden)
            return outputs, attention_weights_history

        # Our targets are decoder inputs shifted by one.
        targets = [decoder_inputs[i + 1]
                for i in xrange(len(decoder_inputs) - 1)]

        softmax_loss_function = None # default to tf.nn.sparse_softmax_cross_entropy_with_logits

        # Training outputs and losses.
        if forward_only:
            self.outputs, self.losses, self.attention_weights_histories = model_with_buckets(
                    encoder_inputs_tensor, decoder_inputs, targets,
                    self.target_weights, buckets, lambda x, y, z: seq2seq_f(x, y, z, True),
                    softmax_loss_function=softmax_loss_function)
        else:
            self.outputs, self.losses, self.attention_weights_histories = model_with_buckets(
                    encoder_inputs_tensor, decoder_inputs, targets,
                    self.target_weights, buckets, lambda x, y, z: seq2seq_f(x, y, z, False),
                    softmax_loss_function=softmax_loss_function)