예제 #1
0
def train_step(model, data, summary_op, summary_writer):
    """Train one epoch
    """
    start_time = time.time()
    sess = model.sess
    runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    prog = Progbar(target=data.num_batches)
    iterator = data.dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    sess.run(iterator.initializer)
    for idx in range(data.num_batches):
        try:
            dataset = sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break
        feed_dict = feed.build_feed_dict(model, dataset,
                                         data.max_sentence_length, True)
        if 'bert' in model.config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict,
                                       options=runopts)
            if idx == 0:
                tf.logging.debug('# bert_token_ids')
                t = dataset['bert_token_ids'][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
                tf.logging.debug('# bert_token_masks')
                t = dataset['bert_token_masks'][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
                tf.logging.debug('# bert_embedding')
                t = bert_embeddings[0][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
                tf.logging.debug('# bert_wordidx2tokenidx')
                t = dataset['bert_wordidx2tokenidx'][:1]
                tf.logging.debug(' '.join([str(x) for x in np.shape(t)]))
                tf.logging.debug(' '.join([str(x) for x in t]))
            # update feed_dict
            feed.update_feed_dict(model, feed_dict, bert_embeddings,
                                  dataset['bert_wordidx2tokenidx'])
            step, summaries, _, loss, accuracy, f1, learning_rate = \
                sess.run([model.global_step, summary_op, model.train_op, \
                          model.loss, model.accuracy, model.f1, \
                          model.learning_rate], feed_dict=feed_dict, options=runopts)
        else:
            step, summaries, _, loss, accuracy, f1, learning_rate = \
                sess.run([model.global_step, summary_op, model.train_op, \
                          model.loss, model.accuracy, model.f1, \
                          model.learning_rate], feed_dict=feed_dict, options=runopts)

        summary_writer.add_summary(summaries, step)
        prog.update(idx + 1,
                    [('step', step), ('train loss', loss),
                     ('train accuracy', accuracy), ('train f1', f1),
                     ('lr(invalid if use_bert_optimization)', learning_rate)])
    duration_time = time.time() - start_time
    out = '\nduration_time : ' + str(duration_time) + ' sec for this epoch'
    tf.logging.debug(out)
예제 #2
0
def dev_step(model, data, summary_writer, epoch):
    """Evaluate dev data
    """
    sess = model.sess
    runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    sum_loss = 0.0
    sum_accuracy = 0.0
    sum_f1 = 0.0
    sum_output_indices = None
    sum_logits_indices = None
    sum_sentence_lengths = None
    trans_params = None
    global_step = 0
    prog = Progbar(target=data.num_batches)
    iterator = data.dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    sess.run(iterator.initializer)

    # evaluate on dev data sliced by batch_size to prevent OOM(Out Of Memory).
    for idx in range(data.num_batches):
        try:
            dataset = sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break
        feed_dict = feed.build_feed_dict(model, dataset,
                                         data.max_sentence_length, False)
        if 'bert' in model.config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict,
                                       options=runopts)
            # update feed_dict
            feed.update_feed_dict(model, feed_dict, bert_embeddings,
                                  dataset['bert_wordidx2tokenidx'])
        global_step, logits_indices, sentence_lengths, loss, accuracy, f1 = \
            sess.run([model.global_step, model.logits_indices, model.sentence_lengths, \
                      model.loss, model.accuracy, model.f1], feed_dict=feed_dict)
        prog.update(idx + 1, [('dev loss', loss), ('dev accuracy', accuracy),
                              ('dev f1', f1)])
        sum_loss += loss
        sum_accuracy += accuracy
        sum_f1 += f1
        sum_output_indices = np_concat(sum_output_indices,
                                       np.argmax(dataset['tags'], 2))
        sum_logits_indices = np_concat(sum_logits_indices, logits_indices)
        sum_sentence_lengths = np_concat(sum_sentence_lengths,
                                         sentence_lengths)
        idx += 1
    avg_loss = sum_loss / data.num_batches
    avg_accuracy = sum_accuracy / data.num_batches
    avg_f1 = sum_f1 / data.num_batches
    tag_preds = model.config.logits_indices_to_tags_seq(
        sum_logits_indices, sum_sentence_lengths)
    tag_corrects = model.config.logits_indices_to_tags_seq(
        sum_output_indices, sum_sentence_lengths)
    tf.logging.debug('\n[epoch %s/%s] dev precision, recall, f1(token): ' %
                     (epoch, model.config.epoch))
    token_f1, l_token_prec, l_token_rec, l_token_f1 = TokenEval.compute_f1(
        model.config.class_size, sum_logits_indices, sum_output_indices,
        sum_sentence_lengths)
    tf.logging.debug('[' + ' '.join([str(x) for x in l_token_prec]) + ']')
    tf.logging.debug('[' + ' '.join([str(x) for x in l_token_rec]) + ']')
    tf.logging.debug('[' + ' '.join([str(x) for x in l_token_f1]) + ']')
    chunk_prec, chunk_rec, chunk_f1 = ChunkEval.compute_f1(
        tag_preds, tag_corrects)
    tf.logging.debug('dev precision(chunk), recall(chunk), f1(chunk): %s, %s, %s' % \
        (chunk_prec, chunk_rec, chunk_f1) + \
        '(invalid for bert due to X tag)')

    # create summaries manually.
    summary_value = [
        tf.Summary.Value(tag='loss', simple_value=avg_loss),
        tf.Summary.Value(tag='accuracy', simple_value=avg_accuracy),
        tf.Summary.Value(tag='f1', simple_value=avg_f1),
        tf.Summary.Value(tag='token_f1', simple_value=token_f1),
        tf.Summary.Value(tag='chunk_f1', simple_value=chunk_f1)
    ]
    summaries = tf.Summary(value=summary_value)
    summary_writer.add_summary(summaries, global_step)

    return token_f1, chunk_f1, avg_f1
예제 #3
0
def dev_step(model, data, summary_writer, epoch):
    """Evaluate dev data
    """
    def np_concat(sum_var, var):
        if sum_var is not None:
            sum_var = np.concatenate((sum_var, var), axis=0)
        else:
            sum_var = var
        return sum_var

    sess = model.sess
    runopts = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    sum_loss = 0.0
    sum_accuracy = 0.0
    sum_f1 = 0.0
    sum_output_indices = None
    sum_logits_indices = None
    sum_sentence_lengths = None
    trans_params = None
    global_step = 0
    prog = Progbar(target=data.num_batches)
    iterator = data.dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    sess.run(iterator.initializer)

    # evaluate on dev data sliced by batch_size to prevent OOM(Out Of Memory).
    for idx in range(data.num_batches):
        try:
            dataset = sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break
        feed_dict = feed.build_feed_dict(model, dataset,
                                         data.max_sentence_length, False)
        if 'bert' in model.config.emb_class:
            # compute bert embedding at runtime
            bert_embeddings = sess.run([model.bert_embeddings_subgraph],
                                       feed_dict=feed_dict,
                                       options=runopts)
            # update feed_dict
            feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(
                config, bert_embeddings, dataset['bert_wordidx2tokenidx'], idx)
        global_step, logits_indices, sentence_lengths, loss, accuracy, f1 = \
            sess.run([model.global_step, model.logits_indices, model.sentence_lengths, \
                      model.loss, model.accuracy, model.f1], feed_dict=feed_dict)
        prog.update(idx + 1, [('dev loss', loss), ('dev accuracy', accuracy),
                              ('dev f1(tf_metrics)', f1)])
        sum_loss += loss
        sum_accuracy += accuracy
        sum_f1 += f1
        sum_output_indices = np_concat(sum_output_indices,
                                       np.argmax(dataset['tags'], 2))
        sum_logits_indices = np_concat(sum_logits_indices, logits_indices)
        sum_sentence_lengths = np_concat(sum_sentence_lengths,
                                         sentence_lengths)
        idx += 1
    avg_loss = sum_loss / data.num_batches
    avg_accuracy = sum_accuracy / data.num_batches
    avg_f1 = sum_f1 / data.num_batches
    tag_preds = model.config.logits_indices_to_tags_seq(
        sum_logits_indices, sum_sentence_lengths)
    tag_corrects = model.config.logits_indices_to_tags_seq(
        sum_output_indices, sum_sentence_lengths)
    seqeval_prec = precision_score(tag_corrects, tag_preds)
    seqeval_rec = recall_score(tag_corrects, tag_preds)
    seqeval_f1 = f1_score(tag_corrects, tag_preds)
    tf.logging.debug('\n[epoch %s/%s] dev precision(seqeval), recall(seqeval), f1(seqeval): %s, %s, %s' % \
        (epoch, model.config.epoch, seqeval_prec, seqeval_rec, seqeval_f1))

    # create summaries manually.
    summary_value = [
        tf.Summary.Value(tag='loss', simple_value=avg_loss),
        tf.Summary.Value(tag='accuracy', simple_value=avg_accuracy),
        tf.Summary.Value(tag='f1(tf_metrics)', simple_value=avg_f1),
        tf.Summary.Value(tag='f1(seqeval)', simple_value=seqeval_f1)
    ]
    summaries = tf.Summary(value=summary_value)
    summary_writer.add_summary(summaries, global_step)

    return seqeval_f1, avg_f1