コード例 #1
0
ファイル: inference.py プロジェクト: johndpope/etagger
def inference_bulk(config):
    """Inference for test file
    """

    # Build input data
    test_file = 'data/test.txt'
    test_data = Input(test_file, config)
    print('max_sentence_length = %d' % test_data.max_sentence_length)
    print('loading input data ... done')

    # Create model
    model = Model(config)

    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess, config.restore)
        print('model restored')
        feed_dict = {
            model.input_data_word_ids: test_data.sentence_word_ids,
            model.input_data_wordchr_ids: test_data.sentence_wordchr_ids,
            model.input_data_pos_ids: test_data.sentence_pos_ids,
            model.input_data_etc: test_data.sentence_etc,
            model.output_data: test_data.sentence_tag
        }
        logits, logits_indices, trans_params, output_data_indices, length, test_loss = \
                     sess.run([model.logits, model.logits_indices, model.trans_params, model.output_data_indices, model.length, model.loss], feed_dict=feed_dict)
        print('test precision, recall, f1(token): ')
        TokenEval.compute_f1(config.class_size, logits, test_data.sentence_tag,
                             length)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params, length)
            tag_preds = test_data.logits_indices_to_tags_seq(
                viterbi_sequences, length)
        else:
            tag_preds = test_data.logits_indices_to_tags_seq(
                logits_indices, length)
        tag_corrects = test_data.logits_indices_to_tags_seq(
            output_data_indices, length)
        test_prec, test_rec, test_f1 = ChunkEval.compute_f1(
            tag_preds, tag_corrects)
        print('test precision, recall, f1(chunk): ', test_prec, test_rec,
              test_f1)
コード例 #2
0
ファイル: inference.py プロジェクト: pvk444/etagger
def inference(config, frozen_pb_path):
    """Inference for bucket
    """

    # load graph
    graph = load_frozen_graph(frozen_pb_path)
    for op in graph.get_operations():
        sys.stderr.write(op.name + '\n')

    # create session with graph
    # if graph is optimized by tensorRT, then
    # from tensorflow.contrib import tensorrt as trt
    # gpu_ops = tf.GPUOptions(per_process_gpu_memory_fraction = 0.50)
    gpu_ops = tf.GPUOptions()
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_ops)
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False,
                                  gpu_options=gpu_ops,
                                  inter_op_parallelism_threads=1,
                                  intra_op_parallelism_threads=1)
    sess = tf.Session(graph=graph, config=session_conf)

    # mapping placeholders and tensors
    p_is_train = graph.get_tensor_by_name('prefix/is_train:0')
    p_sentence_length = graph.get_tensor_by_name('prefix/sentence_length:0')
    p_input_data_pos_ids = graph.get_tensor_by_name(
        'prefix/input_data_pos_ids:0')
    p_input_data_word_ids = graph.get_tensor_by_name(
        'prefix/input_data_word_ids:0')
    p_input_data_wordchr_ids = graph.get_tensor_by_name(
        'prefix/input_data_wordchr_ids:0')
    t_logits = graph.get_tensor_by_name('prefix/logits:0')
    t_trans_params = graph.get_tensor_by_name('prefix/loss/trans_params:0')
    t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0')

    num_buckets = 0
    total_duration_time = 0.0
    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            # Build input data
            inp = Input(bucket, config, build_output=False)
            feed_dict = {
                p_input_data_pos_ids: inp.sentence_pos_ids,
                p_is_train: False,
                p_sentence_length: inp.max_sentence_length
            }
            feed_dict[p_input_data_word_ids] = inp.sentence_word_ids
            feed_dict[p_input_data_wordchr_ids] = inp.sentence_wordchr_ids
            if config.emb_class == 'elmo':
                feed_dict[
                    p_elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
            if config.emb_class == 'bert':
                feed_dict[
                    p_bert_input_data_token_ids] = inp.sentence_bert_token_ids
                feed_dict[
                    p_bert_input_data_token_masks] = inp.sentence_bert_token_masks
                feed_dict[
                    p_bert_input_data_segment_ids] = inp.sentence_bert_segment_ids
            logits, trans_params, sentence_lengths = sess.run([t_logits, t_trans_params, t_sentence_lengths], \
                                                              feed_dict=feed_dict)
            if config.use_crf:
                viterbi_sequences = viterbi_decode(logits, trans_params,
                                                   sentence_lengths)
                tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                                 sentence_lengths[0])
            else:
                tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
            for i in range(len(bucket)):
                if config.emb_class == 'bert':
                    j = inp.sentence_bert_wordidx2tokenidx[0][i]
                    out = bucket[i] + ' ' + tags[j]
                else:
                    out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            out = 'duration_time : ' + str(duration_time) + ' sec'
            sys.stderr.write(out + '\n')
            num_buckets += 1
            total_duration_time += duration_time
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        # Build input data
        inp = Input(bucket, config, build_output=False)
        feed_dict = {
            p_input_data_pos_ids: inp.sentence_pos_ids,
            p_is_train: False,
            p_sentence_length: inp.max_sentence_length
        }
        feed_dict[p_input_data_word_ids] = inp.sentence_word_ids
        feed_dict[p_input_data_wordchr_ids] = inp.sentence_wordchr_ids
        if config.emb_class == 'elmo':
            feed_dict[
                p_elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
        if config.emb_class == 'bert':
            feed_dict[
                p_bert_input_data_token_ids] = inp.sentence_bert_token_ids
            feed_dict[
                p_bert_input_data_token_masks] = inp.sentence_bert_token_masks
            feed_dict[
                p_bert_input_data_segment_ids] = inp.sentence_bert_segment_ids
        logits, trans_params, sentence_lengths = sess.run([t_logits, t_trans_params, t_sentence_lengths], \
                                                          feed_dict=feed_dict)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params,
                                               sentence_lengths)
            tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                             sentence_lengths[0])
        else:
            tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
        for i in range(len(bucket)):
            if config.emb_class == 'bert':
                j = inp.sentence_bert_wordidx2tokenidx[0][i]
                out = bucket[i] + ' ' + tags[j]
            else:
                out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        out = 'duration_time : ' + str(duration_time) + ' sec'
        sys.stderr.write(out + '\n')
        num_buckets += 1
        total_duration_time += duration_time

    out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n'
    out += 'average processing time / bucket : ' + str(
        total_duration_time / num_buckets) + ' sec'
    sys.stderr.write(out + '\n')

    sess.close()
コード例 #3
0
ファイル: train.py プロジェクト: pvk444/etagger
def dev_step(sess, model, config, data, summary_writer, epoch):
    idx = 0
    nbatches = (len(data.sentence_tags) + config.dev_batch_size -
                1) // config.dev_batch_size
    prog = Progbar(target=nbatches)
    sum_loss = 0.0
    sum_accuracy = 0.0
    sum_logits = None
    sum_sentence_lengths = None
    trans_params = None
    global_step = 0
    # evaluate on dev data sliced by dev_batch_size to prevent OOM
    for ptr in range(0, len(data.sentence_tags), config.dev_batch_size):
        config.is_training = False
        feed_dict = {
            model.input_data_pos_ids:
            data.sentence_pos_ids[ptr:ptr + config.dev_batch_size],
            model.output_data:
            data.sentence_tags[ptr:ptr + config.dev_batch_size],
            model.is_train:
            config.is_training,
            model.sentence_length:
            data.max_sentence_length
        }
        feed_dict[model.input_data_word_ids] = data.sentence_word_ids[
            ptr:ptr + config.dev_batch_size]
        feed_dict[model.input_data_wordchr_ids] = data.sentence_wordchr_ids[
            ptr:ptr + config.dev_batch_size]
        if config.emb_class == 'elmo':
            feed_dict[
                model.
                elmo_input_data_wordchr_ids] = data.sentence_elmo_wordchr_ids[
                    ptr:ptr + config.dev_batch_size]
        if config.emb_class == 'bert':
            feed_dict[
                model.
                bert_input_data_token_ids] = data.sentence_bert_token_ids[
                    ptr:ptr + config.batch_size]
            feed_dict[
                model.
                bert_input_data_token_masks] = data.sentence_bert_token_masks[
                    ptr:ptr + config.batch_size]
            feed_dict[
                model.
                bert_input_data_segment_ids] = data.sentence_bert_segment_ids[
                    ptr:ptr + config.batch_size]
        global_step, logits, trans_params, sentence_lengths, loss, accuracy = \
                 sess.run([model.global_step, model.logits, model.trans_params, model.sentence_lengths, \
                           model.loss, model.accuracy], feed_dict=feed_dict)
        prog.update(idx + 1, [('dev loss', loss), ('dev accuracy', accuracy)])
        sum_loss += loss
        sum_accuracy += accuracy
        sum_logits = np_concat(sum_logits, logits)
        sum_sentence_lengths = np_concat(sum_sentence_lengths,
                                         sentence_lengths)
        idx += 1
    sum_loss = sum_loss / nbatches
    sum_accuracy = sum_accuracy / nbatches
    print('[epoch %s/%s] dev precision, recall, f1(token): ' %
          (epoch, config.epoch))
    token_f1 = TokenEval.compute_f1(config.class_size, sum_logits,
                                    data.sentence_tags, sum_sentence_lengths)

    if config.use_crf:
        viterbi_sequences = viterbi_decode(sum_logits, trans_params,
                                           sum_sentence_lengths)
        tag_preds = data.logits_indices_to_tags_seq(viterbi_sequences,
                                                    sum_sentence_lengths)
    else:
        sum_logits_indices = np.argmax(sum_logits, 2)
        tag_preds = data.logits_indices_to_tags_seq(sum_logits_indices,
                                                    sum_sentence_lengths)
    sum_output_data_indices = np.argmax(data.sentence_tags, 2)
    tag_corrects = data.logits_indices_to_tags_seq(sum_output_data_indices,
                                                   sum_sentence_lengths)
    prec, rec, f1 = ChunkEval.compute_f1(tag_preds, tag_corrects)
    print('dev precision, recall, f1(chunk): ', prec, rec, f1,
          ', this is no meaningful for emb_class=bert')
    chunk_f1 = f1
    m = chunk_f1
    # create summaries manually
    summary_value = [
        tf.Summary.Value(tag='loss_1', simple_value=sum_loss),
        tf.Summary.Value(tag='accuracy_1', simple_value=sum_accuracy),
        tf.Summary.Value(tag='token_f1', simple_value=token_f1),
        tf.Summary.Value(tag='chunk_f1', simple_value=chunk_f1)
    ]
    summaries = tf.Summary(value=summary_value)
    summary_writer.add_summary(summaries, global_step)

    m = token_f1
    return m
コード例 #4
0
def inference_line(config):
    """Inference for raw string
    """
    def get_entity(doc, begin, end):
        for ent in doc.ents:
            # check included
            if ent.start_char <= begin and end <= ent.end_char:
                if ent.start_char == begin: return 'B-' + ent.label_
                else: return 'I-' + ent.label_
        return 'O'

    def build_bucket(nlp, line):
        bucket = []
        doc = nlp(line)
        for token in doc:
            begin = token.idx
            end = begin + len(token.text) - 1
            temp = []
            '''
            print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_,
                  token.shape_, token.is_alpha, token.is_stop, begin, end)
            '''
            temp.append(token.text)
            temp.append(token.tag_)
            temp.append('O')  # no chunking info
            entity = get_entity(doc, begin, end)
            temp.append(entity)  # entity by spacy
            temp = ' '.join(temp)
            bucket.append(temp)
        return bucket

    import spacy
    nlp = spacy.load('en')

    # Create model
    model = Model(config)

    # Restore model
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    sess = tf.Session(config=session_conf)
    feed_dict = {}
    feed_dict = {model.wrd_embeddings_init: config.embvec.wrd_embeddings}
    sess.run(tf.global_variables_initializer(), feed_dict=feed_dict)
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    sys.stderr.write('model restored' + '\n')

    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line: continue
        # Create bucket
        try:
            bucket = build_bucket(nlp, line)
        except Exception as e:
            sys.stderr.write(str(e) + '\n')
            continue
        # Build input data
        inp = Input(bucket, config, build_output=False)
        feed_dict = {
            model.input_data_pos_ids: inp.sentence_pos_ids,
            model.is_train: False,
            model.sentence_length: inp.max_sentence_length
        }
        feed_dict[model.input_data_word_ids] = inp.sentence_word_ids
        feed_dict[model.input_data_wordchr_ids] = inp.sentence_wordchr_ids
        if config.emb_class == 'elmo':
            feed_dict[
                model.
                elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
        if config.emb_class == 'bert':
            feed_dict[
                model.bert_input_data_token_ids] = inp.sentence_bert_token_ids
            feed_dict[
                model.
                bert_input_data_token_masks] = inp.sentence_bert_token_masks
            feed_dict[
                model.
                bert_input_data_segment_ids] = inp.sentence_bert_segment_ids
        logits, trans_params, sentence_lengths = sess.run([model.logits, model.trans_params, \
                                                           model.sentence_lengths], \
                                                          feed_dict=feed_dict)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params,
                                               sentence_lengths)
            tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                             sentence_lengths[0])
        else:
            tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
        for i in range(len(bucket)):
            if config.emb_class == 'bert':
                j = inp.sentence_bert_wordidx2tokenidx[0][i]
                out = bucket[i] + ' ' + tags[j]
            else:
                out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')

    sess.close()
コード例 #5
0
def inference_bucket(config):
    """Inference for bucket
    """

    # Create model
    model = Model(config)

    # Restore model
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False,
                                  inter_op_parallelism_threads=1,
                                  intra_op_parallelism_threads=1)
    '''
    sess = tf.Session(config=session_conf)
    feed_dict = {model.wrd_embeddings_init: config.embvec.wrd_embeddings}
    sess.run(tf.global_variables_initializer(), feed_dict=feed_dict)
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    sys.stderr.write('model restored' + '\n')
    '''
    print(tf.global_variables())
    print(tf.trainable_variables())
    '''
    num_buckets = 0
    total_duration_time = 0.0
    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            # Build input data
            inp = Input(bucket, config, build_output=False)
            feed_dict = {
                model.input_data_pos_ids: inp.sentence_pos_ids,
                model.is_train: False,
                model.sentence_length: inp.max_sentence_length
            }
            feed_dict[model.input_data_word_ids] = inp.sentence_word_ids
            feed_dict[model.input_data_wordchr_ids] = inp.sentence_wordchr_ids
            if config.emb_class == 'elmo':
                feed_dict[
                    model.
                    elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
            if config.emb_class == 'bert':
                feed_dict[
                    model.
                    bert_input_data_token_ids] = inp.sentence_bert_token_ids
                feed_dict[
                    model.
                    bert_input_data_token_masks] = inp.sentence_bert_token_masks
                feed_dict[
                    model.
                    bert_input_data_segment_ids] = inp.sentence_bert_segment_ids
            logits, trans_params, sentence_lengths = sess.run([model.logits, model.trans_params, \
                                                               model.sentence_lengths], \
                                                              feed_dict=feed_dict)
            if config.use_crf:
                viterbi_sequences = viterbi_decode(logits, trans_params,
                                                   sentence_lengths)
                tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                                 sentence_lengths[0])
            else:
                tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
            for i in range(len(bucket)):
                if config.emb_class == 'bert':
                    j = inp.sentence_bert_wordidx2tokenidx[0][i]
                    out = bucket[i] + ' ' + tags[j]
                else:
                    out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            out = 'duration_time : ' + str(duration_time) + ' sec'
            sys.stderr.write(out + '\n')
            num_buckets += 1
            total_duration_time += duration_time
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        # Build input data
        inp = Input(bucket, config)
        feed_dict = {
            model.input_data_pos_ids: inp.sentence_pos_ids,
            model.is_train: False,
            model.sentence_length: inp.max_sentence_length
        }
        feed_dict[model.input_data_word_ids] = inp.sentence_word_ids
        feed_dict[model.input_data_wordchr_ids] = inp.sentence_wordchr_ids
        if config.emb_class == 'elmo':
            feed_dict[
                model.
                elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
        if config.emb_class == 'bert':
            feed_dict[
                model.bert_input_data_token_ids] = inp.sentence_bert_token_ids
            feed_dict[
                model.
                bert_input_data_token_masks] = inp.sentence_bert_token_masks
            feed_dict[
                model.
                bert_input_data_segment_ids] = inp.sentence_bert_segment_ids
        logits, trans_params, sentence_lengths = sess.run([model.logits, model.trans_params, \
                                                           model.sentence_lengths], \
                                                          feed_dict=feed_dict)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params,
                                               sentence_lengths)
            tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                             sentence_lengths[0])
        else:
            tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
        for i in range(len(bucket)):
            if config.emb_class == 'bert':
                j = inp.sentence_bert_wordidx2tokenidx[0][i]
                out = bucket[i] + ' ' + tags[j]
            else:
                out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        out = 'duration_time : ' + str(duration_time) + ' sec'
        sys.stderr.write(out + '\n')
        num_buckets += 1
        total_duration_time += duration_time

    out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n'
    out += 'average processing time / bucket : ' + str(
        total_duration_time / num_buckets) + ' sec'
    sys.stderr.write(out + '\n')

    sess.close()
コード例 #6
0
ファイル: train.py プロジェクト: johndpope/etagger
def do_train(model, config, train_data, dev_data, test_data):
    learning_rate_init=0.001  # initial
    learning_rate_final=0.0001 # final
    learning_rate=learning_rate_init
    intermid_epoch = 20       # after this epoch, change learning rate
    maximum = 0
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        if config.restore is not None:
            saver.restore(sess, config.restore)
            print('model restored')
        # summary setting
        loss_summary = tf.summary.scalar('loss', model.loss)
        acc_summary = tf.summary.scalar('accuracy', model.accuracy)
        train_summary_op = tf.summary.merge([loss_summary, acc_summary])
        train_summary_dir = os.path.join(config.summary_dir, 'summaries', 'train')
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(config.summary_dir, 'summaries', 'dev')
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
        # training steps
        for e in range(config.epoch):
            # run epoch
            idx = 0
            nbatches = (len(train_data.sentence_word_ids) + config.batch_size - 1) // config.batch_size
            prog = Progbar(target=nbatches)
            for ptr in range(0, len(train_data.sentence_word_ids), config.batch_size):
                feed_dict={model.input_data_word_ids: train_data.sentence_word_ids[ptr:ptr + config.batch_size],
                           model.input_data_wordchr_ids: train_data.sentence_wordchr_ids[ptr:ptr + config.batch_size],
                           model.input_data_pos_ids: train_data.sentence_pos_ids[ptr:ptr + config.batch_size],
                           model.input_data_etc: train_data.sentence_etc[ptr:ptr + config.batch_size],
                           model.output_data: train_data.sentence_tag[ptr:ptr + config.batch_size],
                           model.learning_rate:learning_rate}
                step, train_summaries, _, train_loss, train_accuracy = \
                           sess.run([model.global_step, train_summary_op, model.train_op, model.loss, model.accuracy], feed_dict=feed_dict)
                prog.update(idx + 1, [('train loss', train_loss), ('train accuracy', train_accuracy)])
                train_summary_writer.add_summary(train_summaries, step)
                idx += 1
            # evaluate on dev data
            feed_dict={model.input_data_word_ids: dev_data.sentence_word_ids,
                       model.input_data_wordchr_ids: dev_data.sentence_wordchr_ids,
                       model.input_data_pos_ids: dev_data.sentence_pos_ids,
                       model.input_data_etc: dev_data.sentence_etc,
                       model.output_data: dev_data.sentence_tag}
            step, dev_summaries, logits, logits_indices, trans_params, output_data_indices, length, dev_loss, dev_accuracy = \
                       sess.run([model.global_step, dev_summary_op, model.logits, model.logits_indices, model.trans_params, model.output_data_indices, model.length, model.loss, model.accuracy], feed_dict=feed_dict)
            print('epoch: %d / %d, step: %d, dev loss: %s, dev accuracy: %s' % (e, config.epoch, step, dev_loss, dev_accuracy))
            dev_summary_writer.add_summary(dev_summaries, step)
            print('dev precision, recall, f1(token): ')
            token_f1 = TokenEval.compute_f1(config.class_size, logits, dev_data.sentence_tag, length)
            if config.use_crf:
                viterbi_sequences = viterbi_decode(logits, trans_params, length)
                tag_preds = dev_data.logits_indices_to_tags_seq(viterbi_sequences, length)
            else:
                tag_preds = dev_data.logits_indices_to_tags_seq(logits_indices, length)
            tag_corrects = dev_data.logits_indices_to_tags_seq(output_data_indices, length)
            dev_prec, dev_rec, dev_f1 = ChunkEval.compute_f1(tag_preds, tag_corrects)
            print('dev precision, recall, f1(chunk): ', dev_prec, dev_rec, dev_f1)
            chunk_f1 = dev_f1
            # save best model
            '''
            m = chunk_f1 # slightly lower than token-based f1 for test
            '''
            m = token_f1
            if m > maximum:
                print('new best f1 score!')
                maximum = m
                save_path = saver.save(sess, config.checkpoint_dir + '/' + 'model_max.ckpt')
                print('max model saved in file: %s' % save_path)
                feed_dict={model.input_data_word_ids: test_data.sentence_word_ids,
                           model.input_data_wordchr_ids: test_data.sentence_wordchr_ids,
                           model.input_data_pos_ids: test_data.sentence_pos_ids,
                           model.input_data_etc: test_data.sentence_etc,
                           model.output_data: test_data.sentence_tag}
                step, logits, logits_indices, trans_params, output_data_indices, length, test_loss, test_accuracy = \
                           sess.run([model.global_step, model.logits, model.logits_indices, model.trans_params, model.output_data_indices, model.length, model.loss, model.accuracy], feed_dict=feed_dict)
                print('epoch: %d / %d, step: %d, test loss: %s, test accuracy: %s' % (e, config.epoch, step, test_loss, test_accuracy))
                print('test precision, recall, f1(token): ')
                TokenEval.compute_f1(config.class_size, logits, test_data.sentence_tag, length)
                if config.use_crf:
                    viterbi_sequences = viterbi_decode(logits, trans_params, length)
                    tag_preds = test_data.logits_indices_to_tags_seq(viterbi_sequences, length)
                else:
                    tag_preds = test_data.logits_indices_to_tags_seq(logits_indices, length)
                tag_corrects = test_data.logits_indices_to_tags_seq(output_data_indices, length)
                test_prec, test_rec, test_f1 = ChunkEval.compute_f1(tag_preds, tag_corrects)
                print('test precision, recall, f1(chunk): ', test_prec, test_rec, test_f1)
            # learning rate change
            if e > intermid_epoch: learning_rate=learning_rate_final
コード例 #7
0
ファイル: inference.py プロジェクト: johndpope/etagger
def inference_bucket(config):
    """Inference for bucket
    """

    # Create model
    model = Model(config)

    # Restore model
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    sess = tf.Session(config=session_conf)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    sys.stderr.write('model restored' + '\n')

    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            # Build input data
            inp = Input(bucket, config)
            feed_dict = {
                model.input_data_word_ids: inp.sentence_word_ids,
                model.input_data_wordchr_ids: inp.sentence_wordchr_ids,
                model.input_data_pos_ids: inp.sentence_pos_ids,
                model.input_data_etc: inp.sentence_etc,
                model.output_data: inp.sentence_tag
            }
            logits, trans_params, length, loss = \
                         sess.run([model.logits, model.trans_params, model.length, model.loss], feed_dict=feed_dict)
            if config.use_crf:
                viterbi_sequences = viterbi_decode(logits, trans_params,
                                                   length)
                tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                                 length[0])
            else:
                tags = inp.logit_to_tags(logits[0], length[0])
            for i in range(len(bucket)):
                out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            '''
            out = 'duration_time : ' + str(duration_time) + ' sec'
            sys.stderr.write(out + '\n')
            '''
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        # Build input data
        inp = Input(bucket, config)
        feed_dict = {
            model.input_data_word_ids: inp.sentence_word_ids,
            model.input_data_wordchr_ids: inp.sentence_wordchr_ids,
            model.input_data_pos_ids: inp.sentence_pos_ids,
            model.input_data_etc: inp.sentence_etc,
            model.output_data: inp.sentence_tag
        }
        logits, trans_params, length, loss = \
                     sess.run([model.logits, model.trans_params, model.length, model.loss], feed_dict=feed_dict)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params, length)
            tags = inp.logit_indices_to_tags(viterbi_sequences[0], length[0])
        else:
            tags = inp.logit_to_tags(logits[0], length[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        '''
        out = 'duration_time : ' + str(duration_time) + ' sec'
        sys.stderr.write(out + '\n')
        '''

    sess.close()
コード例 #8
0
ファイル: inference.py プロジェクト: johndpope/etagger
def inference_line(config):
    """Inference for raw string
    """
    def get_entity(doc, begin, end):
        for ent in doc.ents:
            # check included
            if ent.start_char <= begin and end <= ent.end_char:
                if ent.start_char == begin: return 'B-' + ent.label_
                else: return 'I-' + ent.label_
        return 'O'

    def build_bucket(nlp, line):
        bucket = []
        uline = line.decode('utf-8', 'ignore')  # unicode
        doc = nlp(uline)
        for token in doc:
            begin = token.idx
            end = begin + len(token.text) - 1
            temp = []
            '''
            print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_,
                  token.shape_, token.is_alpha, token.is_stop, begin, end)
            '''
            temp.append(token.text)
            temp.append(token.tag_)
            temp.append('O')  # no chunking info
            entity = get_entity(doc, begin, end)
            temp.append(entity)  # entity by spacy
            utemp = ' '.join(temp)
            bucket.append(utemp.encode('utf-8'))
        return bucket

    import spacy
    nlp = spacy.load('en')

    # Create model
    model = Model(config)

    # Restore model
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    sess = tf.Session(config=session_conf)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    sys.stderr.write('model restored' + '\n')

    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line: continue
        # Create bucket
        try:
            bucket = build_bucket(nlp, line)
        except Exception as e:
            sys.stderr.write(str(e) + '\n')
            continue
        # Build input data
        inp = Input(bucket, config)
        feed_dict = {
            model.input_data_word_ids: inp.sentence_word_ids,
            model.input_data_wordchr_ids: inp.sentence_wordchr_ids,
            model.input_data_pos_ids: inp.sentence_pos_ids,
            model.input_data_etc: inp.sentence_etc,
            model.output_data: inp.sentence_tag
        }
        logits, trans_params, length, loss = \
                     sess.run([model.logits, model.trans_params, model.length, model.loss], feed_dict=feed_dict)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params, length)
            tags = inp.logit_indices_to_tags(viterbi_sequences[0], length[0])
        else:
            tags = inp.logit_to_tags(logits[0], length[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')

    sess.close()
コード例 #9
0
ファイル: inference_trt.py プロジェクト: sayduke/etagger
def inference(config, frozen_pb_path):
    """Inference for bucket
    """

    # load graph_def
    graph_def = load_frozen_graph_def(frozen_pb_path)

    # get optimized graph_def
    trt_graph_def = trt.create_inference_graph(
        input_graph_def=graph_def,
        outputs=['logits', 'loss/trans_params', 'sentence_lengths'],
        max_batch_size=128,
        max_workspace_size_bytes=1 << 30,
        precision_mode='FP16',  # TRT Engine precision "FP32","FP16" or "INT8"
        minimum_segment_size=3  # minimum number of nodes in an engine
    )

    # reset graph
    tf.reset_default_graph()

    # load optimized graph_def to default graph
    graph = load_graph(trt_graph_def, prefix='prefix')
    for op in graph.get_operations():
        sys.stderr.write(op.name + '\n')

    # create session with optimized graph
    gpu_ops = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_ops)
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False,
                                  gpu_options=gpu_ops,
                                  inter_op_parallelism_threads=1,
                                  intra_op_parallelism_threads=1)
    sess = tf.Session(graph=graph, config=session_conf)

    # mapping placeholders and tensors
    p_is_train = graph.get_tensor_by_name('prefix/is_train:0')
    p_sentence_length = graph.get_tensor_by_name('prefix/sentence_length:0')
    p_input_data_pos_ids = graph.get_tensor_by_name(
        'prefix/input_data_pos_ids:0')
    p_input_data_word_ids = graph.get_tensor_by_name(
        'prefix/input_data_word_ids:0')
    p_input_data_wordchr_ids = graph.get_tensor_by_name(
        'prefix/input_data_wordchr_ids:0')
    p_input_data_etcs = graph.get_tensor_by_name('prefix/input_data_etcs:0')
    t_logits = graph.get_tensor_by_name('prefix/logits:0')
    t_trans_params = graph.get_tensor_by_name('prefix/loss/trans_params:0')
    t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0')

    num_buckets = 0
    total_duration_time = 0.0
    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            # Build input data
            inp = Input(bucket, config, build_output=False)
            feed_dict = {
                p_input_data_pos_ids: inp.sentence_pos_ids,
                p_input_data_etcs: inp.sentence_etcs,
                p_is_train: False,
                p_sentence_length: inp.max_sentence_length
            }
            if config.use_elmo:
                feed_dict[
                    p_elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
            else:
                feed_dict[p_input_data_word_ids] = inp.sentence_word_ids
                feed_dict[p_input_data_wordchr_ids] = inp.sentence_wordchr_ids
            logits, trans_params, sentence_lengths = sess.run([t_logits, t_trans_params, t_sentence_lengths], \
                                                              feed_dict=feed_dict)
            if config.use_crf:
                viterbi_sequences = viterbi_decode(logits, trans_params,
                                                   sentence_lengths)
                tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                                 sentence_lengths[0])
            else:
                tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
            for i in range(len(bucket)):
                out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            out = 'duration_time : ' + str(duration_time) + ' sec'
            sys.stderr.write(out + '\n')
            num_buckets += 1
            total_duration_time += duration_time
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        # Build input data
        inp = Input(bucket, config, build_output=False)
        feed_dict = {
            p_input_data_pos_ids: inp.sentence_pos_ids,
            p_input_data_etcs: inp.sentence_etcs,
            p_is_train: False,
            p_sentence_length: inp.max_sentence_length
        }
        if config.use_elmo:
            feed_dict[
                p_elmo_input_data_wordchr_ids] = inp.sentence_elmo_wordchr_ids
        else:
            feed_dict[p_input_data_word_ids] = inp.sentence_word_ids
            feed_dict[p_input_data_wordchr_ids] = inp.sentence_wordchr_ids
        logits, trans_params, sentence_lengths = sess.run([t_logits, t_trans_params, t_sentence_lengths], \
                                                          feed_dict=feed_dict)
        if config.use_crf:
            viterbi_sequences = viterbi_decode(logits, trans_params,
                                               sentence_lengths)
            tags = inp.logit_indices_to_tags(viterbi_sequences[0],
                                             sentence_lengths[0])
        else:
            tags = inp.logit_to_tags(logits[0], sentence_lengths[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        out = 'duration_time : ' + str(duration_time) + ' sec'
        sys.stderr.write(out + '\n')
        num_buckets += 1
        total_duration_time += duration_time

    out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n'
    out += 'average processing time / bucket : ' + str(
        total_duration_time / num_buckets) + ' sec'
    sys.stderr.write(out + '\n')

    sess.close()
コード例 #10
0
def inference_bulk(config):
    """Inference for test file
    """

    # Build input data
    test_file = 'data/test.txt'
    test_data = Input(test_file, config, build_output=True)
    print('loading input data ... done')

    # Create model
    model = Model(config)

    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    sess = tf.Session(config=session_conf)
    # Restore model
    feed_dict = {}
    if not config.use_elmo:
        feed_dict = {model.wrd_embeddings_init: config.embvec.wrd_embeddings}
    sess.run(tf.global_variables_initializer(), feed_dict=feed_dict)
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    print('model restored')
    feed_dict = {
        model.input_data_pos_ids: test_data.sentence_pos_ids,
        model.input_data_etcs: test_data.sentence_etcs,
        model.output_data: test_data.sentence_tags,
        model.is_train: False,
        model.sentence_length: test_data.max_sentence_length
    }
    if config.use_elmo:
        feed_dict[
            model.
            elmo_input_data_wordchr_ids] = test_data.sentence_elmo_wordchr_ids
    else:
        feed_dict[model.input_data_word_ids] = test_data.sentence_word_ids
        feed_dict[
            model.input_data_wordchr_ids] = test_data.sentence_wordchr_ids
    logits, trans_params, sentence_lengths = \
                 sess.run([model.logits, model.trans_params, model.sentence_lengths], \
                           feed_dict=feed_dict)
    print('test precision, recall, f1(token): ')
    TokenEval.compute_f1(config.class_size, logits, test_data.sentence_tags,
                         sentence_lengths)
    if config.use_crf:
        viterbi_sequences = viterbi_decode(logits, trans_params,
                                           sentence_lengths)
        tag_preds = test_data.logits_indices_to_tags_seq(
            viterbi_sequences, sentence_lengths)
    else:
        logits_indices = np.argmax(logits, 2)
        tag_preds = test_data.logits_indices_to_tags_seq(
            logits_indices, sentence_lengths)
    output_data_indices = np.argmax(test_data.sentence_tags, 2)
    tag_corrects = test_data.logits_indices_to_tags_seq(
        output_data_indices, sentence_lengths)
    test_prec, test_rec, test_f1 = ChunkEval.compute_f1(
        tag_preds, tag_corrects)
    print('test precision, recall, f1(chunk): ', test_prec, test_rec, test_f1)

    sess.close()