Esempio n. 1
0
class ORTTransformerTrainer:
    """
    """

    model: PreTrainedModel
    args: TrainingArguments
    train_dataset: Dataset
    eval_dataset: Dataset
    compute_metrics: Callable[[EvalPrediction], Dict]

    def __init__(
        self,
        model: PreTrainedModel,
        model_desc: ModelDescription,
        new_model_desc: dict,
        args: TrainingArguments,
        train_dataset: Dataset,
        eval_dataset: Dataset,
        compute_metrics: Callable[[EvalPrediction], Dict],
        use_new_api : Optional[bool] = False,
    ):
        """
        """

        self.model = model
        self.model_desc = model_desc
        self.new_model_desc = new_model_desc
        self.args = args
        self.data_collator = DefaultDataCollator()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.args.local_rank in [-1, 0]:
            os.makedirs(self.args.output_dir, exist_ok=True)

        self.use_new_api = use_new_api

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = (
            SequentialSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset)
        )
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_eval_dataloader(self) -> DataLoader:
        return DataLoader(
            self.eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )


    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        if self.use_new_api:
            lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total))

            loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
            device = self.args.device.type
            device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0'
            options = orttrainer.ORTTrainerOptions({'batch' : {
                                                        'gradient_accumulation_steps' : self.args.gradient_accumulation_steps},
                                                    'device': {'id': device},
                                                    'mixed_precision': {
                                                        'enabled': self.args.fp16,
                                                        'loss_scaler': loss_scaler},
                                                    'debug': {'deterministic_compute': True, },
                                                    'utils': {
                                                        'grad_norm_clip': False},
                                                    'distributed': {'allreduce_post_accumulation': True},
                                                    'lr_scheduler': lr_scheduler
                                                    })

            param_optimizer = list(self.model.named_parameters())
            params = [{
                'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n],
                "weight_decay_mode": 1, }, {
                'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)],
                "weight_decay_mode": 1, }
                ]

            optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
            self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options)
        else:
            def map_optimizer_attributes(name):
                no_decay = "bias" in name or "LayerNorm.weight" in name
                if no_decay:
                    return {"weight_decay_mode" : 1}
                else:
                    return {"weight_decay_mode" : 1}
            get_lr_this_step = get_linear_schedule_with_warmup(self.args.warmup_steps, t_total, self.args.learning_rate)
            loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) if self.args.fp16 else None
            self.model = ORTTrainer(self.model, None,
                self.model_desc,
                "AdamOptimizer",
                map_optimizer_attributes=map_optimizer_attributes,
                learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
                device=self.args.device,
                gradient_accumulation_steps=self.args.gradient_accumulation_steps,
                use_mixed_precision=self.args.fp16,
                allreduce_post_accumulation=True,
                get_lr_this_step=get_lr_this_step,
                loss_scaler=loss_scaler,
                enable_grad_norm_clip=False,
                _opset_version=12,
                _use_deterministic_compute=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size
            * self.args.gradient_accumulation_steps
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (
                            global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps
                            if not self.use_new_api:
                                learning_rate_scalar = get_lr_this_step(global_step)
                                logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)

    def _training_step(
        self, model, inputs: Dict[str, torch.Tensor]) -> float:
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model.train_step(**inputs)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

        return loss.item()

    def save_model(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx"))

    def evaluate(self) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader()

        output = self._prediction_loop(eval_dataloader, description="Evaluation")
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
        self, dataloader: DataLoader, description: str
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", len(dataloader.dataset))
        logger.info("  Batch size = %d", dataloader.batch_size)
        eval_losses: List[float] = []
        preds: np.ndarray = None
        label_ids: np.ndarray = None

        if not self.use_new_api:
            self.model.eval()

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                if self.use_new_api:
                    outputs = self.model.eval_step(**inputs)
                else:
                    outputs = self.model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if preds is None:
                preds = logits.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            if inputs.get("labels") is not None:
                if label_ids is None:
                    label_ids = inputs["labels"].detach().cpu().numpy()
                else:
                    label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["loss"] = np.mean(eval_losses)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
Esempio n. 2
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=14,
            _use_deterministic_compute=True)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)
def runBertTrainingTest(gradient_accumulation_steps,
                        use_mixed_precision,
                        allreduce_post_accumulation,
                        use_simple_model_desc=True,
                        use_internel_loss_scale=False):
    model_desc = bert_model_description()
    simple_model_desc = remove_extra_info(
        model_desc) if use_simple_model_desc else model_desc
    learning_rate_description = ort_trainer_learning_rate_description()
    device = torch.device("cuda", 0)

    torch.manual_seed(1)
    onnxruntime.set_seed(1)

    onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx"))

    loss_scaler = LossScaler("ort_test_input_loss_scalar",
                             True) if use_internel_loss_scale else None

    model = ORTTrainer(onnx_model,
                       None,
                       simple_model_desc,
                       "LambOptimizer",
                       map_optimizer_attributes,
                       learning_rate_description,
                       device,
                       postprocess_model=None,
                       gradient_accumulation_steps=gradient_accumulation_steps,
                       world_rank=0,
                       world_size=1,
                       loss_scaler=loss_scaler,
                       use_mixed_precision=use_mixed_precision,
                       allreduce_post_accumulation=allreduce_post_accumulation)

    if loss_scaler is None:
        loss_scaler = LossScaler(model.loss_scale_input_name, True)

    input_ids_batches = []
    segment_ids_batches = []
    input_mask_batches = []
    masked_lm_labels_batches = []
    next_sentence_labels_batches = []
    batch_size = 16
    num_batches = 8
    for batch in range(num_batches):
        input_ids_batches = [
            *input_ids_batches,
            generate_sample_batch(model_desc.inputs_[0], batch_size, device)
        ]
        segment_ids_batches = [
            *segment_ids_batches,
            generate_sample_batch(model_desc.inputs_[1], batch_size, device)
        ]
        input_mask_batches = [
            *input_mask_batches,
            generate_sample_batch(model_desc.inputs_[2], batch_size, device)
        ]
        masked_lm_labels_batches = [
            *masked_lm_labels_batches,
            generate_sample_batch(model_desc.inputs_[3], batch_size, device)
        ]
        next_sentence_labels_batches = [
            *next_sentence_labels_batches,
            generate_sample_batch(model_desc.inputs_[4], batch_size, device)
        ]

    lr_batch_list = [
        0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06,
        1.8404908e-06, 2.3006135e-06, 2.7607362e-06, 3.2208588e-06,
        3.6809815e-06
    ]

    actual_losses = []
    actual_all_finites = []

    for batch_count in range(num_batches):
        input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size,
                                          device)
        segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size,
                                            device)
        input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size,
                                           device)
        masked_lm_labels = generate_sample_batch(model_desc.inputs_[3],
                                                 batch_size, device)
        next_sentence_labels = generate_sample_batch(model_desc.inputs_[4],
                                                     batch_size, device)
        lr = lr_batch_list[batch_count]

        learning_rate = torch.tensor([lr]).to(device)
        training_args = [
            input_ids, segment_ids, input_mask, masked_lm_labels,
            next_sentence_labels, learning_rate
        ]
        if use_mixed_precision:
            if not use_internel_loss_scale:
                loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
                training_args.append(loss_scale)
            actual_loss = model.train_step(*training_args)
            if isinstance(actual_loss, (list, tuple)):
                assert len(actual_loss) == 2
                actual_loss, actual_all_finite = actual_loss
                if not use_internel_loss_scale:
                    loss_scaler.update_loss_scale(actual_all_finite.item())
                    actual_all_finites = [
                        *actual_all_finites,
                        actual_all_finite.cpu().numpy().item(0)
                    ]

            actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)]
        else:
            loss = model(*training_args)
            actual_losses = [*actual_losses, loss.cpu().numpy().item(0)]

        if batch_count == num_batches - 1:
            # test eval_step api with fetches at the end of the training.
            # if eval_step is called during the training, it will affect the actual training loss (training session is stateful),
            eval_loss = model.eval_step(input_ids,
                                        segment_ids,
                                        input_mask,
                                        masked_lm_labels,
                                        next_sentence_labels,
                                        fetches=['loss'])
            eval_loss = eval_loss.cpu().numpy().item(0)

    # If using internal loss scale, all_finites are handled internally too.
    if use_mixed_precision and not use_internel_loss_scale:
        return actual_losses, actual_all_finites, eval_loss
    else:
        return actual_losses, eval_loss