Example #1
0
    def test_case_1(self):
        args = {
            "epochs": 3,
            "batch_size": 2,
            "validate": False,
            "use_cuda": False,
            "pickle_path": "./save/",
            "save_best_dev": True,
            "model_name": "default_model_name.pkl",
            "loss": Loss("cross_entropy"),
            "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5,
            "evaluator": SeqLabelEvaluator()
        }
        trainer = SeqLabelTrainer(**args)

        train_data = [
            [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
            [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
        ]
        vocab = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
            '!': 5,
            '@': 6,
            '#': 7,
            '$': 8,
            '?': 9
        }
        label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

        data_set = DataSet()
        for example in train_data:
            text, label = example[0], example[1]
            x = TextField(text, False)
            x_len = LabelField(len(text), is_target=False)
            y = TextField(label, is_target=False)
            ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
            data_set.append(ins)

        data_set.index_field("word_seq", vocab)
        data_set.index_field("truth", label_vocab)

        model = SeqLabeling(args)

        trainer.train(network=model, train_data=data_set, dev_data=data_set)
        # If this can run, everything is OK.

        os.system("rm -rf save")
        print("pickle path deleted")
Example #2
0
 def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
     if in_word_splitter is None:
         in_word_splitter = self.in_word_splitter
     dataset = DataSet()
     with open(filepath, 'r') as f:
         words = []
         for line in f:
             line = line.strip()
             if len(line) == 0:  # new line
                 if len(words) == 0:  # 不能接受空行
                     continue
                 line = ' '.join(words)
                 if cut_long_sent:
                     sents = cut_long_sentence(line)
                 else:
                     sents = [line]
                 for sent in sents:
                     instance = Instance(raw_sentence=sent)
                     dataset.append(instance)
                 words = []
             else:
                 line = line.split()[0]
                 if in_word_splitter is None:
                     words.append(line)
                 else:
                     words.append(line.split(in_word_splitter)[0])
     return dataset
Example #3
0
 def test_append(self):
     dd = DataSet()
     for _ in range(3):
         dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
     self.assertEqual(len(dd), 3)
     self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
     self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
Example #4
0
 def _load(self, path: str = None):
     logging.info(path)
     ds = DataSet()
     with open(path, 'r', encoding='utf-8') as f:
         for line in f:
             if line == '': continue
             splits = line.strip().split('\t')
             if len(splits) == 4:
                 raw_targets = [int(i) for i in splits[3].strip().lstrip('[').rstrip(']').split(' ')]
             elif len(splits) == 3:
                 raw_targets = [0, 0, 0, 0, 0]
             else:
                 logging.error('data format error')
             raw_query = splits[0]
             raw_entity = splits[1]
             left_context = raw_query[0:raw_query.find(raw_entity)]
             right_context = raw_query[raw_query.find(raw_entity) + len(raw_entity):]
             if left_context == '': left_context = '-'
             if right_context == '': right_context = '-'
             raw_entity_label = splits[2]
             if left_context and right_context and raw_entity and raw_entity_label:
                 ds.append(
                     Instance(left_context=tokenize(left_context),
                              right_context=tokenize(right_context),
                              raw_entity=tokenize(raw_entity),
                              raw_entity_label=entity_label_tokenize(raw_entity_label),
                              target=raw_targets))
     return ds
Example #5
0
    def convert(self, data):
        """Convert a 3D list to a DataSet object.

        :param data: A 3D tensor.
            Example::
                [
                    [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
                    [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
                    ...
                ]

        :return: A DataSet object.
        """

        data_set = DataSet()

        for example in data:
            p, h, l = example
            # list, list, str
            instance = Instance()
            instance.add_field("premise", p)
            instance.add_field("hypothesis", h)
            instance.add_field("truth", l)
            data_set.append(instance)
        data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len")
        data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len")
        data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len")
        data_set.set_target("truth")
        return data_set
Example #6
0
    def convert(self, data):
        """Convert a 3D list to a DataSet object.

        :param data: A 3D tensor.
            [
                [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
                [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
                ...
            ]
        :return: data_set: A DataSet object.
        """

        data_set = DataSet()

        for example in data:
            p, h, l = example
            # list, list, str
            x1 = TextField(p, is_target=False)
            x2 = TextField(h, is_target=False)
            x1_len = TextField([1] * len(p), is_target=False)
            x2_len = TextField([1] * len(h), is_target=False)
            y = LabelField(l, is_target=True)
            instance = Instance()
            instance.add_field("premise", x1)
            instance.add_field("hypothesis", x2)
            instance.add_field("premise_len", x1_len)
            instance.add_field("hypothesis_len", x2_len)
            instance.add_field("truth", y)
            data_set.append(instance)

        return data_set
Example #7
0
    def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
        """
        允许使用的情况有(默认以\t或空格作为seg)
            这是 fastNLP , 一个 非常 good 的 包 .
        和
            也/D  在/P  團員/Na  之中/Ng  ,/COMMACATEGORY
        如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0]
        :param filepath:
        :param in_word_splitter:
        :return:
        """
        if in_word_splitter == None:
            in_word_splitter = self.in_word_splitter
        dataset = DataSet()
        with open(filepath, 'r') as f:
            for line in f:
                line = line.strip()
                if len(line.replace(' ', '')) == 0:  # 不能接受空行
                    continue

                if not in_word_splitter is None:
                    words = []
                    for part in line.split():
                        word = part.split(in_word_splitter)[0]
                        words.append(word)
                        line = ' '.join(words)
                if cut_long_sent:
                    sents = cut_long_sentence(line)
                else:
                    sents = [line]
                for sent in sents:
                    instance = Instance(raw_sentence=sent)
                    dataset.append(instance)

        return dataset
Example #8
0
    def load(self, path):
        datalist = []
        with open(path, 'r', encoding='utf-8') as f:
            sample = []
            for line in f:
                if line.startswith('\n'):
                    datalist.append(sample)
                    sample = []
                elif line.startswith('#'):
                    continue
                else:
                    sample.append(line.split('\t'))
            if len(sample) > 0:
                datalist.append(sample)

        ds = DataSet(name='conll')
        for sample in datalist:
            # print(sample)
            res = self.get_one(sample)
            ds.append(
                Instance(word_seq=TextField(res[0], is_target=False),
                         pos_seq=TextField(res[1], is_target=False),
                         head_indices=SeqLabelField(res[2], is_target=True),
                         head_labels=TextField(res[3], is_target=True)))

        return ds
Example #9
0
    def load(self, path):
        datalist = []
        with open(path, 'r', encoding='utf-8') as f:
            sample = []
            for line in f:
                if line.startswith('\n'):
                    datalist.append(sample)
                    sample = []
                elif line.startswith('#'):
                    continue
                else:
                    sample.append(line.split('\t'))
            if len(sample) > 0:
                datalist.append(sample)

        data = [self.get_one(sample) for sample in datalist]
        data_list = list(filter(lambda x: x is not None, data))

        ds = DataSet()
        for example in data_list:
            ds.append(
                Instance(words=example[0],
                         pos_tags=example[1],
                         heads=example[2],
                         labels=example[3]))
        return ds
Example #10
0
    def load(self, path, cut_long_sent=False):
        datalist = []
        with open(path, 'r', encoding='utf-8') as f:
            sample = []
            for line in f:
                if line.startswith('\n'):
                    datalist.append(sample)
                    sample = []
                elif line.startswith('#'):
                    continue
                else:
                    sample.append(line.split('\t'))
            if len(sample) > 0:
                datalist.append(sample)

        ds = DataSet()
        for sample in datalist:
            # print(sample)
            res = self.get_one(sample)
            if res is None:
                continue
            line = '  '.join(res)
            if cut_long_sent:
                sents = cut_long_sentence(line)
            else:
                sents = [line]
            for raw_sentence in sents:
                ds.append(Instance(raw_sentence=raw_sentence))

        return ds
Example #11
0
    def test(self):
        data = DataSet()
        for text, label in zip(texts, labels):
            x = TextField(text, is_target=False)
            y = LabelField(label, is_target=True)
            ins = Instance(text=x, label=y)
            data.append(ins)

        # use vocabulary to index data
        data.index_field("text", vocab)

        # define naive sampler for batch class
        class SeqSampler:
            def __call__(self, dataset):
                return list(range(len(dataset)))

        # use batch to iterate dataset
        data_iterator = Batch(data, 2, SeqSampler(), False)
        total_data = 0
        for batch_x, batch_y in data_iterator:
            total_data += batch_x["text"].size(0)
            self.assertTrue(batch_x["text"].size(0) == 2
                            or total_data == len(raw_texts))
            self.assertTrue(isinstance(batch_x, dict))
            self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
            self.assertTrue(isinstance(batch_y, dict))
            self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))
Example #12
0
    def _load(self, path):
        ds = DataSet()
        for idx, data in _read_conll(path,
                                     indexes=self.indexes,
                                     dropna=self.dropna):
            #            if data[0][0] == '#':
            #                data[0] = data[0][1:]
            #                data[1] = data[1][1:]
            for i in range(len(self.headers)):
                if data[i][0].startswith('NE-'):
                    data[i] = data[i][1:]
                if 'TOKEN' in data[i][0]:
                    data[i] = data[i][1:]

            # print(data) #data[1] = iob(list(data[1]))
            doc_start = False
            for i, h in enumerate(self.headers):
                field = data[i]
                if str(' '.join(list(field))).startswith(' #'):
                    continue
                if str(field[0]).startswith('-DOCSTART-'):
                    doc_start = True
                    break
            if doc_start:
                continue
            ins = {h: data[i] for i, h in enumerate(self.headers)}
            ds.append(Instance(**ins))
        if len(ds) == 0:
            raise RuntimeError("No data found {}.".format(path))
        return ds
Example #13
0
 def test(self):
     data = DataSet()
     for text in texts:
         x = TextField(text, is_target=False)
         ins = Instance(text=x)
         data.append(ins)
     data_set = create_dataset_from_lists(texts, vocab, has_target=False)
     self.assertTrue(type(data) == type(data_set))
Example #14
0
    def load(self, path):
        """
        返回的DataSet, 包含以下的field
            words:list of str,
            tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
        假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
        ::

            1	编者按	编者按	NN	O	11	nmod:topic
            2	:	:	PU	O	11	punct
            3	7月	7月	NT	DATE	4	compound:nn
            4	12日	12日	NT	DATE	11	nmod:tmod
            5	,	,	PU	O	11	punct

            1	这	这	DT	O	3	det
            2	款	款	M	O	1	mark:clf
            3	飞行	飞行	NN	O	8	nsubj
            4	从	从	P	O	5	case
            5	外型	外型	NN	O	8	nmod:prep

        """
        datalist = []
        with open(path, 'r', encoding='utf-8') as f:
            sample = []
            for line in f:
                if line.startswith('\n'):
                    datalist.append(sample)
                    sample = []
                elif line.startswith('#'):
                    continue
                else:
                    sample.append(line.split('\t'))
            if len(sample) > 0:
                datalist.append(sample)

        ds = DataSet()
        for sample in datalist:
            # print(sample)
            res = self.get_one(sample)
            if res is None:
                continue
            char_seq = []
            pos_seq = []
            for word, tag in zip(res[0], res[1]):
                char_seq.extend(list(word))
                if len(word) == 1:
                    pos_seq.append('S-{}'.format(tag))
                elif len(word) > 1:
                    pos_seq.append('B-{}'.format(tag))
                    for _ in range(len(word) - 2):
                        pos_seq.append('M-{}'.format(tag))
                    pos_seq.append('E-{}'.format(tag))
                else:
                    raise ValueError("Zero length of word detected.")

            ds.append(Instance(words=char_seq, tag=pos_seq))

        return ds
Example #15
0
 def convert(self, data):
     data_set = DataSet()
     for item in data:
         sent_words, sent_pos_tag = item[0], item[1]
         data_set.append(Instance(words=sent_words, tags=sent_pos_tag))
     data_set.apply(lambda ins: len(ins), new_field_name="seq_len")
     data_set.set_target("tags")
     data_set.set_input("sent_words")
     data_set.set_input("seq_len")
     return data_set
Example #16
0
    def convert(self, parsed_data):
        dataset = DataSet()
        for sample in parsed_data:
            label0_list = list(map(lambda labels: labels[0], sample[1]))
            label1_list = list(map(lambda labels: labels[1], sample[1]))
            label2_list = list(map(lambda labels: labels[2], sample[1]))
            dataset.append(
                Instance(token_list=sample[0],
                         label0_list=label0_list,
                         label1_list=label1_list,
                         label2_list=label2_list))

        return dataset
Example #17
0
 def convert(self, data):
     dataset = DataSet()
     for sample in data:
         word_seq = [BOS] + sample[0] + [EOS]
         pos_seq = [BOS] + sample[1] + [EOS]
         heads = [0] + list(map(int, sample[2])) + [0]
         head_tags = [BOS] + sample[3] + [EOS]
         dataset.append(
             Instance(word_seq=TextField(word_seq, is_target=False),
                      pos_seq=TextField(pos_seq, is_target=False),
                      gold_heads=SeqLabelField(heads, is_target=False),
                      head_indices=SeqLabelField(heads, is_target=True),
                      head_labels=TextField(head_tags, is_target=True)))
     return dataset
Example #18
0
def convert(data):
    dataset = DataSet()
    for sample in data:
        word_seq = [BOS] + sample[0]
        pos_seq = [BOS] + sample[1]
        heads = [0] + list(map(int, sample[2]))
        head_tags = [BOS] + sample[3]
        dataset.append(
            Instance(words=word_seq,
                     pos=pos_seq,
                     gold_heads=heads,
                     arc_true=heads,
                     tags=head_tags))
    return dataset
Example #19
0
 def convert(data):
     BOS = '<BOS>'
     dataset = DataSet()
     for sample in data:
         word_seq = [BOS] + sample[0]
         pos_seq = [BOS] + sample[1]
         heads = [0] + sample[2]
         head_tags = [BOS] + sample[3]
         dataset.append(
             Instance(raw_words=word_seq,
                      pos=pos_seq,
                      gold_heads=heads,
                      arc_true=heads,
                      tags=head_tags))
     return dataset
Example #20
0
def convert_seq_dataset(data):
    """Create an DataSet instance that contains no labels.

    :param data: list of list of strings, [num_examples, *].
            ::
            [
                [word_11, word_12, ...],
                ...
            ]

    :return: a DataSet.
    """
    dataset = DataSet()
    for word_seq in data:
        dataset.append(Instance(word_seq=word_seq))
    return dataset
Example #21
0
    def load(self, path, cut_long_sent=False):
        """
        返回的DataSet只包含raw_sentence这个field,内容为str。
        假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
        ::

            1	编者按	编者按	NN	O	11	nmod:topic
            2	:	:	PU	O	11	punct
            3	7月	7月	NT	DATE	4	compound:nn
            4	12日	12日	NT	DATE	11	nmod:tmod
            5	,	,	PU	O	11	punct

            1	这	这	DT	O	3	det
            2	款	款	M	O	1	mark:clf
            3	飞行	飞行	NN	O	8	nsubj
            4	从	从	P	O	5	case
            5	外型	外型	NN	O	8	nmod:prep

        """
        datalist = []
        with open(path, 'r', encoding='utf-8') as f:
            sample = []
            for line in f:
                if line.startswith('\n'):
                    datalist.append(sample)
                    sample = []
                elif line.startswith('#'):
                    continue
                else:
                    sample.append(line.strip().split())
            if len(sample) > 0:
                datalist.append(sample)

        ds = DataSet()
        for sample in datalist:
            # print(sample)
            res = self.get_char_lst(sample)
            if res is None:
                continue
            line = ' '.join(res)
            if cut_long_sent:
                sents = cut_long_sentence(line)
            else:
                sents = [line]
            for raw_sentence in sents:
                ds.append(Instance(raw_sentence=raw_sentence))
        return ds
Example #22
0
 def convert(self, data):
     data_set = DataSet()
     for item in data:
         sent_words = item[0]
         if self.pos is True and self.ner is True:
             instance = Instance(words=sent_words,
                                 pos_tags=item[1],
                                 ner=item[2])
         elif self.pos is True:
             instance = Instance(words=sent_words, pos_tags=item[1])
         elif self.ner is True:
             instance = Instance(words=sent_words, ner=item[1])
         else:
             instance = Instance(words=sent_words)
         data_set.append(instance)
     data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
     return data_set
Example #23
0
def convert_seq2seq_dataset(data):
    """Convert list of data into DataSet

    :param data: list of list of strings, [num_examples, *].
            ::
            [
                [ [word_11, word_12, ...], [label_1, label_1, ...] ],
                [ [word_21, word_22, ...], [label_2, label_1, ...] ],
                ...
            ]

    :return: a DataSet.
    """
    dataset = DataSet()
    for sample in data:
        dataset.append(Instance(word_seq=sample[0], label_seq=sample[1]))
    return dataset
Example #24
0
    def convert_to_dataset(self, data, vocab, label_vocab):
        """Convert list of indices into a DataSet object.

        :param data: list. Entries are strings.
        :param vocab: a dict, mapping string (token) to index (int).
        :param label_vocab: a dict, mapping string (label) to index (int).
        :return data_set: a DataSet object
        """
        use_word_seq = False
        use_label_seq = False
        use_label_str = False

        # construct a DataSet object and fill it with Instances
        data_set = DataSet()
        for example in data:
            words, label = example[0], example[1]
            instance = Instance()

            if isinstance(words, list):
                x = TextField(words, is_target=False)
                instance.add_field("word_seq", x)
                use_word_seq = True
            else:
                raise NotImplementedError("words is a {}".format(type(words)))

            if isinstance(label, list):
                y = TextField(label, is_target=True)
                instance.add_field("label_seq", y)
                use_label_seq = True
            elif isinstance(label, str):
                y = LabelField(label, is_target=True)
                instance.add_field("label", y)
                use_label_str = True
            else:
                raise NotImplementedError("label is a {}".format(type(label)))
            data_set.append(instance)

        # convert strings to indices
        if use_word_seq:
            data_set.index_field("word_seq", vocab)
        if use_label_seq:
            data_set.index_field("label_seq", label_vocab)
        if use_label_str:
            data_set.index_field("label", label_vocab)

        return data_set
Example #25
0
    def test_case_1(self):
        model_args = {
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5
        }
        valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
                      "save_loss": True, "batch_size": 2, "pickle_path": "./save/",
                      "use_cuda": False, "print_every_step": 1}

        train_data = [
            [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
            [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
        ]
        vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
        label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

        data_set = DataSet()
        for example in train_data:
            text, label = example[0], example[1]
            x = TextField(text, False)
            y = TextField(label, is_target=True)
            ins = Instance(word_seq=x, label_seq=y)
            data_set.append(ins)

        data_set.index_field("word_seq", vocab)
        data_set.index_field("label_seq", label_vocab)

        model = SeqLabeling(model_args)

        tester = SeqLabelTester(**valid_args)
        tester.test(network=model, dev_data=data_set)
        # If this can run, everything is OK.

        os.system("rm -rf save")
        print("pickle path deleted")
Example #26
0
def convert_seq2seq_dataset(data):
    """Convert list of data into DataSet

    :param data: list of list of strings, [num_examples, *].
            ::
            [
                [ [word_11, word_12, ...], [label_1, label_1, ...] ],
                [ [word_21, word_22, ...], [label_2, label_1, ...] ],
                ...
            ]

    :return: a DataSet.
    """
    dataset = DataSet()
    for sample in data:
        word_seq, label_seq = sample[0], sample[1]
        ins = Instance()
        ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
            .add_field("label_seq", TextField(label_seq, is_target=True))
        dataset.append(ins)
    return dataset
Example #27
0
    def load(self, path):
        datalist = []
        with open(path, 'r', encoding='utf-8') as f:
            sample = []
            for line in f:
                if line.startswith('\n'):
                    datalist.append(sample)
                    sample = []
                elif line.startswith('#'):
                    continue
                else:
                    sample.append(line.split('\t'))
            if len(sample) > 0:
                datalist.append(sample)

        ds = DataSet()
        for sample in datalist:
            # print(sample)
            res = self.get_one(sample)
            if res is None:
                continue
            char_seq = []
            pos_seq = []
            for word, tag in zip(res[0], res[1]):
                if len(word)==1:
                    char_seq.append(word)
                    pos_seq.append('S-{}'.format(tag))
                elif len(word)>1:
                    pos_seq.append('B-{}'.format(tag))
                    for _ in range(len(word)-2):
                        pos_seq.append('M-{}'.format(tag))
                    pos_seq.append('E-{}'.format(tag))
                    char_seq.extend(list(word))
                else:
                    raise ValueError("Zero length of word detected.")

            ds.append(Instance(words=char_seq,
                               tag=pos_seq))

        return ds
Example #28
0
    def _load(self, path):
        ds = DataSet()
        for idx, data in _read_conll(path,
                                     sep=self.sep,
                                     indexes=self.indexes,
                                     dropna=self.dropna):
            ins = {h: data[i] for i, h in enumerate(self.headers)}
            ds.append(Instance(**ins))
        return ds


# 第一步,定义具体的loader读取对应格式的输入数据,(如果没有内置已实现的loader,则实现自定义的loader类,具体方法:
# instance是采用dict的形式存放field信息,代表一个具体的sample语料
# dataset里面存放instance
#
# loader = myConllLoader(headers=['raw_words', 'ner'], indexes=[0, 1])
# paths = {
#     'train': "/Users/wangming/.fastNLP/dataset/weibo_NER/train.conll",
#     'dev': "/Users/wangming/.fastNLP/dataset/weibo_NER/dev.conll",
#     "test": "/Users/wangming/.fastNLP/dataset/weibo_NER/test.conll"
# }
# datasets = loader.load(paths).datasets
# # print(*list(datasets.keys()))
# # print(datasets['train'][0:3])

# word_vocab = Vocabulary()
# label_vocab = Vocabulary(padding=None, unknown=None)

# word_vocab.from_dataset(datasets['train'],
#                         field_name='raw_words',
#                         no_create_entry_dataset=[datasets['dev'], datasets['test']])
# label_vocab.from_dataset(datasets['train'], field_name='ner')
# print('label_vocab:{}\n{}'.format(len(label_vocab), label_vocab.idx2word))

# word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='raw_words_index')
# label_vocab.index_dataset(*list(datasets.values()), field_name='ner', new_field_name='ner_index')

# print(datasets['train'][0:3])
# print(word_vocab.idx2word[791])
Example #29
0
    def test(self, filepath):
        data = ConllxDataLoader().load(filepath)
        ds = DataSet()
        for ins1, ins2 in zip(add_seg_tag(data), data):
            ds.append(
                Instance(words=ins1[0],
                         tag=ins1[1],
                         gold_words=ins2[0],
                         gold_pos=ins2[1],
                         gold_heads=ins2[2],
                         gold_head_tags=ins2[3]))

        pp = self.pipeline
        for p in pp:
            if p.field_name == 'word_list':
                p.field_name = 'gold_words'
            elif p.field_name == 'pos_list':
                p.field_name = 'gold_pos'
        pp(ds)
        head_cor, label_cor, total = 0, 0, 0
        for ins in ds:
            head_gold = ins['gold_heads']
            head_pred = ins['heads']
            length = len(head_gold)
            total += length
            for i in range(length):
                head_cor += 1 if head_pred[i] == head_gold[i] else 0
        uas = head_cor / total
        print('uas:{:.2f}'.format(uas))

        for p in pp:
            if p.field_name == 'gold_words':
                p.field_name = 'word_list'
            elif p.field_name == 'gold_pos':
                p.field_name = 'pos_list'

        return uas
Example #30
0
    # prepare vocabulary
    vocab = {}
    for text in texts:
        for tokens in text.split():
            if tokens not in vocab:
                vocab[tokens] = len(vocab)
    print("vocabulary: ", vocab)

    # prepare input dataset
    data = DataSet()
    for text, label in zip(texts, labels):
        x = TextField(text.split(), False)
        y = LabelField(label, is_target=True)
        ins = Instance(text=x, label=y)
        data.append(ins)

    # use vocabulary to index data
    data.index_field("text", vocab)

    # define naive sampler for batch class
    class SeqSampler:
        def __call__(self, dataset):
            return list(range(len(dataset)))

    # use batch to iterate dataset
    data_iterator = Batch(data, 2, SeqSampler(), False)
    for epoch in range(1):
        for batch_x, batch_y in data_iterator:
            print(batch_x)
            print(batch_y)