def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): # Common data total_steps = 5 bptt=35 # Setup experimental API torch.manual_seed(seed) set_seed(seed) loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True,}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop experimental_loss = [] experimental_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) exp_loss, exp_preds = trainer.train_step(data, targets) experimental_loss.append(exp_loss.cpu()) experimental_preds_dtype.append(exp_preds.dtype) # Setup legacy API torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True) loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True) legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device=device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=loss_scaler) # Training loop legacy_loss = [] legacy_preds_dtype = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) legacy_loss.append(leg_loss.cpu()) legacy_preds_dtype.append(leg_preds.dtype) # Compare legacy vs experimental APIs assert experimental_preds_dtype == legacy_preds_dtype _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2) _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): total_steps = len(expected_loss) torch.manual_seed(seed) set_seed(seed) bptt=35 # Setup ORTTrainer loss_scaler = amp.DynamicLossScaler() options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, 'mixed_precision' : { 'enabled' : True, 'loss_scaler' : loss_scaler}, 'debug' : {'deterministic_compute' : True}}) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) # Training loop actual_loss = [] for i in range(total_steps): data, targets = batcher_fn(train_data, i) if fetches: trainer._train_step_info.fetches=['loss'] loss = trainer.train_step(data, targets) else: loss, _ = trainer.train_step(data, targets) actual_loss.append(loss.cpu()) # Eval once just to test fetches in action val_data, val_targets = batcher_fn(val_data, 0) if fetches: trainer._train_step_info.fetches=['loss'] loss = trainer.eval_step(val_data, val_targets) trainer._train_step_info.fetches=[] loss, preds = trainer.eval_step(val_data, val_targets) # Compare loss to ground truth computed from current ORTTrainer API _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4) assert trainer._onnx_model is not None
sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) # Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=1e-6) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6) # Dynamic Loss Scaler implemented implicitly @pytest.mark.parametrize("loss_scaler, expected_losses", [ (None, [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934,\ 11.105679512023926, 10.981968879699707, 11.081787109375, 10.997162818908691, 11.107288360595703]), (amp.DynamicLossScaler(), [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172,\ 10.988829612731934, 11.105679512023926, 10.981969833374023, 11.081744194030762, 10.997139930725098, 11.107272148132324]), (CustomLossScaler(), [10.98803424835205, 10.99240493774414, 11.090554237365723, 11.042823791503906, 10.98877239227295,\ 11.105667114257812, 10.981982231140137, 11.081765174865723, 10.997125625610352, 11.107298851013184]) ]) def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model()
learning_rates = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) # Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=1e-6) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6) # Dynamic Loss Scaler implemented implicitly @pytest.mark.parametrize("loss_scaler, expected_losses", [ (None, [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934, 11.105679512023926, 10.981969833374023, 11.08173656463623, 10.997121810913086, 11.10731315612793]), (amp.DynamicLossScaler(), [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934, 11.105679512023926, 10.981969833374023, 11.081737518310547, 10.99714183807373, 11.107304573059082]), (CustomLossScaler(), [10.98803424835205, 10.99240493774414, 11.090554237365723, 11.042823791503906, 10.98877239227295, 11.105667114257812, 10.981982231140137, 11.081765174865723, 10.997125625610352, 11.107298851013184]) ]) def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Modeling model_desc = bert_model_description() model = load_bert_onnx_model()
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, batch_args_option, dataset_len, epochs, use_new_api): dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) if use_new_api: assert use_internal_loss_scaler, 'new api should always use internal loss scaler' new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': fp16, 'loss_scaler': new_api_loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True }, 'distributed': { 'allreduce_post_accumulation': True }, 'lr_scheduler': new_api_lr_scheduler }) param_optimizer = list(model.named_parameters()) params = [{ 'params': [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }, { 'params': [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }] vocab_size = 99 new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'masked_lm_labels', ['batch', 'max_seq_len_in_batch'], ), ('next_sentence_label', [ 'batch', ])], 'outputs': [('loss', [ 1, ], True), ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size]), ('seq_relationship_scores', ['batch', 2])] } optim_config = optim.LambConfig(params=params, lr=2e-5) model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) print("running with new frontend API") else: model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=device, _enable_internal_postprocess=True, gradient_accumulation_steps=gradient_accumulation_steps, # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 world_rank=args.local_rank, world_size=args.world_size, use_mixed_precision=fp16, allreduce_post_accumulation=allreduce_post_accumulation, get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, loss_scaler=loss_scaler if use_internal_loss_scaler else None, _opset_version=12) print("running with old frontend API") # trainig loop eval_batch = None if not use_new_api: model.train() for epoch in range(epochs): for step, batch in enumerate(dataloader): if eval_batch is None: eval_batch = batch if not use_internal_get_lr_this_step: lr = get_lr_this_step(step) learning_rate = torch.tensor([lr]) if not use_internal_loss_scaler and fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) if batch_args_option == BatchArgsOption.List: if not use_internal_get_lr_this_step: batch = batch + [ learning_rate, ] if not use_internal_loss_scaler and fp16: batch = batch + [ loss_scale, ] outputs = model.train_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) print(outputs[0]) # eval if batch_args_option == BatchArgsOption.List: outputs = model.eval_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) outputs = model.eval_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) outputs = model.eval_step(*args, **kwargs) return (output.cpu().numpy() for output in outputs)
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = ( self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 ) else: t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs if self.use_new_api: lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None device = self.args.device.type device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0' options = orttrainer.ORTTrainerOptions({'batch' : { 'gradient_accumulation_steps' : self.args.gradient_accumulation_steps}, 'device': {'id': device}, 'mixed_precision': { 'enabled': self.args.fp16, 'loss_scaler': loss_scaler}, 'debug': {'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': False}, 'distributed': {'allreduce_post_accumulation': True}, 'lr_scheduler': lr_scheduler }) param_optimizer = list(self.model.named_parameters()) params = [{ 'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], "weight_decay_mode": 1, }, { 'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], "weight_decay_mode": 1, } ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options) else: def map_optimizer_attributes(name): no_decay = "bias" in name or "LayerNorm.weight" in name if no_decay: return {"weight_decay_mode" : 1} else: return {"weight_decay_mode" : 1} get_lr_this_step = get_linear_schedule_with_warmup(self.args.warmup_steps, t_total, self.args.learning_rate) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) if self.args.fp16 else None self.model = ORTTrainer(self.model, None, self.model_desc, "AdamOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args.gradient_accumulation_steps, use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, loss_scaler=loss_scaler, enable_grad_norm_clip=False, _opset_version=12, _use_deterministic_compute=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) ): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( global_step == 1 and self.args.logging_first_step ): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps if not self.use_new_api: learning_rate_scalar = get_lr_this_step(global_step) logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)
# Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6) # Dynamic Loss Scaler implemented implicitly @pytest.mark.parametrize( "loss_scaler, expected_losses", [(None, [ 10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934, 11.105679512023926, 10.981969833374023, 11.08173656463623, 10.997121810913086, 11.10731315612793 ]), (amp.DynamicLossScaler(), [ 10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934, 11.105679512023926, 10.981969833374023, 11.081737518310547, 10.99714183807373, 11.107304573059082 ]), (CustomLossScaler(), [ 10.98803424835205, 10.99240493774414, 11.090554237365723, 11.042823791503906, 10.98877239227295, 11.105667114257812, 10.981982231140137, 11.081765174865723, 10.997125625610352, 11.107298851013184 ])]) def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 device = 'cuda'