def testToyBertCheckpointFrozenWeights(): # Common setup seed = 1 total_steps = 10 torch.manual_seed(seed) onnxruntime.set_seed(seed) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'utils': { 'frozen_weights': ['bert.encoder.layer.0.attention.self.value.weight'] } }) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() optim_config = optim.LambConfig() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train for a few steps for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, seed) _ = trainer.train_step(*sample_input) sample_input = generate_random_input_from_model_desc( model_desc, seed + total_steps + 1) # Evaluate once to get a base loss loss = trainer.eval_step(*sample_input) # Save checkpoint state_dict = checkpoint.experimental_state_dict(trainer) # Load previous state into another instance of ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() optim_config2 = optim.LambConfig() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) checkpoint.experimental_load_state_dict(trainer2, state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = checkpoint.experimental_state_dict(trainer2) assert state_dict.keys() == loaded_state_dict.keys()
def testORTDeterministicCompute(seed, device): # Common setup optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device' : { 'id' : device, 'mem_limit' : 10*1024*1024 } }) # Setup for the first ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) data, targets = batcher_fn(train_data, 0) _ = first_trainer.train_step(data, targets) assert first_trainer._onnx_model is not None # Setup for the second ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device) second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) _ = second_trainer.train_step(data, targets) assert second_trainer._onnx_model is not None # Compare two different instances with identical setup assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) _test_helpers.assert_onnx_weights(first_trainer, second_trainer)
def testToyBERTModelGradientAccumulationLegacyExperimental( gradient_accumulation_steps): # Common setup total_steps = 10 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) loss = trainer.train_step(*sample_input) experimental_losses.append(loss.cpu().item()) # LEGACY IMPLEMENTATION device = torch.device(device) torch.manual_seed(seed) onnxruntime.set_seed(seed) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "LambOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, rtol=1e-6)
def testORTTrainerFrozenWeights(model_params): # Common setup device = 'cuda' total_steps = 10 # Setup ORTTrainer WITHOUT frozen weights options = orttrainer.ORTTrainerOptions({}) model, model_desc, my_loss, batcher_fn,\ train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = trainer.train_step(data, targets) # All model_params must be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights options = orttrainer.ORTTrainerOptions({'utils' : {'frozen_weights' : model_params}}) model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = trainer.train_step(data, targets) # All model_params CANNOT be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert not all([param in session_state for param in model_params])
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): # Common setup total_steps = 10 device = "cuda" seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6)
def testToyBertCheckpointLoadZero(): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, 'device' : {'id' : device}, 'distributed' : {'allreduce_post_accumulation' : True}}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb') # Expected values expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device) segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_eval_loss = actual_eval_loss.cpu().numpy().item(0) # Check results assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-4)
def testToyBERTDeterministicCheck(expected_losses): # Common setup train_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=1e-6)
def testOptimizerConfigInvalidParams(optim_name): # lr is not supported within params with pytest.raises(AssertionError) as e: params = [{'params': ['layer1.weight'], 'lr': 0.1}] if optim_name == 'Adam': optim.AdamConfig(params=params, lr=0.2) elif optim_name == 'Lamb': optim.LambConfig(params=params, lr=0.2) else: raise ValueError('invalid input') assert str(e.value) == "'lr' is not supported inside params"
def testToyBERTModelSimpleTrainStep(dynamic_shape): model_desc = bert_model_description(dynamic_shape) model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(10): sample_input = generate_random_input_from_model_desc(model_desc) output = trainer.train_step(*sample_input) assert output.shape == torch.Size([])
def testOptimizerConfigParams(optim_name): rtol = 1e-5 params = [{'params': ['layer1.weight'], 'alpha': 0.1}] if optim_name == 'Adam': cfg = optim.AdamConfig(params=params, alpha=0.2) elif optim_name == 'Lamb': cfg = optim.LambConfig(params=params, alpha=0.2) else: raise ValueError('invalid input') assert len(cfg.params) == 1, "params should have length 1" assert_allclose(cfg.params[0]['alpha'], 0.1, rtol=rtol, err_msg="invalid lr on params[0]")
def testORTTrainerFrozenWeights(model_params): device = 'cuda' total_steps = 10 seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() # Setup ORTTrainer WITHOUT frozen weights opts_dict = { 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, } opts = orttrainer.ORTTrainerOptions(opts_dict) torch.manual_seed(seed) onnxruntime.set_seed(seed) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params must be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights opts_dict.update({'utils': {'frozen_weights': model_params}}) opts = orttrainer.ORTTrainerOptions(opts_dict) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params CANNOT be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert not any([param in session_state for param in model_params])
def testToyBERTModelLegacyExperimentalBasicTraining(): # Common setup train_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( lr=0.001) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "LambOptimizer", None, learning_rate_description, device) legacy_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True, rtol=1e-5)
def testOptimizerConfigLamb(): '''Test initialization of Lamb''' cfg = optim.LambConfig() assert cfg.name == 'LambOptimizer' rtol = 1e-05 assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") assert cfg.ratio_min == float('-inf'), "ratio_min mismatch" assert cfg.ratio_max == float('inf'), "ratio_max mismatch" assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") assert cfg.do_bias_correction == False, "do_bias_correction mismatch"
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = checkpoint.experimental_state_dict(trainer) ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd ## Modify one of the state values and load into ORTTrainer sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 checkpoint.experimental_load_state_dict(trainer, sd) ## Save a checkpoint ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test') del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test') loaded_sd = checkpoint.experimental_state_dict(trainer2) # Assert whether original state and the one loaded from checkpoint matches for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))
def testToyBERTSaveAsONNX(): device = 'cuda' onnx_file_name = '_____temp_toy_bert_onnx_model.onnx' if os.path.exists(onnx_file_name): os.remove(onnx_file_name) assert not os.path.exists(onnx_file_name) # Load trainer model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.save_as_onnx(onnx_file_name) assert os.path.exists(onnx_file_name) with open(onnx_file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(onnx_file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) for initializer, loaded_initializer in zip( trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): assert initializer.name == loaded_initializer.name assert (onnx.helper.printable_graph( trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( trainer._onnx_model.graph)) _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): # Common data total_steps = 5 bptt=35 # Setup experimental API torch.manual_seed(seed) set_seed(seed) loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True,}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] experimental_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) experimental_preds_dtype.append(exp_preds.dtype) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=loss_scaler) # Training loop legacy_loss = [] legacy_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) legacy_preds_dtype.append(leg_preds.dtype) # Compare legacy vs experimental APIs assert experimental_preds_dtype == legacy_preds_dtype _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2) _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
def testTrainStepInfoInvalidInput(invalid_input): '''Test invalid initialization of TrainStepInfo''' optimizer_config = optim.LambConfig() with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, step=invalid_input)
def testORTTrainerDynamicShape(dynamic_axes): # Common setup device = 'cuda' # Setup ORTTrainer options = orttrainer.ORTTrainerOptions({}) model, model_desc, my_loss, batcher_fn,\ train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop total_steps = 10 for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = trainer.train_step(data, targets) assert trainer._onnx_model is not None
def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): torch.manual_seed(seed) set_seed(seed) # Setup ORTTrainer options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop actual_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) loss, _ = trainer.train_step(data, targets) actual_loss.append(loss.cpu()) # Compare legacy vs experimental APIs _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=1e-6)
def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): total_steps = len(expected_loss) torch.manual_seed(seed) set_seed(seed) bptt=35 # Setup ORTTrainer loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop actual_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) if fetches: trainer._train_step_info.fetches=['loss'] loss = trainer.train_step(data, targets) else: loss, _ = trainer.train_step(data, targets) actual_loss.append(loss.cpu()) # Eval once just to test fetches in action val_data, val_targets = batcher_fn(val_data, 0) if fetches: trainer._train_step_info.fetches=['loss'] loss = trainer.eval_step(val_data, val_targets) trainer._train_step_info.fetches=[] loss, preds = trainer.eval_step(val_data, val_targets) # Compare loss to ground truth computed from current ORTTrainer API _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4) assert trainer._onnx_model is not None
def testTrainStepInfo(): '''Test valid initializations of TrainStepInfo''' optimizer_config = optim.LambConfig() fetches=['out1','out2'] step_info = orttrainer.TrainStepInfo(optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456) assert step_info.optimizer_config == optimizer_config assert step_info.all_finite == False assert step_info.fetches == fetches assert step_info.optimization_step == 123 assert step_info.step == 456 step_info = orttrainer.TrainStepInfo(optimizer_config) assert step_info.optimizer_config == optimizer_config assert step_info.all_finite == True assert step_info.fetches == [] assert step_info.optimization_step == 0 assert step_info.step == 0
def testORTTrainerToyBERTModel(): # Common setup seed = 1 torch.manual_seed(seed) set_seed(seed) # Modeling pytorch_transformer_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata') bert_onnx_model_path = os.path.join(pytorch_transformer_path, "bert_toy_postprocessed.onnx") model = onnx.load(bert_onnx_model_path) model_desc = bert_model_description() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Generating fake input sample_input = generate_random_input_from_model_desc(model_desc) # Train output = trainer.train_step(*sample_input) # Check output assert output.shape == torch.Size([])
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): # Common data total_steps = 5 bptt = 35 # Setup for the experimental ORTTRainer run torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'device' : { 'id' : device }, 'debug' : { 'deterministic_compute': True }, }) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) _ = trainer.train_step(data, targets) # Setup for the legacy ORTTrainer run torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) # Compare legacy vs experimental APIs _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): # Common data torch.set_printoptions(precision=10) # Setup experimental API torch.manual_seed(seed) set_seed(seed) options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) # Training loop legacy_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) # Compare legacy vs experimental APIs _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
def testToyBERTModelLegacyExperimentalCustomOptimParameters( params, legacy_optim_map): # Common setup total_steps = 10 device = "cuda" seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig(params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) torch.manual_seed(seed) onnxruntime.set_seed(seed) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( trainer.optim_config.lr) torch.manual_seed(seed) onnxruntime.set_seed(seed) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "LambOptimizer", legacy_optim_map, learning_rate_description, device, _use_deterministic_compute=True) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) legacy_sample_input = [*sample_input, learning_rate] legacy_losses.append( legacy_trainer.train_step(legacy_sample_input).cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): total_steps = 1 initial_lr = 1. tolerance = 1e-4 # PyTorch Transformer model as example opts = {'device' : {'id' : device}} if lr_scheduler: total_steps = 10 opts.update({'lr_scheduler' : lr_scheduler(total_steps=total_steps, warmup=0.5)}) opts = orttrainer.ORTTrainerOptions(opts) optim_config = optim.LambConfig(lr=initial_lr) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Run a train or evaluation step if step_fn == 'eval_step': data, targets = batcher_fn(val_data, 0) elif step_fn == 'train_step': data, targets = batcher_fn(train_data, 0) else: raise ValueError('Invalid step_fn') # Export model to ONNX if step_fn == 'eval_step': step_fn = trainer.eval_step output = trainer.eval_step(data, targets) elif step_fn == 'train_step': step_fn = trainer.train_step for i in range(total_steps): output = trainer.train_step(data, targets) if lr_scheduler: lr_list = trainer.options.lr_scheduler.get_last_lr() assert_allclose(lr_list[0], expected_lr_values[i], rtol=tolerance, err_msg="lr mismatch") else: raise ValueError('Invalid step_fn') assert trainer._onnx_model is not None # Check output shape after train/eval step for out, desc in zip(output, trainer.model_desc.outputs): if trainer.loss_fn and desc.is_loss: continue assert list(out.size()) == desc.shape # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs sig = inspect.signature(model.forward) for i in range(len(sig.parameters.keys())): input_name = trainer.model_desc.inputs[i][0] input_dim = trainer.model_desc.inputs[i][1] input_type = trainer.model_desc.inputs[i][2] assert trainer._onnx_model.graph.input[i].name == input_name for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): assert input_dim[dim_idx] == dim.dim_value assert input_type == _utils.dtype_onnx_to_torch( trainer._onnx_model.graph.input[i].type.tensor_type.elem_type) # Check name, shape and dtype of the ORT graph outputs for i in range(len(trainer.model_desc.outputs)): output_name = trainer.model_desc.outputs[i][0] output_dim = trainer.model_desc.outputs[i][1] output_type = trainer.model_desc.outputs[i][3] assert trainer._onnx_model.graph.output[i].name == output_name for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): assert output_dim[dim_idx] == dim.dim_value assert output_type == _utils.dtype_onnx_to_torch( trainer._onnx_model.graph.output[i].type.tensor_type.elem_type) # Save current model as ONNX as a file file_name = os.path.join('..','..','..','temp_onnx_model.onnx') trainer.save_as_onnx(file_name) assert os.path.exists(file_name) with open(file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) step_fn(data, targets) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) assert (trainer_from_onnx._onnx_model == trainer._onnx_model) assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph) assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))
def testDynamicLossScaler(): rtol = 1e-5 default_scaler = amp.loss_scaler.DynamicLossScaler() # Initial state train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") assert default_scaler.up_scale_window == 2000 assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") assert_allclose(default_scaler.max_loss_scale, float( 1 << 24), rtol=rtol, err_msg="max loss scale mismatch") # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) loss_scale = float(1 << 16) for cycles in range(1, 10): # 1999 updates without overflow produces 1999 stable steps for i in range(1, 2000): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == i assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached new_loss_scale = default_scaler.update(train_step_info) if cycles <= 8: loss_scale *= 2 assert default_scaler._stable_steps_count == 0 assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # After 8 cycles, loss scale should be float(1 << 16)*(2**8) assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on loss_scale = float(1 << 16)*(2**8) for count in range(1, 2050): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == (count % 2000) assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # Setting train_step_info.all_finite = False to test down scaling train_step_info.all_finite = False # Performing 24 updates to half the loss scale each time loss_scale = float(1 << 16)*(2**8) for count in range(1, 25): new_loss_scale = default_scaler.update(train_step_info) loss_scale /= 2 assert default_scaler._stable_steps_count == 0 assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # After 24 updates with gradient overflow, loss scale is 1.0 assert_allclose(new_loss_scale, 1., rtol=rtol, err_msg="loss scale mismatch") # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on for count in range(1, 5): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == 0 assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch")
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, batch_args_option, dataset_len, epochs, use_new_api): dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) if use_new_api: assert use_internal_loss_scaler, 'new api should always use internal loss scaler' new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': fp16, 'loss_scaler': new_api_loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True }, 'distributed': { 'allreduce_post_accumulation': True }, 'lr_scheduler': new_api_lr_scheduler }) param_optimizer = list(model.named_parameters()) params = [{ 'params': [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }, { 'params': [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }] vocab_size = 99 new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'masked_lm_labels', ['batch', 'max_seq_len_in_batch'], ), ('next_sentence_label', [ 'batch', ])], 'outputs': [('loss', [ 1, ], True), ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size]), ('seq_relationship_scores', ['batch', 2])] } optim_config = optim.LambConfig(params=params, lr=2e-5) model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) print("running with new frontend API") else: model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=device, _enable_internal_postprocess=True, gradient_accumulation_steps=gradient_accumulation_steps, # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 world_rank=args.local_rank, world_size=args.world_size, use_mixed_precision=fp16, allreduce_post_accumulation=allreduce_post_accumulation, get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, loss_scaler=loss_scaler if use_internal_loss_scaler else None, _opset_version=12) print("running with old frontend API") # trainig loop eval_batch = None if not use_new_api: model.train() for epoch in range(epochs): for step, batch in enumerate(dataloader): if eval_batch is None: eval_batch = batch if not use_internal_get_lr_this_step: lr = get_lr_this_step(step) learning_rate = torch.tensor([lr]) if not use_internal_loss_scaler and fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) if batch_args_option == BatchArgsOption.List: if not use_internal_get_lr_this_step: batch = batch + [ learning_rate, ] if not use_internal_loss_scaler and fp16: batch = batch + [ loss_scale, ] outputs = model.train_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) print(outputs[0]) # eval if batch_args_option == BatchArgsOption.List: outputs = model.eval_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) outputs = model.eval_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) outputs = model.eval_step(*args, **kwargs) return (output.cpu().numpy() for output in outputs)