Ejemplo n.º 1
0
 def test_creation_saving_restoring(self):
     cfg = AAYNBaseConfig()
     cfg.encoder_tokenizer.tokenizer_name = 'yttm'
     cfg.encoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
     cfg.decoder_tokenizer.tokenizer_name = 'yttm'
     cfg.decoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
     cfg.train_ds = None
     cfg.validation_ds = None
     cfg.test_ds = None
     model = MTEncDecModel(cfg=cfg)
     assert isinstance(model, MTEncDecModel)
     # Create a new temporary directory
     with tempfile.TemporaryDirectory() as restore_folder:
         with tempfile.TemporaryDirectory() as save_folder:
             save_folder_path = save_folder
             # Where model will be saved
             model_save_path = os.path.join(
                 save_folder, f"{model.__class__.__name__}.nemo")
             model.save_to(save_path=model_save_path)
             # Where model will be restored from
             model_restore_path = os.path.join(
                 restore_folder, f"{model.__class__.__name__}.nemo")
             shutil.copy(model_save_path, model_restore_path)
         # at this point save_folder should not exist
         assert save_folder_path is not None and not os.path.exists(
             save_folder_path)
         assert not os.path.exists(model_save_path)
         assert os.path.exists(model_restore_path)
         # attempt to restore
         model_copy = model.__class__.restore_from(
             restore_path=model_restore_path)
         assert model.num_weights == model_copy.num_weights
Ejemplo n.º 2
0
def get_cfg():
    cfg = AAYNBaseConfig()
    cfg.encoder_tokenizer.tokenizer_name = 'yttm'
    cfg.encoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
    cfg.decoder_tokenizer.tokenizer_name = 'yttm'
    cfg.decoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
    cfg.train_ds = None
    cfg.validation_ds = None
    cfg.test_ds = None
    return cfg