コード例 #1
0
ファイル: test_nmt_model.py プロジェクト: xaiguy/NeMo
 def test_creation_saving_restoring(self):
     cfg = AAYNBaseConfig()
     cfg.encoder_tokenizer.tokenizer_name = 'yttm'
     cfg.encoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
     cfg.decoder_tokenizer.tokenizer_name = 'yttm'
     cfg.decoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
     cfg.train_ds = None
     cfg.validation_ds = None
     cfg.test_ds = None
     model = MTEncDecModel(cfg=cfg)
     assert isinstance(model, MTEncDecModel)
     # Create a new temporary directory
     with tempfile.TemporaryDirectory() as restore_folder:
         with tempfile.TemporaryDirectory() as save_folder:
             save_folder_path = save_folder
             # Where model will be saved
             model_save_path = os.path.join(
                 save_folder, f"{model.__class__.__name__}.nemo")
             model.save_to(save_path=model_save_path)
             # Where model will be restored from
             model_restore_path = os.path.join(
                 restore_folder, f"{model.__class__.__name__}.nemo")
             shutil.copy(model_save_path, model_restore_path)
         # at this point save_folder should not exist
         assert save_folder_path is not None and not os.path.exists(
             save_folder_path)
         assert not os.path.exists(model_save_path)
         assert os.path.exists(model_restore_path)
         # attempt to restore
         model_copy = model.__class__.restore_from(
             restore_path=model_restore_path)
         assert model.num_weights == model_copy.num_weights
コード例 #2
0
ファイル: test_nmt_model.py プロジェクト: chenchy/NeMo
def get_cfg():
    cfg = AAYNBaseConfig()
    cfg.encoder_tokenizer.tokenizer_name = 'yttm'
    cfg.encoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
    cfg.decoder_tokenizer.tokenizer_name = 'yttm'
    cfg.decoder_tokenizer.tokenizer_model = 'tests/.data/yttm.4096.en-de.model'
    cfg.train_ds = None
    cfg.validation_ds = None
    cfg.test_ds = None
    return cfg
コード例 #3
0
ファイル: enc_dec_nmt.py プロジェクト: zt706/NeMo
class MTEncDecConfig(NemoConfig):
    model: AAYNBaseConfig = AAYNBaseConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(
        name='MTEncDec', files_to_copy=[])
コード例 #4
0
ファイル: enc_dec_nmt.py プロジェクト: mia0226/NeMo
class MTEncDecConfig(NemoConfig):
    name: Optional[str] = 'MTEncDec'
    do_training: bool = True
    model: AAYNBaseConfig = AAYNBaseConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTEncDec', files_to_copy=[])