예제 #1
0
def read_instances_from_file(file, max_len=400, keep_case=False):
    ''' Collect instances and construct vocab '''

    dataset = DataSet()
    trimmed_sent = 0

    with open(file) as f:
        lines = f.readlines()
        for l in lines:
            l = l.strip().split('\t')
            if len(l) < 2:
                continue
            label = int(l[0])
            sent = l[1]
            if not keep_case:
                sent = sent.lower()
            word_lst = sent.split()
            if len(word_lst) > max_len:
                word_lst = word_lst[:max_len]
                trimmed_sent += 1
            if word_lst:
                dataset.append(Instance(words=word_lst, label=label))

    logger.info('Get {} instances from file {}'.format(len(dataset), file))
    if trimmed_sent:
        logger.info('{} sentences are trimmed. Max sentence length: {}.'
                    .format(trimmed_sent, max_len))

    return dataset
def process_poems_large(file_name, sentence_len, vocab_size):
    sentences = []
    with open(file_name, "r", encoding='utf-8', ) as f:
        for line in f.readlines():
            try:
                line = line.strip()
                if line:
                    title, content = line.split(':')
                    # print(title)
                    # print(content)
                    # content = line.replace(' ', '').replace(',','').replace('。','')
                    content = content.replace(' ', '') #包含标点符号
                    # 可以只取五言诗
                    # if len(content) < 6 or content[5] != ',':
                    #     continue
                    if len(content) < 20:
                        continue
                    if ':' in content or '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
                        continue
                    #截断长度
                    if len(content) > sentence_len:
                        content = content[:sentence_len]
                    content = content + end_token
                    sentences.append(content)
            except ValueError as e:
                pass

    dataset = DataSet()
    # sentences = random.sample(sentences, 5000)
    for sentence in sentences:
        instance = Instance()
        instance['raw_sentence'] = sentence
        instance['target'] = sentence[1:] + sentence[-1]
        dataset.append(instance)

    dataset.set_input("raw_sentence")
    dataset.set_target("target")
    
    # for iter in dataset:
    #     print(iter['raw_sentence'])
    print("dataset_size:", len(dataset))

    train_data, dev_data = dataset.split(0.2)
    train_data.rename_field("raw_sentence", "sentence")
    dev_data.rename_field("raw_sentence", "sentence")
    vocab = Vocabulary(max_size=vocab_size, min_freq=2, unknown='<unk>', padding='<pad>')

    # 构建词表
    train_data.apply(lambda x: [vocab.add(word) for word in x['sentence']])
    vocab.build_vocab()

    # 根据词表index句子
    train_data.apply(lambda x: [vocab.to_index(word) for word in x['sentence']], new_field_name='sentence')
    train_data.apply(lambda x: [vocab.to_index(word) for word in x['target']], new_field_name='target')
    dev_data.apply(lambda x: [vocab.to_index(word) for word in x['sentence']], new_field_name='sentence')
    dev_data.apply(lambda x: [vocab.to_index(word) for word in x['target']], new_field_name='target')

    print("vocabulary_size:", len(vocab))

    return train_data, dev_data, vocab
예제 #3
0
def read_file(filename, processing_word=get_processing_word(lowercase=False)):
    dataset = DataSet()
    niter = 0
    with codecs.open(filename, "r", "utf-16") as f:
        words, tags = [], []
        for line in f:
            line = line.strip()
            if len(line) == 0 or line.startswith("-DOCSTART-"):
                if len(words) != 0:
                    assert len(words) > 2
                    if niter == 1:
                        print(words, tags)
                    niter += 1
                    dataset.append(
                        Instance(ori_words=words[:-1], ori_tags=tags[:-1]))
                    words, tags = [], []
            else:
                word, tag = line.split()
                word = processing_word(word)
                words.append(word)
                tags.append(tag.lower())

    dataset.apply_field(lambda x: [x[0]],
                        field_name='ori_words',
                        new_field_name='task')
    dataset.apply_field(lambda x: len(x),
                        field_name='ori_tags',
                        new_field_name='seq_len')
    dataset.apply_field(lambda x: expand(x),
                        field_name='ori_words',
                        new_field_name="bi1")
    return dataset
예제 #4
0
파일: loader.py 프로젝트: yhcc/BertForRD
    def load(self, folder):
        fns ={
            'dev':'{}_dev.csv'.format(self.lg1_lg2),
            'test':'{}_test500.csv'.format(self.lg1_lg2),
            'train': '{}_train500_10.csv'.format(self.lg1_lg2)
        }
        target_lg = self.lg1_lg2.split('_')[0]
        data_bundle = DataBundle()
        for name, fn in fns.items():
            path = os.path.join(folder, fn)
            ds = DataSet()
            with open(path, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        parts = line.split('\t')
                        if self.lower:
                            ins = Instance(word=parts[1].lower(), definition=parts[-1].lower())
                        else:
                            ins = Instance(word=parts[1], definition=parts[-1])
                        ds.append(ins)
            data_bundle.set_dataset(ds, name=name)
        target_words = {}
        with open(os.path.join(folder, '{}.txt'.format(target_lg)), encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    if self.lower:
                        line = line.lower()
                    target_words[line] = 1
        target_words = list(target_words.keys())

        setattr(data_bundle, 'target_words', target_words)
        return data_bundle
예제 #5
0
def construct_dataset(sentences):
    dataset = DataSet()
    for sentence in sentences:
        instance = Instance()
        instance['raw_sentence'] = sentence
        dataset.append(instance)
    return dataset
예제 #6
0
 def load(self, path: str, bigram: bool = False) -> DataSet:
     """
     :param path: str
     :param bigram: 是否使用bigram feature
     :return:
     """
     dataset = DataSet()
     with open(path, 'r', encoding='utf-8') as f:
         for line in f:
             line = line.strip()
             if not line:  # 去掉空行
                 continue
             parts = line.split()
             word_lens = map(len, parts)
             chars = list(''.join(parts))
             tags = self._word_len_to_target(word_lens)
             assert len(chars) == len(tags['target'])
             dataset.append(
                 Instance(raw_chars=chars, **tags, seq_len=len(chars)))
     if len(dataset) == 0:
         raise RuntimeError(f"{path} has no valid data.")
     if bigram:
         dataset.apply_field(self._gen_bigram,
                             field_name='raw_chars',
                             new_field_name='bigrams')
     return dataset
예제 #7
0
def preprocess():
    train_set = DataSet()
    for i in range(len(raw_train.data)):
        train_set.append(
            Instance(sentence=raw_train.data[i],
                     target=int(raw_train.target[i])))

    train_set.apply(lambda x: x['sentence'].translate(
        str.maketrans("", "", string.punctuation)).lower(),
                    new_field_name='sentence')
    train_set.apply(lambda x: x['sentence'].split(), new_field_name='words')
    train_set.apply(lambda x: len(x['words']), new_field_name='seq_len')

    test_set = DataSet()
    for i in range(len(raw_test.data)):
        test_set.append(
            Instance(sentence=raw_test.data[i],
                     target=int(raw_test.target[i])))

    test_set.apply(lambda x: x['sentence'].translate(
        str.maketrans("", "", string.punctuation)).lower(),
                   new_field_name='sentence')
    test_set.apply(lambda x: x['sentence'].split(), new_field_name='words')
    test_set.apply(lambda x: len(x['words']), new_field_name='seq_len')

    vocab = Vocabulary(min_freq=10)
    train_set.apply(lambda x: [vocab.add(word) for word in x['words']])
    test_set.apply(lambda x: [vocab.add(word) for word in x['words']])
    vocab.build_vocab()
    vocab.index_dataset(train_set, field_name='words', new_field_name='words')
    vocab.index_dataset(test_set, field_name='words', new_field_name='words')

    return train_set, test_set, vocab
예제 #8
0
def make_dataset(data):
    dataset = DataSet()
    mx = 0
    le = None
    for x, y in zip(data.data, data.target):
        xx = deal(x)
        ins = Instance(sentence=xx, label=int(y))
        if mx < len(xx.split()):
            mx = max(mx, len(xx.split()))
            le = xx
        dataset.append(ins)
    print(mx)
    dataset.apply_field(lambda x: x.split(),
                        field_name='sentence',
                        new_field_name='words')
    dataset.apply_field(lambda x: len(x),
                        field_name='words',
                        new_field_name='seq_len')

    dataset.rename_field('words', Const.INPUT)
    dataset.rename_field('seq_len', Const.INPUT_LEN)
    dataset.rename_field('label', Const.TARGET)

    dataset.set_input(Const.INPUT, Const.INPUT_LEN)
    dataset.set_target(Const.TARGET)
    return dataset
예제 #9
0
def get_joke_data(data_path):
    data_set = DataSet()
    sample_num = 0
    sample_len = []
    if os.path.exists(data_path):
        with open(data_path, 'r', encoding='utf-8') as fin:
            for lid, line in enumerate(fin):
                joke = json.loads(line)
                if joke['support'] > 0:
                    if len(joke['content']) == 0:
                        continue
                    else:
                        instance = Instance(raw_joke=joke['content'])
                        data_set.append(instance)
                        sample_num += 1
                        sample_len.append(len(joke['content']))
    else:
        print("the data path doesn't  exit.")
    print("Got {} samples from file.".format(sample_num))
    for i in range(5):
        import random
        id = random.randint(0, sample_num)
        print("sample {}: {}".format(id, data_set[id]['raw_joke']))

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    plt.hist(sample_len, bins=50, range=(0, 1000))
    plt.savefig("./examples.jpg")
    count = 0
    for i in sample_len:
        if i < 255:
            count += 1
    print(count, '/', len(sample_len))
    return data_set
예제 #10
0
def get_fastnlp_dataset():
    text_train, text_test = get_text_classification_datasets()
    train_data = DataSet()
    test_data = DataSet()
    for i in range(len(text_train.data)):
        train_data.append(
            Instance(text=split_sent(text_train.data[i]),
                     target=int(text_train.target[i])))
    for i in range(len(text_test.data)):
        test_data.append(
            Instance(text=split_sent(text_test.data[i]),
                     target=int(text_test.target[i])))

    # 构建词表
    vocab = Vocabulary(min_freq=5, unknown='<unk>', padding='<pad>')
    train_data.apply(lambda x: [vocab.add(word) for word in x['text']])
    vocab.build_vocab()

    # 根据词表映射句子
    train_data.apply(lambda x: [vocab.to_index(word) for word in x['text']],
                     new_field_name='word_seq')
    test_data.apply(lambda x: [vocab.to_index(word) for word in x['text']],
                    new_field_name='word_seq')

    # 设定特征域和标签域
    train_data.set_input("word_seq")
    test_data.set_input("word_seq")
    train_data.set_target("target")
    test_data.set_target("target")

    return train_data, test_data, vocab
예제 #11
0
def make_dataset(data):
    dataset = DataSet()
    tot = 0
    for x in data:

        seq = "[CLS] " + x["raw_text"]
        seq = tokenizer.encode(seq)
        """
        seq=["[CLS]"]+word_tokenize(x["raw_text"])
        seq=tokenizer.convert_tokens_to_ids(seq)
        """
        if len(seq) > 512:
            seq = seq[:512]
            tot += 1
            # print(x["raw_text"])
            # print()

        label = int(x["label"])
        ins = Instance(origin=x["raw_text"],
                       seq=seq,
                       label=label,
                       seq_len=len(seq))
        dataset.append(ins)

    dataset.set_input("seq", "seq_len")
    dataset.set_target("label")
    print(dataset[5])
    print("number:", len(dataset), tot)
    print()
    return dataset
예제 #12
0
 def test_append(self):
     dd = DataSet()
     for _ in range(3):
         dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
     self.assertEqual(len(dd), 3)
     self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
     self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
예제 #13
0
def load(path):

    data = DataSet()
    _data = []

    with open(path, "r", encoding="utf-8") as fil:
        fil.readline()

        for line in fil:
            try:
                tradi, verna = line.strip().split("\t")
            except ValueError:
                continue

            tradi = chinese_tokenizer(tradi)
            verna = chinese_tokenizer(verna)

            vocab.add_word_lst(tradi)
            vocab.add_word_lst(verna)

            _data.append(Instance(traditional=tradi, vernacular=verna))

    random.shuffle(_data)
    for x in _data:
        data.append(x)

    data.set_input("vernacular")
    data.set_target("traditional")
    return data
def create_dataset():
        # categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space', 'rec.motorcycles']
        # categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space', 'rec.motorcycles', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale']
        categories = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware',
                      'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
                      'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space',
                      'soc.religion.christian', 'talk.politics.guns',
                      'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']

        newsgroups_train = fetch_20newsgroups(subset='train', categories=categories, data_home='../../..')
        newsgroups_test = fetch_20newsgroups(subset='test', categories=categories, data_home='../../..')

        dataset = DataSet()

        for i in range(len(newsgroups_train.data)):
            if len(newsgroups_train.data[i]) <= 2000:
                dataset.append(Instance(raw_sentence=newsgroups_train.data[i], target=int(newsgroups_train.target[i])))
        for i in range(len(newsgroups_test.data)):
            if len(newsgroups_test.data[i]) <= 2000:
                dataset.append(Instance(raw_sentence=newsgroups_test.data[i], target=int(newsgroups_test.target[i])))

        dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='sentence')
        dataset.apply(lambda x: x['sentence'].split(), new_field_name='words')
        dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')

        vocab = Vocabulary(min_freq=2).from_dataset(dataset, field_name='words')
        vocab.index_dataset(dataset, field_name='words', new_field_name='words')

        dataset.set_input('words', 'seq_len')
        dataset.set_target('target')

        train_dev_data, test_data = dataset.split(0.1)
        train_data, dev_data = train_dev_data.split(0.1)

        return vocab, train_data, dev_data, test_data
예제 #15
0
    def dataset(self):
        d = DataSet()
        for key in self.data:
            for ins in self.data[key]['dataset']['chars']:
                ins = Instance(chars=ins)
                d.append(ins)

        return d
예제 #16
0
 def test_demo(self):
     # related to issue https://github.com/fastnlp/fastNLP/issues/324#issue-705081091
     from fastNLP import DataSet, Instance
     from fastNLP.io import DataBundle
     data_bundle = DataBundle()
     ds = DataSet()
     ds.append(Instance(raw_words="截流 进入 最后 冲刺 ( 附 图片 1 张 )"))
     data_bundle.set_dataset(ds, name='train')
     data_bundle = CWSPipe().process(data_bundle)
     self.assertFalse('<' in data_bundle.get_vocab('chars'))
예제 #17
0
def combine_data_set(ds_a, ds_b):
    ds = DataSet()
    for ins in ds_a:
        ds.append(ins)
    for ins in ds_b:
        ds.append(ins)
    for k in ds_a.field_arrays.keys():
        ds.set_input(k, flag=ds_a.field_arrays[k].is_input)
        ds.set_target(k, flag=ds_a.field_arrays[k].is_target)
    return ds
예제 #18
0
def process_data_1(embed_file, cws_train):
    embed, vocab = EmbedLoader.load_without_vocab(embed_file)
    time.sleep(1)  # 测试是否通过读取cache获得结果
    with open(cws_train, 'r', encoding='utf-8') as f:
        d = DataSet()
        for line in f:
            line = line.strip()
            if len(line) > 0:
                d.append(Instance(raw=line))
    return embed, vocab, d
예제 #19
0
def readdata():
    global target_len
    min_count = 10
    #categories = ['comp.os.ms-windows.misc', 'rec.motorcycles', 'sci.space', 'talk.politics.misc', ]
    dataset_train = fetch_20newsgroups(subset='train', data_home='../../..')
    dataset_test = fetch_20newsgroups(subset='test', data_home='../../..')

    data = dataset_train.data
    target = dataset_train.target
    target_len = len(dataset_train.target_names)
    train_data =  DataSet()
    padding = 0
    for i in range(len(data)):
        data_t =  re.sub("\d+|\s+|/", " ", data[i] )
        temp = [word.strip(string.punctuation).lower() for word in data_t.split() if word.strip(string.punctuation) != '']
        train_data.append(Instance(raw = data[i], label = int(target[i]), words = temp))
        if len(temp) > padding:
            padding = len(temp)
    train_data.apply(lambda x: x['raw'].lower(), new_field_name='raw')

    data = dataset_test.data
    target = dataset_test.target
    test_data =  DataSet()
    padding = 0
    for i in range(len(data)):
        data_t =  re.sub("\d+|\s+|/", " ", data[i] )
        temp = [word.strip(string.punctuation).lower() for word in data_t.split() if word.strip(string.punctuation) != '']
        test_data.append(Instance(raw = data[i], label = int(target[i]), words = temp))
        if len(temp) > padding:
            padding = len(temp)
    test_data.apply(lambda x: x['raw'].lower(), new_field_name='raw')

    train_data.apply(lambda x: len(x['words']), new_field_name='len')
    test_data.apply(lambda x: len(x['words']), new_field_name='len')

    vocab = Vocabulary(min_freq=10)
    train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
    vocab.build_vocab()
    train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='seq')
    test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='seq')
    train_data.rename_field('seq', Const.INPUT)
    train_data.rename_field('len', Const.INPUT_LEN)
    train_data.rename_field('label', Const.TARGET)

    test_data.rename_field('seq', Const.INPUT)
    test_data.rename_field('len', Const.INPUT_LEN)
    test_data.rename_field('label', Const.TARGET)

    test_data.set_input(Const.INPUT, Const.INPUT_LEN)
    test_data.set_target(Const.TARGET)
    train_data.set_input(Const.INPUT, Const.INPUT_LEN)
    train_data.set_target(Const.TARGET)

    test_data, dev_data = test_data.split(0.5)
    return train_data,dev_data,test_data,vocab
예제 #20
0
def construct_dataset(sentences):
    """Construct a data set from a list of sentences.

    :param sentences: list of list of str
    :return dataset: a DataSet object
    """
    dataset = DataSet()
    for sentence in sentences:
        instance = Instance()
        instance['raw_sentence'] = sentence
        dataset.append(instance)
    return dataset
예제 #21
0
 def _load(self, path: str = None):
     ds = DataSet()
     with codecs.open(path, mode='r', encoding='utf-8') as fr:
         for line in fr:
             line = line.strip()
             if len(line) == 0:
                 continue
             sep_index = line.index('\t')
             raw_chars = line[sep_index + 1:]
             target = line[:sep_index]
             if raw_chars:
                 ds.append(Instance(raw_chars=raw_chars, target=target))
     return ds
예제 #22
0
    def __init__(self, path=".data/yelp", dataset="yelp", batch_size=32):

        if dataset == "yelp":
            dataset = DataSet()

            for db_set in ['train']:
                text_file = os.path.join(path, 'sentiment.' + db_set + '.text')
                label_file = os.path.join(path,
                                          'sentiment.' + db_set + '.labels')
                with io.open(text_file, 'r', encoding="utf-8") as tf, io.open(
                        label_file, 'r', encoding="utf-8") as lf:
                    for text in tf:
                        label = lf.readline()
                        dataset.append(Instance(text=text, label=label))

            dataset.apply(lambda x: x['text'].lower(), new_field_name='text')
            dataset.apply(
                lambda x: ['<start>'] + x['text'].split() + ['<eos>'],
                new_field_name='words')
            dataset.drop(lambda x: len(x['words']) > 1 + 15 + 1)
            dataset.apply(lambda x: x['words'] + ['<pad>'] *
                          (17 - len(x['words'])),
                          new_field_name='words')
            dataset.apply(lambda x: int(x['label']),
                          new_field_name='label_seq',
                          is_target=True)

            _train_data, _test_data = dataset.split(0.3)

            _vocab = Vocabulary(min_freq=2)
            _train_data.apply(
                lambda x: [_vocab.add(word) for word in x['words']])
            _vocab.build_vocab()

            _train_data.apply(
                lambda x: [_vocab.to_index(word) for word in x['words']],
                new_field_name='word_seq',
                is_input=True)
            _test_data.apply(
                lambda x: [_vocab.to_index(word) for word in x['words']],
                new_field_name='word_seq',
                is_input=True)

        self.train_data = _train_data
        self.test_data = _test_data
        self.vocab = _vocab
        self.batch_size = batch_size
        self.train_iter = iter(
            Batch(dataset=self.train_data,
                  batch_size=self.batch_size,
                  sampler=SequentialSampler()))
def get_data():
    s = ''
    dataset = DataSet()
    for line in open('../handout/tangshi.txt'):
        if (line == '\n'):
            dataset.append(Instance(raw_sentence=s, label='0'))
            #print(s)
            s = ''
        else:
            s += line.replace('\n', '')

    dataset.apply(add_end, new_field_name='raw_sentence')
    dataset.apply(split_sent, new_field_name='words')
    return dataset
예제 #24
0
 def _load(self, path: str) -> DataSet:
     ds = DataSet()
     all_count = 0
     csv_reader = csv.reader(open(path, encoding='utf-8'), delimiter='\t')
     skip_row = 0
     for idx, row in enumerate(csv_reader):
         if idx <= skip_row:
             continue
         target = row[1]
         words = self.tokenizer(row[0])
         ds.append(Instance(words=words, target=target))
         all_count += 1
     print("all count:", all_count)
     return ds
예제 #25
0
def preprocess(input):
    data = input.data
    target = input.target
    dataset = DataSet()
    for i in range(len(data)):
        data_tmp = data[i]
        for c in string.whitespace:
            data_tmp = data_tmp.replace(c, ' ')
        for c in string.punctuation:
            data_tmp = data_tmp.replace(c, '')
        data_tmp = data_tmp.lower().split()
        # print(data_tmp)
        dataset.append(Instance(sentence=data_tmp, target=int(target[i])))
    dataset.apply(lambda x: len(x['sentence']), new_field_name='seq_len')
    return dataset
예제 #26
0
파일: loader.py 프로젝트: yhcc/BertForRD
def read_dataset(path, lower, word_idx=1, def_idx=-1):
    ds = DataSet()
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split('\t')
                if lower:
                    ins = Instance(word=parts[word_idx].lower(),
                                   definition=parts[def_idx].lower())
                else:
                    ins = Instance(word=parts[word_idx],
                                   definition=parts[def_idx])
                ds.append(ins)
    return ds
예제 #27
0
    def _load(self, path):
        dataset = DataSet()
        with open(path, 'r', encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split('\t')
                target = parts[0]
                words = parts[1].lower().split()
                dataset.append(Instance(words=words, target=target))
        if len(dataset) == 0:
            raise RuntimeError(f"{path} has no valid data.")

        return dataset
예제 #28
0
 def _load(self, path):
     """
     :param str path: 存储数据的路径
     :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
     """
     datalist = []
     with open(path, 'r', encoding='utf-8') as f:
         datas = []
         for l in f:
             datas.extend([(s, self.tag_v[t])
                           for s, t in self._get_one(l, self.subtree)])
     ds = DataSet()
     for words, tag in datas:
         ds.append(Instance(words=words, target=tag))
     return ds
def process_poems(file_name, sentence_len, vocab_size):
    sentences = []
    with open(file_name, "r", encoding='utf-8', ) as f:
        for line in f.readlines():
            try:
                line = line.strip()
                if line:
                    # content = line.replace(' ', '').replace(',','').replace('。','')
                    content = line.replace(' ', '') #包含标点符号
                    if len(content) < 10 or len(content) > sentence_len:
                        continue
                    # print(content)
                    content = content + end_token
                    sentences.append(content)
            except ValueError as e:
                pass

    dataset = DataSet()
    for sentence in sentences:
        instance = Instance()
        instance['raw_sentence'] = sentence
        instance['target'] = sentence[1:] + sentence[-1]
        dataset.append(instance)

    dataset.set_input("raw_sentence")
    dataset.set_target("target")

    # for iter in dataset:
    #     print(iter)
    print("dataset_size:", len(dataset))

    train_data, dev_data = dataset.split(0.2)
    train_data.rename_field("raw_sentence", "sentence")
    dev_data.rename_field("raw_sentence", "sentence")
    vocab = Vocabulary(max_size=vocab_size, min_freq=2, unknown='<unk>', padding='<pad>')

    # 构建词表
    train_data.apply(lambda x: [vocab.add(word) for word in x['sentence']])
    vocab.build_vocab()
    print("vocabulary_size:", len(vocab))

    # 根据词表index句子
    train_data.apply(lambda x: [vocab.to_index(word) for word in x['sentence']], new_field_name='sentence')
    train_data.apply(lambda x: [vocab.to_index(word) for word in x['target']], new_field_name='target')
    dev_data.apply(lambda x: [vocab.to_index(word) for word in x['sentence']], new_field_name='sentence')
    dev_data.apply(lambda x: [vocab.to_index(word) for word in x['target']], new_field_name='target')

    return train_data, dev_data, vocab
예제 #30
0
파일: yelpLoader.py 프로젝트: zxlzr/fastNLP
 def _load(self, path):
     ds = DataSet()
     csv_reader=csv.reader(open(path,encoding='utf-8'))
     all_count=0
     real_count=0
     for row in csv_reader:
         all_count+=1
         if len(row)==2:
             target=self.tag_v[row[0]+".0"]
             words = clean_str(row[1], self.tokenizer, self.lower)
             if len(words)!=0:
                 ds.append(Instance(words=words,target=target))
                 real_count += 1
     print("all count:", all_count)
     print("real count:", real_count)
     return ds