def run_ort_training_step(args, global_step, training_steps, model, batch): input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if args.fp16: loss_scaler = LossScaler(model.loss_scale_input_name, True, up_scale_window=2000) lr = get_lr(args, global_step, args.schedule) learning_rate = torch.tensor([lr]) if args.fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) loss = model.train_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate, loss_scale) all_finite = 1 if isinstance(loss, (list, tuple)): assert len(loss) == 2 loss, all_finite = loss else: loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate) if training_steps % args.gradient_accumulation_steps == 0: if args.fp16: loss_scaler.update_loss_scale(all_finite.item()) global_step += 1 return loss
def runBertTrainingTest(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc=True, use_internel_loss_scale=False): torch.manual_seed(1) onnxruntime.set_seed(1) loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None model, model_desc, device = create_ort_trainer(gradient_accumulation_steps, use_mixed_precision, allreduce_post_accumulation, use_simple_model_desc, loss_scaler) if loss_scaler is None: loss_scaler = LossScaler(model.loss_scale_input_name, True) input_ids_batches = [] segment_ids_batches = [] input_mask_batches = [] masked_lm_labels_batches = [] next_sentence_labels_batches = [] batch_size = 16 num_batches = 8 for batch in range(num_batches): input_ids_batches = [*input_ids_batches, generate_sample_batch(model_desc.inputs_[0], batch_size, device)] segment_ids_batches = [*segment_ids_batches, generate_sample_batch(model_desc.inputs_[1], batch_size, device)] input_mask_batches = [*input_mask_batches, generate_sample_batch(model_desc.inputs_[2], batch_size, device)] masked_lm_labels_batches = [*masked_lm_labels_batches, generate_sample_batch(model_desc.inputs_[3], batch_size, device)] next_sentence_labels_batches = [*next_sentence_labels_batches, generate_sample_batch(model_desc.inputs_[4], batch_size, device)] lr_batch_list = [0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06, 1.8404908e-06, 2.3006135e-06, 2.7607362e-06, 3.2208588e-06, 3.6809815e-06] actual_losses = [] actual_all_finites = [] for batch_count in range(num_batches): input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device) segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device) input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device) masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device) next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device) lr = lr_batch_list[batch_count] learning_rate = torch.tensor([lr]).to(device) training_args = [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, learning_rate] if use_mixed_precision: if not use_internel_loss_scale: loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) training_args.append(loss_scale) actual_loss = model.train_step(*training_args) if isinstance(actual_loss, (list, tuple)): assert len(actual_loss) == 2 actual_loss, actual_all_finite = actual_loss if not use_internel_loss_scale: loss_scaler.update_loss_scale(actual_all_finite.item()) actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)] actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] else: loss = model(*training_args) actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] if batch_count == num_batches - 1: # test eval_step api with fetches at the end of the training. # if eval_step is called during the training, it will affect the actual training loss (training session is stateful), eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss']) eval_loss = eval_loss.cpu().numpy().item(0) # If using internal loss scale, all_finites are handled internally too. if use_mixed_precision and not use_internel_loss_scale: return actual_losses, actual_all_finites, eval_loss else: return actual_losses, eval_loss
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs scheduler = linear_schedule_with_warmup( num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total) loss_scaler = LossScaler( self.ort_model.loss_scale_input_name, True, up_scale_window=2000, loss_scale=float(1 << 20)) if self.args.fp16 else 1 model = self.ort_model if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) # Train! if self.is_world_master(): logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (self.args.world_size if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 global_batch_train_start = time.time() train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue if len(inputs['input_ids'] ) < self.args.per_gpu_train_batch_size: #skip incomplete batch logger.info('Skipping incomplete batch...') continue learning_rate = torch.tensor([ scheduler.get_lr_this_step(global_step, base_lr=self.args.learning_rate) ]) loss, all_finite = self._training_step(model, inputs, learning_rate, loss_scaler) tr_loss += loss if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: loss_scaler.update_loss_scale(all_finite.item()) global_step += 1 global_batch_train_duration = time.time( ) - global_batch_train_start global_batch_train_start = time.time() if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} loss_avg = (tr_loss - logging_loss) / ( self.args.logging_steps * self.args.gradient_accumulation_steps) logs["learning_rate"] = learning_rate.item() logs["loss"] = loss_avg logs["global_step"] = global_step logs[ "global_step_time"] = global_batch_train_duration logging_loss = tr_loss if self.tb_writer: for k, v in logs.items(): self.tb_writer.add_scalar( k, v, global_step) run.log(k, v) epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.ort_model else: assert model is self.ort_model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") self.save_model(output_dir) # self._rotate_checkpoints() if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break if self.tb_writer: self.tb_writer.close() self.update_torch_model() del (self.ort_model) self.ort_model = None logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(global_step, tr_loss / global_step)