Ejemplo n.º 1
0
    def testBertCheckpointingLoadZero(self):
        torch.manual_seed(1)
        onnxruntime.set_seed(1)
        model,_,device = create_ort_trainer(gradient_accumulation_steps=1,
                        use_mixed_precision=False,
                        allreduce_post_accumulation=True,
                        use_simple_model_desc=True,
                        loss_scaler=None)

        ckpt_dir = get_name("ort_ckpt")
        load_checkpoint(model, ckpt_dir, 'bert_toy_lamb')
                
        expected_eval_loss = [10.997552871]

        input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[  635],[27399],[20647],[18519],[15546]], device=device)
        segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device)
        input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device)
        masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device)
        next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

        actual_eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss'])
        actual_eval_loss = actual_eval_loss.cpu().numpy().item(0)
        # import pdb; pdb.set_trace()
        print(actual_eval_loss)

        rtol = 1e-03
        assert_allclose(expected_eval_loss, actual_eval_loss, err_msg="evaluation loss mismatch")
Ejemplo n.º 2
0
    def testBertCheckpointingBasic(self):
        model,_,_ = create_ort_trainer(gradient_accumulation_steps=1,
                        use_mixed_precision=False,
                        allreduce_post_accumulation=True,
                        use_simple_model_desc=True,
                        loss_scaler=None)
        sd = model.state_dict()

        # modify one of the default values
        sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] +=1
        model.load_state_dict(sd)

        ckpt_dir = get_name("ort_ckpt")
        save_checkpoint(model, ckpt_dir, 'bert_toy_save_test')
        del model

        # create new model
        model2,_,_ = create_ort_trainer(gradient_accumulation_steps=1,
                        use_mixed_precision=False,
                        allreduce_post_accumulation=True,
                        use_simple_model_desc=True,
                        loss_scaler=None)

        # load changed checkpoint
        load_checkpoint(model2, ckpt_dir, 'bert_toy_save_test')
        loaded_sd = model2.state_dict()

        for k,v in loaded_sd.items():
            assert torch.all(torch.eq(v, sd[k]))