Example #1
0
def run_ort_training_step(args, global_step, training_steps, model, batch):
    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

    if args.fp16:
        loss_scaler = LossScaler(model.loss_scale_input_name,
                                 True,
                                 up_scale_window=2000)

    lr = get_lr(args, global_step, args.schedule)
    learning_rate = torch.tensor([lr])
    if args.fp16:
        loss_scale = torch.tensor([loss_scaler.loss_scale_])
        loss = model.train_step(input_ids, segment_ids, input_mask,
                                masked_lm_labels, next_sentence_labels,
                                learning_rate, loss_scale)
        all_finite = 1
        if isinstance(loss, (list, tuple)):
            assert len(loss) == 2
            loss, all_finite = loss
    else:
        loss = model(input_ids, segment_ids, input_mask, masked_lm_labels,
                     next_sentence_labels, learning_rate)
    if training_steps % args.gradient_accumulation_steps == 0:
        if args.fp16:
            loss_scaler.update_loss_scale(all_finite.item())
        global_step += 1

    return loss
Example #2
0
def runBertTrainingTest(gradient_accumulation_steps,
                        use_mixed_precision,
                        allreduce_post_accumulation,
                        use_simple_model_desc=True,
                        use_internel_loss_scale=False):
    torch.manual_seed(1)
    onnxruntime.set_seed(1)
  
    loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None

    model, model_desc, device = create_ort_trainer(gradient_accumulation_steps,
                        use_mixed_precision,
                        allreduce_post_accumulation,
                        use_simple_model_desc,
                        loss_scaler)

    if loss_scaler is None:
        loss_scaler = LossScaler(model.loss_scale_input_name, True)

    input_ids_batches = []
    segment_ids_batches = []
    input_mask_batches = []
    masked_lm_labels_batches = []
    next_sentence_labels_batches = []
    batch_size = 16
    num_batches = 8
    for batch in range(num_batches):
        input_ids_batches = [*input_ids_batches, generate_sample_batch(model_desc.inputs_[0], batch_size, device)]
        segment_ids_batches = [*segment_ids_batches, generate_sample_batch(model_desc.inputs_[1], batch_size, device)]
        input_mask_batches = [*input_mask_batches, generate_sample_batch(model_desc.inputs_[2], batch_size, device)]
        masked_lm_labels_batches = [*masked_lm_labels_batches, generate_sample_batch(model_desc.inputs_[3], batch_size, device)]
        next_sentence_labels_batches = [*next_sentence_labels_batches, generate_sample_batch(model_desc.inputs_[4], batch_size, device)]

    lr_batch_list = [0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06, 1.8404908e-06,
                     2.3006135e-06, 2.7607362e-06, 3.2208588e-06, 3.6809815e-06]

    actual_losses = []
    actual_all_finites = []

    for batch_count in range(num_batches):
        input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device)
        segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device)
        input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device)
        masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device)
        next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device)
        lr = lr_batch_list[batch_count]

        learning_rate = torch.tensor([lr]).to(device)
        training_args = [input_ids,
                         segment_ids,
                         input_mask,
                         masked_lm_labels,
                         next_sentence_labels,
                         learning_rate]
        if use_mixed_precision:
            if not use_internel_loss_scale:
                loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
                training_args.append(loss_scale)
            actual_loss = model.train_step(*training_args)
            if isinstance(actual_loss, (list, tuple)):
                assert len(actual_loss) == 2
                actual_loss, actual_all_finite = actual_loss
                if not use_internel_loss_scale:
                    loss_scaler.update_loss_scale(actual_all_finite.item())
                    actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)]

            actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)]
        else:
            loss = model(*training_args)
            actual_losses = [*actual_losses, loss.cpu().numpy().item(0)]

        if batch_count == num_batches - 1:
            # test eval_step api with fetches at the end of the training.
            # if eval_step is called during the training, it will affect the actual training loss (training session is stateful),
            eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss'])
            eval_loss = eval_loss.cpu().numpy().item(0)

    # If using internal loss scale, all_finites are handled internally too.
    if use_mixed_precision and not use_internel_loss_scale:
        return actual_losses, actual_all_finites, eval_loss
    else:
        return actual_losses, eval_loss
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        scheduler = linear_schedule_with_warmup(
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=t_total)

        loss_scaler = LossScaler(
            self.ort_model.loss_scale_input_name,
            True,
            up_scale_window=2000,
            loss_scale=float(1 << 20)) if self.args.fp16 else 1

        model = self.ort_model

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())

        # Train!
        if self.is_world_master():
            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_dataloader.dataset))
            logger.info("  Num Epochs = %d", num_train_epochs)
            logger.info("  Instantaneous batch size per GPU = %d",
                        self.args.per_gpu_train_batch_size)
            logger.info(
                "  Total train batch size (w. parallel, distributed & accumulation) = %d",
                self.args.train_batch_size *
                self.args.gradient_accumulation_steps *
                (self.args.world_size if self.args.local_rank != -1 else 1),
            )
            logger.info("  Gradient Accumulation steps = %d",
                        self.args.gradient_accumulation_steps)
            logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        global_batch_train_start = time.time()

        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if len(inputs['input_ids']
                       ) < self.args.per_gpu_train_batch_size:
                    #skip incomplete batch
                    logger.info('Skipping incomplete batch...')
                    continue

                learning_rate = torch.tensor([
                    scheduler.get_lr_this_step(global_step,
                                               base_lr=self.args.learning_rate)
                ])
                loss, all_finite = self._training_step(model, inputs,
                                                       learning_rate,
                                                       loss_scaler)
                tr_loss += loss
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    if self.args.fp16:
                        loss_scaler.update_loss_scale(all_finite.item())

                    global_step += 1
                    global_batch_train_duration = time.time(
                    ) - global_batch_train_start
                    global_batch_train_start = time.time()

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            loss_avg = (tr_loss - logging_loss) / (
                                self.args.logging_steps *
                                self.args.gradient_accumulation_steps)
                            logs["learning_rate"] = learning_rate.item()
                            logs["loss"] = loss_avg
                            logs["global_step"] = global_step
                            logs[
                                "global_step_time"] = global_batch_train_duration
                            logging_loss = tr_loss

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                                    run.log(k, v)
                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.ort_model
                            else:
                                assert model is self.ort_model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            # self._rotate_checkpoints()

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()
        self.update_torch_model()
        del (self.ort_model)
        self.ort_model = None

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)