def test_save_restore_artifact_agg(self, asr_model, test_data_dir):
        tokenizer_dir = os.path.join(test_data_dir, "asr", "tokenizers",
                                     "an4_spe_128")
        tok_en = {"dir": tokenizer_dir, "type": "wpe"}
        # the below is really an english tokenizer but we pretend it is spanish
        tok_es = {"dir": tokenizer_dir, "type": "wpe"}
        tcfg = DictConfig({
            "type": "agg",
            "langs": {
                "en": tok_en,
                "es": tok_es
            }
        })
        with tempfile.TemporaryDirectory() as tmpdir:
            asr_model.change_vocabulary(new_tokenizer_dir=tcfg,
                                        new_tokenizer_type="agg")

            save_path = os.path.join(tmpdir, "ctc_agg.nemo")
            asr_model.train()
            asr_model.save_to(save_path)

            new_model = EncDecRNNTBPEModel.restore_from(save_path)
            assert isinstance(new_model, type(asr_model))
            assert isinstance(new_model.tokenizer,
                              tokenizers.AggregateTokenizer)

            # should be double
            assert new_model.tokenizer.tokenizer.vocab_size == 254
            assert len(new_model.tokenizer.tokenizer.get_vocab()) == 254
    def test_save_restore_artifact(self, asr_model):
        asr_model.train()

        with tempfile.TemporaryDirectory() as tmp_dir:
            path = os.path.join(tmp_dir, 'rnnt_bpe.nemo')
            asr_model.save_to(path)

            new_model = EncDecRNNTBPEModel.restore_from(path)
            assert isinstance(new_model, type(asr_model))
            assert new_model.vocab_path == 'vocab.txt'

            assert len(new_model.tokenizer.tokenizer.get_vocab()) == 128
Example #3
0
    def test_save_restore_artifact_spe(self, asr_model, test_data_dir):
        asr_model.train()

        with tempfile.TemporaryDirectory() as tmpdir:
            tokenizer_dir = os.path.join(test_data_dir, "asr", "tokenizers",
                                         "an4_spe_128")
            asr_model.change_vocabulary(new_tokenizer_dir=tokenizer_dir,
                                        new_tokenizer_type='bpe')

            save_path = os.path.join(tmpdir, 'ctc_bpe.nemo')
            asr_model.train()
            asr_model.save_to(save_path)

            new_model = EncDecRNNTBPEModel.restore_from(save_path)
            assert isinstance(new_model, type(asr_model))
            assert isinstance(new_model.tokenizer,
                              tokenizers.SentencePieceTokenizer)
            assert new_model.model_path.endswith('_tokenizer.model')
            assert new_model.vocab_path.endswith('_vocab.txt')
            assert new_model.spe_vocab_path.endswith('_tokenizer.vocab')