def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {'debug': {
            'deterministic_compute': True
        }})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    sd = trainer.state_dict()

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd['model']['full_precision']

    ## Modify one of the state values and load into ORTTrainer
    sd['model']['full_precision'][
        'bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    trainer.load_state_dict(sd)

    ## Save a checkpoint
    ckpt_dir = 'testdata'
    trainer.save_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp'))
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config,
                                     options=opts)
    trainer2.load_checkpoint(os.path.join(ckpt_dir,
                                          'bert_toy_save_test.ortcp'))
    loaded_sd = trainer2.state_dict()

    # Assert whether original state and the one loaded from checkpoint matches
    _test_commons.assert_all_states_close_ort(sd, loaded_sd)
def testToyBERTSaveAsONNX():
    device = 'cuda'
    onnx_file_name = '_____temp_toy_bert_onnx_model.onnx'
    if os.path.exists(onnx_file_name):
        os.remove(onnx_file_name)
    assert not os.path.exists(onnx_file_name)

    # Load trainer
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })

    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    trainer.save_as_onnx(onnx_file_name)
    assert os.path.exists(onnx_file_name)

    with open(onnx_file_name, "rb") as f:
        bin_str = f.read()
        reload_onnx_model = onnx.load_model_from_string(bin_str)
    os.remove(onnx_file_name)

    # Create a new trainer from persisted ONNX model and compare with original ONNX model
    trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model,
                                              model_desc,
                                              optim_config,
                                              options=opts)
    assert trainer_from_onnx._onnx_model is not None
    assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
    for initializer, loaded_initializer in zip(
            trainer._onnx_model.graph.initializer,
            trainer_from_onnx._onnx_model.graph.initializer):
        assert initializer.name == loaded_initializer.name
    assert (onnx.helper.printable_graph(
        trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(
            trainer._onnx_model.graph))
    _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
def testToyBertCheckpointLoadZero():
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True},
                                         'device' : {'id' : device},
                                         'distributed' : {'allreduce_post_accumulation' : True}})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    ckpt_dir = _test_helpers._get_name("ort_ckpt")
    checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb')

    # Expected values
    expected_eval_loss = [10.997552871]
    input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device)
    segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device)
    input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device)
    masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device)
    next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

    # Actual values
    actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels)
    actual_eval_loss = actual_eval_loss.cpu().numpy().item(0)

    # Check results
    assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
            "batch": {"gradient_accumulation_steps": gradient_accumulation_steps},
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
Ejemplo n.º 5
0
def load_model_optim_state_and_eval(device,
                                    trainer_opts,
                                    use_lamb=True,
                                    seed=1,
                                    learning_rate=0.1):
    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load dummy state
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    optimizer_state_dict = trainer.state_dict()
    del optimizer_state_dict["model"]

    return dummy_init_state, optimizer_state_dict
Ejemplo n.º 6
0
def create_orttrainer_and_save_checkpoint(device,
                                          trainer_opts,
                                          checkpoint_dir,
                                          state_dict_key_name='state_dict',
                                          use_lamb=True):
    learning_rate = 0.1
    seed = 1

    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    if 'distributed' in trainer_opts:
        train_data = next(
            islice(
                _chunkify(train_data,
                          trainer_opts['distributed']['world_size']),
                trainer_opts['distributed']['world_rank'], None))

    # run train steps
    _train(trainer, train_data, batcher_fn)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        _save(trainer, checkpoint_dir, state_dict_key_name)
Ejemplo n.º 7
0
def _create_trainer(zero_enabled=False):
    """Cerates a simple ORTTrainer for ORTTrainer functional tests"""

    device = 'cuda'
    optim_config = optim.LambConfig(lr=0.1)
    opts = {'device': {'id': device}, 'debug': {'deterministic_compute': True}}
    if zero_enabled:
        opts['distributed'] = {
            'world_rank': 0,
            'world_size': 1,
            'horizontal_parallel_size': 1,
            'data_parallel_size': 1,
            'allreduce_post_accumulation': True,
            'deepspeed_zero_optimization': {
                'stage': 1
            }
        }
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=orttrainer.ORTTrainerOptions(opts))

    return trainer
def testORTTransformerModelExport(seed, device):
    # Common setup
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'check_model_export': True,
        },
        'device': {
            'id': device,
        }
    })

    # Setup for the first ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(
        device)
    first_trainer = orttrainer.ORTTrainer(model,
                                          model_desc,
                                          optim_config,
                                          loss_fn=my_loss,
                                          options=opts)
    data, targets = batcher_fn(train_data, 0)
    _ = first_trainer.train_step(data, targets)
    assert first_trainer._onnx_model is not None
def testToyBERTDeterministicCheck(expected_losses):
    # Common setup
    train_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    params = optimizer_parameters(model)
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol)
    def test_external_graph_transformer_triggering(self):
        input_size = 784
        hidden_size = 500
        num_classes = 10
        batch_size = 128
        model = NeuralNet(input_size, hidden_size, num_classes)

        model_desc = {
            'inputs': [('x', [batch_size, input_size]),
                       ('target', [
                           batch_size,
                       ])],
            'outputs': [('loss', [], True)]
        }
        optim_config = optim.SGDConfig()
        opts = orttrainer.ORTTrainerOptions({'device': {'id': 'cpu'}})
        model = orttrainer.ORTTrainer(model,
                                      model_desc,
                                      optim_config,
                                      options=opts)
        # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer
        data = torch.rand(batch_size, input_size)
        target = torch.randint(0, 10, (batch_size, ))

        with OutputGrabber() as out:
            loss = model.train_step(data, target)
        assert '******************Trigger Customized Graph Transformer:  MyGraphTransformer!' in out.capturedtext
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled):
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optimizer
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device
        },
        'mixed_precision': {
            'enabled': mixedprecision_enabled,
        },
        'distributed': {
            'allreduce_post_accumulation': True
        }
    })

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    dummy_init_state = _test_commons.generate_dummy_optim_state(
        model, optimizer)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    trainer.load_state_dict(dummy_init_state)

    # Expected values
    input_ids = torch.tensor(
        [[26598], [21379], [19922], [5219], [5644], [20559], [23777], [25672],
         [22969], [16824], [16822], [635], [27399], [20647], [18519], [15546]],
        device=device)
    segment_ids = torch.tensor([[0], [1], [0], [1], [0], [0], [1], [0], [0],
                                [1], [1], [0], [0], [1], [1], [1]],
                               device=device)
    input_mask = torch.tensor([[0], [0], [0], [0], [1], [1], [1], [0], [1],
                               [1], [0], [0], [0], [1], [0], [0]],
                              device=device)
    masked_lm_labels = torch.tensor(
        [[25496], [16184], [11005], [16228], [14884], [21660], [8678], [23083],
         [4027], [8397], [11921], [1333], [26482], [1666], [17925], [27978]],
        device=device)
    next_sentence_labels = torch.tensor(
        [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

    # Actual values
    _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels,
                          next_sentence_labels)

    actual_state_dict = trainer.state_dict()
    del actual_state_dict['model']
    _test_commons.assert_all_states_close_ort(actual_state_dict,
                                              dummy_init_state)
Ejemplo n.º 12
0
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict and the pytorch model
    """
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # PyTorch transformer model setup
    learning_rate = 0.1
    optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint_file_name = 'checkpoint*.ortcp'
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name))
    trainer.load_checkpoint(*checkpoint_files)

    # run an eval step to innitialize the graph
    torch.manual_seed(seed)
    set_seed(seed)
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return trainer.state_dict(), model
Ejemplo n.º 13
0
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict and the pytorch model
    """
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # PyTorch transformer model setup
    learning_rate = 0.1
    optim_config = optim.LambConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint.experimental_load_checkpoint(trainer, checkpoint_dir)

    # run an eval step to innitialize the graph
    torch.manual_seed(seed)
    set_seed(seed)
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return checkpoint.experimental_state_dict(trainer), model
def _create_trainer(zero_enabled=False):
    """Cerates a simple ORTTrainer for ORTTrainer functional tests"""

    device = "cuda"
    optim_config = optim.LambConfig(lr=0.1)
    opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}}
    if zero_enabled:
        opts["distributed"] = {
            "world_rank": 0,
            "world_size": 1,
            "horizontal_parallel_size": 1,
            "data_parallel_size": 1,
            "allreduce_post_accumulation": True,
            "deepspeed_zero_optimization": {
                "stage": 1
            },
        }
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=orttrainer.ORTTrainerOptions(opts))

    return trainer
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses):
    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(
        loss_scaler, legacy_loss_scaler):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig(lr=0.001)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=legacy_loss_scaler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testORTTrainerFrozenWeights(model_params):
    device = 'cuda'
    total_steps = 10
    seed = 1

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    # Setup ORTTrainer WITHOUT frozen weights
    opts_dict = {
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    }
    opts =  orttrainer.ORTTrainerOptions(opts_dict)

    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        trainer.train_step(*sample_input)

    # All model_params must be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert all([param in session_state for param in model_params])

    # Setup ORTTrainer WITH frozen weights
    opts_dict.update({'utils' : {'frozen_weights' : model_params}})
    opts =  orttrainer.ORTTrainerOptions(opts_dict)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        trainer.train_step(*sample_input)

    # All model_params CANNOT be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert not any([param in session_state for param in model_params])
def testToyBERTModelGradientAccumulationLegacyExperimental(
        gradient_accumulation_steps):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'batch': {
            'gradient_accumulation_steps': gradient_accumulation_steps
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        loss = trainer.train_step(*sample_input)
        experimental_losses.append(loss.cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "AdamOptimizer",
        None,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
        gradient_accumulation_steps=gradient_accumulation_steps)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
Ejemplo n.º 19
0
def create_orttrainer_and_save_checkpoint_bart(
        device,
        trainer_opts,
        checkpoint_dir,
        state_dict_key_name="state_dict",
        use_lamb=True,
        seed=1,
        learning_rate=0.1):
    """Instantiate trainer and save checkpoint for BART.

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads a dummy optimizer state into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=ort_trainer_opts)

    # load dummy optimizer state as we are not going to run real training
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    init_state = copy.deepcopy(dummy_init_state)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
            # save the initial complete model and optimizer states
            if ort_trainer_opts.distributed.world_rank == 0:
                init_state["model"] = {"full_precision": dict()}
                for initializer in model.graph.initializer:
                    init_state["model"]["full_precision"][
                        initializer.name] = numpy_helper.to_array(initializer)
                with open(
                        os.path.join(checkpoint_dir,
                                     "expected_state_dict.pkl"), "wb") as f:
                    pickle.dump(init_state, f)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)
def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config):
    # Common setup
    train_steps = 512

    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    optim_config = optimizer_config(lr=0.01)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    if optimizer_config == optim.AdamConfig:
        legacy_optimizer = 'AdamOptimizer'
    elif optimizer_config == optim.LambConfig:
        legacy_optimizer = 'LambOptimizer'
    elif optimizer_config == optim.SGDConfig:
        legacy_optimizer = 'SGDOptimizer'
    else:
        raise RuntimeError("Invalid optimizer_config")

    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, legacy_optimizer,
                       None,
                       learning_rate_description,
                       device,
                       _use_deterministic_compute=True)
    legacy_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL API
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.AdamConfig(
        params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False
    )
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr)

    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "AdamOptimizer",
        legacy_optim_map,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
    )
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        legacy_sample_input = [*sample_input, learning_rate]
        legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBertCheckpointFrozenWeights():
    # Common setup
    seed = 1
    total_steps = 10
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]},
        }
    )

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    optim_config = optim.LambConfig()
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train for a few steps
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, seed)
        _ = trainer.train_step(*sample_input)
    sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1)
    # Evaluate once to get a base loss
    loss = trainer.eval_step(*sample_input)
    # Save checkpoint
    state_dict = trainer.state_dict()

    # Load previous state into another instance of ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    optim_config2 = optim.LambConfig()
    trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts)
    trainer2.load_state_dict(state_dict)
    # Evaluate once to get a base loss
    ckpt_loss = trainer2.eval_step(*sample_input)

    # Must match as both trainers have the same dict state
    assert_allclose(loss.cpu(), ckpt_loss.cpu())
    loaded_state_dict = trainer2.state_dict()
    _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict)
def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses):
    return  # TODO: re-enable after nondeterminism on backend is fixed
    # Common setup
    device = "cuda"
    total_steps = 10
    seed = 1
    warmup = 0.05
    cycles = 0.5
    power = 1.0
    lr_end = 1e-7
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Setup LR Schedulers
    if (
        lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler
        or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler
    ):
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
            "lr_scheduler": lr_scheduler,
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    learning_rates = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())
        learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0])

    # Check output
    _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol)
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testToyBERTModelBasicTraining(dynamic_shape):
    model_desc = bert_model_description(dynamic_shape)
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({})
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    for i in range(10):
        sample_input = generate_random_input_from_model_desc(model_desc)
        output = trainer.train_step(*sample_input)
        assert output.shape == torch.Size([])
def prepare_model(args, device):
    config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)

    # config.num_hidden_layers = 12
    if args.force_num_hidden_layers:
        logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
        config.num_hidden_layers = args.force_num_hidden_layers

    model = BertForPreTraining(config)
    if args.init_state_dict is not None:
        model.load_state_dict(args.init_state_dict)
    model_desc = bert_model_description(config)

    lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)

    loss_scaler = amp.DynamicLossScaler() if args.fp16 else None

    options = orttrainer.ORTTrainerOptions({'batch': {
                                                'gradient_accumulation_steps': args.gradient_accumulation_steps},
                                            'device': {'id': str(device)},
                                            'mixed_precision': {
                                                'enabled': args.fp16,
                                                'loss_scaler': loss_scaler},
                                            'graph_transformer': {
                                                'attn_dropout_recompute': args.attn_dropout_recompute,
                                                'gelu_recompute': args.gelu_recompute,
                                                'transformer_layer_recompute': args.transformer_layer_recompute,
                                            },
                                            'debug': {'deterministic_compute': True, },
                                            'utils': {
                                                'grad_norm_clip': True},
                                            'distributed': {
                                                'world_rank': max(0, args.local_rank),
                                                'world_size': args.world_size,
                                                'local_rank': max(0, args.local_rank),
                                                'allreduce_post_accumulation': args.allreduce_post_accumulation,
                                                'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage},
                                                'enable_adasum': False},
                                            'lr_scheduler': lr_scheduler
                                            })

    param_optimizer = list(model.named_parameters())
    no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
    params = [{
        'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, {
        'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}]

    optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
    model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options)

    return model
Ejemplo n.º 26
0
def create_initialized_orttrainer(device, trainer_opts, use_lamb=True):
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    learning_rate = 1e-10
    optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    _train(trainer, train_data, batcher_fn)

    return trainer
def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    sd = checkpoint.experimental_state_dict(trainer)

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd

    ## Modify one of the state values and load into ORTTrainer
    sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    checkpoint.experimental_load_state_dict(trainer, sd)

    ## Save a checkpoint
    ckpt_dir = _test_helpers._get_name("ort_ckpt")
    checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test')
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts)
    checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test')
    loaded_sd = checkpoint.experimental_state_dict(trainer2)

    # Assert whether original state and the one loaded from checkpoint matches
    for k,v in loaded_sd.items():
        assert torch.all(torch.eq(v, sd[k]))
def _run_adasum_tests(opts):
    # Common setup
    seed = 42
    optim_config = optim.LambConfig()
    # Setup ORTTRainer
    torch.manual_seed(seed)
    set_seed(seed)
    model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        opts.device.id)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=my_loss,
                                    options=opts)
    # Train once to see flag going through
    data, targets = batcher_fn(train_data, 0)
    result = trainer.train_step(data, targets)
    assert result is not None
Ejemplo n.º 29
0
def create_orttrainer_and_load_checkpoint_bart(device,
                                               trainer_opts,
                                               checkpoint_dir,
                                               use_lamb=True,
                                               seed=1,
                                               learning_rate=0.1):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    # model setup
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint_file_name = "checkpoint*.ortcp"
    checkpoint_files = glob.glob(
        os.path.join(checkpoint_dir, checkpoint_file_name))
    trainer.load_checkpoint(*checkpoint_files)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    expected_state_dict = None
    fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl")
    if os.path.isfile(fname):
        with open(fname, "rb") as f:
            expected_state_dict = pickle.load(f)

    return trainer.state_dict(), expected_state_dict, model
Ejemplo n.º 30
0
def create_orttrainer_and_save_checkpoint(device,
                                          trainer_opts,
                                          checkpoint_dir,
                                          state_dict_key_name="state_dict",
                                          use_lamb=True,
                                          seed=1,
                                          learning_rate=0.1):
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=ort_trainer_opts)

    if "distributed" in trainer_opts:
        train_data = next(
            islice(
                _chunkify(train_data,
                          trainer_opts["distributed"]["world_size"]),
                trainer_opts["distributed"]["world_rank"],
                None,
            ))

    # run train steps
    _train(trainer, train_data, batcher_fn)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)