Ejemplo n.º 1
0
def inference_bucket(config):
    """Inference for bucket.
    """

    # create model and compile
    model = Model(config)
    model.compile()
    sess = model.sess

    # restore model
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    sys.stderr.write('model restored' + '\n')
    '''
    print(tf.global_variables())
    print(tf.trainable_variables())
    '''
    num_buckets = 0
    total_duration_time = 0.0
    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input)
            if 'bert' in config.emb_class:
                # compute bert embedding at runtime
                bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                           feed_dict=feed_dict)
                # update feed_dict
                feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                    config, bert_embeddings,
                    inp.example['bert_wordidx2tokenidx'], -1)
            logits_indices, sentence_lengths = sess.run(
                [model.logits_indices, model.sentence_lengths],
                feed_dict=feed_dict)
            tags = config.logit_indices_to_tags(logits_indices[0],
                                                sentence_lengths[0])
            for i in range(len(bucket)):
                out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            out = 'duration_time : ' + str(duration_time) + ' sec'
            tf.logging.info(out)
            num_buckets += 1
            if num_buckets != 1:  # first one may takes longer time, so ignore in computing duration.
                total_duration_time += duration_time
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input)
        if 'bert' in config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict)
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, inp.example['bert_wordidx2tokenidx'],
                -1)
        logits_indices, sentence_lengths = sess.run(
            [model.logits_indices, model.sentence_lengths],
            feed_dict=feed_dict)
        tags = config.logit_indices_to_tags(logits_indices[0],
                                            sentence_lengths[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        out = 'duration_time : ' + str(duration_time) + ' sec'
        tf.logging.info(out)
        num_buckets += 1
        total_duration_time += duration_time

    out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n'
    out += 'average processing time / bucket : ' + str(
        total_duration_time / (num_buckets - 1)) + ' sec'
    tf.logging.info(out)

    sess.close()
Ejemplo n.º 2
0
def inference_line(config):
    """Inference for raw string.
    """
    def get_entity(doc, begin, end):
        for ent in doc.ents:
            # check included
            if ent.start_char <= begin and end <= ent.end_char:
                if ent.start_char == begin: return 'B-' + ent.label_
                else: return 'I-' + ent.label_
        return 'O'

    def build_bucket(nlp, line):
        bucket = []
        doc = nlp(line)
        for token in doc:
            begin = token.idx
            end = begin + len(token.text) - 1
            temp = []
            '''
            print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_,
                  token.shape_, token.is_alpha, token.is_stop, begin, end)
            '''
            temp.append(token.text)
            temp.append(token.tag_)
            temp.append('O')  # no chunking info
            entity = get_entity(doc, begin, end)
            temp.append(entity)  # entity by spacy
            temp = ' '.join(temp)
            bucket.append(temp)
        return bucket

    import spacy
    nlp = spacy.load('en')

    # create model and compile
    model = Model(config)
    model.compile()
    sess = model.sess

    # restore model
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    tf.logging.info('model restored' + '\n')

    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line: continue
        # create bucket
        try:
            bucket = build_bucket(nlp, line)
        except Exception as e:
            sys.stderr.write(str(e) + '\n')
            continue
        inp, feed_dict = feed.build_input_feed_dict(model, bucket)
        if 'bert' in config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict)
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, inp.example['bert_wordidx2tokenidx'],
                -1)
        logits_indices, sentence_lengths = sess.run(
            [model.logits_indices, model.sentence_lengths],
            feed_dict=feed_dict)
        tags = config.logit_indices_to_tags(logits_indices[0],
                                            sentence_lengths[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')

    sess.close()