コード例 #1
0
    def test_save_and_load(self):
        fp = 'vocab_save_test.txt'
        try:
            # check word2idx没变,no_create_entry正常
            words = list('abcdefaddfdkjfe')
            no_create_entry = list('12342331')
            unk = '[UNK]'
            vocab = Vocabulary(unknown=unk, max_size=500)

            vocab.add_word_lst(words)
            vocab.add_word_lst(no_create_entry, no_create_entry=True)
            vocab.save(fp)

            new_vocab = Vocabulary.load(fp)

            for word, index in vocab:
                self.assertEqual(new_vocab.to_index(word), index)
            for word in no_create_entry:
                self.assertTrue(new_vocab._is_word_no_create_entry(word))
            for word in words:
                self.assertFalse(new_vocab._is_word_no_create_entry(word))
            for idx in range(len(vocab)):
                self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
            self.assertEqual(vocab.unknown, new_vocab.unknown)

            # 测试vocab中包含None的padding和unk
            vocab = Vocabulary(padding=None, unknown=None)
            words = list('abcdefaddfdkjfe')
            no_create_entry = list('12342331')

            vocab.add_word_lst(words)
            vocab.add_word_lst(no_create_entry, no_create_entry=True)
            vocab.save(fp)

            new_vocab = Vocabulary.load(fp)

            for word, index in vocab:
                self.assertEqual(new_vocab.to_index(word), index)
            for word in no_create_entry:
                self.assertTrue(new_vocab._is_word_no_create_entry(word))
            for word in words:
                self.assertFalse(new_vocab._is_word_no_create_entry(word))
            for idx in range(len(vocab)):
                self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
            self.assertEqual(vocab.unknown, new_vocab.unknown)

        finally:
            import os
            if os.path.exists(fp):
                os.remove(fp)
コード例 #2
0
        task.test_set.apply(lambda x: [vocab.to_index(w) for w in x['words']],
                            new_field_name='words_idx')

        task.train_set.set_input('task_id', 'words_idx', flag=True)
        task.train_set.set_target('label', flag=True)

        task.dev_set.set_input('task_id', 'words_idx', flag=True)
        task.dev_set.set_target('label', flag=True)

        task.test_set.set_input('task_id', 'words_idx', flag=True)
        task.test_set.set_target('label', flag=True)

    logger.info('Finished. Dumping vocabulary to data/vocab.txt')
    with open('data/vocab.txt', mode='w', encoding='utf-8') as f:
        for i in range(len(vocab)):
            f.write(vocab.to_word(i) + '\n')

    logger.info('Testing data...')
    for task in task_lst:
        logger.info(str(task.task_id) + ' ' + task.task_name)
        logger.info(task.train_set[0])
        logger.info(task.dev_set[0])
        logger.info(task.test_set[0])

    logger.info('Dumping data...')
    data = {'task_lst': task_lst}
    save_file = open('data/data.pkl', 'wb')
    pickle.dump(data, save_file)
    save_file.close()
    logger.info('Finished. Looking up for word embeddings...')
    embed_path = '/remote-home/txsun/data/word-embedding/glove/glove.840B.300d.txt'
コード例 #3
0
 def test_to_word(self):
     vocab = Vocabulary()
     vocab.update(text)
     self.assertEqual(
         text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
コード例 #4
0
def prepare_dataInfo(mode,
                     vocab_size,
                     config,
                     train_data_path=None,
                     dev_data_path=None,
                     test_data_path=None):
    def sent_to_words(sents):
        result = []
        for sent in sents:
            result.extend([
                word.strip() for word in sent.split(" ")
                if len(word.strip()) != 0
            ])
        return result

    # dataloader = Cnn_dailymailLodaer()
    # 适用于输入是json的文件,每个json必须有field :text和summary,二者都是tokenized
    dataloader = JsonLoader(fields={
        "text": "words",
        "summary": "abstract_sentences"
    })
    if mode == 'train':
        if train_data_path is None or dev_data_path is None:
            print("training with no train data path or dev data path! ")
        paths = {"train": train_data_path, "dev": dev_data_path}
    else:
        if test_data_path is None:
            print("testing with no test data path ! ")
        paths = {"train": train_data_path, "test": test_data_path}
    # dataInfo = dataloader.process(paths, vocab_path, vocab_size)
    print("=" * 10)
    print(paths)
    dataInfo = dataloader.load(paths)
    for key, _dataset in dataInfo.datasets.items():
        _dataset.apply(lambda ins: " ".join(ins['words']),
                       new_field_name='article')
        _dataset.apply(lambda ins: sent_to_words(ins['words']),
                       new_field_name='words')
        _dataset.apply(
            lambda ins: sent_tokenize(" ".join(ins['abstract_sentences'])),
            new_field_name='abstract_sentences')

    vocab = Vocabulary(max_size=vocab_size - 2,
                       padding=PAD_TOKEN,
                       unknown=UNKNOWN_TOKEN)
    vocab.from_dataset(dataInfo.datasets['train'], field_name='words')
    vocab.add(START_DECODING)
    vocab.add(STOP_DECODING)
    print(vocab.to_word(0))
    print(len(vocab))
    assert vocab_size == len(vocab), "vocab_size error!!!"
    dataInfo.set_vocab(vocab, "train")

    for key, dataset in dataInfo.datasets.items():
        data_dict = {
            "enc_len": [],
            "enc_input": [],
            "dec_input": [],
            "target": [],
            "dec_len": [],
            "article_oovs": [],
            "enc_input_extend_vocab": []
        }

        for instance in dataset:
            article = instance["article"]
            abstract_sentences = instance["abstract_sentences"]

            enc_len, enc_input, dec_input, target, dec_len, article_oovs, enc_input_extend_vocab = getting_full_info(
                article, abstract_sentences, dataInfo.vocabs['train'], config)

            data_dict["enc_len"].append(enc_len)
            data_dict["enc_input"].append(enc_input)
            data_dict["dec_input"].append(dec_input)
            data_dict["target"].append(target)
            data_dict["dec_len"].append(dec_len)
            data_dict["article_oovs"].append(article_oovs)
            data_dict["enc_input_extend_vocab"].append(enc_input_extend_vocab)

        logger.info("-----prepare_dataInfo for dataset " + key + "-----")
        logger.info(
            str(len(data_dict["enc_len"])) + " " +
            str(len(data_dict["enc_input"])) + " " +
            str(len(data_dict["dec_input"])) + " " +
            str(len(data_dict["target"])) + " " +
            str(len(data_dict["dec_len"])) + " " +
            str(len(data_dict["article_oovs"])) + " " +
            str(len(data_dict["enc_input_extend_vocab"])))
        dataset.add_field("enc_len", data_dict["enc_len"])
        dataset.add_field("enc_input", data_dict["enc_input"])
        dataset.add_field("dec_input", data_dict["dec_input"])
        dataset.add_field("target", data_dict["target"])
        dataset.add_field("dec_len", data_dict["dec_len"])
        dataset.add_field("article_oovs", data_dict["article_oovs"])
        dataset.add_field("enc_input_extend_vocab",
                          data_dict["enc_input_extend_vocab"])

        dataset.set_input("enc_len", "enc_input", "dec_input", "dec_len",
                          "article_oovs", "enc_input_extend_vocab")
        dataset.set_target("target", "article_oovs", "abstract_sentences")
    '''
    for name, dataset in dataInfo.datasets.items():
        for field_name in dataset.get_field_names():
            dataset.apply_field(convert_list_to_ndarray, field_name=field_name, new_field_name=field_name)
    '''
    return dataInfo