def testORTTrainerFrozenWeights(model_params): # Common setup device = 'cuda' total_steps = 10 # Setup ORTTrainer WITHOUT frozen weights options = orttrainer.ORTTrainerOptions({}) model, model_desc, my_loss, batcher_fn,\ train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = trainer.train_step(data, targets) # All model_params must be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights options = orttrainer.ORTTrainerOptions({'utils' : {'frozen_weights' : model_params}}) model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = trainer.train_step(data, targets) # All model_params CANNOT be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert not all([param in session_state for param in model_params])
def testORTDeterministicCompute(seed, device): # Common setup optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device' : { 'id' : device, 'mem_limit' : 10*1024*1024 } }) # Setup for the first ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) data, targets = batcher_fn(train_data, 0) _ = first_trainer.train_step(data, targets) assert first_trainer._onnx_model is not None # Setup for the second ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device) second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) _ = second_trainer.train_step(data, targets) assert second_trainer._onnx_model is not None # Compare two different instances with identical setup assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) _test_helpers.assert_onnx_weights(first_trainer, second_trainer)
def testORTTrainerFrozenWeights(model_params): device = 'cuda' total_steps = 10 seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() # Setup ORTTrainer WITHOUT frozen weights opts_dict = { 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, } opts = orttrainer.ORTTrainerOptions(opts_dict) torch.manual_seed(seed) onnxruntime.set_seed(seed) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params must be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights opts_dict.update({'utils': {'frozen_weights': model_params}}) opts = orttrainer.ORTTrainerOptions(opts_dict) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params CANNOT be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert not any([param in session_state for param in model_params])
def testToyBertCheckpointFrozenWeights(): # Common setup seed = 1 total_steps = 10 torch.manual_seed(seed) onnxruntime.set_seed(seed) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'utils': { 'frozen_weights': ['bert.encoder.layer.0.attention.self.value.weight'] } }) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() optim_config = optim.LambConfig() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train for a few steps for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, seed) _ = trainer.train_step(*sample_input) sample_input = generate_random_input_from_model_desc( model_desc, seed + total_steps + 1) # Evaluate once to get a base loss loss = trainer.eval_step(*sample_input) # Save checkpoint state_dict = checkpoint.experimental_state_dict(trainer) # Load previous state into another instance of ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() optim_config2 = optim.LambConfig() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) checkpoint.experimental_load_state_dict(trainer2, state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = checkpoint.experimental_state_dict(trainer2) assert state_dict.keys() == loaded_state_dict.keys()
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = checkpoint.experimental_state_dict(trainer) ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd ## Modify one of the state values and load into ORTTrainer sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 checkpoint.experimental_load_state_dict(trainer, sd) ## Save a checkpoint ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test') del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test') loaded_sd = checkpoint.experimental_state_dict(trainer2) # Assert whether original state and the one loaded from checkpoint matches for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))
def testToyBERTSaveAsONNX(): device = 'cuda' onnx_file_name = '_____temp_toy_bert_onnx_model.onnx' if os.path.exists(onnx_file_name): os.remove(onnx_file_name) assert not os.path.exists(onnx_file_name) # Load trainer model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.save_as_onnx(onnx_file_name) assert os.path.exists(onnx_file_name) with open(onnx_file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(onnx_file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) for initializer, loaded_initializer in zip( trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): assert initializer.name == loaded_initializer.name assert (onnx.helper.printable_graph( trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( trainer._onnx_model.graph)) _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
def testToyBERTModelGradientAccumulationLegacyExperimental( gradient_accumulation_steps): # Common setup total_steps = 10 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) loss = trainer.train_step(*sample_input) experimental_losses.append(loss.cpu().item()) # LEGACY IMPLEMENTATION device = torch.device(device) torch.manual_seed(seed) onnxruntime.set_seed(seed) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "LambOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, rtol=1e-6)
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): # Common setup total_steps = 10 device = "cuda" seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6)
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-4)
def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): # Common setup device = 'cuda' total_steps = 10 seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'lr_scheduler' : lr_scheduler(total_steps=total_steps, warmup=0.5) }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] learning_rates = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) # Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=1e-6) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6)
def testToyBERTDeterministicCheck(expected_losses): # Common setup train_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=1e-6)
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental( loss_scaler, legacy_loss_scaler): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=0.001) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=legacy_loss_scaler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBertCheckpointLoadZero(): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, 'device' : {'id' : device}, 'distributed' : {'allreduce_post_accumulation' : True}}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb') # Expected values expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device) segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_eval_loss = actual_eval_loss.cpu().numpy().item(0) # Check results assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): # Common setup device = 'cuda' total_steps = 10 seed = 1 warmup = 0.05 cycles = 0.5 power = 1. lr_end = 1e-7 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Setup LR Schedulers if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'lr_scheduler': lr_scheduler }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] learning_rates = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) # Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=1e-6) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6)
def testToyBERTModelSimpleTrainStep(dynamic_shape): model_desc = bert_model_description(dynamic_shape) model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(10): sample_input = generate_random_input_from_model_desc(model_desc) output = trainer.train_step(*sample_input) assert output.shape == torch.Size([])
def testToyBERTModelLegacyExperimentalBasicTraining(): # Common setup train_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( lr=0.001) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "LambOptimizer", None, learning_rate_description, device) legacy_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True, rtol=1e-5)
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): # Common data total_steps = 5 bptt=35 # Setup experimental API torch.manual_seed(seed) set_seed(seed) loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True,}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] experimental_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) experimental_preds_dtype.append(exp_preds.dtype) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=loss_scaler) # Training loop legacy_loss = [] legacy_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) legacy_preds_dtype.append(leg_preds.dtype) # Compare legacy vs experimental APIs assert experimental_preds_dtype == legacy_preds_dtype _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2) _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
def testToyBertStateDictWrapModelLossFn(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling class LinearModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(2, 4) def forward(self, y=None, x=None): if y is not None: return self.linear(x) + y else: return self.linear(x) + torch.ones(2, 4) pt_model = LinearModel() model_desc = { 'inputs': [('x', [2, 2]), ('label', [ 2, ])], 'outputs': [('loss', [], True), ('output', [2, 4])] } optim_config = optim.SGDConfig(lr=0.02) def loss_fn(x, label): return F.nll_loss(F.log_softmax(x, dim=1), label) trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn) # Compare resulting state_dict keys before train state_dict = checkpoint.experimental_state_dict(trainer) assert state_dict == {} # Executing train_step() once data = torch.randn(2, 2) label = torch.tensor([0, 1], dtype=torch.int64) trainer.train_step(x=data, label=label) # Compare resulting state_dict keys after train state_dict = checkpoint.experimental_state_dict(trainer) assert state_dict.keys() == {'linear.bias', 'linear.weight'}
def testORTTrainerDynamicShape(dynamic_axes): # Common setup device = 'cuda' # Setup ORTTrainer options = orttrainer.ORTTrainerOptions({}) model, model_desc, my_loss, batcher_fn,\ train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop total_steps = 10 for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = trainer.train_step(data, targets) assert trainer._onnx_model is not None
def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): torch.manual_seed(seed) set_seed(seed) # Setup ORTTrainer options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop actual_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) loss, _ = trainer.train_step(data, targets) actual_loss.append(loss.cpu()) # Compare legacy vs experimental APIs _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=1e-6)
def train_ort_model(epoch=1): device = "cuda" ntokens=28785 bptt = 35 batch_size = 20 initial_lr = 0.001 train_data, val_data, test_data = prepare_data(device, 20, 20) pt_model_path = os.path.join('pt_model.py') pt_model = _utils.import_module_from_file(pt_model_path) model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) model_desc = {'inputs': [('input1', [bptt, batch_size]), ('label', [bptt * batch_size])], 'outputs': [('loss', [], True), ('predictions', [bptt, batch_size, ntokens])]} opts = orttrainer.ORTTrainerOptions({'device' : {'id' : device}}) optim_config = optim.SGDConfig(lr=initial_lr) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) total_loss = 0. start_time = time.time() for batch, i in enumerate(range(0, train_data.size(0) - 35, bptt)): data, targets = get_batch(train_data, i) output = trainer.train_step(data, targets) total_loss += output[0].item() log_interval = 200 if batch % log_interval == 0 and batch > 0: cur_loss = total_loss / log_interval elapsed = time.time() - start_time print('| {} | epoch {:3d} | {:5d}/{:5d} batches | ' 'lr {:02.3f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( device, epoch, batch, len(train_data) // bptt, initial_lr, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0 start_time = time.time()
def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): total_steps = len(expected_loss) torch.manual_seed(seed) set_seed(seed) bptt=35 # Setup ORTTrainer loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop actual_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) if fetches: trainer._train_step_info.fetches=['loss'] loss = trainer.train_step(data, targets) else: loss, _ = trainer.train_step(data, targets) actual_loss.append(loss.cpu()) # Eval once just to test fetches in action val_data, val_targets = batcher_fn(val_data, 0) if fetches: trainer._train_step_info.fetches=['loss'] loss = trainer.eval_step(val_data, val_targets) trainer._train_step_info.fetches=[] loss, preds = trainer.eval_step(val_data, val_targets) # Compare loss to ground truth computed from current ORTTrainer API _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4) assert trainer._onnx_model is not None
def testORTTrainerToyBERTModel(): # Common setup seed = 1 torch.manual_seed(seed) set_seed(seed) # Modeling pytorch_transformer_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata') bert_onnx_model_path = os.path.join(pytorch_transformer_path, "bert_toy_postprocessed.onnx") model = onnx.load(bert_onnx_model_path) model_desc = bert_model_description() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Generating fake input sample_input = generate_random_input_from_model_desc(model_desc) # Train output = trainer.train_step(*sample_input) # Check output assert output.shape == torch.Size([])
def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): # Common data torch.set_printoptions(precision=10) # Setup experimental API torch.manual_seed(seed) set_seed(seed) options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) # Training loop legacy_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) # Compare legacy vs experimental APIs _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): # Common data total_steps = 5 bptt = 35 # Setup for the experimental ORTTRainer run torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'device' : { 'id' : device }, 'debug' : { 'deterministic_compute': True }, }) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) _ = trainer.train_step(data, targets) # Setup for the legacy ORTTrainer run torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) # Compare legacy vs experimental APIs _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = ( self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 ) else: t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs if self.use_new_api: lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None device = self.args.device.type device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0' options = orttrainer.ORTTrainerOptions({'batch' : { 'gradient_accumulation_steps' : self.args.gradient_accumulation_steps}, 'device': {'id': device}, 'mixed_precision': { 'enabled': self.args.fp16, 'loss_scaler': loss_scaler}, 'debug': {'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': False}, 'distributed': {'allreduce_post_accumulation': True}, 'lr_scheduler': lr_scheduler }) param_optimizer = list(self.model.named_parameters()) params = [{ 'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], "weight_decay_mode": 1, }, { 'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], "weight_decay_mode": 1, } ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options) else: def map_optimizer_attributes(name): no_decay = "bias" in name or "LayerNorm.weight" in name if no_decay: return {"weight_decay_mode" : 1} else: return {"weight_decay_mode" : 1} get_lr_this_step = get_linear_schedule_with_warmup(self.args.warmup_steps, t_total, self.args.learning_rate) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) if self.args.fp16 else None self.model = ORTTrainer(self.model, None, self.model_desc, "AdamOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args.gradient_accumulation_steps, use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, loss_scaler=loss_scaler, enable_grad_norm_clip=False, _opset_version=12, _use_deterministic_compute=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) ): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( global_step == 1 and self.args.logging_first_step ): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps if not self.use_new_api: learning_rate_scalar = get_lr_this_step(global_step) logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, batch_args_option, dataset_len, epochs, use_new_api): dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) if use_new_api: assert use_internal_loss_scaler, 'new api should always use internal loss scaler' new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': fp16, 'loss_scaler': new_api_loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True }, 'distributed': { 'allreduce_post_accumulation': True }, 'lr_scheduler': new_api_lr_scheduler }) param_optimizer = list(model.named_parameters()) params = [{ 'params': [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }, { 'params': [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }] vocab_size = 99 new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'masked_lm_labels', ['batch', 'max_seq_len_in_batch'], ), ('next_sentence_label', [ 'batch', ])], 'outputs': [('loss', [ 1, ], True), ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size]), ('seq_relationship_scores', ['batch', 2])] } optim_config = optim.LambConfig(params=params, lr=2e-5) model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) print("running with new frontend API") else: model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=device, _enable_internal_postprocess=True, gradient_accumulation_steps=gradient_accumulation_steps, # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 world_rank=args.local_rank, world_size=args.world_size, use_mixed_precision=fp16, allreduce_post_accumulation=allreduce_post_accumulation, get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, loss_scaler=loss_scaler if use_internal_loss_scaler else None, _opset_version=12) print("running with old frontend API") # trainig loop eval_batch = None if not use_new_api: model.train() for epoch in range(epochs): for step, batch in enumerate(dataloader): if eval_batch is None: eval_batch = batch if not use_internal_get_lr_this_step: lr = get_lr_this_step(step) learning_rate = torch.tensor([lr]) if not use_internal_loss_scaler and fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) if batch_args_option == BatchArgsOption.List: if not use_internal_get_lr_this_step: batch = batch + [ learning_rate, ] if not use_internal_loss_scaler and fp16: batch = batch + [ loss_scale, ] outputs = model.train_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) print(outputs[0]) # eval if batch_args_option == BatchArgsOption.List: outputs = model.eval_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) outputs = model.eval_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) outputs = model.eval_step(*args, **kwargs) return (output.cpu().numpy() for output in outputs)
def testToyBERTModelLegacyExperimentalCustomOptimParameters( params, legacy_optim_map): # Common setup total_steps = 10 device = "cuda" seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig(params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) torch.manual_seed(seed) onnxruntime.set_seed(seed) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( trainer.optim_config.lr) torch.manual_seed(seed) onnxruntime.set_seed(seed) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "LambOptimizer", legacy_optim_map, learning_rate_description, device, _use_deterministic_compute=True) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) legacy_sample_input = [*sample_input, learning_rate] legacy_losses.append( legacy_trainer.train_step(legacy_sample_input).cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): ############################################################################ # These tests require hard-coded values for 'total_steps' and 'initial_lr' # ############################################################################ # Common setup total_steps = 10 device = 'cuda' seed = 1 warmup = 0.05 cycles = 0.5 power = 1. lr_end = 1e-7 # Setup both Experimental and Legacy LR Schedulers before the experimental loop if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup) elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles) elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid legacy_lr_scheduler") if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'lr_scheduler': lr_scheduler }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( initial_lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, get_lr_this_step=legacy_lr_scheduler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): total_steps = 1 initial_lr = 1. tolerance = 1e-4 # PyTorch Transformer model as example opts = {'device' : {'id' : device}} if lr_scheduler: total_steps = 10 opts.update({'lr_scheduler' : lr_scheduler(total_steps=total_steps, warmup=0.5)}) opts = orttrainer.ORTTrainerOptions(opts) optim_config = optim.LambConfig(lr=initial_lr) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Run a train or evaluation step if step_fn == 'eval_step': data, targets = batcher_fn(val_data, 0) elif step_fn == 'train_step': data, targets = batcher_fn(train_data, 0) else: raise ValueError('Invalid step_fn') # Export model to ONNX if step_fn == 'eval_step': step_fn = trainer.eval_step output = trainer.eval_step(data, targets) elif step_fn == 'train_step': step_fn = trainer.train_step for i in range(total_steps): output = trainer.train_step(data, targets) if lr_scheduler: lr_list = trainer.options.lr_scheduler.get_last_lr() assert_allclose(lr_list[0], expected_lr_values[i], rtol=tolerance, err_msg="lr mismatch") else: raise ValueError('Invalid step_fn') assert trainer._onnx_model is not None # Check output shape after train/eval step for out, desc in zip(output, trainer.model_desc.outputs): if trainer.loss_fn and desc.is_loss: continue assert list(out.size()) == desc.shape # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs sig = inspect.signature(model.forward) for i in range(len(sig.parameters.keys())): input_name = trainer.model_desc.inputs[i][0] input_dim = trainer.model_desc.inputs[i][1] input_type = trainer.model_desc.inputs[i][2] assert trainer._onnx_model.graph.input[i].name == input_name for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): assert input_dim[dim_idx] == dim.dim_value assert input_type == _utils.dtype_onnx_to_torch( trainer._onnx_model.graph.input[i].type.tensor_type.elem_type) # Check name, shape and dtype of the ORT graph outputs for i in range(len(trainer.model_desc.outputs)): output_name = trainer.model_desc.outputs[i][0] output_dim = trainer.model_desc.outputs[i][1] output_type = trainer.model_desc.outputs[i][3] assert trainer._onnx_model.graph.output[i].name == output_name for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): assert output_dim[dim_idx] == dim.dim_value assert output_type == _utils.dtype_onnx_to_torch( trainer._onnx_model.graph.output[i].type.tensor_type.elem_type) # Save current model as ONNX as a file file_name = os.path.join('..','..','..','temp_onnx_model.onnx') trainer.save_as_onnx(file_name) assert os.path.exists(file_name) with open(file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) step_fn(data, targets) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) assert (trainer_from_onnx._onnx_model == trainer._onnx_model) assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph) assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))