예제 #1
0
def test_get_new_tokens_learns_lowercased_tokens_when_cased_arg_is_True(
        tokenizer: PreTrainedTokenizerFast, training_corpus: Corpus):
    augmentor = VocabAugmentor(tokenizer,
                               cased=True,
                               target_vocab_size=len(tokenizer) + 3)
    new_tokens = augmentor.get_new_tokens(training_corpus)
    assert any(c.isupper() for c in "".join(new_tokens))
예제 #2
0
def test_get_new_tokens_return_correct_number_of_new_tokens(
        tokenizer: PreTrainedTokenizerFast, training_corpus: Corpus,
        n_extra_tokens):
    augmentor = VocabAugmentor(tokenizer,
                               cased=False,
                               target_vocab_size=len(tokenizer) +
                               n_extra_tokens)
    new_tokens = augmentor.get_new_tokens(training_corpus)
    assert len(new_tokens) <= n_extra_tokens
예제 #3
0
def test__get_training_files_tmpfile_returned_properly_saves_text(
    named_tmpfile: IO[str], ):
    corpus: Corpus = [
        "This is a document.", "The document following the first."
    ]
    train_files = VocabAugmentor._get_training_files(corpus, named_tmpfile)
    assert Path(train_files[0]).read_text() == "".join(corpus)
예제 #4
0
def test_VocabAugmentor_error_raised_when_target_vocab_size_is_less_than_tokenizer_vocab_size(
    vocab_size_multiplier, ):
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    target_vocab_size = int(vocab_size_multiplier * len(tokenizer))
    with pytest.raises(ValueError):
        VocabAugmentor(tokenizer,
                       cased=True,
                       target_vocab_size=target_vocab_size)
예제 #5
0
def test__get_training_files_return_tmpfile_when_corpus_is_of_type_Corpus(
    named_tmpfile: IO[str], ):
    corpus: Corpus = [
        "This is a document.", "The document following the first."
    ]
    train_files = VocabAugmentor._get_training_files(corpus, named_tmpfile)
    assert len(train_files) == 1
    assert train_files[0] == named_tmpfile.name
예제 #6
0
def test__get_training_files_correctness_single_file(tmp_path,
                                                     input_corpus_as_str,
                                                     named_tmpfile: IO[str]):
    corpus = tmp_path / "corpus.txt"
    corpus.write_text("")  # Creates the file

    train_files = VocabAugmentor._get_training_files(
        corpus if input_corpus_as_str else Path(corpus), named_tmpfile)

    assert len(train_files) == 1
    assert isinstance(train_files[0], str)
예제 #7
0
def test__get_training_files_correctness_single_directory(
        tmp_path, input_corpus_as_str, named_tmpfile: IO[str]):
    n_files = 3
    corpus_dir = tmp_path
    # Create multiple text files
    for i in range(3):
        (corpus_dir / f'corpus{i}.txt').write_text("")

    train_files = VocabAugmentor._get_training_files(
        corpus_dir if input_corpus_as_str else Path(corpus_dir), named_tmpfile)

    assert len(train_files) == n_files
    assert all(isinstance(file, str) for file in train_files)
예제 #8
0
def test__get_training_files_raise_error_on_nonexistent_file(
        named_tmpfile: IO[str]):
    with pytest.raises(FileNotFoundError):
        VocabAugmentor._get_training_files("nonexistent_file.txt",
                                           named_tmpfile)
예제 #9
0
def test__remove_overlapping_tokens_correctness(augmentor: VocabAugmentor):
    c = Counter(["apple", "a_new_token", "day", "a_new_token"])
    output = augmentor._remove_overlapping_tokens(c)
    assert set(output) == {"a_new_token"}
예제 #10
0
def augmentor(tokenizer) -> VocabAugmentor:
    return VocabAugmentor(tokenizer,
                          cased=False,
                          target_vocab_size=int(1.1 * len(tokenizer)))
예제 #11
0
def test_get_new_tokens_does_not_return_existing_tokens(
        augmentor: VocabAugmentor):
    training_corpus = ["An apple a day keeps the doctors away"]
    new_tokens = augmentor.get_new_tokens(training_corpus)
    assert set(new_tokens) < set(training_corpus[0].split())