示例#1
0
def test_save_model(tmpdir):
    model_dir = tmpdir.mkdir("some_dir")
    model_dir_str = str(model_dir)

    A = {}
    B = {}

    vocabulary = set()

    hmm_model = HMMModel(A, B, vocabulary)
    hmm_model.save_model(model_dir_str)

    assert len(model_dir.listdir()) == 3
示例#2
0
文件: t.py 项目: newszeng/ai_writer
def train():
    model_dir_str = "data/hmm"
    hmm_model = HMMModel()
    output = '/home/terry/pan/github/Bert-Sentence-streamlining/Bert-Sentence-streamlining/data/train_old.json'
    with open(output, 'r') as f:

        items = []
        for line in tqdm(f):
            j_content = json.loads(line)
            if j_content['label'] == "Yes":
                items.append(j_content)
                one_line = bulid_mark(j_content['sentence'])
                # print(j_content['sentence'])
                # print(one_line)
                hmm_model.train_one_line(one_line)
        hmm_model.save_model(model_dir_str)
    text = "它们的岗位,一只边牧可以管理上千头羊群呢,它们为主人忠心耿耿的守护着家畜,守护着家园"
    s = jieba_seg_list(text)
    result = hmm_model.predict(s)
    print(result)
    print(hmm_model)
示例#3
0
class HMMTokenizer(BaseTokenizer):
    def __init__(self, *args, **kwargs):
        super(HMMTokenizer, self).__init__(*args, **kwargs)

        self.hmm_model = HMMModel()  # type: HMMModel

    def train_one_line(self, token_list):
        list_of_word_tag_pair = []
        for word in token_list:
            word = word.strip()

            tag = self._generate_char_tag_for_word(word)

            list_of_word_tag_pair.extend(list(zip(word, tag)))

        self.hmm_model.train_one_line(list_of_word_tag_pair)

    def do_train(self):
        self.hmm_model.do_train()

    @staticmethod
    def _generate_char_tag_for_word(word):
        # TODO: tag set related function should go to a standalone package
        len_of_word = len(word)

        if len_of_word == 1:
            return 'S'

        if len_of_word >= 2:
            number_of_middle = len_of_word - 2
            return 'B' + 'M' * number_of_middle + 'E'

    def predict(self, line, output_graphml_file=None):
        char_list = line

        char_tag_pair = self.hmm_model.predict(char_list, output_graphml_file)

        # TODO: current BMES decoding is not good, can't raise decoding exception

        token_list = []
        word_char = []
        for char, tag in char_tag_pair:
            # no matter what, word_char still need record
            word_char.append(char)

            if tag == "S" or tag == "E":
                # emission token word
                word = "".join(word_char)
                token_list.append(word)

                # reset word_char cache
                word_char = []

        # no matter what, char can not disappear
        if word_char:
            word = "".join(word_char)
            token_list.append(word)

        return token_list

    def segment(self, message):
        # type: (str) -> List[str]

        return self.predict(message)

    def load_model(self):
        self.hmm_model = HMMModel.load_model(self.model_dir)

    def persist_to_dir(self, output_dir):
        # type: (str) -> None
        self.hmm_model.save_model(output_dir)

    def assign_from_loader(self, *args, **kwargs):
        self.hmm_model = kwargs['hmm_model']

    def get_loader(self):
        return HMMLoader