Beispiel #1
0
def train_step(model, data, summary_op, summary_writer):
    """Train one epoch
    """
    start_time = time.time()
    sess = model.sess
    runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    prog = Progbar(target=data.num_batches)
    iterator = data.dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    sess.run(iterator.initializer)
    for idx in range(data.num_batches):
        try:
            dataset = sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break
        feed_dict = feed.build_feed_dict(model, dataset,
                                         data.max_sentence_length, True)
        if 'bert' in model.config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict,
                                       options=runopts)
            if idx == 0:
                tf.logging.debug('# bert_token_ids')
                t = dataset['bert_token_ids'][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
                tf.logging.debug('# bert_token_masks')
                t = dataset['bert_token_masks'][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
                tf.logging.debug('# bert_wordidx2tokenidx')
                t = dataset['bert_wordidx2tokenidx'][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, dataset['bert_wordidx2tokenidx'], idx)
            step, summaries, _, loss, accuracy, f1, learning_rate = \
                sess.run([model.global_step, summary_op, model.train_op, \
                          model.loss, model.accuracy, model.f1, \
                          model.learning_rate], feed_dict=feed_dict, options=runopts)
        else:
            step, summaries, _, loss, accuracy, f1, learning_rate = \
                sess.run([model.global_step, summary_op, model.train_op, \
                          model.loss, model.accuracy, model.f1, \
                          model.learning_rate], feed_dict=feed_dict, options=runopts)

        summary_writer.add_summary(summaries, step)
        prog.update(idx + 1,
                    [('step', step), ('train loss', loss),
                     ('train accuracy', accuracy), ('train f1', f1),
                     ('lr(invalid if use_bert_optimization)', learning_rate)])
    duration_time = time.time() - start_time
    out = '\nduration_time : ' + str(duration_time) + ' sec for this epoch'
    tf.logging.debug(out)
Beispiel #2
0
def analyze(graph, sess, query, config, nlp):
    """Analyze query by nlp, etagger
    """
    bucket = build_bucket(nlp, query)
    inp, feed_dict = feed.build_input_feed_dict_with_graph(
        graph, config, bucket, Input)
    ## mapping output/input tensors for bert
    if 'bert' in config.emb_class:
        t_bert_embeddings_subgraph = graph.get_tensor_by_name(
            'prefix/bert_embeddings_subgraph:0')
        p_bert_embeddings = graph.get_tensor_by_name(
            'prefix/bert_embeddings:0')
    ## mapping output tensors
    t_logits_indices = graph.get_tensor_by_name('prefix/logits_indices:0')
    t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0')
    ## analyze
    if 'bert' in config.emb_class:
        # compute bert embedding at runtime
        bert_embeddings = sess.run([t_bert_embeddings_subgraph],
                                   feed_dict=feed_dict)
        # update feed_dict
        feed_dict[p_bert_embeddings] = feed.align_bert_embeddings(
            config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1)
    logits_indices, sentence_lengths = sess.run(
        [t_logits_indices, t_sentence_lengths], feed_dict=feed_dict)
    tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0])
    ## build output
    out = []
    for i in range(len(bucket)):
        tmp = bucket[i] + ' ' + tags[i]
        tl = tmp.split()
        entry = {}
        entry['id'] = i
        entry['word'] = tl[0]
        entry['pos'] = tl[1]
        entry['chk'] = tl[2]
        entry['tag'] = tl[3]
        entry['predict'] = tl[4]
        out.append(entry)
    return out
Beispiel #3
0
def inference_bucket(config):
    """Inference for bucket.
    """

    # create model and compile
    model = Model(config)
    model.compile()
    sess = model.sess

    # restore model
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    sys.stderr.write('model restored' + '\n')
    '''
    print(tf.global_variables())
    print(tf.trainable_variables())
    '''
    num_buckets = 0
    total_duration_time = 0.0
    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input)
            if 'bert' in config.emb_class:
                # compute bert embedding at runtime
                bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                           feed_dict=feed_dict)
                # update feed_dict
                feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                    config, bert_embeddings,
                    inp.example['bert_wordidx2tokenidx'], -1)
            logits_indices, sentence_lengths = sess.run(
                [model.logits_indices, model.sentence_lengths],
                feed_dict=feed_dict)
            tags = config.logit_indices_to_tags(logits_indices[0],
                                                sentence_lengths[0])
            for i in range(len(bucket)):
                out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            out = 'duration_time : ' + str(duration_time) + ' sec'
            tf.logging.info(out)
            num_buckets += 1
            if num_buckets != 1:  # first one may takes longer time, so ignore in computing duration.
                total_duration_time += duration_time
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input)
        if 'bert' in config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict)
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, inp.example['bert_wordidx2tokenidx'],
                -1)
        logits_indices, sentence_lengths = sess.run(
            [model.logits_indices, model.sentence_lengths],
            feed_dict=feed_dict)
        tags = config.logit_indices_to_tags(logits_indices[0],
                                            sentence_lengths[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        out = 'duration_time : ' + str(duration_time) + ' sec'
        tf.logging.info(out)
        num_buckets += 1
        total_duration_time += duration_time

    out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n'
    out += 'average processing time / bucket : ' + str(
        total_duration_time / (num_buckets - 1)) + ' sec'
    tf.logging.info(out)

    sess.close()
Beispiel #4
0
def inference_line(config):
    """Inference for raw string.
    """
    def get_entity(doc, begin, end):
        for ent in doc.ents:
            # check included
            if ent.start_char <= begin and end <= ent.end_char:
                if ent.start_char == begin: return 'B-' + ent.label_
                else: return 'I-' + ent.label_
        return 'O'

    def build_bucket(nlp, line):
        bucket = []
        doc = nlp(line)
        for token in doc:
            begin = token.idx
            end = begin + len(token.text) - 1
            temp = []
            '''
            print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_,
                  token.shape_, token.is_alpha, token.is_stop, begin, end)
            '''
            temp.append(token.text)
            temp.append(token.tag_)
            temp.append('O')  # no chunking info
            entity = get_entity(doc, begin, end)
            temp.append(entity)  # entity by spacy
            temp = ' '.join(temp)
            bucket.append(temp)
        return bucket

    import spacy
    nlp = spacy.load('en')

    # create model and compile
    model = Model(config)
    model.compile()
    sess = model.sess

    # restore model
    saver = tf.train.Saver()
    saver.restore(sess, config.restore)
    tf.logging.info('model restored' + '\n')

    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line: continue
        # create bucket
        try:
            bucket = build_bucket(nlp, line)
        except Exception as e:
            sys.stderr.write(str(e) + '\n')
            continue
        inp, feed_dict = feed.build_input_feed_dict(model, bucket)
        if 'bert' in config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict)
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, inp.example['bert_wordidx2tokenidx'],
                -1)
        logits_indices, sentence_lengths = sess.run(
            [model.logits_indices, model.sentence_lengths],
            feed_dict=feed_dict)
        tags = config.logit_indices_to_tags(logits_indices[0],
                                            sentence_lengths[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')

    sess.close()
Beispiel #5
0
def inference(config, frozen_pb_path):
    """Inference for bucket
    """

    # load graph
    graph = load_frozen_graph(frozen_pb_path)
    for op in graph.get_operations():
        sys.stderr.write(op.name + '\n')

    # create session with graph
    # if graph is optimized by tensorRT, then
    # from tensorflow.contrib import tensorrt as trt
    # gpu_ops = tf.GPUOptions(per_process_gpu_memory_fraction = 0.50)
    gpu_ops = tf.GPUOptions()
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_ops)
    '''
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False,
                                  gpu_options=gpu_ops,
                                  inter_op_parallelism_threads=0,
                                  intra_op_parallelism_threads=0)
    sess = tf.Session(graph=graph, config=session_conf)

    # mapping output/input tensors for bert
    if 'bert' in config.emb_class:
        t_bert_embeddings_subgraph = graph.get_tensor_by_name(
            'prefix/bert_embeddings_subgraph:0')
        p_bert_embeddings = graph.get_tensor_by_name(
            'prefix/bert_embeddings:0')
    # mapping output tensors
    t_logits_indices = graph.get_tensor_by_name('prefix/logits_indices:0')
    t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0')

    num_buckets = 0
    total_duration_time = 0.0
    bucket = []
    while 1:
        try:
            line = sys.stdin.readline()
        except KeyboardInterrupt:
            break
        if not line: break
        line = line.strip()
        if not line and len(bucket) >= 1:
            start_time = time.time()
            inp, feed_dict = feed.build_input_feed_dict_with_graph(
                graph, config, bucket, Input)
            if 'bert' in config.emb_class:
                # compute bert embedding at runtime
                bert_embeddings = sess.run([t_bert_embeddings_subgraph],
                                           feed_dict=feed_dict)
                # update feed_dict
                feed_dict[p_bert_embeddings] = feed.align_bert_embeddings(
                    config, bert_embeddings,
                    inp.example['bert_wordidx2tokenidx'], -1)
            logits_indices, sentence_lengths = sess.run(
                [t_logits_indices, t_sentence_lengths], feed_dict=feed_dict)
            tags = config.logit_indices_to_tags(logits_indices[0],
                                                sentence_lengths[0])
            for i in range(len(bucket)):
                out = bucket[i] + ' ' + tags[i]
                sys.stdout.write(out + '\n')
            sys.stdout.write('\n')
            bucket = []
            duration_time = time.time() - start_time
            out = 'duration_time : ' + str(duration_time) + ' sec'
            sys.stderr.write(out + '\n')
            num_buckets += 1
            if num_buckets != 1:  # first one may takes longer time, so ignore in computing duration.
                total_duration_time += duration_time
        if line: bucket.append(line)
    if len(bucket) != 0:
        start_time = time.time()
        inp, feed_dict = feed.build_input_feed_dict_with_graph(
            graph, config, bucket, Input)
        if 'bert' in config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([t_bert_embeddings_subgraph],
                                       feed_dict=feed_dict)
            # update feed_dict
            feed_dict[p_bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, inp.example['bert_wordidx2tokenidx'],
                -1)
        logits_indices, sentence_lengths = sess.run(
            [t_logits_indices, t_sentence_lengths], feed_dict=feed_dict)
        tags = config.logit_indices_to_tags(logits_indices[0],
                                            sentence_lengths[0])
        for i in range(len(bucket)):
            out = bucket[i] + ' ' + tags[i]
            sys.stdout.write(out + '\n')
        sys.stdout.write('\n')
        duration_time = time.time() - start_time
        out = 'duration_time : ' + str(duration_time) + ' sec'
        tf.logging.info(out)
        num_buckets += 1
        total_duration_time += duration_time

    out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n'
    out += 'average processing time / bucket : ' + str(
        total_duration_time / (num_buckets - 1)) + ' sec'
    tf.logging.info(out)

    sess.close()
Beispiel #6
0
def dev_step(model, data, summary_writer, epoch):
    """Evaluate dev data
    """
    def np_concat(sum_var, var):
        if sum_var is not None:
            sum_var = np.concatenate((sum_var, var), axis=0)
        else:
            sum_var = var
        return sum_var

    sess = model.sess
    runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    sum_loss = 0.0
    sum_accuracy = 0.0
    sum_f1 = 0.0
    sum_output_indices = None
    sum_logits_indices = None
    sum_sentence_lengths = None
    trans_params = None
    global_step = 0
    prog = Progbar(target=data.num_batches)
    iterator = data.dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    sess.run(iterator.initializer)

    # evaluate on dev data sliced by batch_size to prevent OOM(Out Of Memory).
    for idx in range(data.num_batches):
        try:
            dataset = sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break
        feed_dict = feed.build_feed_dict(model, dataset,
                                         data.max_sentence_length, False)
        if 'bert' in model.config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict,
                                       options=runopts)
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, dataset['bert_wordidx2tokenidx'], idx)
        global_step, logits_indices, sentence_lengths, loss, accuracy, f1 = \
            sess.run([model.global_step, model.logits_indices, model.sentence_lengths, \
                      model.loss, model.accuracy, model.f1], feed_dict=feed_dict)
        prog.update(idx + 1, [('dev loss', loss), ('dev accuracy', accuracy),
                              ('dev f1(tf_metrics)', f1)])
        sum_loss += loss
        sum_accuracy += accuracy
        sum_f1 += f1
        sum_output_indices = np_concat(sum_output_indices,
                                       np.argmax(dataset['tags'], 2))
        sum_logits_indices = np_concat(sum_logits_indices, logits_indices)
        sum_sentence_lengths = np_concat(sum_sentence_lengths,
                                         sentence_lengths)
        idx += 1
    avg_loss = sum_loss / data.num_batches
    avg_accuracy = sum_accuracy / data.num_batches
    avg_f1 = sum_f1 / data.num_batches
    tag_preds = model.config.logits_indices_to_tags_seq(
        sum_logits_indices, sum_sentence_lengths)
    tag_corrects = model.config.logits_indices_to_tags_seq(
        sum_output_indices, sum_sentence_lengths)
    seqeval_prec = precision_score(tag_corrects, tag_preds)
    seqeval_rec = recall_score(tag_corrects, tag_preds)
    seqeval_f1 = f1_score(tag_corrects, tag_preds)
    tf.logging.debug('\n[epoch %s/%s] dev precision(seqeval), recall(seqeval), f1(seqeval): %s, %s, %s' % \
        (epoch, model.config.epoch, seqeval_prec, seqeval_rec, seqeval_f1))

    # create summaries manually.
    summary_value = [
        tf.Summary.Value(tag='loss', simple_value=avg_loss),
        tf.Summary.Value(tag='accuracy', simple_value=avg_accuracy),
        tf.Summary.Value(tag='f1(tf_metrics)', simple_value=avg_f1),
        tf.Summary.Value(tag='f1(seqeval)', simple_value=seqeval_f1)
    ]
    summaries = tf.Summary(value=summary_value)
    summary_writer.add_summary(summaries, global_step)

    return seqeval_f1, avg_f1