Example #1
0
    def __init__(self):
        self.vocab_path = FLAGS.vocab_path
        self.checkpoint_path = FLAGS.checkpoint_path
        self.freeze_graph_path = FLAGS.freeze_graph_path
        self.saved_model_path = FLAGS.saved_model_path

        self.use_crf = FLAGS.use_crf
        self.num_steps = FLAGS.num_steps

        self.default_label = FLAGS.default_label
        self.default_score = FLAGS.default_predict_score

        self.data_utils = DataUtils()
        self.tensorflow_utils = TensorflowUtils()
        self.num_classes = self.data_utils.get_vocabulary_size(os.path.join(FLAGS.vocab_path, 'labels_vocab.txt'))
        self.sequence_labeling_model = SequenceLabelingModel()
        self.init_predict_graph()
Example #2
0
    def __init__(self):
        self.raw_data_path = FLAGS.raw_data_path
        self.vocab_path = FLAGS.vocab_path

        self.use_stored_embedding = FLAGS.use_stored_embedding
        self.use_lstm = FLAGS.use_lstm
        self.use_dynamic_rnn = FLAGS.use_dynamic_rnn
        self.use_bidirectional_rnn = FLAGS.use_bidirectional_rnn

        self.batch_size = FLAGS.batch_size
        self.num_steps = FLAGS.num_steps
        self.num_layers = FLAGS.num_layers
        self.embedding_size = FLAGS.embedding_size
        # self.hidden_size = FLAGS.hidden_size
        self.hidden_size = FLAGS.embedding_size
        self.keep_prob = FLAGS.keep_prob

        self.tensorflow_utils = TensorflowUtils()
Example #3
0
    def __init__(self):
        self.tfrecords_path = FLAGS.tfrecords_path
        self.checkpoint_path = FLAGS.checkpoint_path
        self.tensorboard_path = FLAGS.tensorboard_path

        self.use_crf = FLAGS.use_crf
        self.learning_rate = FLAGS.learning_rate
        self.learning_rate_decay_factor = FLAGS.learning_rate_decay_factor
        self.decay_steps = FLAGS.decay_steps
        self.clip_norm = FLAGS.clip_norm
        self.max_training_step = FLAGS.max_training_step

        self.train_tfrecords_filename = os.path.join(self.tfrecords_path,
                                                     'train.tfrecords')
        self.test_tfrecords_filename = os.path.join(self.tfrecords_path,
                                                    'test.tfrecords')

        self.data_utils = DataUtils()
        self.num_classes = self.data_utils.get_vocabulary_size(
            os.path.join(FLAGS.vocab_path, 'labels_vocab.txt'))
        self.tensorflow_utils = TensorflowUtils()
        self.sequence_labeling_model = SequenceLabelingModel()
Example #4
0
    def train(self):
        self.train_is_alive = True

        self.hdfs_client.hdfs_download(
            os.path.join(self.flags.input_path, 'train.txt'),
            os.path.join(self.flags.datasets_path, 'train.txt'))
        self.hdfs_client.hdfs_download(
            os.path.join(self.flags.input_path, 'test.txt'),
            os.path.join(self.flags.datasets_path, 'test.txt'))

        self.data_utils.label_segment_file(
            os.path.join(self.flags.datasets_path, 'train.txt'),
            os.path.join(self.flags.datasets_path, 'label_train.txt'))
        self.data_utils.label_segment_file(
            os.path.join(self.flags.datasets_path, 'test.txt'),
            os.path.join(self.flags.datasets_path, 'label_test.txt'))

        self.data_utils.split_label_file(
            os.path.join(self.flags.datasets_path, 'label_train.txt'),
            os.path.join(self.flags.datasets_path, 'split_train.txt'))
        self.data_utils.split_label_file(
            os.path.join(self.flags.datasets_path, 'label_test.txt'),
            os.path.join(self.flags.datasets_path, 'split_test.txt'))

        words_vocab, labels_vocab = self.data_utils.create_vocabulary(
            os.path.join(self.flags.datasets_path, 'split_train.txt'),
            self.flags.vocab_path, self.flags.vocab_drop_limit)

        train_word_ids_list, train_label_ids_list = self.data_utils.file_to_word_ids(
            os.path.join(self.flags.datasets_path, 'split_train.txt'),
            words_vocab, labels_vocab)
        test_word_ids_list, test_label_ids_list = self.data_utils.file_to_word_ids(
            os.path.join(self.flags.datasets_path, 'split_test.txt'),
            words_vocab, labels_vocab)

        tensorflow_utils = TensorflowUtils()
        tensorflow_utils.create_record(
            train_word_ids_list, train_label_ids_list,
            os.path.join(self.flags.tfrecords_path, 'train.tfrecords'))
        tensorflow_utils.create_record(
            test_word_ids_list, test_label_ids_list,
            os.path.join(self.flags.tfrecords_path, 'test.tfrecords'))

        self.hdfs_client.hdfs_upload(
            self.flags.vocab_path,
            os.path.join(
                self.flags.output_path,
                os.path.basename(os.path.normpath(self.flags.vocab_path))))

        train = Train()
        train.train()

        self.train_is_alive = False
Example #5
0
class Predict(object):
    def __init__(self):
        self.vocab_path = FLAGS.vocab_path
        self.checkpoint_path = FLAGS.checkpoint_path
        self.freeze_graph_path = FLAGS.freeze_graph_path
        self.saved_model_path = FLAGS.saved_model_path

        self.use_crf = FLAGS.use_crf
        self.num_steps = FLAGS.num_steps

        self.default_label = FLAGS.default_label
        self.default_score = FLAGS.default_predict_score

        self.data_utils = DataUtils()
        self.tensorflow_utils = TensorflowUtils()
        self.num_classes = self.data_utils.get_vocabulary_size(
            os.path.join(FLAGS.vocab_path, 'labels_vocab.txt'))
        self.sequence_labeling_model = SequenceLabelingModel()
        self.init_predict_graph()

    def init_predict_graph(self):
        """
        init predict model graph
        :return:
        """
        # split 1-D String dense Tensor to words SparseTensor
        self.input_sentences = tf.placeholder(dtype=tf.string,
                                              shape=[None],
                                              name='input_sentences')
        sparse_words = tf.string_split(self.input_sentences, delimiter=' ')

        # slice SparseTensor
        valid_indices = tf.less(sparse_words.indices,
                                tf.constant([self.num_steps], dtype=tf.int64))
        valid_indices = tf.reshape(
            tf.split(valid_indices, [1, 1], axis=1)[1], [-1])
        valid_sparse_words = tf.sparse_retain(sparse_words, valid_indices)

        excess_indices = tf.greater_equal(
            sparse_words.indices, tf.constant([self.num_steps],
                                              dtype=tf.int64))
        excess_indices = tf.reshape(
            tf.split(excess_indices, [1, 1], axis=1)[1], [-1])
        excess_sparse_words = tf.sparse_retain(sparse_words, excess_indices)

        # compute sentences lengths
        int_values = tf.ones(shape=tf.shape(valid_sparse_words.values),
                             dtype=tf.int64)
        int_valid_sparse_words = tf.SparseTensor(
            indices=valid_sparse_words.indices,
            values=int_values,
            dense_shape=valid_sparse_words.dense_shape)
        input_sentences_lengths = tf.sparse_reduce_sum(int_valid_sparse_words,
                                                       axis=1)

        # sparse to dense
        default_padding_word = self.data_utils._START_VOCAB[0]
        words = tf.sparse_to_dense(
            sparse_indices=valid_sparse_words.indices,
            output_shape=[valid_sparse_words.dense_shape[0], self.num_steps],
            sparse_values=valid_sparse_words.values,
            default_value=default_padding_word)

        # dict words to ids
        with open(os.path.join(self.vocab_path, 'words_vocab.txt'),
                  encoding='utf-8',
                  mode='rt') as data_file:
            words_table_list = [
                line.strip() for line in data_file if line.strip()
            ]
        words_table_tensor = tf.constant(words_table_list, dtype=tf.string)
        words_table = lookup.index_table_from_tensor(
            mapping=words_table_tensor,
            default_value=self.data_utils._START_VOCAB_ID[3])
        # words_table = lookup.index_table_from_file(os.path.join(vocab_path, 'words_vocab.txt'), default_value=3)
        words_ids = words_table.lookup(words)

        # blstm model predict
        with tf.variable_scope('model', reuse=None):
            logits = self.sequence_labeling_model.inference(
                words_ids,
                input_sentences_lengths,
                self.num_classes,
                is_training=False)

        if self.use_crf:
            logits = tf.reshape(logits,
                                shape=[-1, self.num_steps, self.num_classes])
            transition_params = tf.get_variable(
                "transitions", [self.num_classes, self.num_classes])
            input_sentences_lengths = tf.to_int32(input_sentences_lengths)
            predict_labels_ids, sequence_scores = crf.crf_decode(
                logits, transition_params, input_sentences_lengths)
            predict_labels_ids = tf.to_int64(predict_labels_ids)
            sequence_scores = tf.reshape(sequence_scores, shape=[-1, 1])
            normalized_sequence_scores = self.tensorflow_utils.score_normalize(
                sequence_scores)
            predict_scores = tf.matmul(
                normalized_sequence_scores,
                tf.ones(shape=[1, self.num_steps], dtype=tf.float32))
        else:
            props = tf.nn.softmax(logits)
            max_prop_values, max_prop_indices = tf.nn.top_k(props, k=1)
            predict_labels_ids = tf.reshape(max_prop_indices,
                                            shape=[-1, self.num_steps])
            predict_labels_ids = tf.to_int64(predict_labels_ids)
            predict_scores = tf.reshape(max_prop_values,
                                        shape=[-1, self.num_steps])
        predict_scores = tf.as_string(predict_scores, precision=3)

        # dict ids to labels
        with open(os.path.join(self.vocab_path, 'labels_vocab.txt'),
                  encoding='utf-8',
                  mode='rt') as data_file:
            labels_table_list = [
                line.strip() for line in data_file if line.strip()
            ]
        labels_table_tensor = tf.constant(labels_table_list, dtype=tf.string)
        labels_table = lookup.index_to_string_table_from_tensor(
            mapping=labels_table_tensor, default_value=self.default_label)
        # labels_table = lookup.index_to_string_table_from_file(os.path.join(vocab_path, 'labels_vocab.txt'), default_value='O')
        predict_labels = labels_table.lookup(predict_labels_ids)

        sparse_predict_labels = self.tensorflow_utils.sparse_concat(
            predict_labels, valid_sparse_words, excess_sparse_words,
            self.default_label)
        sparse_predict_scores = self.tensorflow_utils.sparse_concat(
            predict_scores, valid_sparse_words, excess_sparse_words, '0.0')

        self.format_predict_labels = self.tensorflow_utils.sparse_string_join(
            sparse_predict_labels, 'predict_labels')
        self.format_predict_scores = self.tensorflow_utils.sparse_string_join(
            sparse_predict_scores, 'predict_scores')

        saver = tf.train.Saver()
        tables_init_op = tf.tables_initializer()

        self.sess = tf.Session()
        self.sess.run(tables_init_op)
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_path)
        if ckpt and ckpt.model_checkpoint_path:
            print('read model from {}'.format(ckpt.model_checkpoint_path))
            saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found at %s' % self.checkpoint_path)
            return

    def predict(self, words_list):
        """
        Predict labels, the operation of transfer words to ids is processed by tensorflow tensor
        Input words list
        :param words_list:
        :return:
        """
        split_words_list = []
        map_split_indexes = []
        for index in range(len(words_list)):
            temp_words_list = self.data_utils.split_long_sentence(
                words_list[index], self.num_steps)
            map_split_indexes.append(
                list(
                    range(len(split_words_list),
                          len(split_words_list) + len(temp_words_list))))
            split_words_list.extend(temp_words_list)

        predict_labels, predict_scores = self.sess.run(
            [self.format_predict_labels, self.format_predict_scores],
            feed_dict={self.input_sentences: split_words_list})
        predict_labels_str = [
            predict_label.decode('utf-8') for predict_label in predict_labels
        ]
        predict_scores_str = [
            predict_score.decode('utf-8') for predict_score in predict_scores
        ]

        merge_predict_labels_str = []
        merge_predict_scores_str = []
        for indexes in map_split_indexes:
            merge_predict_label_str = ' '.join(
                [predict_labels_str[index] for index in indexes])
            merge_predict_labels_str.append(merge_predict_label_str)
            merge_predict_score_str = ' '.join(
                [predict_scores_str[index] for index in indexes])
            merge_predict_scores_str.append(merge_predict_score_str)

        return merge_predict_labels_str, merge_predict_scores_str

    def file_predict(self, data_filename, predict_filename):
        """
        Predict data_filename, save the predict result into predict_filename
        The label is split into single word, -B -M -E -S
        :param data_filename:
        :param predict_filename:
        :return:
        """
        print('Predict file ' + data_filename)
        sentence_list = []
        words_list = []
        labels_list = []
        predict_labels_list = []
        with open(data_filename, encoding='utf-8', mode='rt') as data_file:
            for line in data_file:
                words, labels = self.data_utils.split(line)
                if words and labels:
                    sentence_list.append(''.join(words))
                    words_list.append(' '.join(words))
                    labels_list.append(' '.join(labels))
                    predict_labels, _ = self.predict([' '.join(words)])
                    predict_labels_list.append(predict_labels[0])
        word_predict_label_list = []
        word_category_list = []
        word_predict_category_list = []
        for (words, labels, predict_labels) in zip(words_list, labels_list,
                                                   predict_labels_list):
            word_list = words.split()
            label_list = labels.split()
            predict_label_list = predict_labels.split()
            word_predict_label = ' '.join([
                word + '/' + predict_label
                for (word, predict_label) in zip(word_list, predict_label_list)
            ])
            word_predict_label_list.append(word_predict_label)
            # merge label
            merge_word_list, merge_label_list = self.data_utils.merge_label(
                word_list, label_list)
            word_category = ' '.join([
                word + '/' + label
                for (word, label) in zip(merge_word_list, merge_label_list)
                if label != self.default_label
            ])
            word_category_list.append(word_category)
            # merge predict label
            merge_predict_word_list, merge_predict_label_list = self.data_utils.merge_label(
                word_list, predict_label_list)
            word_predict_category = ' '.join([
                predict_word + '/' + predict_label
                for (predict_word, predict_label) in zip(
                    merge_predict_word_list, merge_predict_label_list)
                if predict_label != 'O'
            ])
            word_predict_category_list.append(word_predict_category)
        with open(predict_filename, encoding='utf-8',
                  mode='wt') as predict_file:
            for (sentence, word_predict_label, word_category, word_predict_category) in \
                    zip(sentence_list, word_predict_label_list, word_category_list, word_predict_category_list):
                predict_file.write('Passage: ' + sentence + '\n')
                predict_file.write('SinglePredict: ' + word_predict_label +
                                   '\n')
                predict_file.write('Merge: ' + word_category + '\n')
                predict_file.write('MergePredict: ' + word_predict_category +
                                   '\n\n')

    def freeze_graph(self):
        """
        Save graph into .pb file
        :return:
        """
        graph = tf.graph_util.convert_variables_to_constants(
            self.sess, self.sess.graph_def,
            ['init_all_tables', 'predict_labels', 'predict_scores'])
        tf.train.write_graph(graph,
                             self.freeze_graph_path,
                             'frozen_graph.pb',
                             as_text=False)
        print('Successfully freeze model to %s' % self.freeze_graph_path)

    def saved_model_pb(self):
        """
        Saved model into .ph and variables files, loading it by tensorflow serving,
        :return:
        """
        saved_model_path = os.path.join(self.saved_model_path, '1')
        if os.path.exists(saved_model_path):
            shutil.rmtree(saved_model_path)
        builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
        input_tensor_info = tf.saved_model.utils.build_tensor_info(
            self.input_sentences)
        output_labels_tensor_info = tf.saved_model.utils.build_tensor_info(
            self.format_predict_labels)
        output_scores_tensor_info = tf.saved_model.utils.build_tensor_info(
            self.format_predict_scores)
        prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'input_sentences': input_tensor_info},
            outputs={
                'predict_labels': output_labels_tensor_info,
                'predict_scores': output_scores_tensor_info
            },
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
        legacy_init_op = tf.group(tf.tables_initializer(),
                                  name='legacy_init_op')
        builder.add_meta_graph_and_variables(
            self.sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={'predict_segment': prediction_signature},
            legacy_init_op=legacy_init_op)
        builder.save()
        print('Successfully exported model to %s' % saved_model_path)
Example #6
0
class SequenceLabelingModel(object):
    def __init__(self):
        self.raw_data_path = FLAGS.raw_data_path
        self.vocab_path = FLAGS.vocab_path

        self.use_stored_embedding = FLAGS.use_stored_embedding
        self.use_lstm = FLAGS.use_lstm
        self.use_dynamic_rnn = FLAGS.use_dynamic_rnn
        self.use_bidirectional_rnn = FLAGS.use_bidirectional_rnn

        self.batch_size = FLAGS.batch_size
        self.num_steps = FLAGS.num_steps
        self.num_layers = FLAGS.num_layers
        self.embedding_size = FLAGS.embedding_size
        # self.hidden_size = FLAGS.hidden_size
        self.hidden_size = FLAGS.embedding_size
        self.keep_prob = FLAGS.keep_prob

        self.tensorflow_utils = TensorflowUtils()

    def inference(self, inputs, inputs_sequence_length, num_classes,
                  is_training):
        """
        Bilstm + crf model
        :param inputs:
        :param inputs_sequence_length:
        :param num_classes:
        :param is_training:
        :return:
        """
        with tf.device('/cpu:0'):
            if self.use_stored_embedding:
                embedding = self.tensorflow_utils.load_embedding(
                    os.path.join(self.raw_data_path, 'embedding.txt'),
                    os.path.join(self.vocab_path, 'words_vocab.txt'))
            else:
                embedding = tf.get_variable(
                    'embedding', [self.vocab_size, self.embedding_size],
                    initializer=tf.random_uniform_initializer(),
                    dtype=tf.float32)
            inputs_embedding = tf.nn.embedding_lookup(embedding, inputs)
        if is_training and self.keep_prob < 1:
            inputs_embedding = tf.nn.dropout(inputs_embedding, self.keep_prob)

        rnn_cell_collection = []
        bi_flag = 2 if self.use_bidirectional_rnn else 1
        for _ in range(bi_flag):
            initializer = tf.random_uniform_initializer(-0.1, 0.1)
            if self.use_lstm:
                cell = tf.nn.rnn_cell.LSTMCell(num_units=self.hidden_size,
                                               initializer=initializer,
                                               forget_bias=1.0)
            else:
                cell = tf.nn.rnn_cell.GRUCell(num_units=self.hidden_size)
            if is_training and self.keep_prob < 1.0:
                cell = tf.nn.rnn_cell.DropoutWrapper(
                    cell, output_keep_prob=self.keep_prob)
            multi_cell = tf.nn.rnn_cell.MultiRNNCell(
                [cell for _ in range(self.num_layers)])
            rnn_cell_collection.append(multi_cell)

        if self.use_dynamic_rnn:
            if self.use_bidirectional_rnn:
                outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                    rnn_cell_collection[0],
                    rnn_cell_collection[1],
                    inputs_embedding,
                    dtype=tf.float32,
                    sequence_length=inputs_sequence_length)
                outputs = tf.concat(outputs, axis=2)
            else:
                outputs, _ = tf.nn.dynamic_rnn(
                    rnn_cell_collection[0],
                    inputs_embedding,
                    dtype=tf.float32,
                    sequence_length=inputs_sequence_length)
        else:
            inputs_embedding = tf.unstack(inputs_embedding, axis=1)
            if self.use_bidirectional_rnn:
                outputs, _, _ = tf.nn.static_bidirectional_rnn(
                    rnn_cell_collection[0],
                    rnn_cell_collection[1],
                    inputs_embedding,
                    dtype=tf.float32,
                    sequence_length=inputs_sequence_length)
            else:
                outputs, _ = tf.nn.static_rnn(
                    rnn_cell_collection[0],
                    inputs_embedding,
                    dtype=tf.float32,
                    sequence_length=inputs_sequence_length)
            outputs = tf.stack(outputs, axis=1)
        outputs = tf.reshape(outputs, shape=[-1, bi_flag * self.hidden_size])
        weights = tf.get_variable('weights',
                                  [bi_flag * self.hidden_size, num_classes],
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', [num_classes], dtype=tf.float32)
        logits = tf.matmul(outputs, weights) + biases
        return logits

    def loss(self, logits, labels):
        """
        Loss of cross entropy between logits and labels
        :param logits:
        :param labels:
        :return:
        """
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        loss = tf.reduce_mean(cross_entropy, name='loss')
        return loss

    def accuracy(self, logits, labels):
        """
        Computer the accuracy of rnn model
        :param logits:
        :param labels:
        :return:
        """
        props = tf.nn.softmax(logits)
        prediction_labels = tf.argmax(props, 1)
        correct_prediction = tf.equal(prediction_labels, labels)
        accuracy_value = tf.reduce_mean(tf.cast(correct_prediction,
                                                tf.float32))
        return accuracy_value

    def slice_seq(self, logits, labels, sequence_lengths):
        """
        Slice sequence, used by accuracy method
        :param logits:
        :param labels:
        :param words_len:
        :return:
        """
        labels = tf.reshape(labels, shape=[-1])
        slice_indices = tf.constant([], dtype=tf.int64)
        for index in range(self.batch_size):
            sub_slice_indices = tf.range(sequence_lengths[index])
            sub_slice_indices = tf.add(
                tf.constant(index * self.num_steps, dtype=tf.int64),
                sub_slice_indices)
            slice_indices = tf.concat([slice_indices, sub_slice_indices],
                                      axis=0)
        slice_logits = tf.gather(logits, slice_indices)
        slice_labels = tf.gather(labels, slice_indices)
        return slice_logits, slice_labels

    def crf_loss(self, logits, labels, sequence_lengths, num_classes):
        """
        Loss of crf
        :param logits:
        :param labels:
        :param sequence_lengths:
        :param num_classes:
        :return:
        """
        logits = tf.reshape(
            logits, shape=[self.batch_size, self.num_steps, num_classes])
        labels = tf.cast(labels, tf.int32)
        log_likelihood, transition_params = crf.crf_log_likelihood(
            logits, labels, sequence_lengths)
        loss = tf.reduce_mean(-log_likelihood, name='loss')
        return loss, transition_params

    def crf_accuracy(self, logits, labels, sequence_length, transition_params,
                     num_classes):
        """
        Computer the accuracy of rnn + crf model
        :param logits:
        :param labels:
        :param sequence_length:
        :param transition_params:
        :param num_classes:
        :return:
        """
        logits = tf.reshape(
            logits, shape=[self.batch_size, self.num_steps, num_classes])
        sequence_length = tf.to_int32(sequence_length)
        predict_indices, _ = crf.crf_decode(logits, transition_params,
                                            sequence_length)
        predict_indices = tf.to_int64(predict_indices,
                                      name='predict_label_indices')
        correct_prediction = tf.equal(predict_indices, labels)
        accuracy_value = tf.reduce_mean(tf.cast(correct_prediction,
                                                tf.float32))
        return accuracy_value
Example #7
0
class Train(object):
    def __init__(self):
        self.tfrecords_path = FLAGS.tfrecords_path
        self.checkpoint_path = FLAGS.checkpoint_path
        self.tensorboard_path = FLAGS.tensorboard_path

        self.use_crf = FLAGS.use_crf
        self.learning_rate = FLAGS.learning_rate
        self.learning_rate_decay_factor = FLAGS.learning_rate_decay_factor
        self.decay_steps = FLAGS.decay_steps
        self.clip_norm = FLAGS.clip_norm
        self.max_training_step = FLAGS.max_training_step

        self.train_tfrecords_filename = os.path.join(self.tfrecords_path,
                                                     'train.tfrecords')
        self.test_tfrecords_filename = os.path.join(self.tfrecords_path,
                                                    'test.tfrecords')

        self.data_utils = DataUtils()
        self.num_classes = self.data_utils.get_vocabulary_size(
            os.path.join(FLAGS.vocab_path, 'labels_vocab.txt'))
        self.tensorflow_utils = TensorflowUtils()
        self.sequence_labeling_model = SequenceLabelingModel()

    def train(self):
        """
        train bilstm + crf model
        :return:
        """
        train_data = self.tensorflow_utils.read_and_decode(
            self.train_tfrecords_filename)
        train_batch_features, train_batch_labels, train_batch_features_lengths = train_data
        test_data = self.tensorflow_utils.read_and_decode(
            self.test_tfrecords_filename)
        test_batch_features, test_batch_labels, test_batch_features_lengths = test_data

        with tf.device('/cpu:0'):
            global_step = tf.Variable(0, name='global_step', trainable=False)
        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(self.learning_rate,
                                        global_step,
                                        self.decay_steps,
                                        self.learning_rate_decay_factor,
                                        staircase=True)
        optimizer = tf.train.RMSPropOptimizer(lr)

        with tf.variable_scope('model'):
            logits = self.sequence_labeling_model.inference(
                train_batch_features,
                train_batch_features_lengths,
                self.num_classes,
                is_training=True)
        train_batch_labels = tf.to_int64(train_batch_labels)

        if self.use_crf:
            loss, transition_params = self.sequence_labeling_model.crf_loss(
                logits, train_batch_labels, train_batch_features_lengths,
                self.num_classes)
        else:
            slice_logits, slice_train_batch_labels = self.sequence_labeling_model.slice_seq(
                logits, train_batch_labels, train_batch_features_lengths)
            loss = self.sequence_labeling_model.loss(slice_logits,
                                                     slice_train_batch_labels)

        with tf.variable_scope('model', reuse=True):
            accuracy_logits = self.sequence_labeling_model.inference(
                test_batch_features,
                test_batch_features_lengths,
                self.num_classes,
                is_training=False)
        test_batch_labels = tf.to_int64(test_batch_labels)
        if self.use_crf:
            accuracy = self.sequence_labeling_model.crf_accuracy(
                accuracy_logits, test_batch_labels,
                test_batch_features_lengths, transition_params,
                self.num_classes)
        else:
            slice_accuracy_logits, slice_test_batch_labels = self.sequence_labeling_model.slice_seq(
                accuracy_logits, test_batch_labels,
                test_batch_features_lengths)
            accuracy = self.sequence_labeling_model.accuracy(
                slice_accuracy_logits, slice_test_batch_labels)

        # summary
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('accuracy', accuracy)
        tf.summary.scalar('lr', lr)

        # compute and update gradient
        # train_op = optimizer.minimize(loss, global_step=global_step)

        # computer, clip and update gradient
        gradients, variables = zip(*optimizer.compute_gradients(loss))
        clip_gradients, _ = tf.clip_by_global_norm(gradients, self.clip_norm)
        train_op = optimizer.apply_gradients(zip(clip_gradients, variables),
                                             global_step=global_step)

        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver(max_to_keep=None)
        checkpoint_filename = os.path.join(self.checkpoint_path, 'model.ckpt')

        with tf.Session() as sess:
            summary_op = tf.summary.merge_all()
            writer = tf.summary.FileWriter(self.tensorboard_path, sess.graph)
            sess.run(init_op)

            ckpt = tf.train.get_checkpoint_state(self.checkpoint_path)
            if ckpt and ckpt.model_checkpoint_path:
                print('Continue training from the model {}'.format(
                    ckpt.model_checkpoint_path))
                saver.restore(sess, ckpt.model_checkpoint_path)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)

            max_accuracy = 0.0
            min_loss = 100000000.0
            try:
                while not coord.should_stop():
                    _, loss_value, step = sess.run(
                        [train_op, loss, global_step])
                    if step % 100 == 0:
                        accuracy_value, summary_value, lr_value = sess.run(
                            [accuracy, summary_op, lr])
                        china_tz = pytz.timezone('Asia/Shanghai')
                        current_time = datetime.datetime.now(china_tz)
                        print('[{}] Step: {}, loss: {}, accuracy: {}, lr: {}'.
                              format(current_time, step, loss_value,
                                     accuracy_value, lr_value))
                        if accuracy_value > max_accuracy and loss_value < min_loss:
                            writer.add_summary(summary_value, step)
                            data_clean.clean_checkpoint(self.checkpoint_path)
                            saver.save(sess,
                                       checkpoint_filename,
                                       global_step=step)
                            print('save model to %s-%d' %
                                  (checkpoint_filename, step))
                            max_accuracy = accuracy_value
                            min_loss = loss_value
                    if step >= self.max_training_step:
                        print('Done training after %d step' % step)
                        break
            except tf.errors.OutOfRangeError:
                print('Done training after reading all data')
            finally:
                coord.request_stop()
            coord.join(threads)