Exemple #1
0
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict and the pytorch model
    """
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # PyTorch transformer model setup
    learning_rate = 0.1
    optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint_file_name = 'checkpoint*.ortcp'
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name))
    trainer.load_checkpoint(*checkpoint_files)

    # run an eval step to innitialize the graph
    torch.manual_seed(seed)
    set_seed(seed)
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return trainer.state_dict(), model
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict and the pytorch model
    """
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # PyTorch transformer model setup
    learning_rate = 0.1
    optim_config = optim.LambConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint.experimental_load_checkpoint(trainer, checkpoint_dir)

    # run an eval step to innitialize the graph
    torch.manual_seed(seed)
    set_seed(seed)
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return checkpoint.experimental_state_dict(trainer), model
def testORTTransformerModelExport(seed, device):
    # Common setup
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'check_model_export': True,
        },
        'device': {
            'id': device,
        }
    })

    # Setup for the first ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(
        device)
    first_trainer = orttrainer.ORTTrainer(model,
                                          model_desc,
                                          optim_config,
                                          loss_fn=my_loss,
                                          options=opts)
    data, targets = batcher_fn(train_data, 0)
    _ = first_trainer.train_step(data, targets)
    assert first_trainer._onnx_model is not None
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses):
    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
Exemple #5
0
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True):
    learning_rate = 0.1
    seed = 1

    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load dummy state
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state)

    # run an eval step to innitialize the graph
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return dummy_init_state, checkpoint.experimental_state_dict(trainer)
Exemple #6
0
def create_orttrainer_and_save_checkpoint(device,
                                          trainer_opts,
                                          checkpoint_dir,
                                          state_dict_key_name='state_dict',
                                          use_lamb=True):
    learning_rate = 0.1
    seed = 1

    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    if 'distributed' in trainer_opts:
        train_data = next(
            islice(
                _chunkify(train_data,
                          trainer_opts['distributed']['world_size']),
                trainer_opts['distributed']['world_rank'], None))

    # run train steps
    _train(trainer, train_data, batcher_fn)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        _save(trainer, checkpoint_dir, state_dict_key_name)
Exemple #7
0
def _create_trainer(zero_enabled=False):
    """Cerates a simple ORTTrainer for ORTTrainer functional tests"""

    device = 'cuda'
    optim_config = optim.LambConfig(lr=0.1)
    opts = {'device': {'id': device}, 'debug': {'deterministic_compute': True}}
    if zero_enabled:
        opts['distributed'] = {
            'world_rank': 0,
            'world_size': 1,
            'horizontal_parallel_size': 1,
            'data_parallel_size': 1,
            'allreduce_post_accumulation': True,
            'deepspeed_zero_optimization': {
                'stage': 1
            }
        }
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=orttrainer.ORTTrainerOptions(opts))

    return trainer
def _create_trainer(zero_enabled=False):
    """Cerates a simple ORTTrainer for ORTTrainer functional tests"""

    device = "cuda"
    optim_config = optim.LambConfig(lr=0.1)
    opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}}
    if zero_enabled:
        opts["distributed"] = {
            "world_rank": 0,
            "world_size": 1,
            "horizontal_parallel_size": 1,
            "data_parallel_size": 1,
            "allreduce_post_accumulation": True,
            "deepspeed_zero_optimization": {
                "stage": 1
            },
        }
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=orttrainer.ORTTrainerOptions(opts))

    return trainer
def testToyBertCheckpointLoadZero():
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True},
                                         'device' : {'id' : device},
                                         'distributed' : {'allreduce_post_accumulation' : True}})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    ckpt_dir = _test_helpers._get_name("ort_ckpt")
    checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb')

    # Expected values
    expected_eval_loss = [10.997552871]
    input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device)
    segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device)
    input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device)
    masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device)
    next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

    # Actual values
    actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels)
    actual_eval_loss = actual_eval_loss.cpu().numpy().item(0)

    # Check results
    assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
            "batch": {"gradient_accumulation_steps": gradient_accumulation_steps},
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testToyBERTDeterministicCheck(expected_losses):
    # Common setup
    train_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    params = optimizer_parameters(model)
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol)
Exemple #12
0
def create_orttrainer_and_save_checkpoint_bart(
        device,
        trainer_opts,
        checkpoint_dir,
        state_dict_key_name="state_dict",
        use_lamb=True,
        seed=1,
        learning_rate=0.1):
    """Instantiate trainer and save checkpoint for BART.

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads a dummy optimizer state into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=ort_trainer_opts)

    # load dummy optimizer state as we are not going to run real training
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    init_state = copy.deepcopy(dummy_init_state)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
            # save the initial complete model and optimizer states
            if ort_trainer_opts.distributed.world_rank == 0:
                init_state["model"] = {"full_precision": dict()}
                for initializer in model.graph.initializer:
                    init_state["model"]["full_precision"][
                        initializer.name] = numpy_helper.to_array(initializer)
                with open(
                        os.path.join(checkpoint_dir,
                                     "expected_state_dict.pkl"), "wb") as f:
                    pickle.dump(init_state, f)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)
def testToyBertCheckpointFrozenWeights():
    # Common setup
    seed = 1
    total_steps = 10
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]},
        }
    )

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    optim_config = optim.LambConfig()
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train for a few steps
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, seed)
        _ = trainer.train_step(*sample_input)
    sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1)
    # Evaluate once to get a base loss
    loss = trainer.eval_step(*sample_input)
    # Save checkpoint
    state_dict = trainer.state_dict()

    # Load previous state into another instance of ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    optim_config2 = optim.LambConfig()
    trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts)
    trainer2.load_state_dict(state_dict)
    # Evaluate once to get a base loss
    ckpt_loss = trainer2.eval_step(*sample_input)

    # Must match as both trainers have the same dict state
    assert_allclose(loss.cpu(), ckpt_loss.cpu())
    loaded_state_dict = trainer2.state_dict()
    _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict)
def testORTTrainerFrozenWeights(model_params):
    device = 'cuda'
    total_steps = 10
    seed = 1

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    # Setup ORTTrainer WITHOUT frozen weights
    opts_dict = {
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    }
    opts = orttrainer.ORTTrainerOptions(opts_dict)

    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        trainer.train_step(*sample_input)

    # All model_params must be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert all([param in session_state for param in model_params])

    # Setup ORTTrainer WITH frozen weights
    opts_dict.update({'utils': {'frozen_weights': model_params}})
    opts = orttrainer.ORTTrainerOptions(opts_dict)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        trainer.train_step(*sample_input)

    # All model_params CANNOT be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert not any([param in session_state for param in model_params])
def testToyBERTModelBasicTraining(dynamic_shape):
    model_desc = bert_model_description(dynamic_shape)
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({})
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    for i in range(10):
        sample_input = generate_random_input_from_model_desc(model_desc)
        output = trainer.train_step(*sample_input)
        assert output.shape == torch.Size([])
Exemple #16
0
def create_initialized_orttrainer(device, trainer_opts, use_lamb=True):
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    learning_rate = 1e-10
    optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    _train(trainer, train_data, batcher_fn)

    return trainer
Exemple #17
0
def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {'debug': {
            'deterministic_compute': True
        }})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    sd = checkpoint.experimental_state_dict(trainer)

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd

    ## Modify one of the state values and load into ORTTrainer
    sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    checkpoint.experimental_load_state_dict(trainer, sd)

    ## Save a checkpoint
    ckpt_dir = 'testdata'
    checkpoint.experimental_save_checkpoint(trainer, ckpt_dir,
                                            'bert_toy_save_test')
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config,
                                     options=opts)
    checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir,
                                            'bert_toy_save_test')
    loaded_sd = checkpoint.experimental_state_dict(trainer2)

    # Assert whether original state and the one loaded from checkpoint matches
    for k, v in loaded_sd.items():
        assert torch.all(torch.eq(v, sd[k]))
def testToyBERTSaveAsONNX():
    device = 'cuda'
    onnx_file_name = '_____temp_toy_bert_onnx_model.onnx'
    if os.path.exists(onnx_file_name):
        os.remove(onnx_file_name)
    assert not os.path.exists(onnx_file_name)

    # Load trainer
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })

    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    trainer.save_as_onnx(onnx_file_name)
    assert os.path.exists(onnx_file_name)

    with open(onnx_file_name, "rb") as f:
        bin_str = f.read()
        reload_onnx_model = onnx.load_model_from_string(bin_str)
    os.remove(onnx_file_name)

    # Create a new trainer from persisted ONNX model and compare with original ONNX model
    trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model,
                                              model_desc,
                                              optim_config,
                                              options=opts)
    assert trainer_from_onnx._onnx_model is not None
    assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
    for initializer, loaded_initializer in zip(
            trainer._onnx_model.graph.initializer,
            trainer_from_onnx._onnx_model.graph.initializer):
        assert initializer.name == loaded_initializer.name
    assert (onnx.helper.printable_graph(
        trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(
            trainer._onnx_model.graph))
    _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {'debug': {
            'deterministic_compute': True
        }})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    sd = trainer.state_dict()

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd['model']['full_precision']

    ## Modify one of the state values and load into ORTTrainer
    sd['model']['full_precision'][
        'bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    trainer.load_state_dict(sd)

    ## Save a checkpoint
    ckpt_dir = 'testdata'
    trainer.save_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp'))
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config,
                                     options=opts)
    trainer2.load_checkpoint(os.path.join(ckpt_dir,
                                          'bert_toy_save_test.ortcp'))
    loaded_sd = trainer2.state_dict()

    # Assert whether original state and the one loaded from checkpoint matches
    _test_commons.assert_all_states_close_ort(sd, loaded_sd)
Exemple #20
0
def verify_optimizer_state_match(device,
                                 opts,
                                 checkpoint_dir,
                                 world_rank,
                                 use_lamb=False):
    expected_optim_state, trainer_state = load_model_optim_state_and_eval(
        device, opts, use_lamb)
    trainer_state = split_state_dict(trainer_state)
    # round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them.
    with open(
            os.path.join(checkpoint_dir,
                         'distributed_state_' + str(world_rank) + '.pkl'),
            "wb") as f:
        pickle.dump(trainer_state, f)
    dist.barrier()

    if world_rank == 0:
        num_states = len(glob.glob1(checkpoint_dir, "distributed_state*"))
        optimizer_states = dict()
        for rank in range(num_states):
            rank_state_dict = None
            with open(
                    os.path.join(checkpoint_dir,
                                 'distributed_state_' + str(rank) + '.pkl'),
                    'rb') as f:
                rank_state_dict = pickle.load(f)

            # collect optimizer states for later comparison since they are sharded
            aggregate_states(optimizer_states, rank_state_dict['optimizer'])

        # compare optimizer states
        optimizer_config = optim.LambConfig(
        ) if use_lamb else optim.AdamConfig()
        actual_optim_state = get_optim_state_from_state_dict(
            optimizer_states, optimizer_config)
        assert actual_optim_state.keys() == expected_optim_state.keys()
        for param_name, a_state in actual_optim_state.items():
            for k, v in a_state.items():
                assert_allclose(
                    v.reshape(expected_optim_state[param_name][k].shape),
                    expected_optim_state[param_name][k],
                    err_msg=
                    f"Optimizer state mismatch for param {param_name}, key {k}"
                )

    dist.barrier()
    os.remove(
        os.path.join(checkpoint_dir,
                     'distributed_state_' + str(world_rank) + '.pkl'))
def _run_adasum_tests(opts):
    # Common setup
    seed = 42
    optim_config = optim.LambConfig()
    # Setup ORTTRainer
    torch.manual_seed(seed)
    set_seed(seed)
    model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        opts.device.id)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=my_loss,
                                    options=opts)
    # Train once to see flag going through
    data, targets = batcher_fn(train_data, 0)
    result = trainer.train_step(data, targets)
    assert result is not None
Exemple #22
0
def create_orttrainer_and_load_checkpoint_bart(device,
                                               trainer_opts,
                                               checkpoint_dir,
                                               use_lamb=True,
                                               seed=1,
                                               learning_rate=0.1):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    # model setup
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint_file_name = "checkpoint*.ortcp"
    checkpoint_files = glob.glob(
        os.path.join(checkpoint_dir, checkpoint_file_name))
    trainer.load_checkpoint(*checkpoint_files)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    expected_state_dict = None
    fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl")
    if os.path.isfile(fname):
        with open(fname, "rb") as f:
            expected_state_dict = pickle.load(f)

    return trainer.state_dict(), expected_state_dict, model
Exemple #23
0
def create_orttrainer_and_save_checkpoint(device,
                                          trainer_opts,
                                          checkpoint_dir,
                                          state_dict_key_name="state_dict",
                                          use_lamb=True,
                                          seed=1,
                                          learning_rate=0.1):
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=ort_trainer_opts)

    if "distributed" in trainer_opts:
        train_data = next(
            islice(
                _chunkify(train_data,
                          trainer_opts["distributed"]["world_size"]),
                trainer_opts["distributed"]["world_rank"],
                None,
            ))

    # run train steps
    _train(trainer, train_data, batcher_fn)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config2,
                                     options=opts)
    trainer2.load_state_dict(state_dict)
    # Evaluate once to get a base loss
    ckpt_loss = trainer2.eval_step(*sample_input)

    # Must match as both trainers have the same dict state
    assert_allclose(loss.cpu(), ckpt_loss.cpu())
    loaded_state_dict = trainer2.state_dict()
    _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict)


@pytest.mark.parametrize("optimizer, mixedprecision_enabled", [
    (optim.LambConfig(), False),
    (optim.AdamConfig(), False),
    (optim.LambConfig(), True),
    (optim.AdamConfig(), True),
])
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled):
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optimizer
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
Exemple #25
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=14,
            _use_deterministic_compute=True)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)
def create_ort_trainer(args, device, model):

    # MODEL INPUT AND OUTPUT DESCRIPTION
    vocab_size = 30528
    micro_batch = args.train_batch_size // args.gradient_accumulation_steps
    model_desc = {
        'inputs':
        [('input_ids', [args.train_batch_size, args.max_seq_length]),
         ('segment_ids', [args.train_batch_size, args.max_seq_length]),
         ('input_mask', [args.train_batch_size, args.max_seq_length]),
         ('masked_lm_labels', [args.train_batch_size, args.max_seq_length]),
         ('next_sentence_labels', [args.train_batch_size, 2])],
        'outputs': [('loss', [], True)]
    }

    # TRAINING OPTIMIZER SPECIFICATION
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
    optim_config = optim.LambConfig(
        lr=args.learning_rate,
        alpha=0.9,
        beta=0.999,
        lambda_coef=0.01,
        epsilon=1e-6,
        do_bias_correction=True,
        params=[{
            'params':
            [n for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'alpha':
            0.9,
            'beta':
            0.999,
            'lambda_coef':
            0.00,
            'epsilon':
            1e-6
        }])

    # LEARNING RATE SCHEDULE SPECIFICATION
    lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(
        total_steps=int(args.max_steps), warmup=args.warmup_proportion)

    # ONNXRUNTIME TRAINER OPTIONS
    trainer_config = ORTTrainerOptions({
        'device': {
            'id': str(device),
            'mem_limit': int(args.gpu_memory_limit_gb * 1024 * 1024 * 1024)
        },
        'batch': {
            'gradient_accumulation_steps': args.gradient_accumulation_steps
        },
        'distributed': {
            'world_size':
            args.world_size,
            'world_rank':
            args.world_rank,
            'allreduce_post_accumulation':
            True if args.allreduce_post_accumulation else False,
            'deepspeed_zero_optimization': {
                'stage': 1 if args.deepspeed_zero_stage else 0,
            }
        },
        'lr_scheduler': lr_scheduler,
        'mixed_precision': {
            'enabled': True if args.fp16 else False,
        }
    })

    # ONNXRUNTIME TRAINER CONSTRUCTION (loss fn embedded in model)
    trainer = ORTTrainer(model,
                         model_desc,
                         optim_config,
                         loss_fn=None,
                         options=trainer_config)

    return trainer