Exemple #1
0
    def test(self):
        data = DataSet()
        for text, label in zip(texts, labels):
            x = TextField(text, is_target=False)
            y = LabelField(label, is_target=True)
            ins = Instance(text=x, label=y)
            data.append(ins)

        # use vocabulary to index data
        data.index_field("text", vocab)

        # define naive sampler for batch class
        class SeqSampler:
            def __call__(self, dataset):
                return list(range(len(dataset)))

        # use batch to iterate dataset
        data_iterator = Batch(data, 2, SeqSampler(), False)
        total_data = 0
        for batch_x, batch_y in data_iterator:
            total_data += batch_x["text"].size(0)
            self.assertTrue(batch_x["text"].size(0) == 2
                            or total_data == len(raw_texts))
            self.assertTrue(isinstance(batch_x, dict))
            self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
            self.assertTrue(isinstance(batch_y, dict))
            self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))
Exemple #2
0
    def test_case_1(self):
        args = {
            "epochs": 3,
            "batch_size": 2,
            "validate": False,
            "use_cuda": False,
            "pickle_path": "./save/",
            "save_best_dev": True,
            "model_name": "default_model_name.pkl",
            "loss": Loss("cross_entropy"),
            "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5,
            "evaluator": SeqLabelEvaluator()
        }
        trainer = SeqLabelTrainer(**args)

        train_data = [
            [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
            [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
        ]
        vocab = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
            '!': 5,
            '@': 6,
            '#': 7,
            '$': 8,
            '?': 9
        }
        label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

        data_set = DataSet()
        for example in train_data:
            text, label = example[0], example[1]
            x = TextField(text, False)
            x_len = LabelField(len(text), is_target=False)
            y = TextField(label, is_target=False)
            ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
            data_set.append(ins)

        data_set.index_field("word_seq", vocab)
        data_set.index_field("truth", label_vocab)

        model = SeqLabeling(args)

        trainer.train(network=model, train_data=data_set, dev_data=data_set)
        # If this can run, everything is OK.

        os.system("rm -rf save")
        print("pickle path deleted")
Exemple #3
0
    def test_label_field(self):
        label = LabelField("A", is_target=True)
        self.assertEqual(label.get_length(), 1)
        self.assertEqual(label.index({"A": 10}), 10)

        label = LabelField(30, is_target=True)
        self.assertEqual(label.get_length(), 1)
        tensor = label.to_tensor(0)
        self.assertEqual(tensor.shape, ())
        self.assertEqual(int(tensor), 30)
Exemple #4
0
 def convert_for_infer(self, data, vocabs):
     for word_seq in data:
         # list
         x = TextField(word_seq, is_target=False)
         x_len = LabelField(len(word_seq), is_target=False)
         instance = Instance()
         instance.add_field("word_seq", x)
         instance.add_field("word_seq_origin_len", x_len)
         self.append(instance)
     self.index_field("word_seq", vocabs["word_vocab"])
Exemple #5
0
 def convert_with_vocabs(self, data, vocabs):
     for example in data:
         word_seq, label = example[0], example[1]
         # list, str
         x = TextField(word_seq, is_target=False)
         y = LabelField(label, is_target=True)
         instance = Instance()
         instance.add_field("word_seq", x)
         instance.add_field("label", y)
         self.append(instance)
     self.index_field("word_seq", vocabs["word_vocab"])
     self.index_field("label", vocabs["label_vocab"])
Exemple #6
0
 def convert(self, data):
     for example in data:
         word_seq, label = example[0], example[1]
         # list, str
         self.word_vocab.update(word_seq)
         self.label_vocab.update(label)
         x = TextField(word_seq, is_target=False)
         y = LabelField(label, is_target=True)
         instance = Instance()
         instance.add_field("word_seq", x)
         instance.add_field("label", y)
         self.append(instance)
     self.index_field("word_seq", self.word_vocab)
     self.index_field("label", self.label_vocab)
Exemple #7
0
 def convert_with_vocabs(self, data, vocabs):
     for example in data:
         word_seq, label_seq = example[0], example[1]
         # list, list
         x = TextField(word_seq, is_target=False)
         x_len = LabelField(len(word_seq), is_target=False)
         y = TextField(label_seq, is_target=False)
         instance = Instance()
         instance.add_field("word_seq", x)
         instance.add_field("truth", y)
         instance.add_field("word_seq_origin_len", x_len)
         self.append(instance)
     self.index_field("word_seq", vocabs["word_vocab"])
     self.index_field("truth", vocabs["label_vocab"])
Exemple #8
0
    def convert_to_dataset(self, data, vocab, label_vocab):
        """Convert list of indices into a DataSet object.

        :param data: list. Entries are strings.
        :param vocab: a dict, mapping string (token) to index (int).
        :param label_vocab: a dict, mapping string (label) to index (int).
        :return data_set: a DataSet object
        """
        use_word_seq = False
        use_label_seq = False
        use_label_str = False

        # construct a DataSet object and fill it with Instances
        data_set = DataSet()
        for example in data:
            words, label = example[0], example[1]
            instance = Instance()

            if isinstance(words, list):
                x = TextField(words, is_target=False)
                instance.add_field("word_seq", x)
                use_word_seq = True
            else:
                raise NotImplementedError("words is a {}".format(type(words)))

            if isinstance(label, list):
                y = TextField(label, is_target=True)
                instance.add_field("label_seq", y)
                use_label_seq = True
            elif isinstance(label, str):
                y = LabelField(label, is_target=True)
                instance.add_field("label", y)
                use_label_str = True
            else:
                raise NotImplementedError("label is a {}".format(type(label)))
            data_set.append(instance)

        # convert strings to indices
        if use_word_seq:
            data_set.index_field("word_seq", vocab)
        if use_label_seq:
            data_set.index_field("label_seq", label_vocab)
        if use_label_str:
            data_set.index_field("label", label_vocab)

        return data_set
Exemple #9
0
    def convert(self, data):
        """Convert lists of strings into Instances with Fields.

        :param data: 3-level lists. Entries are strings.
        """
        bar = ProgressBar(total=len(data))
        for example in data:
            word_seq, label_seq = example[0], example[1]
            # list, list
            self.word_vocab.update(word_seq)
            self.label_vocab.update(label_seq)
            x = TextField(word_seq, is_target=False)
            x_len = LabelField(len(word_seq), is_target=False)
            y = TextField(label_seq, is_target=False)
            instance = Instance()
            instance.add_field("word_seq", x)
            instance.add_field("truth", y)
            instance.add_field("word_seq_origin_len", x_len)
            self.append(instance)
            bar.move()
        self.index_field("word_seq", self.word_vocab)
        self.index_field("truth", self.label_vocab)
Exemple #10
0
    def test_case_1(self):
        model_args = {
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5
        }
        valid_args = {
            "save_output": True,
            "validate_in_training": True,
            "save_dev_input": True,
            "save_loss": True,
            "batch_size": 2,
            "pickle_path": "./save/",
            "use_cuda": False,
            "print_every_step": 1,
            "evaluator": SeqLabelEvaluator()
        }

        train_data = [
            [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
            [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
        ]
        vocab = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
            '!': 5,
            '@': 6,
            '#': 7,
            '$': 8,
            '?': 9
        }
        label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

        data_set = SeqLabelDataSet()
        for example in train_data:
            text, label = example[0], example[1]
            x = TextField(text, False)
            x_len = LabelField(len(text), is_target=False)
            y = TextField(label, is_target=True)
            ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
            data_set.append(ins)

        data_set.index_field("word_seq", vocab)
        data_set.index_field("truth", label_vocab)

        model = SeqLabeling(model_args)

        tester = SeqLabelTester(**valid_args)
        tester.test(network=model, dev_data=data_set)
        # If this can run, everything is OK.

        os.system("rm -rf save")
        print("pickle path deleted")
Exemple #11
0
    texts = ["i am a cat", "this is a test of new batch", "haha"]
    labels = [0, 1, 0]

    # prepare vocabulary
    vocab = {}
    for text in texts:
        for tokens in text.split():
            if tokens not in vocab:
                vocab[tokens] = len(vocab)
    print("vocabulary: ", vocab)

    # prepare input dataset
    data = DataSet()
    for text, label in zip(texts, labels):
        x = TextField(text.split(), False)
        y = LabelField(label, is_target=True)
        ins = Instance(text=x, label=y)
        data.append(ins)

    # use vocabulary to index data
    data.index_field("text", vocab)

    # define naive sampler for batch class
    class SeqSampler:
        def __call__(self, dataset):
            return list(range(len(dataset)))

    # use batch to iterate dataset
    data_iterator = Batch(data, 2, SeqSampler(), False)
    for epoch in range(1):
        for batch_x, batch_y in data_iterator: