Exemple #1
0
def run_ort_training_step(args, global_step, training_steps, model, batch):
    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

    if args.fp16:
        loss_scaler = LossScaler(model.loss_scale_input_name,
                                 True,
                                 up_scale_window=2000)

    lr = get_lr(args, global_step, args.schedule)
    learning_rate = torch.tensor([lr])
    if args.fp16:
        loss_scale = torch.tensor([loss_scaler.loss_scale_])
        loss = model.train_step(input_ids, segment_ids, input_mask,
                                masked_lm_labels, next_sentence_labels,
                                learning_rate, loss_scale)
        all_finite = 1
        if isinstance(loss, (list, tuple)):
            assert len(loss) == 2
            loss, all_finite = loss
    else:
        loss = model(input_ids, segment_ids, input_mask, masked_lm_labels,
                     next_sentence_labels, learning_rate)
    if training_steps % args.gradient_accumulation_steps == 0:
        if args.fp16:
            loss_scaler.update_loss_scale(all_finite.item())
        global_step += 1

    return loss
Exemple #2
0
def create_ort_trainer(args, device, model):
    # set GPU memory limitation
    from onnxruntime.capi._pybind_state import set_cuda_mem_limit
    ort_cuda_mem_limit_in_gbs = 1
    set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 *1024))

    # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
    def map_optimizer_attributes(name):
        no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
        no_decay = False
        for no_decay_key in no_decay_keys:
            if no_decay_key in name:
                no_decay = True
                break
        if no_decay:
            return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
        else:
            return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}

    # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. 
    # train_step does forward, backward, and optimize step.
    model = ORTTrainer(model, None, bert_model_description(args), "LambOptimizer", 
        map_optimizer_attributes,
        IODescription('Learning_Rate', [1,], torch.float32),
        device,
        _opset_version = 10)

    if args.fp16:
        setattr(args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000))

    return model
def create_ort_trainer(args, device, model):

    # set GPU memory limitation (per card!)
    from onnxruntime.capi._pybind_state import set_cuda_mem_limit
    ort_cuda_mem_limit_in_gbs = args.gpu_memory_limit_gb
    set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 * 1024))

    # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
    def map_optimizer_attributes(name):
        no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
        no_decay = False
        for no_decay_key in no_decay_keys:
            if no_decay_key in name:
                no_decay = True
                break
        if no_decay:
            return {
                "alpha": 0.9,
                "beta": 0.999,
                "lambda": 0.0,
                "epsilon": 1e-6
            }
        else:
            return {
                "alpha": 0.9,
                "beta": 0.999,
                "lambda": 0.01,
                "epsilon": 1e-6
            }

    # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes.
    # train_step does forward, backward, and optimize step.
    model = ORTTrainer(
        model,
        None,
        bert_model_description(args),
        "LambOptimizer",
        map_optimizer_attributes,
        IODescription('Learning_Rate', [
            1,
        ], torch.float32),
        device,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        world_rank=args.world_rank,
        world_size=args.world_size,
        use_mixed_precision=True if args.fp16 else False,
        allreduce_post_accumulation=True
        if args.allreduce_post_accumulation else False,
        deepspeed_zero_stage=1 if args.deepspeed_zero_stage else 0,
        _opset_version=12)

    if args.fp16:
        setattr(
            args, 'ort_loss_scale',
            LossScaler(model.loss_scale_input_name, True,
                       up_scale_window=2000))

    return model
Exemple #4
0
def runBertTrainingTest(gradient_accumulation_steps,
                        use_mixed_precision,
                        allreduce_post_accumulation,
                        use_simple_model_desc=True,
                        use_internel_loss_scale=False):
    torch.manual_seed(1)
    onnxruntime.set_seed(1)
  
    loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None

    model, model_desc, device = create_ort_trainer(gradient_accumulation_steps,
                        use_mixed_precision,
                        allreduce_post_accumulation,
                        use_simple_model_desc,
                        loss_scaler)

    if loss_scaler is None:
        loss_scaler = LossScaler(model.loss_scale_input_name, True)

    input_ids_batches = []
    segment_ids_batches = []
    input_mask_batches = []
    masked_lm_labels_batches = []
    next_sentence_labels_batches = []
    batch_size = 16
    num_batches = 8
    for batch in range(num_batches):
        input_ids_batches = [*input_ids_batches, generate_sample_batch(model_desc.inputs_[0], batch_size, device)]
        segment_ids_batches = [*segment_ids_batches, generate_sample_batch(model_desc.inputs_[1], batch_size, device)]
        input_mask_batches = [*input_mask_batches, generate_sample_batch(model_desc.inputs_[2], batch_size, device)]
        masked_lm_labels_batches = [*masked_lm_labels_batches, generate_sample_batch(model_desc.inputs_[3], batch_size, device)]
        next_sentence_labels_batches = [*next_sentence_labels_batches, generate_sample_batch(model_desc.inputs_[4], batch_size, device)]

    lr_batch_list = [0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06, 1.8404908e-06,
                     2.3006135e-06, 2.7607362e-06, 3.2208588e-06, 3.6809815e-06]

    actual_losses = []
    actual_all_finites = []

    for batch_count in range(num_batches):
        input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device)
        segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device)
        input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device)
        masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device)
        next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device)
        lr = lr_batch_list[batch_count]

        learning_rate = torch.tensor([lr]).to(device)
        training_args = [input_ids,
                         segment_ids,
                         input_mask,
                         masked_lm_labels,
                         next_sentence_labels,
                         learning_rate]
        if use_mixed_precision:
            if not use_internel_loss_scale:
                loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
                training_args.append(loss_scale)
            actual_loss = model.train_step(*training_args)
            if isinstance(actual_loss, (list, tuple)):
                assert len(actual_loss) == 2
                actual_loss, actual_all_finite = actual_loss
                if not use_internel_loss_scale:
                    loss_scaler.update_loss_scale(actual_all_finite.item())
                    actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)]

            actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)]
        else:
            loss = model(*training_args)
            actual_losses = [*actual_losses, loss.cpu().numpy().item(0)]

        if batch_count == num_batches - 1:
            # test eval_step api with fetches at the end of the training.
            # if eval_step is called during the training, it will affect the actual training loss (training session is stateful),
            eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss'])
            eval_loss = eval_loss.cpu().numpy().item(0)

    # If using internal loss scale, all_finites are handled internally too.
    if use_mixed_precision and not use_internel_loss_scale:
        return actual_losses, actual_all_finites, eval_loss
    else:
        return actual_losses, eval_loss
Exemple #5
0
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        get_lr_this_step = get_linear_schedule_with_warmup(
            self.args.warmup_steps, t_total, self.args.learning_rate)
        loss_scaler = LossScaler('loss_scale_input_name',
                                 True,
                                 up_scale_window=2000)

        def map_optimizer_attributes(name):
            # no_decay_keys = ["bias", "LayerNorm.weight"]
            no_decay = "bias" in name or "LayerNorm.weight" in name
            if no_decay:
                return {"weight_decay": 0.0, "weight_decay_mode": 1}
            else:
                return {
                    "weight_decay": self.args.weight_decay,
                    "weight_decay_mode": 1
                }

        self.model = ORTTrainer(
            self.model,
            None,
            self.model_desc,
            "AdamOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=self.args.device,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            world_rank=0,
            world_size=1,  # only support single GPU cases
            use_mixed_precision=self.args.fp16,
            allreduce_post_accumulation=True,
            get_lr_this_step=get_lr_this_step,
            loss_scaler=loss_scaler,
            enable_grad_norm_clip=False,
            _opset_version=12,
            _use_deterministic_compute=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps
                            learning_rate_scalar = get_lr_this_step(
                                global_step)
                            logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)
Exemple #6
0
        def create_and_check_bert_for_pretraining(self, config, input_ids,
                                                  token_type_ids, input_mask,
                                                  sequence_labels,
                                                  token_labels, choice_labels):
            seed = 42
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            onnxruntime.set_seed(seed)

            model = BertForPreTraining(config=config)
            model.eval()
            loss, prediction_scores, seq_relationship_score = model(
                input_ids,
                attention_mask=input_mask,
                token_type_ids=token_type_ids,
                masked_lm_labels=token_labels,
                next_sentence_label=sequence_labels)
            model_desc = ModelDescription([
                self.input_ids_desc, self.attention_mask_desc,
                self.token_type_ids_desc, self.masked_lm_labels_desc,
                self.next_sentence_label_desc
            ], [
                self.loss_desc, self.prediction_scores_desc,
                self.seq_relationship_scores_desc
            ])

            from collections import namedtuple
            MyArgs = namedtuple(
                "MyArgs",
                "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
            )
            args = MyArgs(local_rank=0,
                          world_size=1,
                          max_steps=100,
                          learning_rate=0.00001,
                          warmup_proportion=0.01,
                          batch_size=13,
                          seq_len=7)

            def get_lr_this_step(global_step):
                return get_lr(args, global_step)

            loss_scaler = LossScaler('loss_scale_input_name',
                                     True,
                                     up_scale_window=2000)

            # It would be better to test both with/without mixed precision and allreduce_post_accumulation.
            # However, stress test of all the 4 cases is not stable at lease on the test machine.
            # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
            option_fp16 = [True]
            option_allreduce_post_accumulation = [True]
            option_gradient_accumulation_steps = [1, 8]
            option_use_internal_get_lr_this_step = [True, False]
            option_use_internal_loss_scaler = [True, False]
            option_split_batch = [BatchArgsOption.ListAndDict]

            for fp16 in option_fp16:
                for allreduce_post_accumulation in option_allreduce_post_accumulation:
                    for gradient_accumulation_steps in option_gradient_accumulation_steps:
                        for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step:
                            for use_internal_loss_scaler in option_use_internal_loss_scaler:
                                for split_batch in option_split_batch:
                                    print("gradient_accumulation_steps:",
                                          gradient_accumulation_steps)
                                    print("use_internal_loss_scaler:",
                                          use_internal_loss_scaler)
                                    loss_ort, prediction_scores_ort, seq_relationship_score_ort =\
                                        run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16,
                                                allreduce_post_accumulation,
                                                get_lr_this_step, use_internal_get_lr_this_step,
                                                loss_scaler, use_internal_loss_scaler,
                                                split_batch)

                                    print(loss_ort)
                                    print(prediction_scores_ort)
                                    print(seq_relationship_score_ort)
Exemple #7
0
        def create_and_check_bert_for_pretraining(
            self,
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            option_fp16,
            option_allreduce_post_accumulation,
            option_gradient_accumulation_steps,
            option_split_batch,
            option_use_internal_get_lr_this_step=[True],
            option_use_internal_loss_scaler=[True],
        ):
            seed = 42
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            onnxruntime.set_seed(seed)

            model = BertForPreTraining(config=config)
            model.eval()
            loss, prediction_scores, seq_relationship_score = model(
                input_ids,
                attention_mask=input_mask,
                token_type_ids=token_type_ids,
                masked_lm_labels=token_labels,
                next_sentence_label=sequence_labels,
            )
            model_desc = ModelDescription(
                [
                    self.input_ids_desc,
                    self.attention_mask_desc,
                    self.token_type_ids_desc,
                    self.masked_lm_labels_desc,
                    self.next_sentence_label_desc,
                ],
                [
                    self.loss_desc, self.prediction_scores_desc,
                    self.seq_relationship_scores_desc
                ],
            )

            from collections import namedtuple

            MyArgs = namedtuple(
                "MyArgs",
                "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
            )

            dataset_len = 100
            epochs = 8
            max_steps = epochs * dataset_len
            args = MyArgs(
                local_rank=0,
                world_size=1,
                max_steps=max_steps,
                learning_rate=0.00001,
                warmup_proportion=0.01,
                batch_size=13,
                seq_len=7,
            )

            def get_lr_this_step(global_step):
                return get_lr(args, global_step)

            loss_scaler = LossScaler("loss_scale_input_name",
                                     True,
                                     up_scale_window=2000)

            for fp16 in option_fp16:
                for allreduce_post_accumulation in option_allreduce_post_accumulation:
                    for gradient_accumulation_steps in option_gradient_accumulation_steps:
                        for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step:
                            for use_internal_loss_scaler in option_use_internal_loss_scaler:
                                for split_batch in option_split_batch:
                                    print("gradient_accumulation_steps:",
                                          gradient_accumulation_steps)
                                    print("split_batch:", split_batch)

                                    seed = 42
                                    random.seed(seed)
                                    np.random.seed(seed)
                                    torch.manual_seed(seed)
                                    torch.cuda.manual_seed_all(seed)
                                    onnxruntime.set_seed(seed)

                                    (
                                        old_api_loss_ort,
                                        old_api_prediction_scores_ort,
                                        old_api_seq_relationship_score_ort,
                                    ) = run_test(
                                        model,
                                        model_desc,
                                        self.device,
                                        args,
                                        gradient_accumulation_steps,
                                        fp16,
                                        allreduce_post_accumulation,
                                        get_lr_this_step,
                                        use_internal_get_lr_this_step,
                                        loss_scaler,
                                        use_internal_loss_scaler,
                                        split_batch,
                                        dataset_len,
                                        epochs,
                                        use_new_api=False,
                                    )

                                    random.seed(seed)
                                    np.random.seed(seed)
                                    torch.manual_seed(seed)
                                    torch.cuda.manual_seed_all(seed)
                                    onnxruntime.set_seed(seed)
                                    if use_internal_get_lr_this_step and use_internal_loss_scaler:
                                        (
                                            new_api_loss_ort,
                                            new_api_prediction_scores_ort,
                                            new_api_seq_relationship_score_ort,
                                        ) = run_test(
                                            model,
                                            model_desc,
                                            self.device,
                                            args,
                                            gradient_accumulation_steps,
                                            fp16,
                                            allreduce_post_accumulation,
                                            get_lr_this_step,
                                            use_internal_get_lr_this_step,
                                            loss_scaler,
                                            use_internal_loss_scaler,
                                            split_batch,
                                            dataset_len,
                                            epochs,
                                            use_new_api=True,
                                        )

                                        assert_allclose(
                                            old_api_loss_ort, new_api_loss_ort)
                                        assert_allclose(
                                            old_api_prediction_scores_ort,
                                            new_api_prediction_scores_ort)
                                        assert_allclose(
                                            old_api_seq_relationship_score_ort,
                                            new_api_seq_relationship_score_ort)
Exemple #8
0
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        if self.use_new_api:
            lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total))

            loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
            device = self.args.device.type
            device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0'
            options = orttrainer.ORTTrainerOptions({'batch' : {
                                                        'gradient_accumulation_steps' : self.args.gradient_accumulation_steps},
                                                    'device': {'id': device},
                                                    'mixed_precision': {
                                                        'enabled': self.args.fp16,
                                                        'loss_scaler': loss_scaler},
                                                    'debug': {'deterministic_compute': True, },
                                                    'utils': {
                                                        'grad_norm_clip': False},
                                                    'distributed': {'allreduce_post_accumulation': True},
                                                    'lr_scheduler': lr_scheduler
                                                    })

            param_optimizer = list(self.model.named_parameters())
            params = [{
                'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n],
                "weight_decay_mode": 1, }, {
                'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)],
                "weight_decay_mode": 1, }
                ]

            optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
            self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options)
        else:
            def map_optimizer_attributes(name):
                no_decay = "bias" in name or "LayerNorm.weight" in name
                if no_decay:
                    return {"weight_decay_mode" : 1}
                else:
                    return {"weight_decay_mode" : 1}
            get_lr_this_step = get_linear_schedule_with_warmup(self.args.warmup_steps, t_total, self.args.learning_rate)
            loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) if self.args.fp16 else None
            self.model = ORTTrainer(self.model, None,
                self.model_desc,
                "AdamOptimizer",
                map_optimizer_attributes=map_optimizer_attributes,
                learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
                device=self.args.device,
                gradient_accumulation_steps=self.args.gradient_accumulation_steps,
                use_mixed_precision=self.args.fp16,
                allreduce_post_accumulation=True,
                get_lr_this_step=get_lr_this_step,
                loss_scaler=loss_scaler,
                enable_grad_norm_clip=False,
                _opset_version=12,
                _use_deterministic_compute=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size
            * self.args.gradient_accumulation_steps
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (
                            global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps
                            if not self.use_new_api:
                                learning_rate_scalar = get_lr_this_step(global_step)
                                logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        scheduler = linear_schedule_with_warmup(
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=t_total)

        loss_scaler = LossScaler(
            self.ort_model.loss_scale_input_name,
            True,
            up_scale_window=2000,
            loss_scale=float(1 << 20)) if self.args.fp16 else 1

        model = self.ort_model

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())

        # Train!
        if self.is_world_master():
            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_dataloader.dataset))
            logger.info("  Num Epochs = %d", num_train_epochs)
            logger.info("  Instantaneous batch size per GPU = %d",
                        self.args.per_gpu_train_batch_size)
            logger.info(
                "  Total train batch size (w. parallel, distributed & accumulation) = %d",
                self.args.train_batch_size *
                self.args.gradient_accumulation_steps *
                (self.args.world_size if self.args.local_rank != -1 else 1),
            )
            logger.info("  Gradient Accumulation steps = %d",
                        self.args.gradient_accumulation_steps)
            logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        global_batch_train_start = time.time()

        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if len(inputs['input_ids']
                       ) < self.args.per_gpu_train_batch_size:
                    #skip incomplete batch
                    logger.info('Skipping incomplete batch...')
                    continue

                learning_rate = torch.tensor([
                    scheduler.get_lr_this_step(global_step,
                                               base_lr=self.args.learning_rate)
                ])
                loss, all_finite = self._training_step(model, inputs,
                                                       learning_rate,
                                                       loss_scaler)
                tr_loss += loss
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    if self.args.fp16:
                        loss_scaler.update_loss_scale(all_finite.item())

                    global_step += 1
                    global_batch_train_duration = time.time(
                    ) - global_batch_train_start
                    global_batch_train_start = time.time()

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            loss_avg = (tr_loss - logging_loss) / (
                                self.args.logging_steps *
                                self.args.gradient_accumulation_steps)
                            logs["learning_rate"] = learning_rate.item()
                            logs["loss"] = loss_avg
                            logs["global_step"] = global_step
                            logs[
                                "global_step_time"] = global_batch_train_duration
                            logging_loss = tr_loss

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                                    run.log(k, v)
                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.ort_model
                            else:
                                assert model is self.ort_model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            # self._rotate_checkpoints()

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()
        self.update_torch_model()
        del (self.ort_model)
        self.ort_model = None

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)
Exemple #10
0
def create_ort_trainer(args, device, model):
    # set GPU memory limitation
    from onnxruntime.capi._pybind_state import set_cuda_mem_limit
    set_cuda_mem_limit(int(args.ort_cuda_mem_limit_in_gbs * 1024 * 1024 *
                           1024))

    def map_optimizer_attributes(name):
        no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
        no_decay = False
        for no_decay_key in no_decay_keys:
            if no_decay_key in name:
                no_decay = True
                break
        if no_decay:
            return {
                "alpha": 0.9,
                "beta": 0.999,
                "lambda": 0.0,
                "epsilon":
                1e-8,  #if self.optimizer == 'LambOptimizer' else 1e-8,
                # Adam optimizer mode
                # 0: pytorch's Adamw
                # 1: Huggface's Adamw
                "weight_decay_mode": 0,
                "do_bias_correction": 1
            }
        else:
            return {
                "alpha": 0.9,
                "beta": 0.999,
                "lambda": 0.01,
                "epsilon":
                1e-8,  #if self.optimizer == 'LambOptimizer' else 1e-8,
                # Adam optimizer mode
                # 0: pytorch's Adamw
                # 1: Huggface's Adamw
                "weight_decay_mode": 0,
                "do_bias_correction": 1
            }

    #print('Creating ORTTrainer')
    # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes.
    # train_step does forward, backward, and optimize step.
    model = ORTTrainer(
        model,
        None,
        bart_model_description(args),
        "AdamOptimizer",
        map_optimizer_attributes,
        IODescription('Learning_Rate', [
            1,
        ], torch.float32),
        device,  #_extra_postprocess=postprocess_model, 
        gradient_accumulation_steps=args.update_freq[0],
        world_rank=args.distributed_rank,
        world_size=args.distributed_world_size,
        use_mixed_precision=True if args.fp16 else False,
        allreduce_post_accumulation=
        True,  #if args.allreduce_post_accumulation else False,
        #partition_optimizer = False, #if args.partition_optimizer else False,
        _opset_version=12)
    #print('Created ORTTrainer')
    if args.fp16:
        setattr(
            args, 'ort_loss_scale',
            LossScaler(model.loss_scale_input_name, True,
                       up_scale_window=2000))

    return model
        def create_and_check_bert_for_pretraining(self, config, input_ids,
                                                  token_type_ids, input_mask,
                                                  sequence_labels,
                                                  token_labels, choice_labels):
            model = BertForPreTraining(config=config)
            model.eval()
            loss, prediction_scores, seq_relationship_score = model(
                input_ids,
                attention_mask=input_mask,
                token_type_ids=token_type_ids,
                masked_lm_labels=token_labels,
                next_sentence_label=sequence_labels)
            model_desc = ModelDescription([
                self.input_ids_desc, self.attention_mask_desc,
                self.token_type_ids_desc, self.masked_lm_labels_desc,
                self.next_sentence_label_desc
            ], [
                self.loss_desc, self.prediction_scores_desc,
                self.seq_relationship_scores_desc
            ])

            import argparse
            args_ = argparse.Namespace(fp16=True, amp_opt_level='O1')

            from collections import namedtuple
            MyArgs = namedtuple(
                "MyArgs",
                "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
            )
            args = MyArgs(local_rank=0,
                          world_size=1,
                          max_steps=100,
                          learning_rate=0.00001,
                          warmup_proportion=0.01,
                          batch_size=13,
                          seq_len=7)

            from train_with_ort_trainer import get_lr

            def get_lr_this_step(global_step):
                return get_lr(args, global_step)

            loss_scaler = LossScaler('loss_scale_input_name',
                                     True,
                                     up_scale_window=2000)

            option_gradient_accumulation_steps = [8]
            option_fp16 = [True, False]
            option_allreduce_post_accumulation = True
            option_use_internal_get_lr_this_step = False
            option_use_internal_loss_scaler = False
            # TODO: with with fetches

            for gradient_accumulation_steps in option_gradient_accumulation_steps:
                for fp16 in option_fp16:
                    for option_split_batch in BatchArgsOption:
                        loss_ort, prediction_scores_ort, seq_relationship_score_ort =\
                            run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16,
                                     option_allreduce_post_accumulation,
                                     get_lr_this_step, option_use_internal_get_lr_this_step,
                                     loss_scaler, option_use_internal_loss_scaler,
                                     option_split_batch)

                        print(loss_ort)
                        print(prediction_scores_ort)
                        print(seq_relationship_score_ort)