示例#1
0
    def generate_tfrecord(cls, file, token2id_dct, tfrecord_file):
        from qiznlp.common.tfrecord_utils import items2tfrecord
        word2id = token2id_dct['word2id']

        def items_gen():
            with open(file, 'r', encoding='U8') as f:
                for i, line in enumerate(f):
                    item = line.strip().split('\t')
                    if len(item) != 2:
                        print('error item:', repr(line))
                        continue
                    try:
                        multi_s1 = item[0].split('$$$')
                        s2 = item[1]
                        multi_s1_ids = cls.multi_sent2ids(multi_s1, word2id, max_word_len=50)
                        s2_ids = cls.sent2ids(s2, word2id, max_word_len=50)
                        if i < 5:  # check
                            print(f'check {i}:')
                            print(f'{multi_s1} -> {multi_s1_ids}')
                            print(f'{s2} -> {s2_ids}')
                        d = {
                            'multi_s1': multi_s1_ids,
                            's2': s2_ids,
                        }
                        yield d
                    except Exception as e:
                        print('Exception occur in items_gen()!\n', e)
                        continue

        count = items2tfrecord(items_gen(), tfrecord_file)
        return count
示例#2
0
    def generate_tfrecord(cls, file, token2id_dct, tfrecord_file):
        from qiznlp.common.tfrecord_utils import items2tfrecord
        word2id = token2id_dct['word2id']
        label2id = token2id_dct['label2id']

        def items_gen():
            with open(file, 'r', encoding='U8') as f:
                for i, line in enumerate(f):
                    item = line.strip().split('\t')
                    if len(item) != 2:
                        print(repr(line))
                        continue
                    try:
                        s1 = item[0]
                        y = item[1]
                        s1_ids = cls.sent2ids(s1, word2id, max_word_len=50)
                        y_id = cls.label2id(y, label2id)
                        if i < 5:  # check
                            print(f'check {i}:')
                            print(f'{s1} -> {s1_ids}')
                            print(f'{y} -> {y_id}')
                        d = {
                            's1': s1_ids,
                            'target': y_id,
                        }
                        yield d
                    except Exception as e:
                        print('Exception occur in items_gen()!\n', e)
                        continue

        count = items2tfrecord(items_gen(), tfrecord_file)
        return count
示例#3
0
    def generate_tfrecord(cls, file, token2id_dct, tfrecord_file):
        from qiznlp.common.tfrecord_utils import items2tfrecord
        char2id = token2id_dct['char2id']
        bmeo2id = token2id_dct['bmeo2id']

        def items_gen():
            with open(file, 'r', encoding='U8') as f:
                for i, line in enumerate(f):
                    item = line.strip().split('\t')
                    if len(item) != 2:
                        print(repr(line))
                        continue
                    try:
                        s1 = item[0]
                        bmeo = item[1]
                        s1_ids = cls.sent2ids(s1, char2id, max_word_len=100)
                        bmeo_ids = cls.bmeo2ids(bmeo,
                                                bmeo2id,
                                                max_word_len=100)
                        if i < 5:  # check
                            print(f'check {i}:')
                            print(f'{s1} -> {s1_ids}')
                            print(f'{bmeo} -> {bmeo_ids}')
                        d = {
                            's1': s1_ids,
                            'ner_label': bmeo_ids,
                        }
                        yield d
                    except Exception as e:
                        print('Exception occur in items_gen()!\n', e)
                        continue

        count = items2tfrecord(items_gen(), tfrecord_file)
        return count