예제 #1
0
'''
import codecs
import collections
import os
import pickle
import tensorflow as tf

from bert_base.bert import tokenization
from bert_base.bert import modeling
from bert_base.bert import optimization
from bert_base.train.models import create_model, InputFeatures
from bert_base.server.helper import set_logger

__version__ = '0.1.0'

logger = set_logger('NER Training')


def write_tokens(tokens, output_dir, mode):
    """
    将序列解析结果写入到文件中
    只在mode=test的时候启用
    :param tokens:
    :param mode:
    :return:
    """
    if mode == "test":
        path = os.path.join(output_dir, "token_" + mode + ".txt")
        wf = codecs.open(path, 'a', encoding='utf-8')
        for token in tokens:
            if token != "**NULL**":
from bert_base.albert import optimization
from bert_base.albert import tokenization

# import

from bert_base.train.models import create_albert_model, InputFeatures, InputExample
from bert_base.server.helper import set_logger
__version__ = '0.1.0'

__all__ = [
    '__version__', 'DataProcessor', 'NerProcessor', 'write_tokens',
    'convert_single_example', 'filed_based_convert_examples_to_features',
    'file_based_input_fn_builder', 'model_fn_builder', 'train'
]

logger = set_logger('albert NER Training')


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""
    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()