def run_ort_training_step(args, global_step, training_steps, model, batch): input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if args.fp16: loss_scaler = LossScaler(model.loss_scale_input_name, True, up_scale_window=2000) lr = get_lr(args, global_step, args.schedule) learning_rate = torch.tensor([lr]) if args.fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) loss = model.train_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate, loss_scale) all_finite = 1 if isinstance(loss, (list, tuple)): assert len(loss) == 2 loss, all_finite = loss else: loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate) if training_steps % args.gradient_accumulation_steps == 0: if args.fp16: loss_scaler.update_loss_scale(all_finite.item()) global_step += 1 return loss
def create_ort_trainer(args, device, model): # set GPU memory limitation from onnxruntime.capi._pybind_state import set_cuda_mem_limit ort_cuda_mem_limit_in_gbs = 1 set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 *1024)) # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay = False for no_decay_key in no_decay_keys: if no_decay_key in name: no_decay = True break if no_decay: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} else: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. # train_step does forward, backward, and optimize step. model = ORTTrainer(model, None, bert_model_description(args), "LambOptimizer", map_optimizer_attributes, IODescription('Learning_Rate', [1,], torch.float32), device, _opset_version = 10) if args.fp16: setattr(args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000)) return model
def create_ort_trainer(args, device, model): # set GPU memory limitation (per card!) from onnxruntime.capi._pybind_state import set_cuda_mem_limit ort_cuda_mem_limit_in_gbs = args.gpu_memory_limit_gb set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 * 1024)) # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay = False for no_decay_key in no_decay_keys: if no_decay_key in name: no_decay = True break if no_decay: return { "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 } else: return { "alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6 } # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. # train_step does forward, backward, and optimize step. model = ORTTrainer( model, None, bert_model_description(args), "LambOptimizer", map_optimizer_attributes, IODescription('Learning_Rate', [ 1, ], torch.float32), device, gradient_accumulation_steps=args.gradient_accumulation_steps, world_rank=args.world_rank, world_size=args.world_size, use_mixed_precision=True if args.fp16 else False, allreduce_post_accumulation=True if args.allreduce_post_accumulation else False, deepspeed_zero_stage=1 if args.deepspeed_zero_stage else 0, _opset_version=12) if args.fp16: setattr( args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000)) return model
def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc=True, use_internel_loss_scale=False): torch.manual_seed(1) onnxruntime.set_seed(1) loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None model, model_desc, device = create_ort_trainer(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc, loss_scaler) if loss_scaler is None: loss_scaler = LossScaler(model.loss_scale_input_name, True) input_ids_batches = [] segment_ids_batches = [] input_mask_batches = [] masked_lm_labels_batches = [] next_sentence_labels_batches = [] batch_size = 16 num_batches = 8 for batch in range(num_batches): input_ids_batches = [*input_ids_batches, generate_sample_batch(model_desc.inputs_[0], batch_size, device)] segment_ids_batches = [*segment_ids_batches, generate_sample_batch(model_desc.inputs_[1], batch_size, device)] input_mask_batches = [*input_mask_batches, generate_sample_batch(model_desc.inputs_[2], batch_size, device)] masked_lm_labels_batches = [*masked_lm_labels_batches, generate_sample_batch(model_desc.inputs_[3], batch_size, device)] next_sentence_labels_batches = [*next_sentence_labels_batches, generate_sample_batch(model_desc.inputs_[4], batch_size, device)] lr_batch_list = [0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06, 1.8404908e-06, 2.3006135e-06, 2.7607362e-06, 3.2208588e-06, 3.6809815e-06] actual_losses = [] actual_all_finites = [] for batch_count in range(num_batches): input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device) segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device) input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device) masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device) next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device) lr = lr_batch_list[batch_count] learning_rate = torch.tensor([lr]).to(device) training_args = [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate] if use_mixed_precision: if not use_internel_loss_scale: loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) training_args.append(loss_scale) actual_loss = model.train_step(*training_args) if isinstance(actual_loss, (list, tuple)): assert len(actual_loss) == 2 actual_loss, actual_all_finite = actual_loss if not use_internel_loss_scale: loss_scaler.update_loss_scale(actual_all_finite.item()) actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)] actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] else: loss = model(*training_args) actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] if batch_count == num_batches - 1: # test eval_step api with fetches at the end of the training. # if eval_step is called during the training, it will affect the actual training loss (training session is stateful), eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss']) eval_loss = eval_loss.cpu().numpy().item(0) # If using internal loss scale, all_finites are handled internally too. if use_mixed_precision and not use_internel_loss_scale: return actual_losses, actual_all_finites, eval_loss else: return actual_losses, eval_loss
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs get_lr_this_step = get_linear_schedule_with_warmup( self.args.warmup_steps, t_total, self.args.learning_rate) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) def map_optimizer_attributes(name): # no_decay_keys = ["bias", "LayerNorm.weight"] no_decay = "bias" in name or "LayerNorm.weight" in name if no_decay: return {"weight_decay": 0.0, "weight_decay_mode": 1} else: return { "weight_decay": self.args.weight_decay, "weight_decay_mode": 1 } self.model = ORTTrainer( self.model, None, self.model_desc, "AdamOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args.gradient_accumulation_steps, world_rank=0, world_size=1, # only support single GPU cases use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, loss_scaler=loss_scaler, enable_grad_norm_clip=False, _opset_version=12, _use_deterministic_compute=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss ) / self.args.logging_steps learning_rate_scalar = get_lr_this_step( global_step) logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) model = BertForPreTraining(config=config) model.eval() loss, prediction_scores, seq_relationship_score = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, next_sentence_label=sequence_labels) model_desc = ModelDescription([ self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc ], [ self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc ]) from collections import namedtuple MyArgs = namedtuple( "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" ) args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7) def get_lr_this_step(global_step): return get_lr(args, global_step) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) # It would be better to test both with/without mixed precision and allreduce_post_accumulation. # However, stress test of all the 4 cases is not stable at lease on the test machine. # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. option_fp16 = [True] option_allreduce_post_accumulation = [True] option_gradient_accumulation_steps = [1, 8] option_use_internal_get_lr_this_step = [True, False] option_use_internal_loss_scaler = [True, False] option_split_batch = [BatchArgsOption.ListAndDict] for fp16 in option_fp16: for allreduce_post_accumulation in option_allreduce_post_accumulation: for gradient_accumulation_steps in option_gradient_accumulation_steps: for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: for use_internal_loss_scaler in option_use_internal_loss_scaler: for split_batch in option_split_batch: print("gradient_accumulation_steps:", gradient_accumulation_steps) print("use_internal_loss_scaler:", use_internal_loss_scaler) loss_ort, prediction_scores_ort, seq_relationship_score_ort =\ run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, split_batch) print(loss_ort) print(prediction_scores_ort) print(seq_relationship_score_ort)
def create_and_check_bert_for_pretraining( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, option_split_batch, option_use_internal_get_lr_this_step=[True], option_use_internal_loss_scaler=[True], ): seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) model = BertForPreTraining(config=config) model.eval() loss, prediction_scores, seq_relationship_score = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, next_sentence_label=sequence_labels, ) model_desc = ModelDescription( [ self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc, ], [ self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc ], ) from collections import namedtuple MyArgs = namedtuple( "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" ) dataset_len = 100 epochs = 8 max_steps = epochs * dataset_len args = MyArgs( local_rank=0, world_size=1, max_steps=max_steps, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7, ) def get_lr_this_step(global_step): return get_lr(args, global_step) loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000) for fp16 in option_fp16: for allreduce_post_accumulation in option_allreduce_post_accumulation: for gradient_accumulation_steps in option_gradient_accumulation_steps: for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: for use_internal_loss_scaler in option_use_internal_loss_scaler: for split_batch in option_split_batch: print("gradient_accumulation_steps:", gradient_accumulation_steps) print("split_batch:", split_batch) seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) ( old_api_loss_ort, old_api_prediction_scores_ort, old_api_seq_relationship_score_ort, ) = run_test( model, model_desc, self.device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, split_batch, dataset_len, epochs, use_new_api=False, ) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) if use_internal_get_lr_this_step and use_internal_loss_scaler: ( new_api_loss_ort, new_api_prediction_scores_ort, new_api_seq_relationship_score_ort, ) = run_test( model, model_desc, self.device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, split_batch, dataset_len, epochs, use_new_api=True, ) assert_allclose( old_api_loss_ort, new_api_loss_ort) assert_allclose( old_api_prediction_scores_ort, new_api_prediction_scores_ort) assert_allclose( old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort)
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = ( self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 ) else: t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs if self.use_new_api: lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None device = self.args.device.type device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0' options = orttrainer.ORTTrainerOptions({'batch' : { 'gradient_accumulation_steps' : self.args.gradient_accumulation_steps}, 'device': {'id': device}, 'mixed_precision': { 'enabled': self.args.fp16, 'loss_scaler': loss_scaler}, 'debug': {'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': False}, 'distributed': {'allreduce_post_accumulation': True}, 'lr_scheduler': lr_scheduler }) param_optimizer = list(self.model.named_parameters()) params = [{ 'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], "weight_decay_mode": 1, }, { 'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], "weight_decay_mode": 1, } ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options) else: def map_optimizer_attributes(name): no_decay = "bias" in name or "LayerNorm.weight" in name if no_decay: return {"weight_decay_mode" : 1} else: return {"weight_decay_mode" : 1} get_lr_this_step = get_linear_schedule_with_warmup(self.args.warmup_steps, t_total, self.args.learning_rate) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) if self.args.fp16 else None self.model = ORTTrainer(self.model, None, self.model_desc, "AdamOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args.gradient_accumulation_steps, use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, loss_scaler=loss_scaler, enable_grad_norm_clip=False, _opset_version=12, _use_deterministic_compute=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) ): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( global_step == 1 and self.args.logging_first_step ): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps if not self.use_new_api: learning_rate_scalar = get_lr_this_step(global_step) logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs scheduler = linear_schedule_with_warmup( num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total) loss_scaler = LossScaler( self.ort_model.loss_scale_input_name, True, up_scale_window=2000, loss_scale=float(1 << 20)) if self.args.fp16 else 1 model = self.ort_model if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) # Train! if self.is_world_master(): logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (self.args.world_size if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 global_batch_train_start = time.time() train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue if len(inputs['input_ids'] ) < self.args.per_gpu_train_batch_size: #skip incomplete batch logger.info('Skipping incomplete batch...') continue learning_rate = torch.tensor([ scheduler.get_lr_this_step(global_step, base_lr=self.args.learning_rate) ]) loss, all_finite = self._training_step(model, inputs, learning_rate, loss_scaler) tr_loss += loss if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: loss_scaler.update_loss_scale(all_finite.item()) global_step += 1 global_batch_train_duration = time.time( ) - global_batch_train_start global_batch_train_start = time.time() if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} loss_avg = (tr_loss - logging_loss) / ( self.args.logging_steps * self.args.gradient_accumulation_steps) logs["learning_rate"] = learning_rate.item() logs["loss"] = loss_avg logs["global_step"] = global_step logs[ "global_step_time"] = global_batch_train_duration logging_loss = tr_loss if self.tb_writer: for k, v in logs.items(): self.tb_writer.add_scalar( k, v, global_step) run.log(k, v) epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.ort_model else: assert model is self.ort_model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") self.save_model(output_dir) # self._rotate_checkpoints() if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break if self.tb_writer: self.tb_writer.close() self.update_torch_model() del (self.ort_model) self.ort_model = None logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(global_step, tr_loss / global_step)
def create_ort_trainer(args, device, model): # set GPU memory limitation from onnxruntime.capi._pybind_state import set_cuda_mem_limit set_cuda_mem_limit(int(args.ort_cuda_mem_limit_in_gbs * 1024 * 1024 * 1024)) def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay = False for no_decay_key in no_decay_keys: if no_decay_key in name: no_decay = True break if no_decay: return { "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-8, #if self.optimizer == 'LambOptimizer' else 1e-8, # Adam optimizer mode # 0: pytorch's Adamw # 1: Huggface's Adamw "weight_decay_mode": 0, "do_bias_correction": 1 } else: return { "alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-8, #if self.optimizer == 'LambOptimizer' else 1e-8, # Adam optimizer mode # 0: pytorch's Adamw # 1: Huggface's Adamw "weight_decay_mode": 0, "do_bias_correction": 1 } #print('Creating ORTTrainer') # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. # train_step does forward, backward, and optimize step. model = ORTTrainer( model, None, bart_model_description(args), "AdamOptimizer", map_optimizer_attributes, IODescription('Learning_Rate', [ 1, ], torch.float32), device, #_extra_postprocess=postprocess_model, gradient_accumulation_steps=args.update_freq[0], world_rank=args.distributed_rank, world_size=args.distributed_world_size, use_mixed_precision=True if args.fp16 else False, allreduce_post_accumulation= True, #if args.allreduce_post_accumulation else False, #partition_optimizer = False, #if args.partition_optimizer else False, _opset_version=12) #print('Created ORTTrainer') if args.fp16: setattr( args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000)) return model
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForPreTraining(config=config) model.eval() loss, prediction_scores, seq_relationship_score = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, next_sentence_label=sequence_labels) model_desc = ModelDescription([ self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc ], [ self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc ]) import argparse args_ = argparse.Namespace(fp16=True, amp_opt_level='O1') from collections import namedtuple MyArgs = namedtuple( "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" ) args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7) from train_with_ort_trainer import get_lr def get_lr_this_step(global_step): return get_lr(args, global_step) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) option_gradient_accumulation_steps = [8] option_fp16 = [True, False] option_allreduce_post_accumulation = True option_use_internal_get_lr_this_step = False option_use_internal_loss_scaler = False # TODO: with with fetches for gradient_accumulation_steps in option_gradient_accumulation_steps: for fp16 in option_fp16: for option_split_batch in BatchArgsOption: loss_ort, prediction_scores_ort, seq_relationship_score_ort =\ run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16, option_allreduce_post_accumulation, get_lr_this_step, option_use_internal_get_lr_this_step, loss_scaler, option_use_internal_loss_scaler, option_split_batch) print(loss_ort) print(prediction_scores_ort) print(seq_relationship_score_ort)