Exemplo n.º 1
0
    def test_set_from_file_reads_padded_files(self):
        # pylint: disable=protected-access
        vocab_filename = self.TEST_DIR / 'vocab_file'
        with codecs.open(vocab_filename, 'w', 'utf-8') as vocab_file:
            vocab_file.write('<S>\n')
            vocab_file.write('</S>\n')
            vocab_file.write('<UNK>\n')
            vocab_file.write('a\n')
            vocab_file.write('tricky\x0bchar\n')
            vocab_file.write('word\n')
            vocab_file.write('another\n')

        vocab = Vocabulary()
        vocab.set_from_file(vocab_filename, is_padded=True, oov_token="<UNK>")

        assert vocab._oov_token == DEFAULT_OOV_TOKEN
        assert vocab.get_token_index("random string") == 3
        assert vocab.get_token_index("<S>") == 1
        assert vocab.get_token_index("</S>") == 2
        assert vocab.get_token_index(DEFAULT_OOV_TOKEN) == 3
        assert vocab.get_token_index("a") == 4
        assert vocab.get_token_index("tricky\x0bchar") == 5
        assert vocab.get_token_index("word") == 6
        assert vocab.get_token_index("another") == 7
        assert vocab.get_token_from_index(0) == vocab._padding_token
        assert vocab.get_token_from_index(1) == "<S>"
        assert vocab.get_token_from_index(2) == "</S>"
        assert vocab.get_token_from_index(3) == DEFAULT_OOV_TOKEN
        assert vocab.get_token_from_index(4) == "a"
        assert vocab.get_token_from_index(5) == "tricky\x0bchar"
        assert vocab.get_token_from_index(6) == "word"
        assert vocab.get_token_from_index(7) == "another"
Exemplo n.º 2
0
    def test_set_from_file_reads_padded_files(self):

        vocab_filename = self.TEST_DIR / "vocab_file"
        with codecs.open(vocab_filename, "w", "utf-8") as vocab_file:
            vocab_file.write("<S>\n")
            vocab_file.write("</S>\n")
            vocab_file.write("<UNK>\n")
            vocab_file.write("a\n")
            vocab_file.write("tricky\x0bchar\n")
            vocab_file.write("word\n")
            vocab_file.write("another\n")

        vocab = Vocabulary()
        vocab.set_from_file(vocab_filename, is_padded=True, oov_token="<UNK>")

        assert vocab._oov_token == DEFAULT_OOV_TOKEN
        assert vocab.get_token_index("random string") == 3
        assert vocab.get_token_index("<S>") == 1
        assert vocab.get_token_index("</S>") == 2
        assert vocab.get_token_index(DEFAULT_OOV_TOKEN) == 3
        assert vocab.get_token_index("a") == 4
        assert vocab.get_token_index("tricky\x0bchar") == 5
        assert vocab.get_token_index("word") == 6
        assert vocab.get_token_index("another") == 7
        assert vocab.get_token_from_index(0) == vocab._padding_token
        assert vocab.get_token_from_index(1) == "<S>"
        assert vocab.get_token_from_index(2) == "</S>"
        assert vocab.get_token_from_index(3) == DEFAULT_OOV_TOKEN
        assert vocab.get_token_from_index(4) == "a"
        assert vocab.get_token_from_index(5) == "tricky\x0bchar"
        assert vocab.get_token_from_index(6) == "word"
        assert vocab.get_token_from_index(7) == "another"
Exemplo n.º 3
0
    def test_set_from_file_reads_padded_files(self):
        # pylint: disable=protected-access
        vocab_filename = self.TEST_DIR + 'vocab_file'
        with codecs.open(vocab_filename, 'w', 'utf-8') as vocab_file:
            vocab_file.write('<S>\n')
            vocab_file.write('</S>\n')
            vocab_file.write('<UNK>\n')
            vocab_file.write('a\n')
            vocab_file.write('word\n')
            vocab_file.write('another\n')

        vocab = Vocabulary()
        vocab.set_from_file(vocab_filename, is_padded=True, oov_token="<UNK>")

        assert vocab._oov_token == DEFAULT_OOV_TOKEN
        assert vocab.get_token_index("random string") == 3
        assert vocab.get_token_index("<S>") == 1
        assert vocab.get_token_index("</S>") == 2
        assert vocab.get_token_index(DEFAULT_OOV_TOKEN) == 3
        assert vocab.get_token_index("a") == 4
        assert vocab.get_token_index("word") == 5
        assert vocab.get_token_index("another") == 6
        assert vocab.get_token_from_index(0) == vocab._padding_token
        assert vocab.get_token_from_index(1) == "<S>"
        assert vocab.get_token_from_index(2) == "</S>"
        assert vocab.get_token_from_index(3) == DEFAULT_OOV_TOKEN
        assert vocab.get_token_from_index(4) == "a"
        assert vocab.get_token_from_index(5) == "word"
        assert vocab.get_token_from_index(6) == "another"
Exemplo n.º 4
0
def load_vocab_from_directory(directory: str,
                              padding_token: str = "[PAD]",
                              oov_token: str = "[UNK]") -> Vocabulary:
    """
    Load pre-trained vocabulary form a directory (since the original method does not work --> OOV problem)
    
    Args:
        directory (str)
        padding_token (str): default OOV token symbol ("[PAD]" our case, since we are using BERT)
        oov_token (str): default OOV token symbol ("[UNK]" our case, since we are using BERT)

    Returns:
        Vocabulary
    """
    NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'

    print("Loading token dictionary from", directory)
    with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r',
                     'utf-8') as namespace_file:
        non_padded_namespaces = [
            namespace_str.strip() for namespace_str in namespace_file
        ]

    vocab = Vocabulary(non_padded_namespaces=non_padded_namespaces)

    # Check every file in the directory.
    for namespace_filename in os.listdir(directory):
        if namespace_filename == NAMESPACE_PADDING_FILE:
            continue
        if namespace_filename.startswith("."):
            continue
        namespace = namespace_filename.replace('.txt', '')
        if any(
                namespace_match(pattern, namespace)
                for pattern in non_padded_namespaces):
            is_padded = False
        else:
            is_padded = True
        filename = os.path.join(directory, namespace_filename)
        vocab.set_from_file(filename,
                            is_padded,
                            oov_token=oov_token,
                            namespace=namespace)
        vocab._padding_token = padding_token

    return vocab
Exemplo n.º 5
0
    def test_set_from_file_reads_non_padded_files(self):

        vocab_filename = self.TEST_DIR / "vocab_file"
        with codecs.open(vocab_filename, "w", "utf-8") as vocab_file:
            vocab_file.write("B-PERS\n")
            vocab_file.write("I-PERS\n")
            vocab_file.write("O\n")
            vocab_file.write("B-ORG\n")
            vocab_file.write("I-ORG\n")

        vocab = Vocabulary()
        vocab.set_from_file(vocab_filename, is_padded=False, namespace="tags")
        assert vocab.get_token_index("B-PERS", namespace="tags") == 0
        assert vocab.get_token_index("I-PERS", namespace="tags") == 1
        assert vocab.get_token_index("O", namespace="tags") == 2
        assert vocab.get_token_index("B-ORG", namespace="tags") == 3
        assert vocab.get_token_index("I-ORG", namespace="tags") == 4
        assert vocab.get_token_from_index(0, namespace="tags") == "B-PERS"
        assert vocab.get_token_from_index(1, namespace="tags") == "I-PERS"
        assert vocab.get_token_from_index(2, namespace="tags") == "O"
        assert vocab.get_token_from_index(3, namespace="tags") == "B-ORG"
        assert vocab.get_token_from_index(4, namespace="tags") == "I-ORG"
Exemplo n.º 6
0
    def test_set_from_file_reads_non_padded_files(self):
        # pylint: disable=protected-access
        vocab_filename = self.TEST_DIR / 'vocab_file'
        with codecs.open(vocab_filename, 'w', 'utf-8') as vocab_file:
            vocab_file.write('B-PERS\n')
            vocab_file.write('I-PERS\n')
            vocab_file.write('O\n')
            vocab_file.write('B-ORG\n')
            vocab_file.write('I-ORG\n')

        vocab = Vocabulary()
        vocab.set_from_file(vocab_filename, is_padded=False, namespace='tags')
        assert vocab.get_token_index("B-PERS", namespace='tags') == 0
        assert vocab.get_token_index("I-PERS", namespace='tags') == 1
        assert vocab.get_token_index("O", namespace='tags') == 2
        assert vocab.get_token_index("B-ORG", namespace='tags') == 3
        assert vocab.get_token_index("I-ORG", namespace='tags') == 4
        assert vocab.get_token_from_index(0, namespace='tags') == "B-PERS"
        assert vocab.get_token_from_index(1, namespace='tags') == "I-PERS"
        assert vocab.get_token_from_index(2, namespace='tags') == "O"
        assert vocab.get_token_from_index(3, namespace='tags') == "B-ORG"
        assert vocab.get_token_from_index(4, namespace='tags') == "I-ORG"
Exemplo n.º 7
0
    def test_set_from_file_reads_non_padded_files(self):
        # pylint: disable=protected-access
        vocab_filename = self.TEST_DIR + 'vocab_file'
        with codecs.open(vocab_filename, 'w', 'utf-8') as vocab_file:
            vocab_file.write('B-PERS\n')
            vocab_file.write('I-PERS\n')
            vocab_file.write('O\n')
            vocab_file.write('B-ORG\n')
            vocab_file.write('I-ORG\n')

        vocab = Vocabulary()
        vocab.set_from_file(vocab_filename, is_padded=False, namespace='tags')
        assert vocab.get_token_index("B-PERS", namespace='tags') == 0
        assert vocab.get_token_index("I-PERS", namespace='tags') == 1
        assert vocab.get_token_index("O", namespace='tags') == 2
        assert vocab.get_token_index("B-ORG", namespace='tags') == 3
        assert vocab.get_token_index("I-ORG", namespace='tags') == 4
        assert vocab.get_token_from_index(0, namespace='tags') == "B-PERS"
        assert vocab.get_token_from_index(1, namespace='tags') == "I-PERS"
        assert vocab.get_token_from_index(2, namespace='tags') == "O"
        assert vocab.get_token_from_index(3, namespace='tags') == "B-ORG"
        assert vocab.get_token_from_index(4, namespace='tags') == "I-ORG"
Exemplo n.º 8
0
# coding=utf-8
# @Author: 莫冉
# @Date: 2020-08-06

from allennlp.data.vocabulary import Vocabulary

vocab_file = "../data/base_bert/vocab.txt"
save_path = "../../../vocab_path"

vocab = Vocabulary(padding_token="[PAD]", oov_token="[UNK]")

vocab.set_from_file(vocab_file, is_padded=True, oov_token="[UNK]")

vocab.save_to_files(save_path)

print(vocab.get_token_index(vocab._oov_token))