def main() -> None:
    parser = argparse.ArgumentParser(
        description="Build vocabulary from corpus data.")
    parser.add_argument(
        "--corpus-data",
        type=str,
        required=True,
        help=
        "The path pattern (glob) to all tokenized corpus files (train, test, val)."
    )
    parser.add_argument("--langs",
                        type=str,
                        required=True,
                        help="The pre-trained model languages.")
    parser.add_argument("--output",
                        type=str,
                        required=True,
                        help="The vocabulary file.")
    args = parser.parse_args()

    langs = args.langs.split(",")
    ft_dict = Dictionary()
    for data_path in glob(args.corpus_data):
        Dictionary.add_file_to_dictionary(data_path, ft_dict, tokenize_line, 4)
    ft_dict.finalize(padding_factor=0)
    pad_dict(ft_dict, len(langs) + 1)
    ft_dict.save(args.output)
Пример #2
0
    def test_finalize(self):
        txt = [
            "A B C D",
            "B C D",
            "C D",
            "D",
        ]
        ref_ids1 = list(
            map(
                torch.IntTensor,
                [
                    [4, 5, 6, 7, 2],
                    [5, 6, 7, 2],
                    [6, 7, 2],
                    [7, 2],
                ],
            ))
        ref_ids2 = list(
            map(
                torch.IntTensor,
                [
                    [7, 6, 5, 4, 2],
                    [6, 5, 4, 2],
                    [5, 4, 2],
                    [4, 2],
                ],
            ))

        # build dictionary
        d = Dictionary()
        for line in txt:
            d.encode_line(line, add_if_not_exist=True)

        def get_ids(dictionary):
            ids = []
            for line in txt:
                ids.append(dictionary.encode_line(line,
                                                  add_if_not_exist=False))
            return ids

        def assertMatch(ids, ref_ids):
            for toks, ref_toks in zip(ids, ref_ids):
                self.assertEqual(toks.size(), ref_toks.size())
                self.assertEqual(0, (toks != ref_toks).sum().item())

        ids = get_ids(d)
        assertMatch(ids, ref_ids1)

        # check finalized dictionary
        d.finalize()
        finalized_ids = get_ids(d)
        assertMatch(finalized_ids, ref_ids2)

        # write to disk and reload
        with tempfile.NamedTemporaryFile(mode="w") as tmp_dict:
            d.save(tmp_dict.name)
            d = Dictionary.load(tmp_dict.name)
            reload_ids = get_ids(d)
            assertMatch(reload_ids, ref_ids2)
            assertMatch(finalized_ids, reload_ids)
Пример #3
0
def write_dictionary(model_dir, lst):
    '''Write out dictionary in fair seq format.'''
    joined_dict = Dictionary()
    for toks in lst:
        for t in toks:
            joined_dict.add_symbol(t)

    print('| dictionary: {} types'.format(len(joined_dict)))

    with open(model_dir + '/dict.txt', 'w') as fd:
        joined_dict.save(fd)
Пример #4
0
    def test_finalize(self):
        txt = [
            'A B C D',
            'B C D',
            'C D',
            'D',
        ]
        ref_ids1 = list(map(torch.IntTensor, [
            [4, 5, 6, 7, 2],
            [5, 6, 7, 2],
            [6, 7, 2],
            [7, 2],
        ]))
        ref_ids2 = list(map(torch.IntTensor, [
            [7, 6, 5, 4, 2],
            [6, 5, 4, 2],
            [5, 4, 2],
            [4, 2],
        ]))

        # build dictionary
        d = Dictionary()
        for line in txt:
            Tokenizer.tokenize(line, d, add_if_not_exist=True)

        def get_ids(dictionary):
            ids = []
            for line in txt:
                ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
            return ids

        def assertMatch(ids, ref_ids):
            for toks, ref_toks in zip(ids, ref_ids):
                self.assertEqual(toks.size(), ref_toks.size())
                self.assertEqual(0, (toks != ref_toks).sum().item())

        ids = get_ids(d)
        assertMatch(ids, ref_ids1)

        # check finalized dictionary
        d.finalize()
        finalized_ids = get_ids(d)
        assertMatch(finalized_ids, ref_ids2)

        # write to disk and reload
        with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
            d.save(tmp_dict.name)
            d = Dictionary.load(tmp_dict.name)
            reload_ids = get_ids(d)
            assertMatch(reload_ids, ref_ids2)
            assertMatch(finalized_ids, reload_ids)