예제 #1
0
 def testBasicLSTMCellWithStateTuple(self):
     with self.test_session() as sess:
         with variable_scope.variable_scope(
                 "root", initializer=init_ops.constant_initializer(0.5)):
             x = array_ops.zeros([1, 2])
             m0 = array_ops.zeros([1, 4])
             m1 = array_ops.zeros([1, 4])
             cell = core_rnn_cell_impl.MultiRNNCell([
                 core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
                 for _ in range(2)
             ],
                                                    state_is_tuple=True)
             g, (out_m0, out_m1) = cell(x, (m0, m1))
             sess.run([variables_lib.global_variables_initializer()])
             res = sess.run(
                 [g, out_m0, out_m1], {
                     x.name: np.array([[1., 1.]]),
                     m0.name: 0.1 * np.ones([1, 4]),
                     m1.name: 0.1 * np.ones([1, 4])
                 })
             self.assertEqual(len(res), 3)
             # The numbers in results were not calculated, this is just a smoke test.
             # Note, however, these values should match the original
             # version having state_is_tuple=False.
             self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
             expected_mem0 = np.array(
                 [[0.68967271, 0.68967271, 0.44848421, 0.44848421]])
             expected_mem1 = np.array(
                 [[0.39897051, 0.39897051, 0.24024698, 0.24024698]])
             self.assertAllClose(res[1], expected_mem0)
             self.assertAllClose(res[2], expected_mem1)
예제 #2
0
def ndlstm_base_unrolled(inputs, noutput, scope=None, reverse=False):
    """Run an LSTM, either forward or backward.

  This is a 1D LSTM implementation using unrolling and the TensorFlow
  LSTM op.

  Args:
    inputs: input sequence (length, batch_size, ninput)
    noutput: depth of output
    scope: optional scope name
    reverse: run LSTM in reverse

  Returns:
    Output sequence (length, batch_size, noutput)

  """
    with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]):
        length, batch_size, _ = _shape(inputs)
        lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput,
                                                     state_is_tuple=False)
        state = array_ops.zeros([batch_size, lstm_cell.state_size])
        output_u = []
        inputs_u = array_ops.unstack(inputs)
        if reverse:
            inputs_u = list(reversed(inputs_u))
        for i in xrange(length):
            if i > 0:
                variable_scope.get_variable_scope().reuse_variables()
            output, state = lstm_cell(inputs_u[i], state)
            output_u += [output]
        if reverse:
            output_u = list(reversed(output_u))
        outputs = array_ops.stack(output_u)
        return outputs
예제 #3
0
def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False):
    """Run an LSTM, either forward or backward.

  This is a 1D LSTM implementation using dynamic_rnn and
  the TensorFlow LSTM op.

  Args:
    inputs: input sequence (length, batch_size, ninput)
    noutput: depth of output
    scope: optional scope name
    reverse: run LSTM in reverse

  Returns:
    Output sequence (length, batch_size, noutput)
  """
    with variable_scope.variable_scope(scope, "SeqLstm", [inputs]):
        # TODO(tmb) make batch size, sequence_length dynamic
        # example: sequence_length = tf.shape(inputs)[0]
        _, batch_size, _ = _shape(inputs)
        lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput,
                                                     state_is_tuple=False)
        state = array_ops.zeros([batch_size, lstm_cell.state_size])
        sequence_length = int(inputs.get_shape()[0])
        sequence_lengths = math_ops.to_int64(
            array_ops.fill([batch_size], sequence_length))
        if reverse:
            inputs = array_ops.reverse_v2(inputs, [0])
        outputs, _ = rnn.dynamic_rnn(lstm_cell,
                                     inputs,
                                     sequence_lengths,
                                     state,
                                     time_major=True)
        if reverse:
            outputs = array_ops.reverse_v2(outputs, [0])
        return outputs
예제 #4
0
 def testAttentionCellWrapperZeros(self):
     num_units = 8
     attn_length = 16
     batch_size = 3
     input_size = 4
     for state_is_tuple in [False, True]:
         with ops.Graph().as_default():
             with self.test_session() as sess:
                 with variable_scope.variable_scope("state_is_tuple_" +
                                                    str(state_is_tuple)):
                     lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
                         num_units, state_is_tuple=state_is_tuple)
                     cell = rnn_cell.AttentionCellWrapper(
                         lstm_cell,
                         attn_length,
                         state_is_tuple=state_is_tuple)
                     if state_is_tuple:
                         zeros = array_ops.zeros([batch_size, num_units],
                                                 dtype=np.float32)
                         attn_state_zeros = array_ops.zeros(
                             [batch_size, attn_length * num_units],
                             dtype=np.float32)
                         zero_state = ((zeros, zeros), zeros,
                                       attn_state_zeros)
                     else:
                         zero_state = array_ops.zeros([
                             batch_size, num_units * 2 +
                             attn_length * num_units + num_units
                         ],
                                                      dtype=np.float32)
                     inputs = array_ops.zeros([batch_size, input_size],
                                              dtype=dtypes.float32)
                     output, state = cell(inputs, zero_state)
                     self.assertEquals(output.get_shape(),
                                       [batch_size, num_units])
                     if state_is_tuple:
                         self.assertEquals(len(state), 3)
                         self.assertEquals(len(state[0]), 2)
                         self.assertEquals(state[0][0].get_shape(),
                                           [batch_size, num_units])
                         self.assertEquals(state[0][1].get_shape(),
                                           [batch_size, num_units])
                         self.assertEquals(state[1].get_shape(),
                                           [batch_size, num_units])
                         self.assertEquals(
                             state[2].get_shape(),
                             [batch_size, attn_length * num_units])
                         tensors = [output] + list(state)
                     else:
                         self.assertEquals(state.get_shape(), [
                             batch_size, num_units * 2 + num_units +
                             attn_length * num_units
                         ])
                         tensors = [output, state]
                     zero_result = sum([
                         math_ops.reduce_sum(math_ops.abs(x))
                         for x in tensors
                     ])
                     sess.run(variables.global_variables_initializer())
                     self.assertTrue(sess.run(zero_result) < 1e-6)
예제 #5
0
def sequence_to_final(inputs, noutput, scope=None, name=None, reverse=False):
    """Run an LSTM across all steps and returns only the final state.

  Args:
    inputs: (length, batch_size, depth) tensor
    noutput: size of output vector
    scope: optional scope name
    name: optional name for output tensor
    reverse: run in reverse

  Returns:
    Batch of size (batch_size, noutput).
  """
    with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]):
        length, batch_size, _ = _shape(inputs)
        lstm = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
        state = array_ops.zeros([batch_size, lstm.state_size])
        inputs_u = array_ops.unstack(inputs)
        if reverse:
            inputs_u = list(reversed(inputs_u))
        for i in xrange(length):
            if i > 0:
                variable_scope.get_variable_scope().reuse_variables()
            output, state = lstm(inputs_u[i], state)
        outputs = array_ops.reshape(output, [batch_size, noutput], name=name)
        return outputs
예제 #6
0
    def testLSTMBasicToBlockCell(self):
        with self.test_session(use_gpu=self._use_gpu) as sess:
            x = array_ops.zeros([1, 2])
            x_values = np.random.randn(1, 2)

            m0_val = 0.1 * np.ones([1, 2])
            m1_val = -0.1 * np.ones([1, 2])
            m2_val = -0.2 * np.ones([1, 2])
            m3_val = 0.2 * np.ones([1, 2])

            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890212)
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                m0 = array_ops.zeros([1, 2])
                m1 = array_ops.zeros([1, 2])
                m2 = array_ops.zeros([1, 2])
                m3 = array_ops.zeros([1, 2])
                g, ((out_m0, out_m1),
                    (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
                        [
                            core_rnn_cell_impl.BasicLSTMCell(
                                2, state_is_tuple=True) for _ in range(2)
                        ],
                        state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
                sess.run([variables.global_variables_initializer()])
                basic_res = sess.run(
                    [g, out_m0, out_m1, out_m2, out_m3], {
                        x.name: x_values,
                        m0.name: m0_val,
                        m1.name: m1_val,
                        m2.name: m2_val,
                        m3.name: m3_val
                    })

            with variable_scope.variable_scope("block",
                                               initializer=initializer):
                m0 = array_ops.zeros([1, 2])
                m1 = array_ops.zeros([1, 2])
                m2 = array_ops.zeros([1, 2])
                m3 = array_ops.zeros([1, 2])
                g, ((out_m0, out_m1),
                    (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
                        [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
                        state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
                sess.run([variables.global_variables_initializer()])
                block_res = sess.run(
                    [g, out_m0, out_m1, out_m2, out_m3], {
                        x.name: x_values,
                        m0.name: m0_val,
                        m1.name: m1_val,
                        m2.name: m2_val,
                        m3.name: m3_val
                    })

            self.assertEqual(len(basic_res), len(block_res))
            for basic, block in zip(basic_res, block_res):
                self.assertAllClose(basic, block)
예제 #7
0
 def testAttentionCellWrapperCorrectResult(self):
   num_units = 4
   attn_length = 6
   batch_size = 2
   expected_output = np.array(
       [[1.068372, 0.45496, -0.678277, 0.340538],
        [1.018088, 0.378983, -0.572179, 0.268591]],
       dtype=np.float32)
   expected_state = np.array(
       [[0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
         0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
         0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
         0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
         0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
         0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
         0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
         0.51843399],
        [0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
         0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
         0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
         0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
         0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
         0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
         0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
         0.70582712]],
       dtype=np.float32)
   seed = 12345
   random_seed.set_random_seed(seed)
   for state_is_tuple in [False, True]:
     with session.Session() as sess:
       with variable_scope.variable_scope(
           "state_is_tuple", reuse=state_is_tuple,
           initializer=init_ops.glorot_uniform_initializer()):
         lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
             num_units, state_is_tuple=state_is_tuple)
         cell = rnn_cell.AttentionCellWrapper(
             lstm_cell, attn_length, state_is_tuple=state_is_tuple)
         zeros1 = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 1)
         zeros2 = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 2)
         zeros3 = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 3)
         attn_state_zeros = random_ops.random_uniform(
             (batch_size, attn_length * num_units), 0.0, 1.0, seed=seed + 4)
         zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros)
         if not state_is_tuple:
           zero_state = array_ops.concat([
               zero_state[0][0], zero_state[0][1], zero_state[1], zero_state[2]
           ], 1)
         inputs = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 5)
         output, state = cell(inputs, zero_state)
         if state_is_tuple:
           state = array_ops.concat(
               [state[0][0], state[0][1], state[1], state[2]], 1)
         sess.run(variables.global_variables_initializer())
         self.assertAllClose(sess.run(output), expected_output)
         self.assertAllClose(sess.run(state), expected_state)
예제 #8
0
 def testBasicLSTMCell(self):
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       x = array_ops.zeros([1, 2])
       m = array_ops.zeros([1, 8])
       g, out_m = core_rnn_cell_impl.MultiRNNCell(
           [core_rnn_cell_impl.BasicLSTMCell(
               2, state_is_tuple=False) for _ in range(2)],
           state_is_tuple=False)(x, m)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run(
           [g, out_m],
           {x.name: np.array([[1., 1.]]),
            m.name: 0.1 * np.ones([1, 8])})
       self.assertEqual(len(res), 2)
       variables = variables_lib.global_variables()
       self.assertEqual(4, len(variables))
       self.assertEquals(variables[0].op.name,
                         "root/multi_rnn_cell/cell_0/basic_lstm_cell/weights")
       self.assertEquals(variables[1].op.name,
                         "root/multi_rnn_cell/cell_0/basic_lstm_cell/biases")
       self.assertEquals(variables[2].op.name,
                         "root/multi_rnn_cell/cell_1/basic_lstm_cell/weights")
       self.assertEquals(variables[3].op.name,
                         "root/multi_rnn_cell/cell_1/basic_lstm_cell/biases")
       # The numbers in results were not calculated, this is just a smoke test.
       self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
       expected_mem = np.array([[
           0.68967271, 0.68967271, 0.44848421, 0.44848421, 0.39897051,
           0.39897051, 0.24024698, 0.24024698
       ]])
       self.assertAllClose(res[1], expected_mem)
     with variable_scope.variable_scope(
         "other", initializer=init_ops.constant_initializer(0.5)):
       x = array_ops.zeros(
           [1, 3])  # Test BasicLSTMCell with input_size != num_units.
       m = array_ops.zeros([1, 4])
       g, out_m = core_rnn_cell_impl.BasicLSTMCell(
           2, state_is_tuple=False)(x, m)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run(
           [g, out_m],
           {x.name: np.array([[1., 1., 1.]]),
            m.name: 0.1 * np.ones([1, 4])})
       self.assertEqual(len(res), 2)
예제 #9
0
    def train(self,
              train_x,
              train_y,
              learning_rate=0.05,
              limit=1000,
              batch_n=1,
              seq_len=3,
              input_n=2,
              hidden_n=5,
              output_n=2):
        # RNN 训练
        # 输入层
        self.input_layer = [
            tf.placeholder("float", [seq_len, input_n]) for i in range(batch_n)
        ]

        # 标签
        self.label_layer = tf.placeholder("float", [seq_len, output_n])

        self.weights = tf.Variable(tf.random_normal([hidden_n, output_n]))
        self.biases = tf.Variable(tf.random_normal([output_n]))

        # self.lstm_cell = rnn_cell.BasicLSTMCell(hidden_n, forget_bias=1.0)
        # 设置 神经元的个数 和 forget_bias 参数
        self.lstm_cell = core_rnn_cell_impl.BasicLSTMCell(hidden_n,
                                                          forget_bias=1.0)

        # outputs, states = rnn.rnn(self.lstm_cell, self.input_layer, dtype=tf.float32)
        outputs, states = tf.contrib.rnn.static_rnn(self.lstm_cell,
                                                    self.input_layer,
                                                    dtype=tf.float32)

        self.prediction = tf.matmul(outputs[-1], self.weights) + self.biases

        # 定义 损失函数
        # self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.prediction, self.label_layer))
        #todo:need fix
        self.loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(
                labels=self.predict(test_x=train_x), logits=self.label_layer))

        # 使用 Adam 优化器进行训练
        self.trainer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.loss)

        initer = tf.initialize_all_variables()

        train_x = train_x.reshape((batch_n, seq_len, input_n))
        train_y = train_y.reshape((seq_len, output_n))

        # run graph
        self.session.run(initer)
        for i in range(limit):
            self.session.run(self.trainer,
                             feed_dict={
                                 self.input_layer[0]: train_x[0],
                                 self.label_layer: train_y
                             })
예제 #10
0
 def testAttentionCellWrapperValues(self):
     num_units = 8
     attn_length = 16
     batch_size = 3
     for state_is_tuple in [False, True]:
         with ops.Graph().as_default():
             with self.test_session() as sess:
                 with variable_scope.variable_scope("state_is_tuple_" +
                                                    str(state_is_tuple)):
                     lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
                         num_units, state_is_tuple=state_is_tuple)
                     cell = rnn_cell.AttentionCellWrapper(
                         lstm_cell,
                         attn_length,
                         state_is_tuple=state_is_tuple)
                     if state_is_tuple:
                         zeros = constant_op.constant(0.1 * np.ones(
                             [batch_size, num_units], dtype=np.float32),
                                                      dtype=dtypes.float32)
                         attn_state_zeros = constant_op.constant(
                             0.1 *
                             np.ones([batch_size, attn_length * num_units],
                                     dtype=np.float32),
                             dtype=dtypes.float32)
                         zero_state = ((zeros, zeros), zeros,
                                       attn_state_zeros)
                     else:
                         zero_state = constant_op.constant(
                             0.1 * np.ones([
                                 batch_size, num_units * 2 + num_units +
                                 attn_length * num_units
                             ],
                                           dtype=np.float32),
                             dtype=dtypes.float32)
                     inputs = constant_op.constant(np.array(
                         [[1., 1., 1., 1.], [2., 2., 2., 2.],
                          [3., 3., 3., 3.]],
                         dtype=np.float32),
                                                   dtype=dtypes.float32)
                     output, state = cell(inputs, zero_state)
                     if state_is_tuple:
                         concat_state = array_ops.concat(
                             [state[0][0], state[0][1], state[1], state[2]],
                             1)
                     else:
                         concat_state = state
                     sess.run(variables.global_variables_initializer())
                     output, state = sess.run([output, concat_state])
                     # Different inputs so different outputs and states
                     for i in range(1, batch_size):
                         self.assertTrue(
                             float(
                                 np.linalg.norm((output[0, :] -
                                                 output[i, :]))) > 1e-6)
                         self.assertTrue(
                             float(
                                 np.linalg.norm((state[0, :] -
                                                 state[i, :]))) > 1e-6)
  def testMultiRNNState(self):
    """Test that state flattening/reconstruction works for `MultiRNNCell`."""
    batch_size = 11
    sequence_length = 16
    train_steps = 5
    cell_sizes = [4, 8, 7]
    learning_rate = 0.1

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.to_float(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length])), 2)
        input_dict = {
            dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform(
                [batch_size, cell_size], seed=((i + 1) * seed))
            for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])
        }
        input_dict['inputs'] = inputs
        return input_dict, labels

      return input_fn

    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
    config = run_config.RunConfig(tf_random_seed=21212)
    cell = core_rnn_cell_impl.MultiRNNCell(
        [core_rnn_cell_impl.BasicLSTMCell(size) for size in cell_sizes])
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        sequence_feature_columns=seq_columns,
        cell_type=cell,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    for i, state_size in enumerate([4, 4, 8, 8, 7, 7]):
      state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)]
      self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
예제 #12
0
 def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
   cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
   return seq2seq_lib.embedding_tied_rnn_seq2seq(
       enc_inp,
       dec_inp,
       cell,
       num_decoder_symbols,
       embedding_size=2,
       feed_previous=feed_previous)
예제 #13
0
 def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
   cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
   return seq2seq_lib.embedding_attention_seq2seq(
       enc_inp,
       dec_inp,
       cell,
       num_encoder_symbols,
       num_decoder_symbols,
       embedding_size=2,
       feed_previous=feed_previous)
 def testEmbeddingWrapperWithDynamicRnn(self):
     with self.test_session() as sess:
         with variable_scope.variable_scope("root"):
             inputs = ops.convert_to_tensor([[[0], [0]]],
                                            dtype=dtypes.int64)
             input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
             embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
                 core_rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True),
                 embedding_classes=1,
                 embedding_size=2)
             outputs, _ = rnn.dynamic_rnn(cell=embedding_cell,
                                          inputs=inputs,
                                          sequence_length=input_lengths,
                                          dtype=dtypes.float32)
             sess.run([variables_lib.global_variables_initializer()])
             # This will fail if output's dtype is inferred from input's.
             sess.run(outputs)
예제 #15
0
 def testBasicLSTMCellStateSizeError(self):
     """Tests that state_size must be num_units * 2."""
     with self.test_session() as sess:
         with variable_scope.variable_scope(
                 "root", initializer=init_ops.constant_initializer(0.5)):
             num_units = 2
             state_size = num_units * 3  # state_size must be num_units * 2
             batch_size = 3
             input_size = 4
             x = array_ops.zeros([batch_size, input_size])
             m = array_ops.zeros([batch_size, state_size])
             with self.assertRaises(ValueError):
                 g, out_m = core_rnn_cell_impl.BasicLSTMCell(
                     num_units, state_is_tuple=False)(x, m)
                 sess.run([variables_lib.global_variables_initializer()])
                 sess.run(
                     [g, out_m], {
                         x.name: 1 * np.ones([batch_size, input_size]),
                         m.name: 0.1 * np.ones([batch_size, state_size])
                     })
    def testBasicLSTMCellStateTupleType(self):
        with self.test_session():
            with variable_scope.variable_scope(
                    "root", initializer=init_ops.constant_initializer(0.5)):
                x = array_ops.zeros([1, 2])
                m0 = (array_ops.zeros([1, 2]), ) * 2
                m1 = (array_ops.zeros([1, 2]), ) * 2
                cell = core_rnn_cell_impl.MultiRNNCell(
                    [core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
                    state_is_tuple=True)
                self.assertTrue(isinstance(cell.state_size, tuple))
                self.assertTrue(
                    isinstance(cell.state_size[0],
                               core_rnn_cell_impl.LSTMStateTuple))
                self.assertTrue(
                    isinstance(cell.state_size[1],
                               core_rnn_cell_impl.LSTMStateTuple))

                # Pass in regular tuples
                _, (out_m0, out_m1) = cell(x, (m0, m1))
                self.assertTrue(
                    isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
                self.assertTrue(
                    isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))

                # Pass in LSTMStateTuples
                variable_scope.get_variable_scope().reuse_variables()
                zero_state = cell.zero_state(1, dtypes.float32)
                self.assertTrue(isinstance(zero_state, tuple))
                self.assertTrue(
                    isinstance(zero_state[0],
                               core_rnn_cell_impl.LSTMStateTuple))
                self.assertTrue(
                    isinstance(zero_state[1],
                               core_rnn_cell_impl.LSTMStateTuple))
                _, (out_m0, out_m1) = cell(x, zero_state)
                self.assertTrue(
                    isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
                self.assertTrue(
                    isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
 def testAttentionCellWrapperCorrectResult(self):
     num_units = 4
     attn_length = 6
     batch_size = 2
     expected_output = np.array([[0.955392, 0.408507, -0.60122, 0.270718],
                                 [0.903681, 0.331165, -0.500238, 0.224052]],
                                dtype=np.float32)
     expected_state = np.array(
         [[
             0.81331915, 0.32036272, 0.28079176, 1.08888793, 0.41264394,
             0.1062041, 0.10444493, 0.32050529, 0.64655536, 0.70794445,
             0.51896095, 0.31809306, 0.58086717, 0.49446869, 0.7641536,
             0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
             0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
             0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
             0.99211812, 0.12295902, 1.01412082, 0.33123279, -0.71114945,
             0.40583119
         ],
          [
              0.59962207, 0.42597458, -0.22491696, 0.98063421, 0.32548007,
              0.11623692, -0.10100613, 0.27708149, 0.76956916, 0.6360054,
              0.51719815, 0.50458527, 0.73000264, 0.66986895, 0.73576689,
              0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
              0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
              0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
              0.36127412, 0.12125921, 0.99780077, 0.31886846, -0.67595094,
              0.56531656
          ]],
         dtype=np.float32)
     seed = 12345
     random_seed.set_random_seed(seed)
     for state_is_tuple in [False, True]:
         with session.Session() as sess:
             with variable_scope.variable_scope("state_is_tuple",
                                                reuse=state_is_tuple):
                 lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
                     num_units, state_is_tuple=state_is_tuple)
                 cell = rnn_cell.AttentionCellWrapper(
                     lstm_cell, attn_length, state_is_tuple=state_is_tuple)
                 zeros1 = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 1)
                 zeros2 = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 2)
                 zeros3 = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 3)
                 attn_state_zeros = random_ops.random_uniform(
                     (batch_size, attn_length * num_units),
                     0.0,
                     1.0,
                     seed=seed + 4)
                 zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros)
                 if not state_is_tuple:
                     zero_state = array_ops.concat_v2([
                         zero_state[0][0], zero_state[0][1], zero_state[1],
                         zero_state[2]
                     ], 1)
                 inputs = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 5)
                 output, state = cell(inputs, zero_state)
                 if state_is_tuple:
                     state = array_ops.concat_v2(
                         [state[0][0], state[0][1], state[1], state[2]], 1)
                 sess.run(variables.global_variables_initializer())
                 self.assertAllClose(sess.run(output), expected_output)
                 self.assertAllClose(sess.run(state), expected_state)
예제 #18
0
  def testLSTMFusedSequenceLengths(self):
    """Verify proper support for sequence lengths in LSTMBlockFusedCell."""
    with self.test_session(use_gpu=self._use_gpu) as sess:
      batch_size = 3
      input_size = 4
      cell_size = 5
      max_sequence_length = 6

      inputs = []
      for _ in range(max_sequence_length):
        inp = ops.convert_to_tensor(
            np.random.randn(batch_size, input_size), dtype=dtypes.float32)
        inputs.append(inp)
      seq_lengths = constant_op.constant([3, 4, 5])

      initializer = init_ops.random_uniform_initializer(
          -0.01, 0.01, seed=19890213)
      with variable_scope.variable_scope("basic", initializer=initializer):
        cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
        outputs, state = core_rnn.static_rnn(
            cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
        sess.run([variables.global_variables_initializer()])
        basic_outputs, basic_state = sess.run([outputs, state[0]])
        basic_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        basic_wgrads = sess.run(
            gradients_impl.gradients(outputs, variables.trainable_variables()))

      with variable_scope.variable_scope("fused", initializer=initializer):
        cell = lstm_ops.LSTMBlockFusedCell(
            cell_size, cell_clip=0, use_peephole=False)
        outputs, state = cell(
            inputs, dtype=dtypes.float32, sequence_length=seq_lengths)

        sess.run([variables.global_variables_initializer()])
        fused_outputs, fused_state = sess.run([outputs, state[0]])
        fused_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        fused_vars = [
            v for v in variables.trainable_variables()
            if v.name.startswith("fused/")
        ]
        fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars))

      self.assertAllClose(basic_outputs, fused_outputs)
      self.assertAllClose(basic_state, fused_state)
      self.assertAllClose(basic_grads, fused_grads)
      for basic, fused in zip(basic_wgrads, fused_wgrads):
        self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)

      # Verify that state propagation works if we turn our sequence into
      # tiny (single-time) subsequences, i.e. unfuse the cell
      with variable_scope.variable_scope(
          "unfused", initializer=initializer) as vs:
        cell = lstm_ops.LSTMBlockFusedCell(
            cell_size, cell_clip=0, use_peephole=False)
        outputs = []
        state = None
        for i, inp in enumerate(inputs):
          lengths = [int(i < l) for l in seq_lengths.eval()]
          output, state = cell(
              [inp],
              initial_state=state,
              dtype=dtypes.float32,
              sequence_length=lengths)
          vs.reuse_variables()
          outputs.append(output[0])
        outputs = array_ops.stack(outputs)

        sess.run([variables.global_variables_initializer()])
        unfused_outputs, unfused_state = sess.run([outputs, state[0]])
        unfused_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        unfused_vars = [
            v for v in variables.trainable_variables()
            if v.name.startswith("unfused/")
        ]
        unfused_wgrads = sess.run(
            gradients_impl.gradients(outputs, unfused_vars))

      self.assertAllClose(basic_outputs, unfused_outputs)
      self.assertAllClose(basic_state, unfused_state)
      self.assertAllClose(basic_grads, unfused_grads)
      for basic, unfused in zip(basic_wgrads, unfused_wgrads):
        self.assertAllClose(basic, unfused, rtol=1e-2, atol=1e-2)
예제 #19
0
  def testLSTMBasicToBlock(self):
    with self.test_session(use_gpu=self._use_gpu) as sess:
      batch_size = 2
      input_size = 3
      cell_size = 4
      sequence_length = 5

      inputs = []
      for _ in range(sequence_length):
        inp = ops.convert_to_tensor(
            np.random.randn(batch_size, input_size), dtype=dtypes.float32)
        inputs.append(inp)

      initializer = init_ops.random_uniform_initializer(
          -0.01, 0.01, seed=19890212)
      with variable_scope.variable_scope("basic", initializer=initializer):
        cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
        outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)

        sess.run([variables.global_variables_initializer()])
        basic_outputs, basic_state = sess.run([outputs, state[0]])
        basic_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        basic_wgrads = sess.run(
            gradients_impl.gradients(outputs, variables.trainable_variables()))

      with variable_scope.variable_scope("block", initializer=initializer):
        w = variable_scope.get_variable(
            "w",
            shape=[input_size + cell_size, cell_size * 4],
            dtype=dtypes.float32)
        b = variable_scope.get_variable(
            "b",
            shape=[cell_size * 4],
            dtype=dtypes.float32,
            initializer=init_ops.zeros_initializer())

        _, _, _, _, _, _, outputs = block_lstm(
            ops.convert_to_tensor(
                sequence_length, dtype=dtypes.int64),
            inputs,
            w,
            b,
            cell_clip=0)

        sess.run([variables.global_variables_initializer()])
        block_outputs = sess.run(outputs)
        block_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        block_wgrads = sess.run(gradients_impl.gradients(outputs, [w, b]))

      self.assertAllClose(basic_outputs, block_outputs)
      self.assertAllClose(basic_grads, block_grads)
      for basic, block in zip(basic_wgrads, block_wgrads):
        self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2)

      with variable_scope.variable_scope("fused", initializer=initializer):
        cell = lstm_ops.LSTMBlockFusedCell(
            cell_size, cell_clip=0, use_peephole=False)
        outputs, state = cell(inputs, dtype=dtypes.float32)

        sess.run([variables.global_variables_initializer()])
        fused_outputs, fused_state = sess.run([outputs, state[0]])
        fused_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        fused_vars = [
            v for v in variables.trainable_variables()
            if v.name.startswith("fused/")
        ]
        fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars))

      self.assertAllClose(basic_outputs, fused_outputs)
      self.assertAllClose(basic_state, fused_state)
      self.assertAllClose(basic_grads, fused_grads)
      for basic, fused in zip(basic_wgrads, fused_wgrads):
        self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
예제 #20
0
    def __init__(self,
                 source_vocab_size,
                 target_vocab_size,
                 buckets,
                 size,
                 num_layers,
                 max_gradient_norm,
                 batch_size,
                 learning_rate,
                 learning_rate_decay_factor,
                 use_lstm=False,
                 num_samples=512,
                 forward_only=False,
                 config=None,
                 corrective_tokens_mask=None):
        """Create the model.

        Args:
          source_vocab_size: size of the source vocabulary.
          target_vocab_size: size of the target vocabulary.
          buckets: a list of pairs (I, O), where I specifies maximum input
            length that will be processed in that bucket, and O specifies
            maximum output length. Training instances that have longer than I
            or outputs longer than O will be pushed to the next bucket and
            padded accordingly. We assume that the list is sorted, e.g., [(2,
            4), (8, 16)].
          size: number of units in each layer of the model.
          num_layers: number of layers in the model.
          max_gradient_norm: gradients will be clipped to maximally this norm.
          batch_size: the size of the batches used during training;
            the model construction is independent of batch_size, so it can be
            changed after initialization if this is convenient, e.g.,
            for decoding.
          learning_rate: learning rate to start with.
          learning_rate_decay_factor: decay learning rate by this much when
            needed.
          use_lstm: if true, we use LSTM cells instead of GRU cells.
          num_samples: number of samples for sampled softmax.
          forward_only: if set, we do not construct the backward pass in the
            model.
        """
        self.source_vocab_size = source_vocab_size
        self.target_vocab_size = target_vocab_size
        self.buckets = buckets
        self.batch_size = batch_size
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)
        self.config = config

        # Feeds for inputs.
        self.encoder_inputs = []
        self.decoder_inputs = []
        self.target_weights = []
        for i in range(buckets[-1][0]):  # Last bucket is the biggest one.
            self.encoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="encoder{0}".format(i)))
        for i in range(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        # One hot encoding of corrective tokens.
        corrective_tokens_tensor = tf.constant(
            corrective_tokens_mask
            if corrective_tokens_mask else np.zeros(self.target_vocab_size),
            shape=[self.target_vocab_size],
            dtype=tf.float32)
        batched_corrective_tokens = tf.stack([corrective_tokens_tensor] *
                                             self.batch_size)
        self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \
            tf.placeholder(
            tf.float32,
            shape=[None, None],
            name="corrective_tokens")

        # Our targets are decoder inputs shifted by one.
        targets = [
            self.decoder_inputs[i + 1]
            for i in range(len(self.decoder_inputs) - 1)
        ]
        # If we use sampled softmax, we need an output projection.
        output_projection = None
        softmax_loss_function = None
        # Sampled softmax only makes sense if we sample less than vocabulary
        # size.
        if num_samples > 0 and num_samples < self.target_vocab_size:
            w = tf.get_variable("proj_w", [size, self.target_vocab_size])
            w_t = tf.transpose(w)
            b = tf.get_variable("proj_b", [self.target_vocab_size])

            output_projection = (w, b)

            def sampled_loss(labels, inputs):
                labels = tf.reshape(labels, [-1, 1])
                return tf.nn.sampled_softmax_loss(w_t, b, labels, inputs,
                                                  num_samples,
                                                  self.target_vocab_size)

            softmax_loss_function = sampled_loss

        # Create the internal multi-layer cell for our RNN.
        single_cell = core_rnn_cell_impl.GRUCell(size)
        if use_lstm:
            single_cell = core_rnn_cell_impl.BasicLSTMCell(size)
        cell = single_cell
        if num_layers > 1:
            cell = core_rnn_cell_impl.MultiRNNCell([single_cell] * num_layers)

        # The seq2seq function: we use embedding for the input and attention.
        def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
            """

            :param encoder_inputs: list of length equal to the input bucket
            length of 1-D tensors (of length equal to the batch size) whose
            elements consist of the token index of each sample in the batch
            at a given index in the input.
            :param decoder_inputs:
            :param do_decode:
            :return:
            """

            if do_decode:
                # Modify bias here to bias the model towards selecting words
                # present in the input sentence.
                input_bias = self.build_input_bias(
                    encoder_inputs, batch_corrective_tokens_mask)

                # Redefined seq2seq to allow for the injection of a special
                # decoding function that
                return seq2seq.embedding_attention_seq2seq(
                    encoder_inputs,
                    decoder_inputs,
                    cell,
                    num_encoder_symbols=source_vocab_size,
                    num_decoder_symbols=target_vocab_size,
                    embedding_size=size,
                    output_projection=output_projection,
                    feed_previous=do_decode,
                    loop_fn_factory=
                    apply_input_bias_and_extract_argmax_fn_factory(input_bias))
            else:
                return seq2seq.embedding_attention_seq2seq(
                    encoder_inputs,
                    decoder_inputs,
                    cell,
                    num_encoder_symbols=source_vocab_size,
                    num_decoder_symbols=target_vocab_size,
                    embedding_size=size,
                    output_projection=output_projection,
                    feed_previous=do_decode)

        # Training outputs and losses.
        if forward_only:
            self.outputs, self.losses = seq2seq.model_with_buckets(
                self.encoder_inputs,
                self.decoder_inputs,
                targets,
                self.target_weights,
                buckets,
                lambda x, y: seq2seq_f(x, y, True),
                softmax_loss_function=softmax_loss_function)

            if output_projection is not None:
                for b in range(len(buckets)):
                    # We need to apply the same input bias used during model
                    # evaluation when decoding.
                    input_bias = self.build_input_bias(
                        self.encoder_inputs[:buckets[b][0]],
                        batch_corrective_tokens_mask)
                    self.outputs[b] = [
                        project_and_apply_input_bias(output, output_projection,
                                                     input_bias)
                        for output in self.outputs[b]
                    ]
        else:
            self.outputs, self.losses = seq2seq.model_with_buckets(
                self.encoder_inputs,
                self.decoder_inputs,
                targets,
                self.target_weights,
                buckets,
                lambda x, y: seq2seq_f(x, y, False),
                softmax_loss_function=softmax_loss_function)

        # Gradients and SGD update operation for training the model.
        params = tf.trainable_variables()
        if not forward_only:
            self.gradient_norms = []
            self.updates = []
            opt = tf.train.RMSPropOptimizer(0.001) if self.config.use_rms_prop \
                else tf.train.GradientDescentOptimizer(self.learning_rate)
            # opt = tf.train.AdamOptimizer()

            for b in range(len(buckets)):
                gradients = tf.gradients(self.losses[b], params)
                clipped_gradients, norm = tf.clip_by_global_norm(
                    gradients, max_gradient_norm)
                self.gradient_norms.append(norm)
                self.updates.append(
                    opt.apply_gradients(zip(clipped_gradients, params),
                                        global_step=self.global_step))

        self.saver = tf.train.Saver(tf.global_variables())
예제 #21
0
    def testMultipleRuns(self):
        """Tests resuming training by feeding state."""
        cell_sizes = [4, 7]
        batch_size = 11
        learning_rate = 0.1
        train_sequence_length = 21
        train_steps = 121
        prediction_steps = [3, 2, 5, 11, 6]

        def get_input_fn(batch_size,
                         sequence_length,
                         state_dict,
                         starting_step=0):
            def input_fn():
                sequence = constant_op.constant(
                    [[(starting_step + i + j) % 2
                      for j in range(sequence_length + 1)]
                     for i in range(batch_size)],
                    dtype=dtypes.int32)
                labels = array_ops.slice(sequence, [0, 0],
                                         [batch_size, sequence_length])
                inputs = array_ops.expand_dims(
                    math_ops.to_float(
                        array_ops.slice(sequence, [0, 1],
                                        [batch_size, sequence_length])), 2)
                input_dict = state_dict
                input_dict['inputs'] = inputs
                return input_dict, labels

            return input_fn

        seq_columns = [
            feature_column.real_valued_column('inputs', dimension=1)
        ]
        config = run_config.RunConfig(tf_random_seed=21212)
        cell = core_rnn_cell_impl.MultiRNNCell(
            [core_rnn_cell_impl.BasicLSTMCell(size) for size in cell_sizes])

        model_dir = tempfile.mkdtemp()
        sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
            num_classes=2,
            num_units=None,
            sequence_feature_columns=seq_columns,
            cell_type=cell,
            learning_rate=learning_rate,
            config=config,
            model_dir=model_dir)

        train_input_fn = get_input_fn(batch_size,
                                      train_sequence_length,
                                      state_dict={})

        sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

        def incremental_predict(estimator, increments):
            """Run `estimator.predict` for `i` steps for `i` in `increments`."""
            step = 0
            incremental_state_dict = {}
            for increment in increments:
                input_fn = get_input_fn(batch_size,
                                        increment,
                                        state_dict=incremental_state_dict,
                                        starting_step=step)
                prediction_dict = estimator.predict(input_fn=input_fn,
                                                    as_iterable=False)
                step += increment
                incremental_state_dict = {
                    k: v
                    for (k, v) in prediction_dict.items()
                    if k.startswith(dynamic_rnn_estimator.RNNKeys.STATE_PREFIX)
                }
            return prediction_dict

        pred_all_at_once = incremental_predict(sequence_estimator,
                                               [sum(prediction_steps)])
        pred_step_by_step = incremental_predict(sequence_estimator,
                                                prediction_steps)

        # Check that the last `prediction_steps[-1]` steps give the same
        # predictions.
        np.testing.assert_array_equal(
            pred_all_at_once['predictions'][:, -1 * prediction_steps[-1]:],
            pred_step_by_step['predictions'],
            err_msg='Mismatch on last {} predictions.'.format(
                prediction_steps[-1]))
        # Check that final states are identical.
        for k, v in pred_all_at_once.items():
            if k.startswith(dynamic_rnn_estimator.RNNKeys.STATE_PREFIX):
                np.testing.assert_array_equal(
                    v,
                    pred_step_by_step[k],
                    err_msg='Mismatch on state {}.'.format(k))
    def __init__(self,
                 batch_size=100,
                 vocab_size=5620,
                 word_dim=100,
                 lstm_dim=100,
                 num_classes=4,
                 num_corpus=8,
                 l2_reg_lambda=0.0,
                 adv_weight=0.05,
                 lr=0.001,
                 clip=5,
                 init_embedding=None,
                 gates=False,
                 adv=False,
                 reuseshare=False,
                 sep=False):

        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.word_dim = word_dim
        self.lstm_dim = lstm_dim
        self.num_classes = num_classes
        self.num_corpus = num_corpus
        self.l2_reg_lambda = l2_reg_lambda
        self.lr = lr
        self.clip = clip
        self.gateus = gates
        self.adv = adv
        self.reuse = reuseshare

        # placeholders
        self.x = tf.placeholder(tf.int32, [None, None])
        self.y = tf.placeholder(tf.int32, [None, None])
        self.y_class = tf.placeholder(tf.int32, [None])
        self.seq_len = tf.placeholder(tf.int32, [None])
        self.dropout_keep_prob = tf.placeholder(tf.float32,
                                                name="dropout_keep_prob")

        if init_embedding is None:
            self.init_embedding = np.zeros([vocab_size, word_dim],
                                           dtype=np.float32)
        else:
            self.init_embedding = init_embedding

        with tf.variable_scope("embedding") as scope:
            self.embedding = tf.Variable(self.init_embedding, name="embedding")

        lstm_fw_cell = core_rnn_cell_impl.BasicLSTMCell(self.lstm_dim)
        lstm_bw_cell = core_rnn_cell_impl.BasicLSTMCell(self.lstm_dim)

        def _shared_layer(input_data, seq_len):

            (forward_output,
             backward_output), _ = tf.nn.bidirectional_dynamic_rnn(
                 lstm_fw_cell,
                 lstm_bw_cell,
                 input_data,
                 dtype=tf.float32,
                 sequence_length=seq_len,
             )
            output = tf.concat(axis=2,
                               values=[forward_output, backward_output])

            return output

        def _private_layer(output_pub, input_data, seq_len, y):
            size = tf.shape(input_data)[0]
            if sep is False:
                if self.gateus:
                    target_dim = tf.shape(output_pub)[2]
                    factor = tf.concat(axis=2, values=[input_data, output_pub])
                    dim = tf.shape(factor)[2]
                    W_g = tf.get_variable(
                        shape=[
                            self.lstm_dim * 2 + self.word_dim * 9,
                            self.lstm_dim * 2
                        ],
                        initializer=tf.truncated_normal_initializer(
                            stddev=0.01),
                        name="w_gates")
                    factor = tf.reshape(factor, [-1, dim])
                    gate = tf.matmul(factor, W_g)

                    output_prep = tf.nn.sigmoid(
                        tf.reshape(output_pub, [-1, target_dim]))
                    output_pub = tf.multiply(output_prep, gate)
                    output_pub = tf.reshape(output_pub, [size, -1, target_dim])

                combined_input_data = tf.concat(
                    axis=2, values=[input_data, output_pub])
                combined_input_data = tf.reshape(
                    combined_input_data,
                    [size, -1, self.lstm_dim * 2 + self.word_dim * 9])
            else:
                combined_input_data = input_data

            (forward_output,
             backward_output), _ = tf.nn.bidirectional_dynamic_rnn(
                 lstm_fw_cell,
                 lstm_bw_cell,
                 combined_input_data,
                 dtype=tf.float32,
                 sequence_length=seq_len,
             )
            output = tf.concat(axis=2,
                               values=[forward_output, backward_output])

            if self.reuse is False:
                output = tf.reshape(output, [-1, self.lstm_dim * 2])
                W = tf.get_variable(
                    shape=[lstm_dim * 2, num_classes],
                    initializer=tf.truncated_normal_initializer(stddev=0.01),
                    name="weights",
                    regularizer=tf.contrib.layers.l2_regularizer(
                        self.l2_reg_lambda))
            else:
                output = tf.concat(axis=2, values=[output, output_pub])
                output = tf.reshape(output, [-1, self.lstm_dim * 4])
                W = tf.get_variable(
                    shape=[lstm_dim * 4, num_classes],
                    initializer=tf.truncated_normal_initializer(stddev=0.01),
                    name="weights",
                    regularizer=tf.contrib.layers.l2_regularizer(
                        self.l2_reg_lambda))

            b = tf.Variable(tf.zeros([num_classes], name="bias"))

            matricized_unary_scores = tf.matmul(output, W) + b
            unary_scores = tf.reshape(matricized_unary_scores,
                                      [size, -1, self.num_classes])

            log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
                unary_scores, y, self.seq_len)

            if self.gateus:
                return unary_scores, log_likelihood, transition_params, gate
            else:
                return unary_scores, log_likelihood, transition_params

        #domain layer
        def _domain_layer(
                output_pub,
                seq_len):  #output_pub batch_size * seq_len * (2 * lstm_dim)
            W_classifier = tf.get_variable(
                shape=[2 * lstm_dim, num_corpus],
                initializer=tf.truncated_normal_initializer(
                    stddev=1.0 / math.sqrt(float(num_corpus))),
                name='W_classifier')
            bias = tf.Variable(tf.zeros([num_corpus], name="class_bias"))
            output_avg = reduce_avg(output_pub, seq_len,
                                    1)  #output_avg batch_size * (2 * lstm_dim)
            logits = tf.matmul(
                output_avg,
                W_classifier) + bias  #logits batch_size * num_corpus
            return logits

        def _Hloss(logits):
            log_soft = tf.nn.log_softmax(logits)  # batch_size * num_corpus
            soft = tf.nn.softmax(logits)
            H_mid = tf.reduce_mean(tf.multiply(soft, log_soft),
                                   axis=0)  # [num_corpus]
            H_loss = tf.reduce_sum(H_mid)
            return H_loss

        def _Dloss(logits, y_class):
            labels = tf.to_int64(y_class)
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=labels, name='xentropy')
            D_loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
            return D_loss

        def _loss(log_likelihood):

            loss = tf.reduce_mean(-log_likelihood)

            return loss

        def _training(loss):

            optimizer = tf.train.AdamOptimizer(self.lr)
            global_step = tf.Variable(0, name="global_step", trainable=False)
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars),
                                              self.clip)
            train_op = optimizer.apply_gradients(list(zip(grads, tvars)),
                                                 global_step=global_step)

            return train_op, global_step

        def _trainingPrivate(loss, taskid):
            optimizer = tf.train.AdamOptimizer(self.lr)
            global_step = tf.Variable(0, name="global_step", trainable=False)
            tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      scope=taskid)
            grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars),
                                              self.clip)
            train_op = optimizer.apply_gradients(list(zip(grads, tvars)),
                                                 global_step=global_step)

            return train_op, global_step

        def _trainingDomain(loss):
            optimizer = tf.train.AdamOptimizer(self.lr)
            global_step = tf.Variable(0, name="global_step", trainable=False)
            tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      scope='domain')
            grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars),
                                              self.clip)
            train_op = optimizer.apply_gradients(list(zip(grads, tvars)),
                                                 global_step=global_step)

            return train_op, global_step

        def _trainingShared(loss, taskid):
            optimizer = tf.train.AdamOptimizer(self.lr)
            global_step = tf.Variable(0, name="global_step", trainable=False)
            tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      scope='shared') + tf.get_collection(
                                          tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope=taskid) + tf.get_collection(
                                              tf.GraphKeys.TRAINABLE_VARIABLES,
                                              scope='embedding')
            grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars),
                                              self.clip)
            train_op = optimizer.apply_gradients(list(zip(grads, tvars)),
                                                 global_step=global_step)

            return train_op, global_step

        seq_len = tf.cast(self.seq_len, tf.int64)
        x = tf.nn.embedding_lookup(
            self.embedding, self.x)  # batch_size * (sequence*9) * word_dim
        size = tf.shape(x)[0]
        # we use window_size 5 and bi_gram, which means for each position,
        # there will be 5+4=9 (character or word) features
        x = tf.reshape(x, [size, -1, 9 * word_dim])  # ba*se*(9*wd)
        x = tf.nn.dropout(x, self.dropout_keep_prob)
        #task1:msr 2:as 3 pku 4 ctb 5 ckip 6 cityu 7 ncc 8 sxu 9 weibo
        with tf.variable_scope("shared"):
            output_pub = _shared_layer(x, seq_len)

        #add adverisal op
        if self.adv:
            with tf.variable_scope("domain"):
                logits = _domain_layer(output_pub, seq_len)
            self.H_loss = _Hloss(logits)
            self.D_loss = _Dloss(logits, self.y_class)

        self.scores = []
        self.transition = []
        self.gate = []
        loglike = []
        #add task op
        for i in range(1, self.num_corpus + 1):
            Taskid = 'task' + str(i)
            with tf.variable_scope(Taskid):
                condition = _private_layer(output_pub, x, seq_len, self.y)
                self.scores.append(condition[0])
                loglike.append(condition[1])
                self.transition.append(condition[2])
                if self.gateus:
                    self.gate.append(condition[3])

        #loss_com is combination loss(cws + hess), losses is basic loss(cws)
        self.losses = [_loss(o) for o in loglike]
        if self.adv:
            self.loss_com = [adv_weight * self.H_loss + o for o in self.losses]
            self.domain_op, self.global_step_domain = _trainingDomain(
                self.D_loss)
        #task_basic_op is for basic train
        self.task_basic_op = []
        self.global_basic_step = []
        for i in range(1, self.num_corpus + 1):
            res = _training(self.losses[i - 1])
            self.task_basic_op.append(res[0])
            self.global_basic_step.append(res[1])

        #task_op is for combination train(cws_loss + hess_loss * adv_weight)
        if self.adv:
            self.task_op = []
            self.global_step = []
            for i in range(1, self.num_corpus + 1):
                Taskid = 'task' + str(i)
                res = _trainingShared(self.loss_com[i - 1], taskid=Taskid)
                self.task_op.append(res[0])
                self.global_step.append(res[1])

        #task_op_ss is for private train
        self.task_op_ss = []
        self.global_pristep = []
        for i in range(1, self.num_corpus + 1):
            Taskid = 'task' + str(i)
            res = _trainingPrivate(self.losses[i - 1], Taskid)
            self.task_op_ss.append(res[0])
            self.global_pristep.append(res[1])
예제 #23
0
  def testOne2ManyRNNSeq2Seq(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        enc_inp = [
            constant_op.constant(
                1, dtypes.int32, shape=[2]) for i in range(2)
        ]
        dec_inp_dict = {}
        dec_inp_dict["0"] = [
            constant_op.constant(
                i, dtypes.int32, shape=[2]) for i in range(3)
        ]
        dec_inp_dict["1"] = [
            constant_op.constant(
                i, dtypes.int32, shape=[2]) for i in range(4)
        ]
        dec_symbols_dict = {"0": 5, "1": 6}
        cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
        outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
            enc_inp, dec_inp_dict, cell, 2, dec_symbols_dict, embedding_size=2))

        sess.run([variables.global_variables_initializer()])
        res = sess.run(outputs_dict["0"])
        self.assertEqual(3, len(res))
        self.assertEqual((2, 5), res[0].shape)
        res = sess.run(outputs_dict["1"])
        self.assertEqual(4, len(res))
        self.assertEqual((2, 6), res[0].shape)
        res = sess.run([state_dict["0"]])
        self.assertEqual((2, 2), res[0].c.shape)
        self.assertEqual((2, 2), res[0].h.shape)
        res = sess.run([state_dict["1"]])
        self.assertEqual((2, 2), res[0].c.shape)
        self.assertEqual((2, 2), res[0].h.shape)

        # Test that previous-feeding model ignores inputs after the first, i.e.
        # dec_inp_dict2 has different inputs from dec_inp_dict after the first
        # time-step.
        dec_inp_dict2 = {}
        dec_inp_dict2["0"] = [
            constant_op.constant(
                0, dtypes.int32, shape=[2]) for _ in range(3)
        ]
        dec_inp_dict2["1"] = [
            constant_op.constant(
                0, dtypes.int32, shape=[2]) for _ in range(4)
        ]
        with variable_scope.variable_scope("other"):
          outputs_dict3, _ = seq2seq_lib.one2many_rnn_seq2seq(
              enc_inp,
              dec_inp_dict2,
              cell,
              2,
              dec_symbols_dict,
              embedding_size=2,
              feed_previous=constant_op.constant(True))
        sess.run([variables.global_variables_initializer()])
        variable_scope.get_variable_scope().reuse_variables()
        outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq(
            enc_inp,
            dec_inp_dict,
            cell,
            2,
            dec_symbols_dict,
            embedding_size=2,
            feed_previous=True)
        outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq(
            enc_inp,
            dec_inp_dict2,
            cell,
            2,
            dec_symbols_dict,
            embedding_size=2,
            feed_previous=True)
        res1 = sess.run(outputs_dict1["0"])
        res2 = sess.run(outputs_dict2["0"])
        res3 = sess.run(outputs_dict3["0"])
        self.assertAllClose(res1, res2)
        self.assertAllClose(res1, res3)
예제 #24
0
  def testEmbeddingAttentionSeq2Seq(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        enc_inp = [
            constant_op.constant(
                1, dtypes.int32, shape=[2]) for i in range(2)
        ]
        dec_inp = [
            constant_op.constant(
                i, dtypes.int32, shape=[2]) for i in range(3)
        ]
        cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
        cell = cell_fn()
        dec, mem = seq2seq_lib.embedding_attention_seq2seq(
            enc_inp,
            dec_inp,
            cell,
            num_encoder_symbols=2,
            num_decoder_symbols=5,
            embedding_size=2)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 5), res[0].shape)

        res = sess.run([mem])
        self.assertEqual((2, 2), res[0].c.shape)
        self.assertEqual((2, 2), res[0].h.shape)

        # Test with state_is_tuple=False.
        with variable_scope.variable_scope("no_tuple"):
          cell_fn = functools.partial(
              core_rnn_cell_impl.BasicLSTMCell,
              2, state_is_tuple=False)
          cell_nt = cell_fn()
          dec, mem = seq2seq_lib.embedding_attention_seq2seq(
              enc_inp,
              dec_inp,
              cell_nt,
              num_encoder_symbols=2,
              num_decoder_symbols=5,
              embedding_size=2)
          sess.run([variables.global_variables_initializer()])
          res = sess.run(dec)
          self.assertEqual(3, len(res))
          self.assertEqual((2, 5), res[0].shape)

          res = sess.run([mem])
          self.assertEqual((2, 4), res[0].shape)

        # Test externally provided output projection.
        w = variable_scope.get_variable("proj_w", [2, 5])
        b = variable_scope.get_variable("proj_b", [5])
        with variable_scope.variable_scope("proj_seq2seq"):
          dec, _ = seq2seq_lib.embedding_attention_seq2seq(
              enc_inp,
              dec_inp,
              cell_fn(),
              num_encoder_symbols=2,
              num_decoder_symbols=5,
              embedding_size=2,
              output_projection=(w, b))
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 2), res[0].shape)
예제 #25
0
  def testAttentionDecoderStateIsTuple(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell(  # pylint: disable=g-long-lambda
            2, state_is_tuple=True)
        cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell(  # pylint: disable=g-long-lambda
            cells=[single_cell() for _ in range(2)], state_is_tuple=True)
        cell = cell_fn()
        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
        enc_outputs, enc_state = core_rnn.static_rnn(
            cell, inp, dtype=dtypes.float32)
        attn_states = array_ops.concat([
            array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
        ], 1)
        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3

        # Use a new cell instance since the attention decoder uses a
        # different variable scope.
        dec, mem = seq2seq_lib.attention_decoder(
            dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 4), res[0].shape)

        res = sess.run([mem])
        self.assertEqual(2, len(res[0]))
        self.assertEqual((2, 2), res[0][0].c.shape)
        self.assertEqual((2, 2), res[0][0].h.shape)
        self.assertEqual((2, 2), res[0][1].c.shape)
        self.assertEqual((2, 2), res[0][1].h.shape)

    def testDynamicAttentionDecoderStateIsTuple(self):
      with self.test_session() as sess:
        with variable_scope.variable_scope(
            "root", initializer=init_ops.constant_initializer(0.5)):
          cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell(  # pylint: disable=g-long-lambda
              cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)])
          cell = cell_fn()
          inp = constant_op.constant(0.5, shape=[2, 2, 2])
          enc_outputs, enc_state = core_rnn.static_rnn(
              cell, inp, dtype=dtypes.float32)
          attn_states = array_ops.concat([
              array_ops.reshape(e, [-1, 1, cell.output_size])
              for e in enc_outputs
          ], 1)
          dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3

          # Use a new cell instance since the attention decoder uses a
          # different variable scope.
          dec, mem = seq2seq_lib.attention_decoder(
              dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
          sess.run([variables.global_variables_initializer()])
          res = sess.run(dec)
          self.assertEqual(3, len(res))
          self.assertEqual((2, 4), res[0].shape)

          res = sess.run([mem])
          self.assertEqual(2, len(res[0]))
          self.assertEqual((2, 2), res[0][0].c.shape)
          self.assertEqual((2, 2), res[0][0].h.shape)
          self.assertEqual((2, 2), res[0][1].c.shape)
          self.assertEqual((2, 2), res[0][1].h.shape)
예제 #26
0
  def testEmbeddingRNNSeq2Seq(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        enc_inp = [
            constant_op.constant(
                1, dtypes.int32, shape=[2]) for i in range(2)
        ]
        dec_inp = [
            constant_op.constant(
                i, dtypes.int32, shape=[2]) for i in range(3)
        ]
        cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
        cell = cell_fn()
        dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
            enc_inp,
            dec_inp,
            cell,
            num_encoder_symbols=2,
            num_decoder_symbols=5,
            embedding_size=2)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 5), res[0].shape)

        res = sess.run([mem])
        self.assertEqual((2, 2), res[0].c.shape)
        self.assertEqual((2, 2), res[0].h.shape)

        # Test with state_is_tuple=False.
        with variable_scope.variable_scope("no_tuple"):
          cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
          dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
              enc_inp,
              dec_inp,
              cell_nt,
              num_encoder_symbols=2,
              num_decoder_symbols=5,
              embedding_size=2)
          sess.run([variables.global_variables_initializer()])
          res = sess.run(dec)
          self.assertEqual(3, len(res))
          self.assertEqual((2, 5), res[0].shape)

          res = sess.run([mem])
          self.assertEqual((2, 4), res[0].shape)

        # Test externally provided output projection.
        w = variable_scope.get_variable("proj_w", [2, 5])
        b = variable_scope.get_variable("proj_b", [5])
        with variable_scope.variable_scope("proj_seq2seq"):
          dec, _ = seq2seq_lib.embedding_rnn_seq2seq(
              enc_inp,
              dec_inp,
              cell_fn(),
              num_encoder_symbols=2,
              num_decoder_symbols=5,
              embedding_size=2,
              output_projection=(w, b))
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 2), res[0].shape)

        # Test that previous-feeding model ignores inputs after the first.
        dec_inp2 = [
            constant_op.constant(
                0, dtypes.int32, shape=[2]) for _ in range(3)
        ]
        with variable_scope.variable_scope("other"):
          d3, _ = seq2seq_lib.embedding_rnn_seq2seq(
              enc_inp,
              dec_inp2,
              cell_fn(),
              num_encoder_symbols=2,
              num_decoder_symbols=5,
              embedding_size=2,
              feed_previous=constant_op.constant(True))
        sess.run([variables.global_variables_initializer()])
        variable_scope.get_variable_scope().reuse_variables()
        d1, _ = seq2seq_lib.embedding_rnn_seq2seq(
            enc_inp,
            dec_inp,
            cell_fn(),
            num_encoder_symbols=2,
            num_decoder_symbols=5,
            embedding_size=2,
            feed_previous=True)
        d2, _ = seq2seq_lib.embedding_rnn_seq2seq(
            enc_inp,
            dec_inp2,
            cell_fn(),
            num_encoder_symbols=2,
            num_decoder_symbols=5,
            embedding_size=2,
            feed_previous=True)
        res1 = sess.run(d1)
        res2 = sess.run(d2)
        res3 = sess.run(d3)
        self.assertAllClose(res1, res2)
        self.assertAllClose(res1, res3)
    def __init__(self,
                 batch_size=100,
                 vocab_size=5620,
                 word_dim=100,
                 lstm_dim=100,
                 num_classes=4,
                 l2_reg_lambda=0.0,
                 lr=0.001,
                 clip=5,
                 init_embedding=None,
                 bi_gram=False,
                 stack=False,
                 lstm_net=False,
                 bi_direction=False):

        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.word_dim = word_dim
        self.lstm_dim = lstm_dim
        self.num_classes = num_classes
        self.l2_reg_lambda = l2_reg_lambda
        self.lr = lr
        self.clip = clip
        self.stack = stack
        self.lstm_net = lstm_net
        self.bi_direction = bi_direction

        if init_embedding is None:
            self.init_embedding = np.zeros([vocab_size, word_dim],
                                           dtype=np.float32)
        else:
            self.init_embedding = init_embedding

        # placeholders
        self.x = tf.placeholder(tf.int32, [None, None])
        self.y = tf.placeholder(tf.int32, [None, None])
        self.seq_len = tf.placeholder(tf.int32, [None])
        self.dropout_keep_prob = tf.placeholder(tf.float32,
                                                name="dropout_keep_prob")

        with tf.variable_scope("embedding"):
            self.embedding = tf.Variable(self.init_embedding, name="embedding")

        with tf.variable_scope("softmax"):
            if self.bi_direction:
                self.W = tf.get_variable(
                    shape=[lstm_dim * 2, num_classes],
                    initializer=tf.truncated_normal_initializer(stddev=0.01),
                    name="weights",
                    regularizer=tf.contrib.layers.l2_regularizer(
                        self.l2_reg_lambda))
            else:
                self.W = tf.get_variable(
                    shape=[lstm_dim, num_classes],
                    initializer=tf.truncated_normal_initializer(stddev=0.01),
                    name="weights",
                    regularizer=tf.contrib.layers.l2_regularizer(
                        self.l2_reg_lambda))

            self.b = tf.Variable(tf.zeros([num_classes], name="bias"))

        with tf.variable_scope("lstm"):
            if self.lstm_net is False:
                self.fw_cell = core_rnn_cell_impl.GRUCell(self.lstm_dim)
                self.bw_cell = core_rnn_cell_impl.GRUCell(self.lstm_dim)
            else:
                self.fw_cell = core_rnn_cell_impl.BasicLSTMCell(self.lstm_dim)
                self.bw_cell = core_rnn_cell_impl.BasicLSTMCell(self.lstm_dim)

        with tf.variable_scope("forward"):
            seq_len = tf.cast(self.seq_len, tf.int64)
            x = tf.nn.embedding_lookup(
                self.embedding,
                self.x)  # batch_size * (sequence*9 or 1) * word_dim
            x = tf.nn.dropout(x, self.dropout_keep_prob)

            size = tf.shape(x)[0]
            if bi_gram is False:
                x = tf.reshape(x, [size, -1, word_dim])  # ba*se*wd
            else:
                x = tf.reshape(x, [size, -1, 9 * word_dim])

            if self.bi_direction:
                (forward_output,
                 backward_output), _ = tf.nn.bidirectional_dynamic_rnn(
                     self.fw_cell,
                     self.bw_cell,
                     x,
                     dtype=tf.float32,
                     sequence_length=seq_len,
                     scope='layer_1')
                output = tf.concat(axis=2,
                                   values=[forward_output, backward_output])
                if self.stack:
                    (forward_output,
                     backward_output), _ = tf.nn.bidirectional_dynamic_rnn(
                         self.fw_cell,
                         self.bw_cell,
                         output,
                         dtype=tf.float32,
                         sequence_length=seq_len,
                         scope='layer_2')
                    output = tf.concat(
                        axis=2, values=[forward_output, backward_output])
            else:
                forward_output, _ = tf.nn.dynamic_rnn(self.fw_cell,
                                                      x,
                                                      dtype=tf.float32,
                                                      sequence_length=seq_len,
                                                      scope='layer_1')
                output = forward_output
                if self.stack:
                    forward_output, _ = tf.nn.dynamic_rnn(
                        self.fw_cell,
                        output,
                        dtype=tf.float32,
                        sequence_length=seq_len,
                        scope='layer_2')
                    output = forward_output

            if self.bi_direction:
                output = tf.reshape(output, [-1, 2 * self.lstm_dim])
            else:
                output = tf.reshape(output, [-1, self.lstm_dim])

            matricized_unary_scores = tf.matmul(output, self.W) + self.b

            self.unary_scores = tf.reshape(matricized_unary_scores,
                                           [size, -1, self.num_classes])

        with tf.variable_scope("loss") as scope:
            # CRF log likelihood
            log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
                self.unary_scores, self.y, self.seq_len)

            self.loss = tf.reduce_mean(-log_likelihood)

        with tf.variable_scope("train_ops") as scope:
            self.optimizer = tf.train.AdamOptimizer(self.lr)

            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)

            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars),
                                              self.clip)
            self.train_op = self.optimizer.apply_gradients(
                zip(grads, tvars), global_step=self.global_step)