Ejemplo n.º 1
0
def load_model_optim_state_and_eval(device,
                                    trainer_opts,
                                    use_lamb=True,
                                    seed=1,
                                    learning_rate=0.1):
    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load dummy state
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    optimizer_state_dict = trainer.state_dict()
    del optimizer_state_dict["model"]

    return dummy_init_state, optimizer_state_dict
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled):
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optimizer
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device
        },
        'mixed_precision': {
            'enabled': mixedprecision_enabled,
        },
        'distributed': {
            'allreduce_post_accumulation': True
        }
    })

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    dummy_init_state = _test_commons.generate_dummy_optim_state(
        model, optimizer)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    trainer.load_state_dict(dummy_init_state)

    # Expected values
    input_ids = torch.tensor(
        [[26598], [21379], [19922], [5219], [5644], [20559], [23777], [25672],
         [22969], [16824], [16822], [635], [27399], [20647], [18519], [15546]],
        device=device)
    segment_ids = torch.tensor([[0], [1], [0], [1], [0], [0], [1], [0], [0],
                                [1], [1], [0], [0], [1], [1], [1]],
                               device=device)
    input_mask = torch.tensor([[0], [0], [0], [0], [1], [1], [1], [0], [1],
                               [1], [0], [0], [0], [1], [0], [0]],
                              device=device)
    masked_lm_labels = torch.tensor(
        [[25496], [16184], [11005], [16228], [14884], [21660], [8678], [23083],
         [4027], [8397], [11921], [1333], [26482], [1666], [17925], [27978]],
        device=device)
    next_sentence_labels = torch.tensor(
        [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

    # Actual values
    _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels,
                          next_sentence_labels)

    actual_state_dict = trainer.state_dict()
    del actual_state_dict['model']
    _test_commons.assert_all_states_close_ort(actual_state_dict,
                                              dummy_init_state)
Ejemplo n.º 3
0
def create_orttrainer_and_save_checkpoint_bart(
        device,
        trainer_opts,
        checkpoint_dir,
        state_dict_key_name="state_dict",
        use_lamb=True,
        seed=1,
        learning_rate=0.1):
    """Instantiate trainer and save checkpoint for BART.

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads a dummy optimizer state into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=ort_trainer_opts)

    # load dummy optimizer state as we are not going to run real training
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    init_state = copy.deepcopy(dummy_init_state)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
            # save the initial complete model and optimizer states
            if ort_trainer_opts.distributed.world_rank == 0:
                init_state["model"] = {"full_precision": dict()}
                for initializer in model.graph.initializer:
                    init_state["model"]["full_precision"][
                        initializer.name] = numpy_helper.to_array(initializer)
                with open(
                        os.path.join(checkpoint_dir,
                                     "expected_state_dict.pkl"), "wb") as f:
                    pickle.dump(init_state, f)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)