def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental( loss_scaler, legacy_loss_scaler): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=0.001) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=legacy_loss_scaler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelGradientAccumulationLegacyExperimental( gradient_accumulation_steps): # Common setup total_steps = 10 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) loss = trainer.train_step(*sample_input) experimental_losses.append(loss.cpu().item()) # LEGACY IMPLEMENTATION device = torch.device(device) torch.manual_seed(seed) onnxruntime.set_seed(seed) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "LambOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, rtol=1e-6)
def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): # Common setup train_steps = 512 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, }) optim_config = optimizer_config(lr=0.01) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) if optimizer_config == optim.AdamConfig: legacy_optimizer = 'AdamOptimizer' elif optimizer_config == optim.LambConfig: legacy_optimizer = 'LambOptimizer' elif optimizer_config == optim.SGDConfig: legacy_optimizer = 'SGDOptimizer' else: raise RuntimeError("Invalid optimizer_config") device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, legacy_optimizer, None, learning_rate_description, device, _use_deterministic_compute=True) legacy_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL API torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig( params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False ) opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "AdamOptimizer", legacy_optim_map, learning_rate_description, device, _use_deterministic_compute=True, ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) legacy_sample_input = [*sample_input, learning_rate] legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelLegacyExperimentalBasicTraining(): # Common setup train_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( lr=0.001) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "LambOptimizer", None, learning_rate_description, device) legacy_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True, rtol=1e-5)
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): # Common data total_steps = 5 bptt=35 # Setup experimental API torch.manual_seed(seed) set_seed(seed) loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True,}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] experimental_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) experimental_preds_dtype.append(exp_preds.dtype) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=loss_scaler) # Training loop legacy_loss = [] legacy_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) legacy_preds_dtype.append(leg_preds.dtype) # Compare legacy vs experimental APIs assert experimental_preds_dtype == legacy_preds_dtype _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2) _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): # Common data torch.set_printoptions(precision=10) # Setup experimental API torch.manual_seed(seed) set_seed(seed) options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) # Training loop legacy_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) # Compare legacy vs experimental APIs _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): # Common data total_steps = 5 bptt = 35 # Setup for the experimental ORTTRainer run torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'device' : { 'id' : device }, 'debug' : { 'deterministic_compute': True }, }) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) _ = trainer.train_step(data, targets) # Setup for the legacy ORTTrainer run torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) # Compare legacy vs experimental APIs _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): ############################################################################ # These tests require hard-coded values for 'total_steps' and 'initial_lr' # ############################################################################ # Common setup total_steps = 128 device = 'cuda' seed = 1 warmup = 0.05 cycles = 0.5 power = 1. lr_end = 1e-7 # Setup both Experimental and Legacy LR Schedulers before the experimental loop if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup) elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles) elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid legacy_lr_scheduler") if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'lr_scheduler': lr_scheduler }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( initial_lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, get_lr_this_step=legacy_lr_scheduler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step): # Common data total_steps = 10 lr = 0.001 warmup = 0.5 cycles = 0.5 power = 1. lr_end = 1e-7 torch.set_printoptions(precision=10) # Setup experimental API torch.manual_seed(seed) set_seed(seed) if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'debug' : {'deterministic_compute' : True}, 'lr_scheduler' : lr_scheduler}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optimizer_config(lr=lr) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) # Setup legacy API torch.manual_seed(seed) set_seed(seed) if optimizer_config == optim.AdamConfig: legacy_optimizer_config = 'AdamOptimizer' elif optimizer_config == optim.LambConfig: legacy_optimizer_config = 'LambOptimizer' elif optimizer_config == optim.SGDConfig: legacy_optimizer_config = 'SGDOptimizer' else: raise RuntimeError("Invalid optimizer_config") if get_lr_this_step == _test_commons.legacy_constant_lr_scheduler or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler: get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup) elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler: get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles) elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler: get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid get_lr_this_step") model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, legacy_optimizer_config, None, lr_desc, device=device, _use_deterministic_compute=True, get_lr_this_step=get_lr_this_step) # Training loop legacy_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets) legacy_loss.append(leg_loss.cpu()) # Compare legacy vs experimental APIs _test_helpers.assert_model_outputs(legacy_loss, experimental_loss)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): ############################################################################ # These tests require hard-coded values for 'total_steps' and 'initial_lr' # ############################################################################ # Common setup total_steps = 10 device = 'cuda' seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'lr_scheduler': lr_scheduler(total_steps=total_steps, warmup=0.5) }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( initial_lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, get_lr_this_step=legacy_lr_scheduler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)