def __init__(self): self.vocab_path = FLAGS.vocab_path self.checkpoint_path = FLAGS.checkpoint_path self.freeze_graph_path = FLAGS.freeze_graph_path self.saved_model_path = FLAGS.saved_model_path self.use_crf = FLAGS.use_crf self.num_steps = FLAGS.num_steps self.default_label = FLAGS.default_label self.default_score = FLAGS.default_predict_score self.data_utils = DataUtils() self.tensorflow_utils = TensorflowUtils() self.num_classes = self.data_utils.get_vocabulary_size(os.path.join(FLAGS.vocab_path, 'labels_vocab.txt')) self.sequence_labeling_model = SequenceLabelingModel() self.init_predict_graph()
def __init__(self): self.raw_data_path = FLAGS.raw_data_path self.vocab_path = FLAGS.vocab_path self.use_stored_embedding = FLAGS.use_stored_embedding self.use_lstm = FLAGS.use_lstm self.use_dynamic_rnn = FLAGS.use_dynamic_rnn self.use_bidirectional_rnn = FLAGS.use_bidirectional_rnn self.batch_size = FLAGS.batch_size self.num_steps = FLAGS.num_steps self.num_layers = FLAGS.num_layers self.embedding_size = FLAGS.embedding_size # self.hidden_size = FLAGS.hidden_size self.hidden_size = FLAGS.embedding_size self.keep_prob = FLAGS.keep_prob self.tensorflow_utils = TensorflowUtils()
def __init__(self): self.tfrecords_path = FLAGS.tfrecords_path self.checkpoint_path = FLAGS.checkpoint_path self.tensorboard_path = FLAGS.tensorboard_path self.use_crf = FLAGS.use_crf self.learning_rate = FLAGS.learning_rate self.learning_rate_decay_factor = FLAGS.learning_rate_decay_factor self.decay_steps = FLAGS.decay_steps self.clip_norm = FLAGS.clip_norm self.max_training_step = FLAGS.max_training_step self.train_tfrecords_filename = os.path.join(self.tfrecords_path, 'train.tfrecords') self.test_tfrecords_filename = os.path.join(self.tfrecords_path, 'test.tfrecords') self.data_utils = DataUtils() self.num_classes = self.data_utils.get_vocabulary_size( os.path.join(FLAGS.vocab_path, 'labels_vocab.txt')) self.tensorflow_utils = TensorflowUtils() self.sequence_labeling_model = SequenceLabelingModel()
def train(self): self.train_is_alive = True self.hdfs_client.hdfs_download( os.path.join(self.flags.input_path, 'train.txt'), os.path.join(self.flags.datasets_path, 'train.txt')) self.hdfs_client.hdfs_download( os.path.join(self.flags.input_path, 'test.txt'), os.path.join(self.flags.datasets_path, 'test.txt')) self.data_utils.label_segment_file( os.path.join(self.flags.datasets_path, 'train.txt'), os.path.join(self.flags.datasets_path, 'label_train.txt')) self.data_utils.label_segment_file( os.path.join(self.flags.datasets_path, 'test.txt'), os.path.join(self.flags.datasets_path, 'label_test.txt')) self.data_utils.split_label_file( os.path.join(self.flags.datasets_path, 'label_train.txt'), os.path.join(self.flags.datasets_path, 'split_train.txt')) self.data_utils.split_label_file( os.path.join(self.flags.datasets_path, 'label_test.txt'), os.path.join(self.flags.datasets_path, 'split_test.txt')) words_vocab, labels_vocab = self.data_utils.create_vocabulary( os.path.join(self.flags.datasets_path, 'split_train.txt'), self.flags.vocab_path, self.flags.vocab_drop_limit) train_word_ids_list, train_label_ids_list = self.data_utils.file_to_word_ids( os.path.join(self.flags.datasets_path, 'split_train.txt'), words_vocab, labels_vocab) test_word_ids_list, test_label_ids_list = self.data_utils.file_to_word_ids( os.path.join(self.flags.datasets_path, 'split_test.txt'), words_vocab, labels_vocab) tensorflow_utils = TensorflowUtils() tensorflow_utils.create_record( train_word_ids_list, train_label_ids_list, os.path.join(self.flags.tfrecords_path, 'train.tfrecords')) tensorflow_utils.create_record( test_word_ids_list, test_label_ids_list, os.path.join(self.flags.tfrecords_path, 'test.tfrecords')) self.hdfs_client.hdfs_upload( self.flags.vocab_path, os.path.join( self.flags.output_path, os.path.basename(os.path.normpath(self.flags.vocab_path)))) train = Train() train.train() self.train_is_alive = False
class Predict(object): def __init__(self): self.vocab_path = FLAGS.vocab_path self.checkpoint_path = FLAGS.checkpoint_path self.freeze_graph_path = FLAGS.freeze_graph_path self.saved_model_path = FLAGS.saved_model_path self.use_crf = FLAGS.use_crf self.num_steps = FLAGS.num_steps self.default_label = FLAGS.default_label self.default_score = FLAGS.default_predict_score self.data_utils = DataUtils() self.tensorflow_utils = TensorflowUtils() self.num_classes = self.data_utils.get_vocabulary_size( os.path.join(FLAGS.vocab_path, 'labels_vocab.txt')) self.sequence_labeling_model = SequenceLabelingModel() self.init_predict_graph() def init_predict_graph(self): """ init predict model graph :return: """ # split 1-D String dense Tensor to words SparseTensor self.input_sentences = tf.placeholder(dtype=tf.string, shape=[None], name='input_sentences') sparse_words = tf.string_split(self.input_sentences, delimiter=' ') # slice SparseTensor valid_indices = tf.less(sparse_words.indices, tf.constant([self.num_steps], dtype=tf.int64)) valid_indices = tf.reshape( tf.split(valid_indices, [1, 1], axis=1)[1], [-1]) valid_sparse_words = tf.sparse_retain(sparse_words, valid_indices) excess_indices = tf.greater_equal( sparse_words.indices, tf.constant([self.num_steps], dtype=tf.int64)) excess_indices = tf.reshape( tf.split(excess_indices, [1, 1], axis=1)[1], [-1]) excess_sparse_words = tf.sparse_retain(sparse_words, excess_indices) # compute sentences lengths int_values = tf.ones(shape=tf.shape(valid_sparse_words.values), dtype=tf.int64) int_valid_sparse_words = tf.SparseTensor( indices=valid_sparse_words.indices, values=int_values, dense_shape=valid_sparse_words.dense_shape) input_sentences_lengths = tf.sparse_reduce_sum(int_valid_sparse_words, axis=1) # sparse to dense default_padding_word = self.data_utils._START_VOCAB[0] words = tf.sparse_to_dense( sparse_indices=valid_sparse_words.indices, output_shape=[valid_sparse_words.dense_shape[0], self.num_steps], sparse_values=valid_sparse_words.values, default_value=default_padding_word) # dict words to ids with open(os.path.join(self.vocab_path, 'words_vocab.txt'), encoding='utf-8', mode='rt') as data_file: words_table_list = [ line.strip() for line in data_file if line.strip() ] words_table_tensor = tf.constant(words_table_list, dtype=tf.string) words_table = lookup.index_table_from_tensor( mapping=words_table_tensor, default_value=self.data_utils._START_VOCAB_ID[3]) # words_table = lookup.index_table_from_file(os.path.join(vocab_path, 'words_vocab.txt'), default_value=3) words_ids = words_table.lookup(words) # blstm model predict with tf.variable_scope('model', reuse=None): logits = self.sequence_labeling_model.inference( words_ids, input_sentences_lengths, self.num_classes, is_training=False) if self.use_crf: logits = tf.reshape(logits, shape=[-1, self.num_steps, self.num_classes]) transition_params = tf.get_variable( "transitions", [self.num_classes, self.num_classes]) input_sentences_lengths = tf.to_int32(input_sentences_lengths) predict_labels_ids, sequence_scores = crf.crf_decode( logits, transition_params, input_sentences_lengths) predict_labels_ids = tf.to_int64(predict_labels_ids) sequence_scores = tf.reshape(sequence_scores, shape=[-1, 1]) normalized_sequence_scores = self.tensorflow_utils.score_normalize( sequence_scores) predict_scores = tf.matmul( normalized_sequence_scores, tf.ones(shape=[1, self.num_steps], dtype=tf.float32)) else: props = tf.nn.softmax(logits) max_prop_values, max_prop_indices = tf.nn.top_k(props, k=1) predict_labels_ids = tf.reshape(max_prop_indices, shape=[-1, self.num_steps]) predict_labels_ids = tf.to_int64(predict_labels_ids) predict_scores = tf.reshape(max_prop_values, shape=[-1, self.num_steps]) predict_scores = tf.as_string(predict_scores, precision=3) # dict ids to labels with open(os.path.join(self.vocab_path, 'labels_vocab.txt'), encoding='utf-8', mode='rt') as data_file: labels_table_list = [ line.strip() for line in data_file if line.strip() ] labels_table_tensor = tf.constant(labels_table_list, dtype=tf.string) labels_table = lookup.index_to_string_table_from_tensor( mapping=labels_table_tensor, default_value=self.default_label) # labels_table = lookup.index_to_string_table_from_file(os.path.join(vocab_path, 'labels_vocab.txt'), default_value='O') predict_labels = labels_table.lookup(predict_labels_ids) sparse_predict_labels = self.tensorflow_utils.sparse_concat( predict_labels, valid_sparse_words, excess_sparse_words, self.default_label) sparse_predict_scores = self.tensorflow_utils.sparse_concat( predict_scores, valid_sparse_words, excess_sparse_words, '0.0') self.format_predict_labels = self.tensorflow_utils.sparse_string_join( sparse_predict_labels, 'predict_labels') self.format_predict_scores = self.tensorflow_utils.sparse_string_join( sparse_predict_scores, 'predict_scores') saver = tf.train.Saver() tables_init_op = tf.tables_initializer() self.sess = tf.Session() self.sess.run(tables_init_op) ckpt = tf.train.get_checkpoint_state(self.checkpoint_path) if ckpt and ckpt.model_checkpoint_path: print('read model from {}'.format(ckpt.model_checkpoint_path)) saver.restore(self.sess, ckpt.model_checkpoint_path) else: print('No checkpoint file found at %s' % self.checkpoint_path) return def predict(self, words_list): """ Predict labels, the operation of transfer words to ids is processed by tensorflow tensor Input words list :param words_list: :return: """ split_words_list = [] map_split_indexes = [] for index in range(len(words_list)): temp_words_list = self.data_utils.split_long_sentence( words_list[index], self.num_steps) map_split_indexes.append( list( range(len(split_words_list), len(split_words_list) + len(temp_words_list)))) split_words_list.extend(temp_words_list) predict_labels, predict_scores = self.sess.run( [self.format_predict_labels, self.format_predict_scores], feed_dict={self.input_sentences: split_words_list}) predict_labels_str = [ predict_label.decode('utf-8') for predict_label in predict_labels ] predict_scores_str = [ predict_score.decode('utf-8') for predict_score in predict_scores ] merge_predict_labels_str = [] merge_predict_scores_str = [] for indexes in map_split_indexes: merge_predict_label_str = ' '.join( [predict_labels_str[index] for index in indexes]) merge_predict_labels_str.append(merge_predict_label_str) merge_predict_score_str = ' '.join( [predict_scores_str[index] for index in indexes]) merge_predict_scores_str.append(merge_predict_score_str) return merge_predict_labels_str, merge_predict_scores_str def file_predict(self, data_filename, predict_filename): """ Predict data_filename, save the predict result into predict_filename The label is split into single word, -B -M -E -S :param data_filename: :param predict_filename: :return: """ print('Predict file ' + data_filename) sentence_list = [] words_list = [] labels_list = [] predict_labels_list = [] with open(data_filename, encoding='utf-8', mode='rt') as data_file: for line in data_file: words, labels = self.data_utils.split(line) if words and labels: sentence_list.append(''.join(words)) words_list.append(' '.join(words)) labels_list.append(' '.join(labels)) predict_labels, _ = self.predict([' '.join(words)]) predict_labels_list.append(predict_labels[0]) word_predict_label_list = [] word_category_list = [] word_predict_category_list = [] for (words, labels, predict_labels) in zip(words_list, labels_list, predict_labels_list): word_list = words.split() label_list = labels.split() predict_label_list = predict_labels.split() word_predict_label = ' '.join([ word + '/' + predict_label for (word, predict_label) in zip(word_list, predict_label_list) ]) word_predict_label_list.append(word_predict_label) # merge label merge_word_list, merge_label_list = self.data_utils.merge_label( word_list, label_list) word_category = ' '.join([ word + '/' + label for (word, label) in zip(merge_word_list, merge_label_list) if label != self.default_label ]) word_category_list.append(word_category) # merge predict label merge_predict_word_list, merge_predict_label_list = self.data_utils.merge_label( word_list, predict_label_list) word_predict_category = ' '.join([ predict_word + '/' + predict_label for (predict_word, predict_label) in zip( merge_predict_word_list, merge_predict_label_list) if predict_label != 'O' ]) word_predict_category_list.append(word_predict_category) with open(predict_filename, encoding='utf-8', mode='wt') as predict_file: for (sentence, word_predict_label, word_category, word_predict_category) in \ zip(sentence_list, word_predict_label_list, word_category_list, word_predict_category_list): predict_file.write('Passage: ' + sentence + '\n') predict_file.write('SinglePredict: ' + word_predict_label + '\n') predict_file.write('Merge: ' + word_category + '\n') predict_file.write('MergePredict: ' + word_predict_category + '\n\n') def freeze_graph(self): """ Save graph into .pb file :return: """ graph = tf.graph_util.convert_variables_to_constants( self.sess, self.sess.graph_def, ['init_all_tables', 'predict_labels', 'predict_scores']) tf.train.write_graph(graph, self.freeze_graph_path, 'frozen_graph.pb', as_text=False) print('Successfully freeze model to %s' % self.freeze_graph_path) def saved_model_pb(self): """ Saved model into .ph and variables files, loading it by tensorflow serving, :return: """ saved_model_path = os.path.join(self.saved_model_path, '1') if os.path.exists(saved_model_path): shutil.rmtree(saved_model_path) builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path) input_tensor_info = tf.saved_model.utils.build_tensor_info( self.input_sentences) output_labels_tensor_info = tf.saved_model.utils.build_tensor_info( self.format_predict_labels) output_scores_tensor_info = tf.saved_model.utils.build_tensor_info( self.format_predict_scores) prediction_signature = tf.saved_model.signature_def_utils.build_signature_def( inputs={'input_sentences': input_tensor_info}, outputs={ 'predict_labels': output_labels_tensor_info, 'predict_scores': output_scores_tensor_info }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') builder.add_meta_graph_and_variables( self.sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={'predict_segment': prediction_signature}, legacy_init_op=legacy_init_op) builder.save() print('Successfully exported model to %s' % saved_model_path)
class SequenceLabelingModel(object): def __init__(self): self.raw_data_path = FLAGS.raw_data_path self.vocab_path = FLAGS.vocab_path self.use_stored_embedding = FLAGS.use_stored_embedding self.use_lstm = FLAGS.use_lstm self.use_dynamic_rnn = FLAGS.use_dynamic_rnn self.use_bidirectional_rnn = FLAGS.use_bidirectional_rnn self.batch_size = FLAGS.batch_size self.num_steps = FLAGS.num_steps self.num_layers = FLAGS.num_layers self.embedding_size = FLAGS.embedding_size # self.hidden_size = FLAGS.hidden_size self.hidden_size = FLAGS.embedding_size self.keep_prob = FLAGS.keep_prob self.tensorflow_utils = TensorflowUtils() def inference(self, inputs, inputs_sequence_length, num_classes, is_training): """ Bilstm + crf model :param inputs: :param inputs_sequence_length: :param num_classes: :param is_training: :return: """ with tf.device('/cpu:0'): if self.use_stored_embedding: embedding = self.tensorflow_utils.load_embedding( os.path.join(self.raw_data_path, 'embedding.txt'), os.path.join(self.vocab_path, 'words_vocab.txt')) else: embedding = tf.get_variable( 'embedding', [self.vocab_size, self.embedding_size], initializer=tf.random_uniform_initializer(), dtype=tf.float32) inputs_embedding = tf.nn.embedding_lookup(embedding, inputs) if is_training and self.keep_prob < 1: inputs_embedding = tf.nn.dropout(inputs_embedding, self.keep_prob) rnn_cell_collection = [] bi_flag = 2 if self.use_bidirectional_rnn else 1 for _ in range(bi_flag): initializer = tf.random_uniform_initializer(-0.1, 0.1) if self.use_lstm: cell = tf.nn.rnn_cell.LSTMCell(num_units=self.hidden_size, initializer=initializer, forget_bias=1.0) else: cell = tf.nn.rnn_cell.GRUCell(num_units=self.hidden_size) if is_training and self.keep_prob < 1.0: cell = tf.nn.rnn_cell.DropoutWrapper( cell, output_keep_prob=self.keep_prob) multi_cell = tf.nn.rnn_cell.MultiRNNCell( [cell for _ in range(self.num_layers)]) rnn_cell_collection.append(multi_cell) if self.use_dynamic_rnn: if self.use_bidirectional_rnn: outputs, _ = tf.nn.bidirectional_dynamic_rnn( rnn_cell_collection[0], rnn_cell_collection[1], inputs_embedding, dtype=tf.float32, sequence_length=inputs_sequence_length) outputs = tf.concat(outputs, axis=2) else: outputs, _ = tf.nn.dynamic_rnn( rnn_cell_collection[0], inputs_embedding, dtype=tf.float32, sequence_length=inputs_sequence_length) else: inputs_embedding = tf.unstack(inputs_embedding, axis=1) if self.use_bidirectional_rnn: outputs, _, _ = tf.nn.static_bidirectional_rnn( rnn_cell_collection[0], rnn_cell_collection[1], inputs_embedding, dtype=tf.float32, sequence_length=inputs_sequence_length) else: outputs, _ = tf.nn.static_rnn( rnn_cell_collection[0], inputs_embedding, dtype=tf.float32, sequence_length=inputs_sequence_length) outputs = tf.stack(outputs, axis=1) outputs = tf.reshape(outputs, shape=[-1, bi_flag * self.hidden_size]) weights = tf.get_variable('weights', [bi_flag * self.hidden_size, num_classes], dtype=tf.float32) biases = tf.get_variable('biases', [num_classes], dtype=tf.float32) logits = tf.matmul(outputs, weights) + biases return logits def loss(self, logits, labels): """ Loss of cross entropy between logits and labels :param logits: :param labels: :return: """ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) loss = tf.reduce_mean(cross_entropy, name='loss') return loss def accuracy(self, logits, labels): """ Computer the accuracy of rnn model :param logits: :param labels: :return: """ props = tf.nn.softmax(logits) prediction_labels = tf.argmax(props, 1) correct_prediction = tf.equal(prediction_labels, labels) accuracy_value = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) return accuracy_value def slice_seq(self, logits, labels, sequence_lengths): """ Slice sequence, used by accuracy method :param logits: :param labels: :param words_len: :return: """ labels = tf.reshape(labels, shape=[-1]) slice_indices = tf.constant([], dtype=tf.int64) for index in range(self.batch_size): sub_slice_indices = tf.range(sequence_lengths[index]) sub_slice_indices = tf.add( tf.constant(index * self.num_steps, dtype=tf.int64), sub_slice_indices) slice_indices = tf.concat([slice_indices, sub_slice_indices], axis=0) slice_logits = tf.gather(logits, slice_indices) slice_labels = tf.gather(labels, slice_indices) return slice_logits, slice_labels def crf_loss(self, logits, labels, sequence_lengths, num_classes): """ Loss of crf :param logits: :param labels: :param sequence_lengths: :param num_classes: :return: """ logits = tf.reshape( logits, shape=[self.batch_size, self.num_steps, num_classes]) labels = tf.cast(labels, tf.int32) log_likelihood, transition_params = crf.crf_log_likelihood( logits, labels, sequence_lengths) loss = tf.reduce_mean(-log_likelihood, name='loss') return loss, transition_params def crf_accuracy(self, logits, labels, sequence_length, transition_params, num_classes): """ Computer the accuracy of rnn + crf model :param logits: :param labels: :param sequence_length: :param transition_params: :param num_classes: :return: """ logits = tf.reshape( logits, shape=[self.batch_size, self.num_steps, num_classes]) sequence_length = tf.to_int32(sequence_length) predict_indices, _ = crf.crf_decode(logits, transition_params, sequence_length) predict_indices = tf.to_int64(predict_indices, name='predict_label_indices') correct_prediction = tf.equal(predict_indices, labels) accuracy_value = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) return accuracy_value
class Train(object): def __init__(self): self.tfrecords_path = FLAGS.tfrecords_path self.checkpoint_path = FLAGS.checkpoint_path self.tensorboard_path = FLAGS.tensorboard_path self.use_crf = FLAGS.use_crf self.learning_rate = FLAGS.learning_rate self.learning_rate_decay_factor = FLAGS.learning_rate_decay_factor self.decay_steps = FLAGS.decay_steps self.clip_norm = FLAGS.clip_norm self.max_training_step = FLAGS.max_training_step self.train_tfrecords_filename = os.path.join(self.tfrecords_path, 'train.tfrecords') self.test_tfrecords_filename = os.path.join(self.tfrecords_path, 'test.tfrecords') self.data_utils = DataUtils() self.num_classes = self.data_utils.get_vocabulary_size( os.path.join(FLAGS.vocab_path, 'labels_vocab.txt')) self.tensorflow_utils = TensorflowUtils() self.sequence_labeling_model = SequenceLabelingModel() def train(self): """ train bilstm + crf model :return: """ train_data = self.tensorflow_utils.read_and_decode( self.train_tfrecords_filename) train_batch_features, train_batch_labels, train_batch_features_lengths = train_data test_data = self.tensorflow_utils.read_and_decode( self.test_tfrecords_filename) test_batch_features, test_batch_labels, test_batch_features_lengths = test_data with tf.device('/cpu:0'): global_step = tf.Variable(0, name='global_step', trainable=False) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(self.learning_rate, global_step, self.decay_steps, self.learning_rate_decay_factor, staircase=True) optimizer = tf.train.RMSPropOptimizer(lr) with tf.variable_scope('model'): logits = self.sequence_labeling_model.inference( train_batch_features, train_batch_features_lengths, self.num_classes, is_training=True) train_batch_labels = tf.to_int64(train_batch_labels) if self.use_crf: loss, transition_params = self.sequence_labeling_model.crf_loss( logits, train_batch_labels, train_batch_features_lengths, self.num_classes) else: slice_logits, slice_train_batch_labels = self.sequence_labeling_model.slice_seq( logits, train_batch_labels, train_batch_features_lengths) loss = self.sequence_labeling_model.loss(slice_logits, slice_train_batch_labels) with tf.variable_scope('model', reuse=True): accuracy_logits = self.sequence_labeling_model.inference( test_batch_features, test_batch_features_lengths, self.num_classes, is_training=False) test_batch_labels = tf.to_int64(test_batch_labels) if self.use_crf: accuracy = self.sequence_labeling_model.crf_accuracy( accuracy_logits, test_batch_labels, test_batch_features_lengths, transition_params, self.num_classes) else: slice_accuracy_logits, slice_test_batch_labels = self.sequence_labeling_model.slice_seq( accuracy_logits, test_batch_labels, test_batch_features_lengths) accuracy = self.sequence_labeling_model.accuracy( slice_accuracy_logits, slice_test_batch_labels) # summary tf.summary.scalar('loss', loss) tf.summary.scalar('accuracy', accuracy) tf.summary.scalar('lr', lr) # compute and update gradient # train_op = optimizer.minimize(loss, global_step=global_step) # computer, clip and update gradient gradients, variables = zip(*optimizer.compute_gradients(loss)) clip_gradients, _ = tf.clip_by_global_norm(gradients, self.clip_norm) train_op = optimizer.apply_gradients(zip(clip_gradients, variables), global_step=global_step) init_op = tf.global_variables_initializer() saver = tf.train.Saver(max_to_keep=None) checkpoint_filename = os.path.join(self.checkpoint_path, 'model.ckpt') with tf.Session() as sess: summary_op = tf.summary.merge_all() writer = tf.summary.FileWriter(self.tensorboard_path, sess.graph) sess.run(init_op) ckpt = tf.train.get_checkpoint_state(self.checkpoint_path) if ckpt and ckpt.model_checkpoint_path: print('Continue training from the model {}'.format( ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) max_accuracy = 0.0 min_loss = 100000000.0 try: while not coord.should_stop(): _, loss_value, step = sess.run( [train_op, loss, global_step]) if step % 100 == 0: accuracy_value, summary_value, lr_value = sess.run( [accuracy, summary_op, lr]) china_tz = pytz.timezone('Asia/Shanghai') current_time = datetime.datetime.now(china_tz) print('[{}] Step: {}, loss: {}, accuracy: {}, lr: {}'. format(current_time, step, loss_value, accuracy_value, lr_value)) if accuracy_value > max_accuracy and loss_value < min_loss: writer.add_summary(summary_value, step) data_clean.clean_checkpoint(self.checkpoint_path) saver.save(sess, checkpoint_filename, global_step=step) print('save model to %s-%d' % (checkpoint_filename, step)) max_accuracy = accuracy_value min_loss = loss_value if step >= self.max_training_step: print('Done training after %d step' % step) break except tf.errors.OutOfRangeError: print('Done training after reading all data') finally: coord.request_stop() coord.join(threads)