def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict and the pytorch model """ seed = 1 torch.manual_seed(seed) set_seed(seed) # PyTorch transformer model setup learning_rate = 0.1 optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint_file_name = 'checkpoint*.ortcp' checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph torch.manual_seed(seed) set_seed(seed) data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return trainer.state_dict(), model
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict and the pytorch model """ seed = 1 torch.manual_seed(seed) set_seed(seed) # PyTorch transformer model setup learning_rate = 0.1 optim_config = optim.LambConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint.experimental_load_checkpoint(trainer, checkpoint_dir) # run an eval step to innitialize the graph torch.manual_seed(seed) set_seed(seed) data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return checkpoint.experimental_state_dict(trainer), model
def testORTTransformerModelExport(seed, device): # Common setup optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'check_model_export': True, }, 'device': { 'id': device, } }) # Setup for the first ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model( device) first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) data, targets = batcher_fn(train_data, 0) _ = first_trainer.train_step(data, targets) assert first_trainer._onnx_model is not None
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda' seed = 1 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True): learning_rate = 0.1 seed = 1 torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load dummy state dummy_init_state = generate_dummy_optim_state(model, optim_config) checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state) # run an eval step to innitialize the graph data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return dummy_init_state, checkpoint.experimental_state_dict(trainer)
def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True): learning_rate = 0.1 seed = 1 torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) if 'distributed' in trainer_opts: train_data = next( islice( _chunkify(train_data, trainer_opts['distributed']['world_size']), trainer_opts['distributed']['world_rank'], None)) # run train steps _train(trainer, train_data, batcher_fn) # save current model parameters as a checkpoint if checkpoint_dir: _save(trainer, checkpoint_dir, state_dict_key_name)
def _create_trainer(zero_enabled=False): """Cerates a simple ORTTrainer for ORTTrainer functional tests""" device = 'cuda' optim_config = optim.LambConfig(lr=0.1) opts = {'device': {'id': device}, 'debug': {'deterministic_compute': True}} if zero_enabled: opts['distributed'] = { 'world_rank': 0, 'world_size': 1, 'horizontal_parallel_size': 1, 'data_parallel_size': 1, 'allreduce_post_accumulation': True, 'deepspeed_zero_optimization': { 'stage': 1 } } model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)) return trainer
def _create_trainer(zero_enabled=False): """Cerates a simple ORTTrainer for ORTTrainer functional tests""" device = "cuda" optim_config = optim.LambConfig(lr=0.1) opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} if zero_enabled: opts["distributed"] = { "world_rank": 0, "world_size": 1, "horizontal_parallel_size": 1, "data_parallel_size": 1, "allreduce_post_accumulation": True, "deepspeed_zero_optimization": { "stage": 1 }, } model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)) return trainer
def testToyBertCheckpointLoadZero(): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, 'device' : {'id' : device}, 'distributed' : {'allreduce_post_accumulation' : True}}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb') # Expected values expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device) segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_eval_loss = actual_eval_loss.cpu().numpy().item(0) # Check results assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): # Common setup total_steps = 10 device = "cuda" seed = 1 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testToyBERTDeterministicCheck(expected_losses): # Common setup train_steps = 10 device = 'cuda' seed = 1 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol)
def create_orttrainer_and_save_checkpoint_bart( device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1): """Instantiate trainer and save checkpoint for BART. - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model - Loads a dummy optimizer state into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict, the expected state dict if present, and the onnx model """ torch.manual_seed(seed) set_seed(seed) ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc = _load_bart_model() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=ort_trainer_opts) # load dummy optimizer state as we are not going to run real training dummy_init_state = generate_dummy_optim_state(model, optim_config) init_state = copy.deepcopy(dummy_init_state) trainer.load_state_dict(dummy_init_state) # run an eval step to innitialize the graph src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc( model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) # save current model parameters as a checkpoint if checkpoint_dir: if _is_model_parallel_run(ort_trainer_opts): _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) # save the initial complete model and optimizer states if ort_trainer_opts.distributed.world_rank == 0: init_state["model"] = {"full_precision": dict()} for initializer in model.graph.initializer: init_state["model"]["full_precision"][ initializer.name] = numpy_helper.to_array(initializer) with open( os.path.join(checkpoint_dir, "expected_state_dict.pkl"), "wb") as f: pickle.dump(init_state, f) else: _save(trainer, checkpoint_dir, state_dict_key_name)
def testToyBertCheckpointFrozenWeights(): # Common setup seed = 1 total_steps = 10 torch.manual_seed(seed) onnxruntime.set_seed(seed) opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]}, } ) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() optim_config = optim.LambConfig() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train for a few steps for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, seed) _ = trainer.train_step(*sample_input) sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1) # Evaluate once to get a base loss loss = trainer.eval_step(*sample_input) # Save checkpoint state_dict = trainer.state_dict() # Load previous state into another instance of ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() optim_config2 = optim.LambConfig() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) trainer2.load_state_dict(state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = trainer2.state_dict() _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict)
def testORTTrainerFrozenWeights(model_params): device = 'cuda' total_steps = 10 seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() # Setup ORTTrainer WITHOUT frozen weights opts_dict = { 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, } opts = orttrainer.ORTTrainerOptions(opts_dict) torch.manual_seed(seed) onnxruntime.set_seed(seed) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params must be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights opts_dict.update({'utils': {'frozen_weights': model_params}}) opts = orttrainer.ORTTrainerOptions(opts_dict) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params CANNOT be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert not any([param in session_state for param in model_params])
def testToyBERTModelBasicTraining(dynamic_shape): model_desc = bert_model_description(dynamic_shape) model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(10): sample_input = generate_random_input_from_model_desc(model_desc) output = trainer.train_step(*sample_input) assert output.shape == torch.Size([])
def create_initialized_orttrainer(device, trainer_opts, use_lamb=True): seed = 1 torch.manual_seed(seed) set_seed(seed) learning_rate = 1e-10 optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) _train(trainer, train_data, batcher_fn) return trainer
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = checkpoint.experimental_state_dict(trainer) ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd ## Modify one of the state values and load into ORTTrainer sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 checkpoint.experimental_load_state_dict(trainer, sd) ## Save a checkpoint ckpt_dir = 'testdata' checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test') del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test') loaded_sd = checkpoint.experimental_state_dict(trainer2) # Assert whether original state and the one loaded from checkpoint matches for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))
def testToyBERTSaveAsONNX(): device = 'cuda' onnx_file_name = '_____temp_toy_bert_onnx_model.onnx' if os.path.exists(onnx_file_name): os.remove(onnx_file_name) assert not os.path.exists(onnx_file_name) # Load trainer model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.save_as_onnx(onnx_file_name) assert os.path.exists(onnx_file_name) with open(onnx_file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(onnx_file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) for initializer, loaded_initializer in zip( trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): assert initializer.name == loaded_initializer.name assert (onnx.helper.printable_graph( trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( trainer._onnx_model.graph)) _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = trainer.state_dict() ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd['model']['full_precision'] ## Modify one of the state values and load into ORTTrainer sd['model']['full_precision'][ 'bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 trainer.load_state_dict(sd) ## Save a checkpoint ckpt_dir = 'testdata' trainer.save_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp')) del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) trainer2.load_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp')) loaded_sd = trainer2.state_dict() # Assert whether original state and the one loaded from checkpoint matches _test_commons.assert_all_states_close_ort(sd, loaded_sd)
def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False): expected_optim_state, trainer_state = load_model_optim_state_and_eval( device, opts, use_lamb) trainer_state = split_state_dict(trainer_state) # round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them. with open( os.path.join(checkpoint_dir, 'distributed_state_' + str(world_rank) + '.pkl'), "wb") as f: pickle.dump(trainer_state, f) dist.barrier() if world_rank == 0: num_states = len(glob.glob1(checkpoint_dir, "distributed_state*")) optimizer_states = dict() for rank in range(num_states): rank_state_dict = None with open( os.path.join(checkpoint_dir, 'distributed_state_' + str(rank) + '.pkl'), 'rb') as f: rank_state_dict = pickle.load(f) # collect optimizer states for later comparison since they are sharded aggregate_states(optimizer_states, rank_state_dict['optimizer']) # compare optimizer states optimizer_config = optim.LambConfig( ) if use_lamb else optim.AdamConfig() actual_optim_state = get_optim_state_from_state_dict( optimizer_states, optimizer_config) assert actual_optim_state.keys() == expected_optim_state.keys() for param_name, a_state in actual_optim_state.items(): for k, v in a_state.items(): assert_allclose( v.reshape(expected_optim_state[param_name][k].shape), expected_optim_state[param_name][k], err_msg= f"Optimizer state mismatch for param {param_name}, key {k}" ) dist.barrier() os.remove( os.path.join(checkpoint_dir, 'distributed_state_' + str(world_rank) + '.pkl'))
def _run_adasum_tests(opts): # Common setup seed = 42 optim_config = optim.LambConfig() # Setup ORTTRainer torch.manual_seed(seed) set_seed(seed) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( opts.device.id) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Train once to see flag going through data, targets = batcher_fn(train_data, 0) result = trainer.train_step(data, targets) assert result is not None
def create_orttrainer_and_load_checkpoint_bart(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict, the expected state dict if present, and the onnx model """ torch.manual_seed(seed) set_seed(seed) # model setup optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc = _load_bart_model() trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint_file_name = "checkpoint*.ortcp" checkpoint_files = glob.glob( os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc( model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) expected_state_dict = None fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl") if os.path.isfile(fname): with open(fname, "rb") as f: expected_state_dict = pickle.load(f) return trainer.state_dict(), expected_state_dict, model
def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=ort_trainer_opts) if "distributed" in trainer_opts: train_data = next( islice( _chunkify(train_data, trainer_opts["distributed"]["world_size"]), trainer_opts["distributed"]["world_rank"], None, )) # run train steps _train(trainer, train_data, batcher_fn) # save current model parameters as a checkpoint if checkpoint_dir: if _is_model_parallel_run(ort_trainer_opts): _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) else: _save(trainer, checkpoint_dir, state_dict_key_name)
trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) trainer2.load_state_dict(state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = trainer2.state_dict() _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) @pytest.mark.parametrize("optimizer, mixedprecision_enabled", [ (optim.LambConfig(), False), (optim.AdamConfig(), False), (optim.LambConfig(), True), (optim.AdamConfig(), True), ]) def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optimizer opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, batch_args_option, dataset_len, epochs, use_new_api): dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) if use_new_api: assert use_internal_loss_scaler, 'new api should always use internal loss scaler' new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': fp16, 'loss_scaler': new_api_loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True }, 'distributed': { 'allreduce_post_accumulation': True }, 'lr_scheduler': new_api_lr_scheduler }) param_optimizer = list(model.named_parameters()) params = [{ 'params': [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }, { 'params': [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }] vocab_size = 99 new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'masked_lm_labels', ['batch', 'max_seq_len_in_batch'], ), ('next_sentence_label', [ 'batch', ])], 'outputs': [('loss', [ 1, ], True), ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size]), ('seq_relationship_scores', ['batch', 2])] } optim_config = optim.LambConfig(params=params, lr=2e-5) model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) print("running with new frontend API") else: model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=device, _enable_internal_postprocess=True, gradient_accumulation_steps=gradient_accumulation_steps, # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 world_rank=args.local_rank, world_size=args.world_size, use_mixed_precision=fp16, allreduce_post_accumulation=allreduce_post_accumulation, get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, loss_scaler=loss_scaler if use_internal_loss_scaler else None, _opset_version=14, _use_deterministic_compute=True) print("running with old frontend API") # trainig loop eval_batch = None if not use_new_api: model.train() for epoch in range(epochs): for step, batch in enumerate(dataloader): if eval_batch is None: eval_batch = batch if not use_internal_get_lr_this_step: lr = get_lr_this_step(step) learning_rate = torch.tensor([lr]) if not use_internal_loss_scaler and fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) if batch_args_option == BatchArgsOption.List: if not use_internal_get_lr_this_step: batch = batch + [ learning_rate, ] if not use_internal_loss_scaler and fp16: batch = batch + [ loss_scale, ] outputs = model.train_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) # eval if batch_args_option == BatchArgsOption.List: outputs = model.eval_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) outputs = model.eval_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) outputs = model.eval_step(*args, **kwargs) return (output.cpu().numpy() for output in outputs)
def create_ort_trainer(args, device, model): # MODEL INPUT AND OUTPUT DESCRIPTION vocab_size = 30528 micro_batch = args.train_batch_size // args.gradient_accumulation_steps model_desc = { 'inputs': [('input_ids', [args.train_batch_size, args.max_seq_length]), ('segment_ids', [args.train_batch_size, args.max_seq_length]), ('input_mask', [args.train_batch_size, args.max_seq_length]), ('masked_lm_labels', [args.train_batch_size, args.max_seq_length]), ('next_sentence_labels', [args.train_batch_size, 2])], 'outputs': [('loss', [], True)] } # TRAINING OPTIMIZER SPECIFICATION param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optim_config = optim.LambConfig( lr=args.learning_rate, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=True, params=[{ 'params': [n for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'alpha': 0.9, 'beta': 0.999, 'lambda_coef': 0.00, 'epsilon': 1e-6 }]) # LEARNING RATE SCHEDULE SPECIFICATION lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler( total_steps=int(args.max_steps), warmup=args.warmup_proportion) # ONNXRUNTIME TRAINER OPTIONS trainer_config = ORTTrainerOptions({ 'device': { 'id': str(device), 'mem_limit': int(args.gpu_memory_limit_gb * 1024 * 1024 * 1024) }, 'batch': { 'gradient_accumulation_steps': args.gradient_accumulation_steps }, 'distributed': { 'world_size': args.world_size, 'world_rank': args.world_rank, 'allreduce_post_accumulation': True if args.allreduce_post_accumulation else False, 'deepspeed_zero_optimization': { 'stage': 1 if args.deepspeed_zero_stage else 0, } }, 'lr_scheduler': lr_scheduler, 'mixed_precision': { 'enabled': True if args.fp16 else False, } }) # ONNXRUNTIME TRAINER CONSTRUCTION (loss fn embedded in model) trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=None, options=trainer_config) return trainer