Пример #1
0
 def test_load_checkpoint_no_restore_state(self):
     """Train for one step, save a checkpoint, and make sure it is loaded
     properly WITHOUT loading the extra state from the checkpoint."""
     test_save_file = test_utils.make_temp_file()
     test_args = test_utils.ModelParamsDict()
     test_args.distributed_rank = 0
     extra_state = test_utils.create_dummy_extra_state(epoch=2)
     trainer, _ = test_utils.gpu_train_step(test_args)
     trainer.save_checkpoint(test_save_file, extra_state)
     loaded, extra_state = checkpoint.load_existing_checkpoint(
         test_save_file, trainer, restore_state=False
     )
     # Loading checkpoint without restore state should reset extra state
     assert loaded and extra_state is None
     os.remove(test_save_file)
Пример #2
0
 def test_first_layer_multihead_attention_(self):
     test_args = test_utils.ModelParamsDict(attention_type="multihead",
                                            attention_heads=2,
                                            first_layer_attention=True)
     trainer, _ = test_utils.gpu_train_step(test_args)
     assert trainer.get_meter("gnorm").avg > 0
Пример #3
0
 def test_layer_norm_lstm_cell(self):
     test_args = test_utils.ModelParamsDict(cell_type="layer_norm_lstm")
     trainer, _ = test_utils.gpu_train_step(test_args)
     assert trainer.get_meter("gnorm").avg > 0
Пример #4
0
 def test_sequence_lstm_encoder(self):
     test_args = test_utils.ModelParamsDict(encoder_bidirectional=True,
                                            sequence_lstm=True)
     trainer, _ = test_utils.gpu_train_step(test_args)
     assert trainer.get_meter("gnorm").avg > 0
Пример #5
0
 def test_gpu_freeze_embedding(self):
     test_args = test_utils.ModelParamsDict(encoder_freeze_embed=True,
                                            decoder_freeze_embed=True)
     test_utils.gpu_train_step(test_args)
Пример #6
0
 def test_gpu_train_step(self):
     test_args = test_utils.ModelParamsDict()
     trainer, _ = test_utils.gpu_train_step(test_args)
     assert trainer.get_meter("gnorm").avg > 0