示例#1
0
    def test_build_vocab(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(
                return_value=iter(txt_content))

            # Trying to build a vocab more than the possible size
            vocab_size = bpe_model.build_vocab(txt_path="no_exist_file.txt",
                                               vocab_size=20)
            # Asserting that we go back to the original size (number of word types.)
            assert vocab_size == 9
            assert bpe_model.max_bpe_len == 9 + len(bpe_model.eow_symbol)

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(
                return_value=iter(txt_content))
            # Trying to build a vocab with an acceptable size.
            vocab_size = bpe_model.build_vocab(txt_path="no_exist_file.txt",
                                               vocab_size=12)
            # asserting that the size is as expected.
            assert vocab_size == len(bpe_model.vocab) == 12
            assert bpe_model.max_bpe_len == 2
示例#2
0
    def test_vocab_init(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(
                return_value=iter(txt_content))
            bpe_model._init_vocab(txt_path="no_exist_file.txt")

            vocab_items = Counter()
            for (vocab_entry, freq) in bpe_model.current_train_data:
                for item in vocab_entry:
                    vocab_items[item] += freq

            assert vocab_items[bpe_model.eow_symbol] == 11
            assert vocab_items["3"] == 7
            assert len(vocab_items) == 10
            assert "12" not in vocab_items
            assert "123" not in vocab_items

            assert len(bpe_model.merge_candidate_indices) == 17
            assert bpe_model.merge_candidate_indices[("2",
                                                      "3")] == {0, 2, 6, 7}

            assert len(bpe_model.merge_candidate_freq) == 17
            assert bpe_model.merge_candidate_freq[("2", "3")] == 5
示例#3
0
    def test_bpe_merge(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(
                return_value=iter(txt_content))
            bpe_model._init_vocab(txt_path="no_exist_file.txt")

            # Trying merging a candidate that does not exist.
            bpe_model.merge_candidate_into_vocab(merge_candidate=("3", "1"))
            assert len(bpe_model.vocab) == 10

            # Trying merging a candidate that exists.
            bpe_model.merge_candidate_into_vocab(merge_candidate=("2", "3"))
            assert len(bpe_model.vocab) == 11

            # Trying merging a candidate that exists. Entry "3" should remove
            # from vocab.
            bpe_model.merge_candidate_into_vocab(merge_candidate=("3", "4"))
            assert len(bpe_model.vocab) == 11

            # Trying merging a candidate that does not exist.
            bpe_model.merge_candidate_into_vocab(
                merge_candidate=("3", bpe_model.eow_symbol))
            assert len(bpe_model.vocab) == 11
示例#4
0
    def test_segment_file(self):
        bpe_model = bpe.BPE()

        tmp_dir = tempfile.mkdtemp()
        input_file, output_file = (
            path.join(tmp_dir, "test.in"),
            path.join(tmp_dir, "test1.out"),
        )

        with open(input_file, "w", encoding="utf-8") as writer:
            writer.write("\n".join(txt_content))
        bpe_model.build_vocab(txt_path=input_file, vocab_size=12)

        output = []
        for line in txt_content:
            cur_line_output = []
            for word in line.strip().split():
                cur_line_output.append(" ".join(bpe_model.segment_word(word)))
            output.append(" ".join(cur_line_output))
            output.append("\n")
        expected_output = "".join(output).strip()

        bpe_model.segment_txt(input_path=input_file, output_path=output_file)
        model_output = open(output_file, "r", encoding="utf-8").read().strip()
        assert expected_output == model_output

        shutil.rmtree(tmp_dir)
示例#5
0
    def test_bpe_merge(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(return_value=iter(txt_content))
            bpe_model._init_vocab(txt_path="no_exist_file.txt")
            num_cpus = 3
            pool = Pool(processes=num_cpus)

            # Trying merging a candidate that does not exist.
            vocab_size = bpe_model.merge_candidate_into_vocab(
                candidate=("3", "1"), num_cpus=num_cpus, pool=pool
            )
            assert vocab_size == 10

            # Trying merging a candidate that does exists.
            vocab_size = bpe_model.merge_candidate_into_vocab(
                candidate=("2", "3"), num_cpus=num_cpus, pool=pool
            )
            assert vocab_size == 11

            # Trying merging a candidate that does exists. Entry "3" should remove
            # from vocab.
            vocab_size = bpe_model.merge_candidate_into_vocab(
                candidate=("3", "4"), num_cpus=num_cpus, pool=pool
            )
            assert vocab_size == 11

            # Trying merging a candidate that does not exist.
            vocab_size = bpe_model.merge_candidate_into_vocab(
                candidate=("3", bpe_model.eow_symbol), num_cpus=num_cpus, pool=pool
            )
            assert vocab_size == 11
示例#6
0
    def test_best_candidate(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(return_value=iter(txt_content))
            bpe_model._init_vocab(txt_path="no_exist_file.txt")
            assert bpe_model.get_best_candidate() == ("1", "2")
示例#7
0
    def test_segment_word(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(return_value=iter(txt_content))

            bpe_model.build_vocab(txt_path="no_exist_file.txt", vocab_size=12)
            assert bpe_model.segment_word("1234") == ["12", "34", bpe_model.eow_symbol]

            # Giving unknown character sequence
            assert bpe_model.segment_word("12634") == [
                "12",
                "6",
                "34",
                bpe_model.eow_symbol,
            ]
示例#8
0
    def test_vocab_init(self):
        bpe_model = bpe.BPE()

        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__ = mock_open
            mock_open.return_value.__iter__ = Mock(
                return_value=iter(txt_content))
            bpe_model._init_vocab(txt_path="no_exist_file.txt")

            vocab_items = Counter()
            for vocab_entry, freq in bpe_model.current_train_data:
                items = vocab_entry.split()
                for item in items:
                    vocab_items[item] += freq

            assert vocab_items[bpe_model.eow_symbol] == 11
            assert vocab_items["3"] == 7
            assert len(vocab_items) == 10
            assert "12" not in vocab_items
            assert "123" not in vocab_items