def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): # Common data total_steps = 5 bptt=35 # Setup experimental API torch.manual_seed(seed) set_seed(seed) loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True,}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] experimental_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) experimental_preds_dtype.append(exp_preds.dtype) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=loss_scaler) # Training loop legacy_loss = [] legacy_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) legacy_preds_dtype.append(leg_preds.dtype) # Compare legacy vs experimental APIs assert experimental_preds_dtype == legacy_preds_dtype _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2) _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
device, _use_deterministic_compute=True, get_lr_this_step=legacy_lr_scheduler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) @pytest.mark.parametrize( "loss_scaler, legacy_loss_scaler", [(None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), (CustomLossScaler(), LegacyCustomLossScaler())]) def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental( loss_scaler, legacy_loss_scaler): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model()
learning_rate_description, device, _use_deterministic_compute=True, get_lr_this_step=legacy_lr_scheduler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) @pytest.mark.parametrize("loss_scaler, legacy_loss_scaler", [ (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), (CustomLossScaler(), LegacyCustomLossScaler()) ]) def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler): # Common setup total_steps = 10 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig()