def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental( loss_scaler, legacy_loss_scaler): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=0.001) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=legacy_loss_scaler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTDeterministicCheck(expected_losses): # Common setup train_steps = 10 device = "cuda" seed = 1 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol)
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True): learning_rate = 0.1 seed = 1 torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load dummy state dummy_init_state = generate_dummy_optim_state(model, optim_config) checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state) # run an eval step to innitialize the graph data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return dummy_init_state, checkpoint.experimental_state_dict(trainer)
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optimizer opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': mixedprecision_enabled, }, 'distributed': { 'allreduce_post_accumulation': True } }) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() dummy_init_state = _test_commons.generate_dummy_optim_state( model, optimizer) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.load_state_dict(dummy_init_state) # Expected values input_ids = torch.tensor( [[26598], [21379], [19922], [5219], [5644], [20559], [23777], [25672], [22969], [16824], [16822], [635], [27399], [20647], [18519], [15546]], device=device) segment_ids = torch.tensor([[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device) input_mask = torch.tensor([[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device) masked_lm_labels = torch.tensor( [[25496], [16184], [11005], [16228], [14884], [21660], [8678], [23083], [4027], [8397], [11921], [1333], [26482], [1666], [17925], [27978]], device=device) next_sentence_labels = torch.tensor( [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_state_dict = trainer.state_dict() del actual_state_dict['model'] _test_commons.assert_all_states_close_ort(actual_state_dict, dummy_init_state)
def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True): learning_rate = 0.1 seed = 1 torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) if 'distributed' in trainer_opts: train_data = next( islice( _chunkify(train_data, trainer_opts['distributed']['world_size']), trainer_opts['distributed']['world_rank'], None)) # run train steps _train(trainer, train_data, batcher_fn) # save current model parameters as a checkpoint if checkpoint_dir: _save(trainer, checkpoint_dir, state_dict_key_name)
def _create_trainer(zero_enabled=False): """Cerates a simple ORTTrainer for ORTTrainer functional tests""" device = "cuda" optim_config = optim.LambConfig(lr=0.1) opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} if zero_enabled: opts["distributed"] = { "world_rank": 0, "world_size": 1, "horizontal_parallel_size": 1, "data_parallel_size": 1, "allreduce_post_accumulation": True, "deepspeed_zero_optimization": { "stage": 1 }, } model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)) return trainer
def _create_trainer(zero_enabled=False): """Cerates a simple ORTTrainer for ORTTrainer functional tests""" device = 'cuda' optim_config = optim.LambConfig(lr=0.1) opts = {'device': {'id': device}, 'debug': {'deterministic_compute': True}} if zero_enabled: opts['distributed'] = { 'world_rank': 0, 'world_size': 1, 'horizontal_parallel_size': 1, 'data_parallel_size': 1, 'allreduce_post_accumulation': True, 'deepspeed_zero_optimization': { 'stage': 1 } } model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)) return trainer
def testToyBertCheckpointLoadZero(): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, 'device' : {'id' : device}, 'distributed' : {'allreduce_post_accumulation' : True}}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) ckpt_dir = _test_helpers._get_name("ort_ckpt") checkpoint.experimental_load_checkpoint(trainer, ckpt_dir, 'bert_toy_lamb') # Expected values expected_eval_loss = [10.997552871] input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device) segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values actual_eval_loss = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_eval_loss = actual_eval_loss.cpu().numpy().item(0) # Check results assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
def test_external_graph_transformer_triggering(self): input_size = 784 hidden_size = 500 num_classes = 10 batch_size = 128 model = NeuralNet(input_size, hidden_size, num_classes) model_desc = { "inputs": [ ("x", [batch_size, input_size]), ( "target", [ batch_size, ], ), ], "outputs": [("loss", [], True)], } optim_config = optim.SGDConfig() opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}}) model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer data = torch.rand(batch_size, input_size) target = torch.randint(0, 10, (batch_size, )) with OutputGrabber() as out: loss = model.train_step(data, target) assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext
def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): # Common setup total_steps = 10 device = "cuda" seed = 1 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda' seed = 1 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) # Check output _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testORTTransformerModelExport(seed, device): # Common setup optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ "debug": { "check_model_export": True, }, "device": { "id": device, }, }) # Setup for the first ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model( device) first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) data, targets = batcher_fn(train_data, 0) _ = first_trainer.train_step(data, targets) assert first_trainer._onnx_model is not None
def testORTTrainerFrozenWeights(model_params): device = 'cuda' total_steps = 10 seed = 1 # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() # Setup ORTTrainer WITHOUT frozen weights opts_dict = { 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, } opts = orttrainer.ORTTrainerOptions(opts_dict) torch.manual_seed(seed) onnxruntime.set_seed(seed) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params must be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights opts_dict.update({'utils' : {'frozen_weights' : model_params}}) opts = orttrainer.ORTTrainerOptions(opts_dict) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) trainer.train_step(*sample_input) # All model_params CANNOT be in the session state assert trainer._onnx_model is not None session_state = trainer._training_session.get_state() assert not any([param in session_state for param in model_params])
def testToyBERTModelGradientAccumulationLegacyExperimental( gradient_accumulation_steps): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) loss = trainer.train_step(*sample_input) experimental_losses.append(loss.cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): # Common setup train_steps = 512 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device': { 'id': device, }, }) optim_config = optimizer_config(lr=0.01) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) if optimizer_config == optim.AdamConfig: legacy_optimizer = 'AdamOptimizer' elif optimizer_config == optim.LambConfig: legacy_optimizer = 'LambOptimizer' elif optimizer_config == optim.SGDConfig: legacy_optimizer = 'SGDOptimizer' else: raise RuntimeError("Invalid optimizer_config") device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, legacy_optimizer, None, learning_rate_description, device, _use_deterministic_compute=True) legacy_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True)
def create_orttrainer_and_save_checkpoint_bart( device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1): """Instantiate trainer and save checkpoint for BART. - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model - Loads a dummy optimizer state into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict, the expected state dict if present, and the onnx model """ torch.manual_seed(seed) set_seed(seed) ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc = _load_bart_model() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=ort_trainer_opts) # load dummy optimizer state as we are not going to run real training dummy_init_state = generate_dummy_optim_state(model, optim_config) init_state = copy.deepcopy(dummy_init_state) trainer.load_state_dict(dummy_init_state) # run an eval step to innitialize the graph src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc( model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) # save current model parameters as a checkpoint if checkpoint_dir: if _is_model_parallel_run(ort_trainer_opts): _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) # save the initial complete model and optimizer states if ort_trainer_opts.distributed.world_rank == 0: init_state["model"] = {"full_precision": dict()} for initializer in model.graph.initializer: init_state["model"]["full_precision"][ initializer.name] = numpy_helper.to_array(initializer) with open( os.path.join(checkpoint_dir, "expected_state_dict.pkl"), "wb") as f: pickle.dump(init_state, f) else: _save(trainer, checkpoint_dir, state_dict_key_name)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL API torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig( params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False ) opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "AdamOptimizer", legacy_optim_map, learning_rate_description, device, _use_deterministic_compute=True, ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) legacy_sample_input = [*sample_input, learning_rate] legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelBasicTraining(dynamic_shape): model_desc = bert_model_description(dynamic_shape) model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(10): sample_input = generate_random_input_from_model_desc(model_desc) output = trainer.train_step(*sample_input) assert output.shape == torch.Size([])
def prepare_model(args, device): config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) # config.num_hidden_layers = 12 if args.force_num_hidden_layers: logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) config.num_hidden_layers = args.force_num_hidden_layers model = BertForPreTraining(config) if args.init_state_dict is not None: model.load_state_dict(args.init_state_dict) model_desc = bert_model_description(config) lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion) loss_scaler = amp.DynamicLossScaler() if args.fp16 else None options = orttrainer.ORTTrainerOptions({'batch': { 'gradient_accumulation_steps': args.gradient_accumulation_steps}, 'device': {'id': str(device)}, 'mixed_precision': { 'enabled': args.fp16, 'loss_scaler': loss_scaler}, 'graph_transformer': { 'attn_dropout_recompute': args.attn_dropout_recompute, 'gelu_recompute': args.gelu_recompute, 'transformer_layer_recompute': args.transformer_layer_recompute, }, 'debug': {'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True}, 'distributed': { 'world_rank': max(0, args.local_rank), 'world_size': args.world_size, 'local_rank': max(0, args.local_rank), 'allreduce_post_accumulation': args.allreduce_post_accumulation, 'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage}, 'enable_adasum': False}, 'lr_scheduler': lr_scheduler }) param_optimizer = list(model.named_parameters()) no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] params = [{ 'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, { 'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) return model
def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): return # TODO: re-enable after nondeterminism on backend is fixed # Common setup device = "cuda" total_steps = 10 seed = 1 warmup = 0.05 cycles = 0.5 power = 1.0 lr_end = 1e-7 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Setup LR Schedulers if ( lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler ): lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, "lr_scheduler": lr_scheduler, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] learning_rates = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) # Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def testToyBertCheckpointFrozenWeights(): # Common setup seed = 1 total_steps = 10 torch.manual_seed(seed) onnxruntime.set_seed(seed) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'utils': { 'frozen_weights': ['bert.encoder.layer.0.attention.self.value.weight'] } }) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() optim_config = optim.LambConfig() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train for a few steps for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, seed) _ = trainer.train_step(*sample_input) sample_input = generate_random_input_from_model_desc( model_desc, seed + total_steps + 1) # Evaluate once to get a base loss loss = trainer.eval_step(*sample_input) # Save checkpoint state_dict = trainer.state_dict() # Load previous state into another instance of ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() optim_config2 = optim.LambConfig() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) trainer2.load_state_dict(state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = trainer2.state_dict() _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict)
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = checkpoint.experimental_state_dict(trainer) ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd ## Modify one of the state values and load into ORTTrainer sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 checkpoint.experimental_load_state_dict(trainer, sd) ## Save a checkpoint ckpt_dir = 'testdata' checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test') del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test') loaded_sd = checkpoint.experimental_state_dict(trainer2) # Assert whether original state and the one loaded from checkpoint matches for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))
def testToyBERTSaveAsONNX(): device = 'cuda' onnx_file_name = '_____temp_toy_bert_onnx_model.onnx' if os.path.exists(onnx_file_name): os.remove(onnx_file_name) assert not os.path.exists(onnx_file_name) # Load trainer model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.save_as_onnx(onnx_file_name) assert os.path.exists(onnx_file_name) with open(onnx_file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(onnx_file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) for initializer, loaded_initializer in zip( trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): assert initializer.name == loaded_initializer.name assert (onnx.helper.printable_graph( trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( trainer._onnx_model.graph)) _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = trainer.state_dict() ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd['model']['full_precision'] ## Modify one of the state values and load into ORTTrainer sd['model']['full_precision'][ 'bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 trainer.load_state_dict(sd) ## Save a checkpoint ckpt_dir = 'testdata' trainer.save_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp')) del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) trainer2.load_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp')) loaded_sd = trainer2.state_dict() # Assert whether original state and the one loaded from checkpoint matches _test_commons.assert_all_states_close_ort(sd, loaded_sd)
def create_orttrainer_and_load_checkpoint_bart(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict, the expected state dict if present, and the onnx model """ torch.manual_seed(seed) set_seed(seed) # model setup optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc = _load_bart_model() trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint_file_name = "checkpoint*.ortcp" checkpoint_files = glob.glob( os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc( model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) expected_state_dict = None fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl") if os.path.isfile(fname): with open(fname, "rb") as f: expected_state_dict = pickle.load(f) return trainer.state_dict(), expected_state_dict, model
def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=ort_trainer_opts) if "distributed" in trainer_opts: train_data = next( islice( _chunkify(train_data, trainer_opts["distributed"]["world_size"]), trainer_opts["distributed"]["world_rank"], None, )) # run train steps _train(trainer, train_data, batcher_fn) # save current model parameters as a checkpoint if checkpoint_dir: if _is_model_parallel_run(ort_trainer_opts): _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) else: _save(trainer, checkpoint_dir, state_dict_key_name)
def test_single_precision_adasum_on_gpu(): # Common setup world_rank = get_mpi_context_world_rank() world_size = get_mpi_context_world_size() set_cuda_device_id(world_rank) device = "cuda:" + str(world_rank) opts = orttrainer.ORTTrainerOptions({ "debug": { "deterministic_compute": True }, "device": { "id": device, }, "distributed": { "world_rank": world_rank, "world_size": world_size, "enable_adasum": True, }, }) _run_adasum_tests(opts)
def test_single_precision_adasum_on_gpu(): # Common setup world_rank = get_mpi_context_world_rank() world_size = get_mpi_context_world_size() set_cuda_device_id(world_rank) device = 'cuda:' + str(world_rank) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'distributed': { 'world_rank': world_rank, 'world_size': world_size, 'enable_adasum': True, } }) _run_adasum_tests(opts)
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict and the pytorch model """ torch.manual_seed(seed) set_seed(seed) # PyTorch transformer model setup optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint_file_name = "checkpoint*.ortcp" checkpoint_files = glob.glob( os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph torch.manual_seed(seed) set_seed(seed) data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return trainer.state_dict(), model
def create_initialized_orttrainer(device, trainer_opts, use_lamb=True, seed=1, learning_rate=1e-10): torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) _train(trainer, train_data, batcher_fn) return trainer