Ejemplo n.º 1
0
    def test_unlabeled(self):
        bio_data = UnlabeledBIODatasetTest.create_fake_data()
        unlabeled_corpus = UnlabeledBIODataset(dataset_id=0, bio_data=bio_data)

        for item in unlabeled_corpus:
            assert 'output' not in item
            assert len(item) == 3
Ejemplo n.º 2
0
    def test_unlabeled(self):
        bio_data = RandomHeuristicTests.create_fake_data()
        unlabeled_corpus = UnlabeledBIODataset(dataset_id=0, bio_data=bio_data)

        h = RandomHeuristic()
        h_res = h.evaluate(unlabeled_corpus)

        assert type(h_res) == torch.Tensor
        assert h_res.shape == (len(unlabeled_corpus), )
Ejemplo n.º 3
0
    def test_gold_oracle(self):
        bio_data = GoldOracleTest.create_fake_data()
        unlabeled_corpus = UnlabeledBIODataset(dataset_id=0, bio_data=bio_data)

        oracle = GoldOracle(bio_data)

        for item in bio_data:
            s_id, s_input, s_output = item['id'], item['input'], item['output']

            oracle_out = oracle.get_query(query=(s_id, s_input))
            o_input = oracle_out['input']
            o_output = oracle_out['output']
            o_id = oracle_out['id']
            o_weight = oracle_out['weight']

            assert o_id == s_id
            assert s_input == o_input
            assert o_output == s_output
            assert o_weight == 1.0
Ejemplo n.º 4
0
    def test_unlabeled(self):
        bio_data = ClusteringHeuristicTest.create_fake_data()
        cwr = ClusteringHeuristicTest.build_cache_cwr()
        unlabeled_corpus = UnlabeledBIODataset(dataset_id=0, bio_data=bio_data)

        h = ClusteringHeuristic(cwr, unlabeled_corpus)
        h_res = h.evaluate(unlabeled_corpus,
                           ClusteringHeuristicTest.SAMPLE_SIZE)

        assert type(h_res) == torch.Tensor
        assert h_res.shape == (len(unlabeled_corpus), )

        new_points = sorted(range(len(unlabeled_corpus)),
                            reverse=True,
                            key=lambda ind: h_res[ind])

        select_val = h_res[new_points][0]
        unselect_val = h_res[new_points][-1]

        assert (h_res == unselect_val).sum() == (
            len(unlabeled_corpus) - ClusteringHeuristicTest.SAMPLE_SIZE)
        assert (
            h_res == select_val).sum() == ClusteringHeuristicTest.SAMPLE_SIZE
Ejemplo n.º 5
0
    def test_keyword_func(self):
        dataset = KeywordFunctionTest.create_fake_data()
        unlabeled_corpus = UnlabeledBIODataset(bio_data=dataset,
                                               dataset_id=dataset.dataset_id)

        func = KeywordMatchFunction('Tag')
        func.train(train_data=dataset)
        annotated_corpus = func.evaluate(unlabeled_corpus=unlabeled_corpus)

        expected_counter = Counter()
        expected_counter['single'] = 2
        expected_counter['double'] = 1

        expected_neg_counter = Counter()
        expected_neg_counter['triple'] = 1
        expected_neg_counter['no_label'] = 1

        assert expected_counter == func.keywords['pos']
        assert expected_neg_counter == func.keywords['neg']

        converter = BIOConverter(binary_class='Tag')
        annotated_corpus = converter.convert(annotated_corpus)
        for ann_entry in annotated_corpus:
            assert self._verify_bio_scheme(ann_entry['output'], 'Tag')
Ejemplo n.º 6
0
def main():
    args = get_active_args().parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)

    device = 'cuda' if torch.cuda.is_available() and args.cuda else 'cpu'

    train_file, valid_file, test_file = get_dataset_files(dataset=args.dataset)

    class_labels: List[str] = construct_f1_class_labels(args.binary_class)

    train_bio = BIODataset(
        dataset_id=0,
        file_name=train_file,
        binary_class=args.binary_class,
    )

    train_bio.parse_file()

    if args.test:
        print('using test set')
    valid_bio = BIODataset(
        dataset_id=1,
        file_name=valid_file if not args.test else test_file,
        binary_class=args.binary_class,
    )

    valid_bio.parse_file()

    vocab = construct_vocab([train_bio, valid_bio])

    unlabeled_corpus = UnlabeledBIODataset(
        dataset_id=train_bio.dataset_id,
        bio_data=train_bio,
    )

    model = build_model(
        model_type=args.model_type,
        vocab=vocab,
        hidden_dim=args.hidden_dim,
        class_labels=class_labels,
        cached=args.cached,
    )

    oracle = GoldOracle(train_bio)

    active_train(
        model=model,
        unlabeled_dataset=unlabeled_corpus,
        valid_dataset=valid_bio,
        vocab=vocab,
        oracle=oracle,
        optimizer_type=args.opt_type,
        optimizer_learning_rate=args.opt_lr,
        optimizer_weight_decay=args.opt_weight_decay,
        use_weak=args.use_weak,
        weak_fine_tune=args.use_weak_fine_tune,
        weak_weight=args.weak_weight,
        weak_function=args.weak_function,
        weak_collator=args.weak_collator,
        sample_strategy=args.sample_strategy,
        batch_size=args.batch_size,
        patience=args.patience,
        num_epochs=args.num_epochs,
        device=device,
        log_dir=args.log_dir,
        model_name=args.model_name,
    )
Ejemplo n.º 7
0
def active_train_iteration(
    heuristic: RandomHeuristic,
    unlabeled_dataset: UnlabeledBIODataset,
    sample_size: int,
    labeled_indexes: List[int],
    oracle: Oracle,
    train_data: DatasetType,
    valid_reader: DatasetReader,
    vocab: Vocabulary,
    model: Model,
    cached_text_field_embedders: List[CachedTextFieldEmbedder],
    spacy_feature_extractor: SpaCyFeatureExtractor,
    optimizer_type: str,
    optimizer_learning_rate: float,
    optimizer_weight_decay: float,
    use_weak: bool,
    weak_weight: float,
    weak_function: List[str],
    weak_collator: str,
    sample_strategy: str,
    batch_size: int,
    patience: int,
    num_epochs: int,
    device: str,
) -> Tuple[Model, Dict[str, object]]:
    # select new points from distribution
    # distribution contains score for each index
    distribution = heuristic.evaluate(unlabeled_dataset, sample_size)
    new_points = []

    # sample the sample size from the distribution
    sample_size = min(sample_size, len(distribution) - 1)
    if sample_strategy == 'sample':
        new_points = torch.multinomial(distribution, sample_size)
    elif sample_strategy == 'top_k':
        new_points = sorted(
            range(len(distribution)), 
            reverse=True,
            key=lambda ind: distribution[ind]
        )
    else:
        raise Exception(f'Unknown sampling strategry: {sample_strategy}')
    new_points = new_points[:sample_size]

    # new points now contains list of indexes in the unlabeled
    # corpus to annotate
    # use new points to augment train_dataset
    # remove points from unlabaled corpus
    query = [
        (
            unlabeled_dataset[ind]['id'],
            unlabeled_dataset[ind]['input'],
        ) for ind in new_points
    ]

    labeled_indexes.extend(
        ind for (ind, _) in query
    )

    oracle_labels = [oracle.get_query(q) for q in query]
    train_data.extend(oracle_labels)

    # remove unlabeled data points from corpus
    [unlabeled_dataset.remove(q) for q in query]

    weak_data = []
    if use_weak:
        # builds a weak set to augment the training
        # set
        weak_data = build_weak_data(
            train_data,
            unlabeled_dataset,
            model,
            weight=weak_weight,
            function_types=weak_function,
            collator_type=weak_collator,
            contextual_word_embeddings=cached_text_field_embedders,
            spacy_feature_extractor=spacy_feature_extractor,
            vocab=vocab,
        )

    model, metrics = train(
        model=model,
        binary_class=unlabeled_dataset.binary_class,
        train_data=train_data + weak_data,
        valid_reader=valid_reader,
        vocab=vocab,
        optimizer_type=optimizer_type,
        optimizer_learning_rate=optimizer_learning_rate,
        optimizer_weight_decay=optimizer_weight_decay,
        batch_size=batch_size,
        patience=patience,
        num_epochs=num_epochs,
        device=device,
    )

    return model, metrics