def testToyBertCheckpointLoadZero(): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, 'device' : {'id' : device}, 'distributed' : {'allreduce_post_accumulation' : True}}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb') # Expected values expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device) segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_eval_loss = actual_eval_loss.cpu().numpy().item(0) # Check results assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = checkpoint.experimental_state_dict(trainer) ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd ## Modify one of the state values and load into ORTTrainer sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 checkpoint.experimental_load_state_dict(trainer, sd) ## Save a checkpoint ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test') del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test') loaded_sd = checkpoint.experimental_state_dict(trainer2) # Assert whether original state and the one loaded from checkpoint matches for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))