Ejemplo n.º 1
0
 def test_len(self):
     vocab = Vocabulary(max_size=None,
                        min_freq=None,
                        unknown=None,
                        padding=None)
     vocab.update(text)
     self.assertEqual(len(vocab), len(counter))
Ejemplo n.º 2
0
    def test_vocab(self):
        import _pickle as pickle
        import os
        vocab = Vocabulary()
        filename = 'vocab'
        vocab.update(filename)
        vocab.update([filename, ['a'], [['b']], ['c']])
        idx = vocab[filename]
        before_pic = (vocab.to_word(idx), vocab[filename])

        with open(filename, 'wb') as f:
            pickle.dump(vocab, f)
        with open(filename, 'rb') as f:
            vocab = pickle.load(f)
        os.remove(filename)

        vocab.build_reverse_vocab()
        after_pic = (vocab.to_word(idx), vocab[filename])
        TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8}
        TRUE_DICT.update(DEFAULT_WORD_TO_INDEX)
        TRUE_IDXDICT = {
            0: '<pad>',
            1: '<unk>',
            2: '<reserved-2>',
            3: '<reserved-3>',
            4: '<reserved-4>',
            5: 'vocab',
            6: 'a',
            7: 'b',
            8: 'c'
        }
        self.assertEqual(before_pic, after_pic)
        self.assertDictEqual(TRUE_DICT, vocab.word2idx)
        self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word)
Ejemplo n.º 3
0
def mock_cws():
    os.makedirs("mock", exist_ok=True)
    text = ["这是最好的基于深度学习的中文分词系统。", "大王叫我来巡山。", "我党多年来致力于改善人民生活水平。"]

    word2id = Vocabulary()
    word_list = [ch for ch in "".join(text)]
    word2id.update(word_list)
    save_pickle(word2id, "./mock/", "word2id.pkl")

    class2id = Vocabulary(need_default=False)
    label_list = ['B', 'M', 'E', 'S']
    class2id.update(label_list)
    save_pickle(class2id, "./mock/", "label2id.pkl")

    model_args = {
        "vocab_size": len(word2id),
        "word_emb_dim": 50,
        "rnn_hidden_units": 50,
        "num_classes": len(class2id)
    }
    config_file = """
    [test_section]
    vocab_size = {}
    word_emb_dim = 50
    rnn_hidden_units = 50
    num_classes = {}
    """.format(len(word2id), len(class2id))
    with open("mock/test.cfg", "w", encoding="utf-8") as f:
        f.write(config_file)

    model = AdvSeqLabel(model_args)
    ModelSaver("mock/cws_basic_model_v_0.pkl").save_pytorch(model)
Ejemplo n.º 4
0
def mock_pos_tag():
    os.makedirs("mock", exist_ok=True)
    text = ["这是最好的基于深度学习的中文分词系统。", "大王叫我来巡山。", "我党多年来致力于改善人民生活水平。"]

    vocab = Vocabulary()
    word_list = [ch for ch in "".join(text)]
    vocab.update(word_list)
    save_pickle(vocab, "./mock/", "word2id.pkl")

    idx2label = Vocabulary(need_default=False)
    label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
    idx2label.update(label_list)
    save_pickle(idx2label, "./mock/", "label2id.pkl")

    model_args = {
        "vocab_size": len(vocab),
        "word_emb_dim": 50,
        "rnn_hidden_units": 50,
        "num_classes": len(idx2label)
    }
    config_file = """
        [test_section]
        vocab_size = {}
        word_emb_dim = 50
        rnn_hidden_units = 50
        num_classes = {}
        """.format(len(vocab), len(idx2label))
    with open("mock/test.cfg", "w", encoding="utf-8") as f:
        f.write(config_file)

    model = AdvSeqLabel(model_args)
    ModelSaver("mock/pos_tag_model_v_0.pkl").save_pytorch(model)
Ejemplo n.º 5
0
 def test_contains(self):
     vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
     vocab.update(text)
     self.assertTrue(text[-1] in vocab)
     self.assertFalse("~!@#" in vocab)
     self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
     self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))
Ejemplo n.º 6
0
    def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl):
        """Load the pre-trained embedding and combine with the given dictionary.

        :param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding.
        :param emb_file: str, the pre-trained embedding file path.
        :param emb_type: str, the pre-trained embedding format, support glove now
        :param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding
        :param emb_pkl: str, the embedding pickle file.
        :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim)
                vocab: input vocab or vocab built by pre-train
        TODO: fragile code
        """
        # If the embedding pickle exists, load it and return.
        # if os.path.exists(emb_pkl):
        #     with open(emb_pkl, "rb") as f:
        #         embedding_tensor, vocab = _pickle.load(f)
        #     return embedding_tensor, vocab
        # Otherwise, load the pre-trained embedding.
        pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
        if vocab is None:
            # build vocabulary from pre-trained embedding
            vocab = Vocabulary()
            for w in pretrain.keys():
                vocab.update(w)
        embedding_tensor = torch.randn(len(vocab), emb_dim)
        for w, v in pretrain.items():
            if len(v.shape) > 1 or emb_dim != v.shape[0]:
                raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,)))
            if vocab.has_word(w):
                embedding_tensor[vocab[w]] = v

        # save and return the result
        # with open(emb_pkl, "wb") as f:
        #     _pickle.dump((embedding_tensor, vocab), f)
        return embedding_tensor, vocab
Ejemplo n.º 7
0
    def test_index(self):
        vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
        vocab.update(text)
        res = [vocab[w] for w in set(text)]
        self.assertEqual(len(res), len(set(res)))

        res = [vocab.to_index(w) for w in set(text)]
        self.assertEqual(len(res), len(set(res)))
Ejemplo n.º 8
0
class SeqLabelDataSet(DataSet):
    def __init__(self, instances=None, load_func=POSDataSetLoader().load):
        super(SeqLabelDataSet, self).__init__(name="",
                                              instances=instances,
                                              load_func=load_func)
        self.word_vocab = Vocabulary()
        self.label_vocab = Vocabulary()

    def convert(self, data):
        """Convert lists of strings into Instances with Fields.

        :param data: 3-level lists. Entries are strings.
        """
        bar = ProgressBar(total=len(data))
        for example in data:
            word_seq, label_seq = example[0], example[1]
            # list, list
            self.word_vocab.update(word_seq)
            self.label_vocab.update(label_seq)
            x = TextField(word_seq, is_target=False)
            x_len = LabelField(len(word_seq), is_target=False)
            y = TextField(label_seq, is_target=False)
            instance = Instance()
            instance.add_field("word_seq", x)
            instance.add_field("truth", y)
            instance.add_field("word_seq_origin_len", x_len)
            self.append(instance)
            bar.move()
        self.index_field("word_seq", self.word_vocab)
        self.index_field("truth", self.label_vocab)
        # no need to index "word_seq_origin_len"

    def convert_with_vocabs(self, data, vocabs):
        for example in data:
            word_seq, label_seq = example[0], example[1]
            # list, list
            x = TextField(word_seq, is_target=False)
            x_len = LabelField(len(word_seq), is_target=False)
            y = TextField(label_seq, is_target=False)
            instance = Instance()
            instance.add_field("word_seq", x)
            instance.add_field("truth", y)
            instance.add_field("word_seq_origin_len", x_len)
            self.append(instance)
        self.index_field("word_seq", vocabs["word_vocab"])
        self.index_field("truth", vocabs["label_vocab"])
        # no need to index "word_seq_origin_len"

    def convert_for_infer(self, data, vocabs):
        for word_seq in data:
            # list
            x = TextField(word_seq, is_target=False)
            x_len = LabelField(len(word_seq), is_target=False)
            instance = Instance()
            instance.add_field("word_seq", x)
            instance.add_field("word_seq_origin_len", x_len)
            self.append(instance)
        self.index_field("word_seq", vocabs["word_vocab"])
Ejemplo n.º 9
0
    def test_vocab(self):
        vocab = Vocabulary()
        word_list = "this is a word list".split()
        vocab.update(word_list)

        pred_dict = {"pred": torch.zeros(4, 3)}
        target_dict = {'target': torch.zeros(4)}
        metric = ConfusionMatrixMetric(vocab=vocab)
        metric(pred_dict=pred_dict, target_dict=target_dict)
        print(metric.get_metric())
Ejemplo n.º 10
0
 def test_contains(self):
     vocab = Vocabulary(max_size=None,
                        min_freq=None,
                        unknown=None,
                        padding=None)
     vocab.update(text)
     self.assertTrue(text[-1] in vocab)
     self.assertFalse("~!@#" in vocab)
     self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
     self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))
Ejemplo n.º 11
0
    def test_additional_update(self):
        vocab = Vocabulary(max_size=None, min_freq=None)
        vocab.update(text)

        _ = vocab["well"]
        self.assertEqual(vocab.rebuild, False)

        vocab.add("hahaha")
        self.assertEqual(vocab.rebuild, True)

        _ = vocab["hahaha"]
        self.assertEqual(vocab.rebuild, False)
        self.assertTrue("hahaha" in vocab)
Ejemplo n.º 12
0
    def test_warning(self):
        vocab = Vocabulary(max_size=len(set(text)), min_freq=None)
        vocab.update(text)
        self.assertEqual(vocab.rebuild, True)
        print(len(vocab))
        self.assertEqual(vocab.rebuild, False)

        vocab.update([
            "hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg",
            "feqfw"
        ])
        # this will print a warning
        self.assertEqual(vocab.rebuild, True)
Ejemplo n.º 13
0
class VocabProcessor(Processor):
    def __init__(self, field_name):
        super(VocabProcessor, self).__init__(field_name, None)
        self.vocab = Vocabulary()

    def process(self, *datasets):
        for dataset in datasets:
            assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
            for ins in dataset:
                tokens = ins[self.field_name]
                self.vocab.update(tokens)

    def get_vocab(self):
        self.vocab.build_vocab()
        return self.vocab
Ejemplo n.º 14
0
class TextClassifyDataSet(DataSet):
    def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
        super(TextClassifyDataSet, self).__init__(name="",
                                                  instances=instances,
                                                  load_func=load_func)
        self.word_vocab = Vocabulary()
        self.label_vocab = Vocabulary(need_default=False)

    def convert(self, data):
        for example in data:
            word_seq, label = example[0], example[1]
            # list, str
            self.word_vocab.update(word_seq)
            self.label_vocab.update(label)
            x = TextField(word_seq, is_target=False)
            y = LabelField(label, is_target=True)
            instance = Instance()
            instance.add_field("word_seq", x)
            instance.add_field("label", y)
            self.append(instance)
        self.index_field("word_seq", self.word_vocab)
        self.index_field("label", self.label_vocab)

    def convert_with_vocabs(self, data, vocabs):
        for example in data:
            word_seq, label = example[0], example[1]
            # list, str
            x = TextField(word_seq, is_target=False)
            y = LabelField(label, is_target=True)
            instance = Instance()
            instance.add_field("word_seq", x)
            instance.add_field("label", y)
            self.append(instance)
        self.index_field("word_seq", vocabs["word_vocab"])
        self.index_field("label", vocabs["label_vocab"])

    def convert_for_infer(self, data, vocabs):
        for word_seq in data:
            # list
            x = TextField(word_seq, is_target=False)
            instance = Instance()
            instance.add_field("word_seq", x)
            self.append(instance)
        self.index_field("word_seq", vocabs["word_vocab"])
Ejemplo n.º 15
0
def add_words_field_2_databundle(data_bundle):
    train_cws_field = "data/wb_cws/train_cws_word.txt"
    dev_cws_field = "data/wb_cws/dev_cws_word.txt"
    test_cws_field = "data/wb_cws/test_cws_word.txt"

    train_field = _read_txt(train_cws_field)
    dev_field = _read_txt(dev_cws_field)
    test_field = _read_txt(test_cws_field)
    #
    #
    data_bundle.get_dataset('train').add_field(field_name="raw_words",
                                               fields=train_field)
    data_bundle.get_dataset('dev').add_field(field_name="raw_words",
                                             fields=dev_field)
    data_bundle.get_dataset('test').add_field(field_name="raw_words",
                                              fields=test_field)

    # 添加词表
    words_vocab = Vocabulary()
    word_list = get_corpus_words(train_cws_field, dev_cws_field,
                                 test_cws_field)
    words_vocab.update(word_list)
    data_bundle.set_vocab(words_vocab, field_name="words")

    # 将raw_words转换为words_id
    for dataset in ["train", "dev", "test"]:
        raw_words = list(data_bundle.get_dataset(dataset)["raw_words"])
        words_ids = []
        for words in raw_words:
            words_id = []
            for word in words:
                words_id.append(words_vocab.to_index(word))
            words_ids.append(words_id)
        data_bundle.get_dataset(dataset).add_field(field_name="words",
                                                   fields=words_ids)
    data_bundle.set_input('words')
    data_bundle.set_ignore_type('words', flag=False)
    data_bundle.set_pad_val("words", 0)
    return data_bundle
Ejemplo n.º 16
0
def mock_text_classify():
    os.makedirs("mock", exist_ok=True)
    text = [
        "世界物联网大会明日在京召开龙头股启动在即", "乌鲁木齐市新增一处城市中心旅游目的地",
        "朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"
    ]
    vocab = Vocabulary()
    word_list = [ch for ch in "".join(text)]
    vocab.update(word_list)
    save_pickle(vocab, "./mock/", "word2id.pkl")

    idx2label = Vocabulary(need_default=False)
    label_list = [
        'class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'
    ]
    idx2label.update(label_list)
    save_pickle(idx2label, "./mock/", "label2id.pkl")

    model_args = {
        "vocab_size": len(vocab),
        "word_emb_dim": 50,
        "rnn_hidden_units": 50,
        "num_classes": len(idx2label)
    }
    config_file = """
            [test_section]
            vocab_size = {}
            word_emb_dim = 50
            rnn_hidden_units = 50
            num_classes = {}
            """.format(len(vocab), len(idx2label))
    with open("mock/test.cfg", "w", encoding="utf-8") as f:
        f.write(config_file)

    model = CNNText(model_args)
    ModelSaver("mock/text_class_model_v0.pkl").save_pytorch(model)
Ejemplo n.º 17
0
class VocabProcessor(Processor):
    def __init__(self, field_name, min_freq=1, max_size=None):

        super(VocabProcessor, self).__init__(field_name, None)
        self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size)

    def process(self, *datasets):
        for dataset in datasets:
            assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
            dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))

    def get_vocab(self):
        self.vocab.build_vocab()
        return self.vocab

    def get_vocab_size(self):
        return len(self.vocab)
Ejemplo n.º 18
0
class Preprocessor(object):
    """Preprocessors are responsible for converting data of strings into data of indices.
    During the pre-processing, the following pickle files will be built:

        - "word2id.pkl", a Vocabulary object, mapping words to indices.
        - "class2id.pkl", a Vocabulary object, mapping labels to indices.
        - "data_train.pkl", a DataSet object for training
        - "data_dev.pkl", a DataSet object for validation, if train_dev_split > 0.
        - "data_test.pkl", a DataSet object for testing, if test_data is not None.

    These four pickle files are expected to be saved in the given pickle directory once they are constructed.
    Preprocessors will check if those files are already in the directory and will reuse them in future calls.
    """
    def __init__(self,
                 label_is_seq=False,
                 share_vocab=False,
                 add_char_field=False):
        """

        :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve
                several special tokens for sequence processing.
        :param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this
                is only available when label_is_seq is True. Default: False.
        :param add_char_field: bool, whether to add character representations to all TextFields. Default: False.
        """
        print("Preprocessor is about to deprecate. Please use DataSet class.")
        self.data_vocab = Vocabulary()
        if label_is_seq is True:
            if share_vocab is True:
                self.label_vocab = self.data_vocab
            else:
                self.label_vocab = Vocabulary()
        else:
            self.label_vocab = Vocabulary(need_default=False)

        self.character_vocab = Vocabulary(need_default=False)
        self.add_char_field = add_char_field

    @property
    def vocab_size(self):
        return len(self.data_vocab)

    @property
    def num_classes(self):
        return len(self.label_vocab)

    @property
    def char_vocab_size(self):
        if self.character_vocab is None:
            self.build_char_dict()
        return len(self.character_vocab)

    def run(self,
            train_dev_data,
            test_data=None,
            pickle_path="./",
            train_dev_split=0,
            cross_val=False,
            n_fold=10):
        """Main pre-processing pipeline.

        :param train_dev_data: three-level list, with either single label or multiple labels in a sample.
        :param test_data: three-level list, with either single label or multiple labels in a sample. (optional)
        :param pickle_path: str, the path to save the pickle files.
        :param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set.
        :param cross_val: bool, whether to do cross validation.
        :param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True.
        :return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset.
                If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
                is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
        """
        if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(
                pickle_path, "class2id.pkl"):
            self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
            self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
        else:
            self.data_vocab, self.label_vocab = self.build_dict(train_dev_data)
            save_pickle(self.data_vocab, pickle_path, "word2id.pkl")
            save_pickle(self.label_vocab, pickle_path, "class2id.pkl")

        self.build_reverse_dict()

        train_set = []
        dev_set = []
        if not cross_val:
            if not pickle_exist(pickle_path, "data_train.pkl"):
                if train_dev_split > 0 and not pickle_exist(
                        pickle_path, "data_dev.pkl"):
                    split = int(len(train_dev_data) * train_dev_split)
                    data_dev = train_dev_data[:split]
                    data_train = train_dev_data[split:]
                    train_set = self.convert_to_dataset(
                        data_train, self.data_vocab, self.label_vocab)
                    dev_set = self.convert_to_dataset(data_dev,
                                                      self.data_vocab,
                                                      self.label_vocab)

                    save_pickle(dev_set, pickle_path, "data_dev.pkl")
                    print("{} of the training data is split for validation. ".
                          format(train_dev_split))
                else:
                    train_set = self.convert_to_dataset(
                        train_dev_data, self.data_vocab, self.label_vocab)
                save_pickle(train_set, pickle_path, "data_train.pkl")
            else:
                train_set = load_pickle(pickle_path, "data_train.pkl")
                if pickle_exist(pickle_path, "data_dev.pkl"):
                    dev_set = load_pickle(pickle_path, "data_dev.pkl")
        else:
            # cross_val is True
            if not pickle_exist(pickle_path, "data_train_0.pkl"):
                # cross validation
                data_cv = self.cv_split(train_dev_data, n_fold)
                for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
                    data_train_cv = self.convert_to_dataset(
                        data_train_cv, self.data_vocab, self.label_vocab)
                    data_dev_cv = self.convert_to_dataset(
                        data_dev_cv, self.data_vocab, self.label_vocab)
                    save_pickle(data_train_cv, pickle_path,
                                "data_train_{}.pkl".format(i))
                    save_pickle(data_dev_cv, pickle_path,
                                "data_dev_{}.pkl".format(i))
                    train_set.append(data_train_cv)
                    dev_set.append(data_dev_cv)
                print("{}-fold cross validation.".format(n_fold))
            else:
                for i in range(n_fold):
                    data_train_cv = load_pickle(pickle_path,
                                                "data_train_{}.pkl".format(i))
                    data_dev_cv = load_pickle(pickle_path,
                                              "data_dev_{}.pkl".format(i))
                    train_set.append(data_train_cv)
                    dev_set.append(data_dev_cv)

        # prepare test data if provided
        test_set = []
        if test_data is not None:
            if not pickle_exist(pickle_path, "data_test.pkl"):
                test_set = self.convert_to_dataset(test_data, self.data_vocab,
                                                   self.label_vocab)
                save_pickle(test_set, pickle_path, "data_test.pkl")

        # return preprocessed results
        results = [train_set]
        if cross_val or train_dev_split > 0:
            results.append(dev_set)
        if test_data:
            results.append(test_set)
        if len(results) == 1:
            return results[0]
        else:
            return tuple(results)

    def build_dict(self, data):
        for example in data:
            word, label = example
            self.data_vocab.update(word)
            self.label_vocab.update(label)
        return self.data_vocab, self.label_vocab

    def build_char_dict(self):
        char_collection = set()
        for word in self.data_vocab.word2idx:
            if len(word) == 0:
                continue
            for ch in word:
                if ch not in char_collection:
                    char_collection.add(ch)
        self.character_vocab.update(list(char_collection))

    def build_reverse_dict(self):
        self.data_vocab.build_reverse_vocab()
        self.label_vocab.build_reverse_vocab()

    def data_split(self, data, train_dev_split):
        """Split data into train and dev set."""
        split = int(len(data) * train_dev_split)
        data_dev = data[:split]
        data_train = data[split:]
        return data_train, data_dev

    def cv_split(self, data, n_fold):
        """Split data for cross validation.

        :param data: list of string
        :param n_fold: int
        :return data_cv:

            ::
            [
                (data_train, data_dev),  # 1st fold
                (data_train, data_dev),  # 2nd fold
                ...
            ]

        """
        data_copy = data.copy()
        np.random.shuffle(data_copy)
        fold_size = round(len(data_copy) / n_fold)
        data_cv = []
        for i in range(n_fold - 1):
            start = i * fold_size
            end = (i + 1) * fold_size
            data_dev = data_copy[start:end]
            data_train = data_copy[:start] + data_copy[end:]
            data_cv.append((data_train, data_dev))
        start = (n_fold - 1) * fold_size
        data_dev = data_copy[start:]
        data_train = data_copy[:start]
        data_cv.append((data_train, data_dev))
        return data_cv

    def convert_to_dataset(self, data, vocab, label_vocab):
        """Convert list of indices into a DataSet object.

        :param data: list. Entries are strings.
        :param vocab: a dict, mapping string (token) to index (int).
        :param label_vocab: a dict, mapping string (label) to index (int).
        :return data_set: a DataSet object
        """
        use_word_seq = False
        use_label_seq = False
        use_label_str = False

        # construct a DataSet object and fill it with Instances
        data_set = DataSet()
        for example in data:
            words, label = example[0], example[1]
            instance = Instance()

            if isinstance(words, list):
                x = TextField(words, is_target=False)
                instance.add_field("word_seq", x)
                use_word_seq = True
            else:
                raise NotImplementedError("words is a {}".format(type(words)))

            if isinstance(label, list):
                y = TextField(label, is_target=True)
                instance.add_field("label_seq", y)
                use_label_seq = True
            elif isinstance(label, str):
                y = LabelField(label, is_target=True)
                instance.add_field("label", y)
                use_label_str = True
            else:
                raise NotImplementedError("label is a {}".format(type(label)))
            data_set.append(instance)

        # convert strings to indices
        if use_word_seq:
            data_set.index_field("word_seq", vocab)
        if use_label_seq:
            data_set.index_field("label_seq", label_vocab)
        if use_label_str:
            data_set.index_field("label", label_vocab)

        return data_set
Ejemplo n.º 19
0
class VocabIndexerProcessor(Processor):
    """
    根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供
        new_added_field_name, 则覆盖原有的field_name.

    """
    def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
                 verbose=0, is_input=True):
        """

        :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
        :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name.
        :param min_freq: 创建的Vocabulary允许的单词最少出现次数.
        :param max_size: 创建的Vocabulary允许的最大的单词数量
        :param verbose: 0, 不输出任何信息;1,输出信息
        :param bool is_input:
        """
        super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
        self.min_freq = min_freq
        self.max_size = max_size

        self.verbose =verbose
        self.is_input = is_input

    def construct_vocab(self, *datasets):
        """
        使用传入的DataSet创建vocabulary

        :param datasets: DataSet类型的数据,用于构建vocabulary
        :return:
        """
        self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size)
        for dataset in datasets:
            assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
            dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
        self.vocab.build_vocab()
        if self.verbose:
            print("Vocabulary Constructed, has {} items.".format(len(self.vocab)))

    def process(self, *datasets, only_index_dataset=None):
        """
        若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary
            后,则会index datasets与only_index_dataset。

        :param datasets: DataSet类型的数据
        :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。
        :return:
        """
        if len(datasets)==0 and not hasattr(self,'vocab'):
            raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.")
        if not hasattr(self, 'vocab'):
            self.construct_vocab(*datasets)
        else:
            if self.verbose:
                print("Using constructed vocabulary with {} items.".format(len(self.vocab)))
        to_index_datasets = []
        if len(datasets)!=0:
            for dataset in datasets:
                assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
                to_index_datasets.append(dataset)

        if not (only_index_dataset is None):
            if isinstance(only_index_dataset, list):
                for dataset in only_index_dataset:
                    assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
                    to_index_datasets.append(dataset)
            elif isinstance(only_index_dataset, DataSet):
                to_index_datasets.append(only_index_dataset)
            else:
                raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset)))

        for dataset in to_index_datasets:
            assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
            dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
                          new_field_name=self.new_added_field_name, is_input=self.is_input)
        # 只返回一个,infer时为了跟其他processor保持一致
        if len(to_index_datasets) == 1:
            return to_index_datasets[0]

    def set_vocab(self, vocab):
        assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))
        self.vocab = vocab

    def delete_vocab(self):
        del self.vocab

    def get_vocab_size(self):
        return len(self.vocab)

    def set_verbose(self, verbose):
        """
        设置processor verbose状态。

        :param verbose: int, 0,不输出任何信息;1,输出vocab 信息。
        :return:
        """
        self.verbose = verbose
Ejemplo n.º 20
0
 def test_to_word(self):
     vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
     vocab.update(text)
     self.assertEqual(
         text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
Ejemplo n.º 21
0
 def test_case(self):
     vocab = Vocabulary()
     vocab.update(["the", "in", "I", "to", "of", "hahaha"])
     embedding = EmbedLoader().fast_load_embedding(
         50, "test/data_for_tests/glove.6B.50d_test.txt", vocab)
     self.assertEqual(tuple(embedding.shape), (len(vocab), 50))
Ejemplo n.º 22
0
 def test_len(self):
     vocab = Vocabulary(need_default=False, max_size=None, min_freq=None)
     vocab.update(text)
     self.assertEqual(len(vocab), len(counter))
Ejemplo n.º 23
0
 def test_update(self):
     vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
     vocab.update(text)
     self.assertEqual(vocab.word_count, counter)