def train_step(model, data, summary_op, summary_writer): """Train one epoch """ start_time = time.time() sess = model.sess runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True) prog = Progbar(target=data.num_batches) iterator = data.dataset.make_initializable_iterator() next_element = iterator.get_next() sess.run(iterator.initializer) for idx in range(data.num_batches): try: dataset = sess.run(next_element) except tf.errors.OutOfRangeError: break feed_dict = feed.build_feed_dict(model, dataset, data.max_sentence_length, True) if 'bert' in model.config.emb_class: # compute bert embedding at runtime bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict, options=runopts) if idx == 0: tf.logging.debug('# bert_token_ids') t = dataset['bert_token_ids'][:1] tf.logging.debug(' '.join([str(x) for x in np.shape(t)])) tf.logging.debug(' '.join([str(x) for x in t])) tf.logging.debug('# bert_token_masks') t = dataset['bert_token_masks'][:1] tf.logging.debug(' '.join([str(x) for x in np.shape(t)])) tf.logging.debug(' '.join([str(x) for x in t])) tf.logging.debug('# bert_embedding') t = bert_embeddings[0][:1] tf.logging.debug(' '.join([str(x) for x in np.shape(t)])) tf.logging.debug(' '.join([str(x) for x in t])) tf.logging.debug('# bert_wordidx2tokenidx') t = dataset['bert_wordidx2tokenidx'][:1] tf.logging.debug(' '.join([str(x) for x in np.shape(t)])) tf.logging.debug(' '.join([str(x) for x in t])) # update feed_dict feed.update_feed_dict(model, feed_dict, bert_embeddings, dataset['bert_wordidx2tokenidx']) step, summaries, _, loss, accuracy, f1, learning_rate = \ sess.run([model.global_step, summary_op, model.train_op, \ model.loss, model.accuracy, model.f1, \ model.learning_rate], feed_dict=feed_dict, options=runopts) else: step, summaries, _, loss, accuracy, f1, learning_rate = \ sess.run([model.global_step, summary_op, model.train_op, \ model.loss, model.accuracy, model.f1, \ model.learning_rate], feed_dict=feed_dict, options=runopts) summary_writer.add_summary(summaries, step) prog.update(idx + 1, [('step', step), ('train loss', loss), ('train accuracy', accuracy), ('train f1', f1), ('lr(invalid if use_bert_optimization)', learning_rate)]) duration_time = time.time() - start_time out = '\nduration_time : ' + str(duration_time) + ' sec for this epoch' tf.logging.debug(out)
def dev_step(model, data, summary_writer, epoch): """Evaluate dev data """ sess = model.sess runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True) sum_loss = 0.0 sum_accuracy = 0.0 sum_f1 = 0.0 sum_output_indices = None sum_logits_indices = None sum_sentence_lengths = None trans_params = None global_step = 0 prog = Progbar(target=data.num_batches) iterator = data.dataset.make_initializable_iterator() next_element = iterator.get_next() sess.run(iterator.initializer) # evaluate on dev data sliced by batch_size to prevent OOM(Out Of Memory). for idx in range(data.num_batches): try: dataset = sess.run(next_element) except tf.errors.OutOfRangeError: break feed_dict = feed.build_feed_dict(model, dataset, data.max_sentence_length, False) if 'bert' in model.config.emb_class: # compute bert embedding at runtime bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict, options=runopts) # update feed_dict feed.update_feed_dict(model, feed_dict, bert_embeddings, dataset['bert_wordidx2tokenidx']) global_step, logits_indices, sentence_lengths, loss, accuracy, f1 = \ sess.run([model.global_step, model.logits_indices, model.sentence_lengths, \ model.loss, model.accuracy, model.f1], feed_dict=feed_dict) prog.update(idx + 1, [('dev loss', loss), ('dev accuracy', accuracy), ('dev f1', f1)]) sum_loss += loss sum_accuracy += accuracy sum_f1 += f1 sum_output_indices = np_concat(sum_output_indices, np.argmax(dataset['tags'], 2)) sum_logits_indices = np_concat(sum_logits_indices, logits_indices) sum_sentence_lengths = np_concat(sum_sentence_lengths, sentence_lengths) idx += 1 avg_loss = sum_loss / data.num_batches avg_accuracy = sum_accuracy / data.num_batches avg_f1 = sum_f1 / data.num_batches tag_preds = model.config.logits_indices_to_tags_seq( sum_logits_indices, sum_sentence_lengths) tag_corrects = model.config.logits_indices_to_tags_seq( sum_output_indices, sum_sentence_lengths) tf.logging.debug('\n[epoch %s/%s] dev precision, recall, f1(token): ' % (epoch, model.config.epoch)) token_f1, l_token_prec, l_token_rec, l_token_f1 = TokenEval.compute_f1( model.config.class_size, sum_logits_indices, sum_output_indices, sum_sentence_lengths) tf.logging.debug('[' + ' '.join([str(x) for x in l_token_prec]) + ']') tf.logging.debug('[' + ' '.join([str(x) for x in l_token_rec]) + ']') tf.logging.debug('[' + ' '.join([str(x) for x in l_token_f1]) + ']') chunk_prec, chunk_rec, chunk_f1 = ChunkEval.compute_f1( tag_preds, tag_corrects) tf.logging.debug('dev precision(chunk), recall(chunk), f1(chunk): %s, %s, %s' % \ (chunk_prec, chunk_rec, chunk_f1) + \ '(invalid for bert due to X tag)') # create summaries manually. summary_value = [ tf.Summary.Value(tag='loss', simple_value=avg_loss), tf.Summary.Value(tag='accuracy', simple_value=avg_accuracy), tf.Summary.Value(tag='f1', simple_value=avg_f1), tf.Summary.Value(tag='token_f1', simple_value=token_f1), tf.Summary.Value(tag='chunk_f1', simple_value=chunk_f1) ] summaries = tf.Summary(value=summary_value) summary_writer.add_summary(summaries, global_step) return token_f1, chunk_f1, avg_f1
def inference_bucket(config): """Inference for bucket. """ # create model and compile model = Model(config) model.compile() sess = model.sess # restore model saver = tf.train.Saver() saver.restore(sess, config.restore) sys.stderr.write('model restored' + '\n') ''' print(tf.global_variables()) print(tf.trainable_variables()) ''' num_buckets = 0 total_duration_time = 0.0 bucket = [] while 1: try: line = sys.stdin.readline() except KeyboardInterrupt: break if not line: break line = line.strip() if not line and len(bucket) >= 1: start_time = time.time() inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input) if 'bert' in config.emb_class: # compute bert embedding at runtime bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict) # update feed_dict feed.update_feed_dict(model, feed_dict, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) logits_indices, sentence_lengths = sess.run( [model.logits_indices, model.sentence_lengths], feed_dict=feed_dict) tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) for i in range(len(bucket)): out = bucket[i] + ' ' + tags[i] sys.stdout.write(out + '\n') sys.stdout.write('\n') bucket = [] duration_time = time.time() - start_time out = 'duration_time : ' + str(duration_time) + ' sec' tf.logging.info(out) num_buckets += 1 if num_buckets != 1: # first one may takes longer time, so ignore in computing duration. total_duration_time += duration_time if line: bucket.append(line) if len(bucket) != 0: start_time = time.time() inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input) if 'bert' in config.emb_class: # compute bert embedding at runtime bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict) # update feed_dict feed.update_feed_dict(model, feed_dict, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) logits_indices, sentence_lengths = sess.run( [model.logits_indices, model.sentence_lengths], feed_dict=feed_dict) tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) for i in range(len(bucket)): out = bucket[i] + ' ' + tags[i] sys.stdout.write(out + '\n') sys.stdout.write('\n') duration_time = time.time() - start_time out = 'duration_time : ' + str(duration_time) + ' sec' tf.logging.info(out) num_buckets += 1 total_duration_time += duration_time out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n' out += 'average processing time / bucket : ' + str( total_duration_time / (num_buckets - 1)) + ' sec' tf.logging.info(out) sess.close()
def inference_line(config): """Inference for raw string. """ def get_entity(doc, begin, end): for ent in doc.ents: # check included if ent.start_char <= begin and end <= ent.end_char: if ent.start_char == begin: return 'B-' + ent.label_ else: return 'I-' + ent.label_ return 'O' def build_bucket(nlp, line): bucket = [] doc = nlp(line) for token in doc: begin = token.idx end = begin + len(token.text) - 1 temp = [] ''' print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_, token.shape_, token.is_alpha, token.is_stop, begin, end) ''' temp.append(token.text) temp.append(token.tag_) temp.append('O') # no chunking info entity = get_entity(doc, begin, end) temp.append(entity) # entity by spacy temp = ' '.join(temp) bucket.append(temp) return bucket import spacy nlp = spacy.load('en') # create model and compile model = Model(config) model.compile() sess = model.sess # restore model saver = tf.train.Saver() saver.restore(sess, config.restore) tf.logging.info('model restored' + '\n') while 1: try: line = sys.stdin.readline() except KeyboardInterrupt: break if not line: break line = line.strip() if not line: continue # create bucket try: bucket = build_bucket(nlp, line) except Exception as e: sys.stderr.write(str(e) + '\n') continue inp, feed_dict = feed.build_input_feed_dict(model, bucket) if 'bert' in config.emb_class: # compute bert embedding at runtime bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict) # update feed_dict feed.update_feed_dict(model, feed_dict, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) logits_indices, sentence_lengths = sess.run( [model.logits_indices, model.sentence_lengths], feed_dict=feed_dict) tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) for i in range(len(bucket)): out = bucket[i] + ' ' + tags[i] sys.stdout.write(out + '\n') sys.stdout.write('\n') sess.close()