Exemplo n.º 1
0
    def test_torch_unlabeled_iteration(self):
        '''
        A test case to ensure unlabeled dataset iterates and collates with
        no problem
        '''
        # Data
        vocab = utils.build_sample_vocab()
        tag_vocab = utils.build_sample_tag_vocab()
        unlabeled_dataset = utils.construct_sample_unlabeled_dataset()

        # Parameters
        batch_size = 2
        shuffle = False
        num_works = 0

        data_loader = conlldataloader.get_unlabeled_data_loader(
            vocab=vocab,
            categories=tag_vocab,
            unlabeled_data=unlabeled_dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_works,
        )

        for i, entry in enumerate(data_loader):
            # ensure iteration works
            pass
Exemplo n.º 2
0
def main_train(d: Dict[str, object], num_epochs: int, trainer_args, trainer_kwargs, database_items):
    print('Started main train')
    def _epoch_comparator(incoming, best) -> bool:
            res = incoming['train_f1_avg'] > best['train_f1_avg']
            if res:
                print("Found better!")
            return res
    trainer = Trainer(
        train_label_fn=lambda data, index : (data[index][0], data[index][1][0], data[index][1][1]),
        test_label_fn=lambda data, index: (data[index][0], data[index][1][0], data[index][1][1]),
        epoch_comparator=_epoch_comparator,
        verbose_log=False,
        logger=None,
        *trainer_args,
        **trainer_kwargs,
    )

    best_epoch, best_epoch_summary = trainer.train(epochs=num_epochs, update_dict=d)

    return best_epoch, best_epoch_summary, trainer.get_best_model(), ner.utils.compute_labels(
        trainer.get_best_model(),
        conlldataloader.get_unlabeled_data_loader(
            vocab=trainer_kwargs['vocab'],
            categories=trainer_kwargs['tags'],
            unlabeled_data=database_items,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            unlabeled_example_fn=lambda dataset, index: (dataset[index][0], dataset[index][1][0]),
            collate_fn=conlldataloader.collate_unlabeld_fn_with_sid,
        ),
        tag_vocab=trainer_kwargs['tags'],
        verbose=True,
        device=trainer_kwargs['device'],
    )
Exemplo n.º 3
0
    def test_compute_labels(self):
        result = compute_labels(
            model=utils.build_all_models()[1],
            data_loader=get_unlabeled_data_loader(
                vocab=utils.build_sample_vocab(),
                categories=utils.build_sample_tag_vocab(),
                unlabeled_data=utils.construct_sample_unlabeled_dataset(),
                batch_size=1,
                shuffle=False,
                num_workers=0,
            ),
            tag_vocab=utils.build_sample_tag_vocab(),
            verbose=True,
        )

        assert len(result) == len(utils.construct_sample_unlabeled_dataset())
Exemplo n.º 4
0
    def evaluate(self, labels=None):
        database_dataset = list(self.database.database.items())
        if labels is None:
            labels = ner.utils.compute_labels(
                self.model,
                conlldataloader.get_unlabeled_data_loader(
                    vocab=self.vocab,
                    categories=self.tag_vocab,
                    unlabeled_data=database_dataset,
                    batch_size=1,
                    shuffle=False,
                    num_workers=0,
                    unlabeled_example_fn=lambda dataset, index:
                    (dataset[index][0], dataset[index][1][0]),
                    collate_fn=conlldataloader.collate_unlabeld_fn_with_sid,
                ),
                tag_vocab=self.tag_vocab,
                verbose=True,
                device=self.device,
            )

        pos_flipped, neg_flipped = utils.compute_total_flip(
            self.prev_labels, labels)

        number_labeled = len(self.train_data) + len(self.test_data)

        flipped_data = {
            "labeled_set_sizes": number_labeled,
            "pos_flipped": pos_flipped,
            "neg_flipped": neg_flipped,
            "total_flipped": pos_flipped + neg_flipped,
        }

        self.prev_labels = labels

        if (len(self.predictions["labeled_set_sizes"]) == 0 or
                number_labeled != self.predictions["labeled_set_sizes"][-1]):
            self.predictions["flipped_data"].append(flipped_data)
            self.predictions["labeled_set_sizes"].append(number_labeled)
            self.predictions["training_set_sizes"].append(len(self.train_data))
            self.predictions["testing_set_sizes"].append(len(self.train_data))
        else:
            self.predictions["flipped_data"][-1] = (flipped_data)
            self.predictions["labeled_set_sizes"] = number_labeled
            self.predictions["training_set_sizes"] = len(self.train_data)
            self.predictions["testing_set_sizes"] = len(self.train_data)

        for i, label in enumerate(labels):
            s_id, (sent, real_label) = database_dataset[i]
            ranges, _, entities = self.explain_labels(sent, label)

            stored_sent, stored_label_info = self.predictions[
                "predicted_data"][s_id]
            real_ranges, _, real_entities = self.explain_labels(
                sent, real_label)
            if stored_label_info is None or len(
                    self.predictions["labeled_set_sizes"]) == 0:
                self.predictions["predicted_data"][s_id] = (
                    sent,
                    {
                        "labeled_set_sizes":
                        [number_labeled],  # len(self.train_data)],
                        "ranges": [ranges],
                        "entities": [entities],
                        "real_ranges":
                        real_ranges,
                        "real_entities":
                        real_entities,
                        "is_test": (s_id in self.test_data
                                    and real_label is not None),
                        "is_train": (s_id in self.train_data
                                     and real_label is not None),
                    })
            else:
                if (len(stored_label_info['labeled_set_sizes']) == 0
                        or number_labeled !=
                        stored_label_info['labeled_set_sizes'][-1]):
                    stored_label_info['labeled_set_sizes'].append(
                        number_labeled)
                    stored_label_info['ranges'].append(ranges)
                    stored_label_info['entities'].append(entities)
                else:
                    stored_label_info['labeled_set_sizes'][
                        -1] = number_labeled  #len(self.train_data)
                    stored_label_info['ranges'][-1] = ranges
                    stored_label_info['entities'][-1] = entities
                stored_label_info[
                    'is_test'] = s_id in self.test_data and real_label is not None
                stored_label_info[
                    'is_train'] = s_id in self.train_data and real_label is not None
                stored_label_info["real_ranges"] = real_ranges
                stored_label_info["real_entities"] = real_entities

        return self.predictions
Exemplo n.º 5
0
    def test_cached_model(self):
        dataset = utils.construct_sample_unlabeled_dataset()
        vocab = utils.build_sample_vocab()
        tag_set = utils.build_sample_tag_vocab()
        ce = utils.construct_cached_embedder()
        cbc = ner.utils.build_model(
            model_type='cached',
            embedding_dim=1024,
            hidden_dim=300,
            vocab=vocab,
            tag_vocab=tag_set,
            batch_size=2,
        )

        cbc.embedder = ce

        data_loader = conlldataloader.get_unlabeled_data_loader(
            vocab=vocab,
            categories=tag_set,
            unlabeled_data=dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            collate_fn=conlldataloader.collate_unlabeld_fn_with_sid,
        )

        for s_ids, sentence, sentence_chars in data_loader:
            embedded_sent = ce.batched_forward_cached(s_ids, sentence)
            computed_embedded_sent = ce.forward(sentence_chars)
            computed_embedded_sent_2 = ce.forward(sentence_chars)
            assert len(sentence[0]) == len(embedded_sent[0])
            assert embedded_sent.shape == computed_embedded_sent.shape

        data_loader = conlldataloader.get_unlabeled_data_loader(
            vocab=vocab,
            categories=tag_set,
            unlabeled_data=dataset,
            batch_size=2,
            shuffle=False,
            num_workers=0,
            collate_fn=conlldataloader.collate_unlabeld_fn_with_sid,
        )

        cbc.eval()
        for s_ids, sentence, sentence_chars in data_loader:
            torch.random.manual_seed(0)
            res = cbc(sentence, sentence_chars, s_ids)
            torch.random.manual_seed(0)
            res2 = cbc(sentence, sentence_chars, None)

        cbc.train()
        self._test_single_model_train(model=cbc)

        cbc.eval()
        labels = ner_utils.compute_labels(
            model=cbc,
            data_loader=data_loader,
            tag_vocab=tag_set,
            verbose=False,
        )

        print(labels)