def testToyBertCheckpointFrozenWeights():
    # Common setup
    seed = 1
    total_steps = 10
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'utils': {
            'frozen_weights':
            ['bert.encoder.layer.0.attention.self.value.weight']
        }
    })

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    optim_config = optim.LambConfig()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    # Train for a few steps
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, seed)
        _ = trainer.train_step(*sample_input)
    sample_input = generate_random_input_from_model_desc(
        model_desc, seed + total_steps + 1)
    # Evaluate once to get a base loss
    loss = trainer.eval_step(*sample_input)
    # Save checkpoint
    state_dict = checkpoint.experimental_state_dict(trainer)

    # Load previous state into another instance of ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    optim_config2 = optim.LambConfig()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config2,
                                     options=opts)
    checkpoint.experimental_load_state_dict(trainer2, state_dict)
    # Evaluate once to get a base loss
    ckpt_loss = trainer2.eval_step(*sample_input)

    # Must match as both trainers have the same dict state
    assert_allclose(loss.cpu(), ckpt_loss.cpu())
    loaded_state_dict = checkpoint.experimental_state_dict(trainer2)
    assert state_dict.keys() == loaded_state_dict.keys()
예제 #2
0
def testORTDeterministicCompute(seed, device):
    # Common setup
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device' : {
            'id' : device,
            'mem_limit' : 10*1024*1024
        }
    })

    # Setup for the first ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    data, targets = batcher_fn(train_data, 0)
    _ = first_trainer.train_step(data, targets)
    assert first_trainer._onnx_model is not None

    # Setup for the second ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device)
    second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    _ = second_trainer.train_step(data, targets)
    assert second_trainer._onnx_model is not None

    # Compare two different instances with identical setup
    assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model)
    _test_helpers.assert_onnx_weights(first_trainer, second_trainer)
def testToyBERTModelGradientAccumulationLegacyExperimental(
        gradient_accumulation_steps):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'batch': {
            'gradient_accumulation_steps': gradient_accumulation_steps
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        loss = trainer.train_step(*sample_input)
        experimental_losses.append(loss.cpu().item())

    # LEGACY IMPLEMENTATION
    device = torch.device(device)
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "LambOptimizer",
        None,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
        gradient_accumulation_steps=gradient_accumulation_steps)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses,
                                       legacy_losses,
                                       rtol=1e-6)
예제 #4
0
def testORTTrainerFrozenWeights(model_params):
    # Common setup
    device = 'cuda'
    total_steps = 10

    # Setup ORTTrainer WITHOUT frozen weights
    options = orttrainer.ORTTrainerOptions({})
    model, model_desc, my_loss, batcher_fn,\
        train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _, _ = trainer.train_step(data, targets)

    # All model_params must be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert all([param in session_state for param in model_params])


    # Setup ORTTrainer WITH frozen weights
    options = orttrainer.ORTTrainerOptions({'utils' : {'frozen_weights' : model_params}})
    model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _, _ = trainer.train_step(data, targets)

    # All model_params CANNOT be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert not all([param in session_state for param in model_params])
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'batch' : {
            'gradient_accumulation_steps' : gradient_accumulation_steps
        },
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6)
def testToyBertCheckpointLoadZero():
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True},
                                         'device' : {'id' : device},
                                         'distributed' : {'allreduce_post_accumulation' : True}})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    ckpt_dir = _test_helpers._get_name("ort_ckpt")
    checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb')

    # Expected values
    expected_eval_loss = [10.997552871]
    input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device)
    segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device)
    input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device)
    masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device)
    next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

    # Actual values
    actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels)
    actual_eval_loss = actual_eval_loss.cpu().numpy().item(0)

    # Check results
    assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses):
    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-4)
def testToyBERTDeterministicCheck(expected_losses):
    # Common setup
    train_steps = 10
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    params = optimizer_parameters(model)
    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # Check output
    _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=1e-6)
예제 #9
0
def testOptimizerConfigInvalidParams(optim_name):
    # lr is not supported within params
    with pytest.raises(AssertionError) as e:
        params = [{'params': ['layer1.weight'], 'lr': 0.1}]
        if optim_name == 'Adam':
            optim.AdamConfig(params=params, lr=0.2)
        elif optim_name == 'Lamb':
            optim.LambConfig(params=params, lr=0.2)
        else:
            raise ValueError('invalid input')
    assert str(e.value) == "'lr' is not supported inside params"
def testToyBERTModelSimpleTrainStep(dynamic_shape):
    model_desc = bert_model_description(dynamic_shape)
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts =  orttrainer.ORTTrainerOptions({})
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    for i in range(10):
        sample_input = generate_random_input_from_model_desc(model_desc)
        output = trainer.train_step(*sample_input)
        assert output.shape == torch.Size([])
예제 #11
0
def testOptimizerConfigParams(optim_name):
    rtol = 1e-5
    params = [{'params': ['layer1.weight'], 'alpha': 0.1}]
    if optim_name == 'Adam':
        cfg = optim.AdamConfig(params=params, alpha=0.2)
    elif optim_name == 'Lamb':
        cfg = optim.LambConfig(params=params, alpha=0.2)
    else:
        raise ValueError('invalid input')
    assert len(cfg.params) == 1, "params should have length 1"
    assert_allclose(cfg.params[0]['alpha'], 0.1,
                    rtol=rtol, err_msg="invalid lr on params[0]")
def testORTTrainerFrozenWeights(model_params):
    device = 'cuda'
    total_steps = 10
    seed = 1

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    # Setup ORTTrainer WITHOUT frozen weights
    opts_dict = {
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    }
    opts = orttrainer.ORTTrainerOptions(opts_dict)

    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        trainer.train_step(*sample_input)

    # All model_params must be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert all([param in session_state for param in model_params])

    # Setup ORTTrainer WITH frozen weights
    opts_dict.update({'utils': {'frozen_weights': model_params}})
    opts = orttrainer.ORTTrainerOptions(opts_dict)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        trainer.train_step(*sample_input)

    # All model_params CANNOT be in the session state
    assert trainer._onnx_model is not None
    session_state = trainer._training_session.get_state()
    assert not any([param in session_state for param in model_params])
def testToyBERTModelLegacyExperimentalBasicTraining():
    # Common setup
    train_steps = 10
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    params = optimizer_parameters(model)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        lr=0.001)
    legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc,
                                       "LambOptimizer", None,
                                       learning_rate_description, device)
    legacy_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses,
                                       legacy_losses,
                                       True,
                                       rtol=1e-5)
예제 #14
0
def testOptimizerConfigLamb():
    '''Test initialization of Lamb'''
    cfg = optim.LambConfig()
    assert cfg.name == 'LambOptimizer'
    rtol = 1e-05
    assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch")
    assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch")
    assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch")
    assert_allclose(0.0, cfg.lambda_coef, rtol=rtol,
                    err_msg="lambda_coef mismatch")
    assert cfg.ratio_min == float('-inf'), "ratio_min mismatch"
    assert cfg.ratio_max == float('inf'), "ratio_max mismatch"
    assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch")
    assert cfg.do_bias_correction == False, "do_bias_correction mismatch"
def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {'debug': {
            'deterministic_compute': True
        }})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    sd = checkpoint.experimental_state_dict(trainer)

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd

    ## Modify one of the state values and load into ORTTrainer
    sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    checkpoint.experimental_load_state_dict(trainer, sd)

    ## Save a checkpoint
    ckpt_dir = _test_helpers._get_name("ort_ckpt")
    checkpoint.experimental_save_checkpoint(trainer, ckpt_dir,
                                            'bert_toy_save_test')
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config,
                                     options=opts)
    checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir,
                                            'bert_toy_save_test')
    loaded_sd = checkpoint.experimental_state_dict(trainer2)

    # Assert whether original state and the one loaded from checkpoint matches
    for k, v in loaded_sd.items():
        assert torch.all(torch.eq(v, sd[k]))
def testToyBERTSaveAsONNX():
    device = 'cuda'
    onnx_file_name = '_____temp_toy_bert_onnx_model.onnx'
    if os.path.exists(onnx_file_name):
        os.remove(onnx_file_name)
    assert not os.path.exists(onnx_file_name)

    # Load trainer
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })

    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    trainer.save_as_onnx(onnx_file_name)
    assert os.path.exists(onnx_file_name)

    with open(onnx_file_name, "rb") as f:
        bin_str = f.read()
        reload_onnx_model = onnx.load_model_from_string(bin_str)
    os.remove(onnx_file_name)

    # Create a new trainer from persisted ONNX model and compare with original ONNX model
    trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model,
                                              model_desc,
                                              optim_config,
                                              options=opts)
    assert trainer_from_onnx._onnx_model is not None
    assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
    for initializer, loaded_initializer in zip(
            trainer._onnx_model.graph.initializer,
            trainer_from_onnx._onnx_model.graph.initializer):
        assert initializer.name == loaded_initializer.name
    assert (onnx.helper.printable_graph(
        trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(
            trainer._onnx_model.graph))
    _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
예제 #17
0
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
    # Common data
    total_steps = 5
    bptt=35

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    loss_scaler = amp.DynamicLossScaler()
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'mixed_precision' : {
                                                'enabled' : True,
                                                'loss_scaler' : loss_scaler},
                                            'debug' : {'deterministic_compute' : True,}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    experimental_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())
        experimental_preds_dtype.append(exp_preds.dtype)

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=loss_scaler)
    # Training loop
    legacy_loss = []
    legacy_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())
        legacy_preds_dtype.append(leg_preds.dtype)

    # Compare legacy vs experimental APIs
    assert experimental_preds_dtype == legacy_preds_dtype
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2)
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
예제 #18
0
def testTrainStepInfoInvalidInput(invalid_input):
    '''Test invalid initialization of TrainStepInfo'''
    optimizer_config = optim.LambConfig()
    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, step=invalid_input)
예제 #19
0
def testORTTrainerDynamicShape(dynamic_axes):
    # Common setup
    device = 'cuda'

    # Setup ORTTrainer
    options = orttrainer.ORTTrainerOptions({})
    model, model_desc, my_loss, batcher_fn,\
        train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

    # Training loop
    total_steps = 10
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _, _ = trainer.train_step(data, targets)

    assert trainer._onnx_model is not None
예제 #20
0
def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss):
    torch.manual_seed(seed)
    set_seed(seed)

    # Setup ORTTrainer
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
                                            'debug' : {'deterministic_compute' : True}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

    # Training loop
    actual_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        loss, _ = trainer.train_step(data, targets)
        actual_loss.append(loss.cpu())

    # Compare legacy vs experimental APIs
    _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=1e-6)
예제 #21
0
def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches):
    total_steps = len(expected_loss)
    torch.manual_seed(seed)
    set_seed(seed)
    bptt=35

    # Setup ORTTrainer
    loss_scaler = amp.DynamicLossScaler()
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'mixed_precision' : {
                                                'enabled' : True,
                                                'loss_scaler' : loss_scaler},
                                            'debug' : {'deterministic_compute' : True}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

    # Training loop
    actual_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        if fetches:
            trainer._train_step_info.fetches=['loss']
            loss = trainer.train_step(data, targets)
        else:
            loss, _ = trainer.train_step(data, targets)
        actual_loss.append(loss.cpu())

    # Eval once just to test fetches in action
    val_data, val_targets = batcher_fn(val_data, 0)
    if fetches:
        trainer._train_step_info.fetches=['loss']
        loss = trainer.eval_step(val_data, val_targets)
        trainer._train_step_info.fetches=[]
    loss, preds = trainer.eval_step(val_data, val_targets)

    # Compare loss to ground truth computed from current ORTTrainer API
    _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4)
    assert trainer._onnx_model is not None
예제 #22
0
def testTrainStepInfo():
    '''Test valid initializations of TrainStepInfo'''

    optimizer_config = optim.LambConfig()
    fetches=['out1','out2']
    step_info = orttrainer.TrainStepInfo(optimizer_config=optimizer_config,
                                         all_finite=False,
                                         fetches=fetches,
                                         optimization_step=123,
                                         step=456)
    assert step_info.optimizer_config == optimizer_config
    assert step_info.all_finite == False
    assert step_info.fetches == fetches
    assert step_info.optimization_step == 123
    assert step_info.step == 456

    step_info = orttrainer.TrainStepInfo(optimizer_config)
    assert step_info.optimizer_config == optimizer_config
    assert step_info.all_finite == True
    assert step_info.fetches == []
    assert step_info.optimization_step == 0
    assert step_info.step == 0
예제 #23
0
def testORTTrainerToyBERTModel():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # Modeling
    pytorch_transformer_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata')
    bert_onnx_model_path = os.path.join(pytorch_transformer_path, "bert_toy_postprocessed.onnx")
    model = onnx.load(bert_onnx_model_path)
    model_desc = bert_model_description()
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}})
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Generating fake input
    sample_input = generate_random_input_from_model_desc(model_desc)

    # Train
    output = trainer.train_step(*sample_input)

    # Check output
    assert output.shape == torch.Size([])
예제 #24
0
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
    # Common data
    total_steps = 5
    bptt = 35

    # Setup for the experimental ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'device' : {
            'id' : device
        },
        'debug' : {
            'deterministic_compute': True
        },
    })
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    # Training loop
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _ = trainer.train_step(data, targets)

    # Setup for the legacy ORTTrainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc,
                                       device, _use_deterministic_compute=True)
    # Training loop
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))

    # Compare legacy vs experimental APIs
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
예제 #25
0
def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps):
    # Common data
    torch.set_printoptions(precision=10)

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
                                            'debug' : {'deterministic_compute' : True}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       gradient_accumulation_steps=gradient_accumulation_steps)
    # Training loop
    legacy_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())

    # Compare legacy vs experimental APIs
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(
        params, legacy_optim_map):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig(params,
                                    alpha=0.9,
                                    beta=0.999,
                                    lambda_coef=0.01,
                                    epsilon=1e-6,
                                    do_bias_correction=False)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })

    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    device = torch.device(device)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        trainer.optim_config.lr)
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "LambOptimizer",
                                       legacy_optim_map,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        legacy_sample_input = [*sample_input, learning_rate]
        legacy_losses.append(
            legacy_trainer.train_step(legacy_sample_input).cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
예제 #27
0
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device):
    total_steps = 1
    initial_lr = 1.
    tolerance = 1e-4

    # PyTorch Transformer model as example
    opts = {'device' : {'id' : device}}
    if lr_scheduler:
        total_steps = 10
        opts.update({'lr_scheduler' : lr_scheduler(total_steps=total_steps, warmup=0.5)})
    opts = orttrainer.ORTTrainerOptions(opts)
    optim_config = optim.LambConfig(lr=initial_lr)
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)

    # Run a train or evaluation step
    if step_fn == 'eval_step':
        data, targets = batcher_fn(val_data, 0)
    elif step_fn == 'train_step':
        data, targets = batcher_fn(train_data, 0)
    else:
        raise ValueError('Invalid step_fn')

    # Export model to ONNX
    if step_fn == 'eval_step':
        step_fn = trainer.eval_step
        output = trainer.eval_step(data, targets)
    elif step_fn == 'train_step':
        step_fn = trainer.train_step
        for i in range(total_steps):
            output = trainer.train_step(data, targets)
            if lr_scheduler:
                lr_list = trainer.options.lr_scheduler.get_last_lr()
                assert_allclose(lr_list[0], expected_lr_values[i], rtol=tolerance, err_msg="lr mismatch")
    else:
        raise ValueError('Invalid step_fn')
    assert trainer._onnx_model is not None

    # Check output shape after train/eval step
    for out, desc in zip(output, trainer.model_desc.outputs):
        if trainer.loss_fn and desc.is_loss:
            continue
        assert list(out.size()) == desc.shape

    # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs
    sig = inspect.signature(model.forward)
    for i in range(len(sig.parameters.keys())):
        input_name = trainer.model_desc.inputs[i][0]
        input_dim = trainer.model_desc.inputs[i][1]
        input_type = trainer.model_desc.inputs[i][2]

        assert trainer._onnx_model.graph.input[i].name == input_name
        for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim):
            assert input_dim[dim_idx] == dim.dim_value
            assert input_type == _utils.dtype_onnx_to_torch(
                trainer._onnx_model.graph.input[i].type.tensor_type.elem_type)

    # Check name, shape and dtype of the ORT graph outputs
    for i in range(len(trainer.model_desc.outputs)):
        output_name = trainer.model_desc.outputs[i][0]
        output_dim = trainer.model_desc.outputs[i][1]
        output_type = trainer.model_desc.outputs[i][3]

        assert trainer._onnx_model.graph.output[i].name == output_name
        for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim):
            assert output_dim[dim_idx] == dim.dim_value
            assert output_type == _utils.dtype_onnx_to_torch(
                trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)

    # Save current model as ONNX as a file
    file_name = os.path.join('..','..','..','temp_onnx_model.onnx')
    trainer.save_as_onnx(file_name)
    assert os.path.exists(file_name)
    with open(file_name, "rb") as f:
        bin_str = f.read()
        reload_onnx_model = onnx.load_model_from_string(bin_str)
    os.remove(file_name)

    # Create a new trainer from persisted ONNX model and compare with original ONNX model
    trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config)
    step_fn(data, targets)
    assert trainer_from_onnx._onnx_model is not None
    assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
    assert (trainer_from_onnx._onnx_model == trainer._onnx_model)
    assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph)
    assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))
예제 #28
0
def testDynamicLossScaler():
    rtol = 1e-5
    default_scaler = amp.loss_scaler.DynamicLossScaler()

    # Initial state
    train_step_info = orttrainer.TrainStepInfo(optim.LambConfig())
    assert_allclose(default_scaler.loss_scale, float(1 << 16),
                    rtol=rtol, err_msg="loss scale mismatch")
    assert default_scaler.up_scale_window == 2000
    assert_allclose(default_scaler.min_loss_scale, 1.0,
                    rtol=rtol, err_msg="min loss scale mismatch")
    assert_allclose(default_scaler.max_loss_scale, float(
        1 << 24), rtol=rtol, err_msg="max loss scale mismatch")

    # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True)
    loss_scale = float(1 << 16)
    for cycles in range(1, 10):

        # 1999 updates without overflow produces 1999 stable steps
        for i in range(1, 2000):
            new_loss_scale = default_scaler.update(train_step_info)
            assert default_scaler._stable_steps_count == i
            assert_allclose(new_loss_scale, loss_scale,
                            rtol=rtol, err_msg=f"loss scale mismatch at update {i}")

        # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached
        new_loss_scale = default_scaler.update(train_step_info)
        if cycles <= 8:
            loss_scale *= 2
        assert default_scaler._stable_steps_count == 0
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")

    # After 8 cycles, loss scale should be float(1 << 16)*(2**8)
    assert_allclose(new_loss_scale, float(1 << 16)
                    * (2**8), rtol=rtol, err_msg="loss scale mismatch")

    # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on
    loss_scale = float(1 << 16)*(2**8)
    for count in range(1, 2050):
        new_loss_scale = default_scaler.update(train_step_info)
        assert default_scaler._stable_steps_count == (count % 2000)
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")

    # Setting train_step_info.all_finite = False to test down scaling
    train_step_info.all_finite = False

    # Performing 24 updates to half the loss scale each time
    loss_scale = float(1 << 16)*(2**8)
    for count in range(1, 25):
        new_loss_scale = default_scaler.update(train_step_info)
        loss_scale /= 2
        assert default_scaler._stable_steps_count == 0
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")

    # After 24 updates with gradient overflow, loss scale is 1.0
    assert_allclose(new_loss_scale, 1.,
                    rtol=rtol, err_msg="loss scale mismatch")

    # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on
    for count in range(1, 5):
        new_loss_scale = default_scaler.update(train_step_info)
        assert default_scaler._stable_steps_count == 0
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")
예제 #29
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=12)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

            print(outputs[0])

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)