def testBertCheckpointingLoadZero(self): torch.manual_seed(1) onnxruntime.set_seed(1) model,_,device = create_ort_trainer(gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=True, use_simple_model_desc=True, loss_scaler=None) ckpt_dir = get_name("ort_ckpt") load_checkpoint(model, ckpt_dir, 'bert_toy_lamb') expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[ 635],[27399],[20647],[18519],[15546]], device=device) segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) actual_eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss']) actual_eval_loss = actual_eval_loss.cpu().numpy().item(0) # import pdb; pdb.set_trace() print(actual_eval_loss) rtol = 1e-03 assert_allclose(expected_eval_loss, actual_eval_loss, err_msg="evaluation loss mismatch")
def testBertCheckpointingBasic(self): model,_,_ = create_ort_trainer(gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=True, use_simple_model_desc=True, loss_scaler=None) sd = model.state_dict() # modify one of the default values sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] +=1 model.load_state_dict(sd) ckpt_dir = get_name("ort_ckpt") save_checkpoint(model, ckpt_dir, 'bert_toy_save_test') del model # create new model model2,_,_ = create_ort_trainer(gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=True, use_simple_model_desc=True, loss_scaler=None) # load changed checkpoint load_checkpoint(model2, ckpt_dir, 'bert_toy_save_test') loaded_sd = model2.state_dict() for k,v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))