Example #1
0
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
    # Common data
    total_steps = 5
    bptt=35

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    loss_scaler = amp.DynamicLossScaler()
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'mixed_precision' : {
                                                'enabled' : True,
                                                'loss_scaler' : loss_scaler},
                                            'debug' : {'deterministic_compute' : True,}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    experimental_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())
        experimental_preds_dtype.append(exp_preds.dtype)

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=loss_scaler)
    # Training loop
    legacy_loss = []
    legacy_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())
        legacy_preds_dtype.append(leg_preds.dtype)

    # Compare legacy vs experimental APIs
    assert experimental_preds_dtype == legacy_preds_dtype
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2)
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
                                       device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)


@pytest.mark.parametrize(
    "loss_scaler, legacy_loss_scaler",
    [(None, Legacy_LossScaler("ort_test_input_loss_scaler", True)),
     (amp.DynamicLossScaler(),
      Legacy_LossScaler("ort_test_input_loss_scaler", True)),
     (CustomLossScaler(), LegacyCustomLossScaler())])
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(
        loss_scaler, legacy_loss_scaler):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
                       learning_rate_description,
                       device,
                       _use_deterministic_compute=True,
                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)


@pytest.mark.parametrize("loss_scaler, legacy_loss_scaler", [
    (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)),
    (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)),
    (CustomLossScaler(), LegacyCustomLossScaler())
])
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()