def prepare_model(args, device):
    config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)

    # config.num_hidden_layers = 12
    if args.force_num_hidden_layers:
        logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
        config.num_hidden_layers = args.force_num_hidden_layers

    model = BertForPreTraining(config)
    if args.init_state_dict is not None:
        model.load_state_dict(args.init_state_dict)
    model_desc = bert_model_description(config)

    lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)

    loss_scaler = amp.DynamicLossScaler() if args.fp16 else None

    options = orttrainer.ORTTrainerOptions({'batch': {
                                                'gradient_accumulation_steps': args.gradient_accumulation_steps},
                                            'device': {'id': str(device)},
                                            'mixed_precision': {
                                                'enabled': args.fp16,
                                                'loss_scaler': loss_scaler},
                                            'graph_transformer': {
                                                'attn_dropout_recompute': args.attn_dropout_recompute,
                                                'gelu_recompute': args.gelu_recompute,
                                                'transformer_layer_recompute': args.transformer_layer_recompute,
                                            },
                                            'debug': {'deterministic_compute': True, },
                                            'utils': {
                                                'grad_norm_clip': True},
                                            'distributed': {
                                                'world_rank': max(0, args.local_rank),
                                                'world_size': args.world_size,
                                                'local_rank': max(0, args.local_rank),
                                                'allreduce_post_accumulation': args.allreduce_post_accumulation,
                                                'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage},
                                                'enable_adasum': False},
                                            'lr_scheduler': lr_scheduler
                                            })

    param_optimizer = list(model.named_parameters())
    no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
    params = [{
        'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, {
        'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}]

    optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
    model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options)

    return model
        learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0])

    # Check output
    _test_helpers.assert_model_outputs(learning_rates,
                                       expected_learning_rates,
                                       rtol=rtol)
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)


@pytest.mark.parametrize(
    "loss_scaler, expected_losses",
    [(None, [
        10.992018, 10.975699, 11.032809, 11.034765, 10.987625, 11.039452,
        10.971539, 11.10148, 11.047551, 11.077468
    ]),
     (amp.DynamicLossScaler(), [
         10.992018, 10.975699, 11.032809, 11.034765, 10.987625, 11.039452,
         10.971539, 11.10148, 11.047551, 11.077468
     ]),
     (CustomLossScaler(), [
         10.992018, 10.975699, 11.032791, 11.034729, 10.987614, 11.039479,
         10.971532, 11.101475, 11.04761, 11.077413
     ])])
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses):
    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
Пример #3
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=14,
            _use_deterministic_compute=True)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)
def prepare_model(args, device):
    config = BertConfig.from_pretrained('bert-base-uncased',
                                        cache_dir=args.cache_dir)
    if args.force_num_hidden_layers:
        logger.info("Modifying model config with num_hidden_layers to %d",
                    args.force_num_hidden_layers)
        config.num_hidden_layers = args.force_num_hidden_layers

    model = BertForPreTraining(config)
    model_desc = bert_model_description(config)

    lr_scheduler = PolyWarmupLRScheduler(total_steps=int(args.max_steps))

    loss_scaler = amp.DynamicLossScaler() if args.fp16 else None

    options = orttrainer.ORTTrainerOptions({
        'batch': {
            'gradient_accumulation_steps': args.gradient_accumulation_steps
        },
        'device': {
            'id': str(device)
        },
        'mixed_precision': {
            'enabled': args.fp16,
            'loss_scaler': loss_scaler
        },
        'debug': {
            'deterministic_compute': True,
        },
        'utils': {
            'grad_norm_clip': True
        },
        'distributed': {
            'allreduce_post_accumulation': True
        },
        'lr_scheduler': lr_scheduler
    })

    param_optimizer = list(model.named_parameters())
    no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
    params = [{
        'params': [
            n for n, p in param_optimizer
            if any(no_decay_key in n for no_decay_key in no_decay_keys)
        ],
        "alpha":
        0.9,
        "beta":
        0.999,
        "lambda":
        0.0,
        "epsilon":
        1e-6
    }, {
        'params': [
            n for n, p in param_optimizer
            if not any(no_decay_key in n for no_decay_key in no_decay_keys)
        ],
        "alpha":
        0.9,
        "beta":
        0.999,
        "lambda":
        0.0,
        "epsilon":
        1e-6
    }]

    optim_config = optim.AdamConfig(params=params,
                                    lr=2e-5,
                                    do_bias_correction=True)
    model = orttrainer.ORTTrainer(model,
                                  model_desc,
                                  optim_config,
                                  options=options)

    return model
Пример #5
0
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(
            t_total, self.args.warmup_steps / float(t_total))

        loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
        device = self.args.device.type

        device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0"
        options = orttrainer.ORTTrainerOptions({
            "batch": {
                "gradient_accumulation_steps":
                self.args.gradient_accumulation_steps
            },
            "device": {
                "id": device
            },
            "mixed_precision": {
                "enabled": self.args.fp16,
                "loss_scaler": loss_scaler
            },
            "debug": {
                "deterministic_compute": True,
            },
            "utils": {
                "grad_norm_clip": False
            },
            "distributed": {
                # we are running single node multi gpu test. thus world_rank = local_rank
                # and world_size = self.args.n_gpu
                "world_rank": max(0, self.args.local_rank),
                "world_size": int(self.world_size),
                "local_rank": max(0, self.args.local_rank),
                "allreduce_post_accumulation": True,
            },
            "lr_scheduler": lr_scheduler,
        })

        param_optimizer = list(self.model.named_parameters())
        params = [
            {
                "params": [
                    n for n, p in param_optimizer
                    if "bias" in n or "LayerNorm.weight" in n
                ],
                "weight_decay_mode":
                1,
            },
            {
                "params": [
                    n for n, p in param_optimizer
                    if not ("bias" in n or "LayerNorm.weight" in n)
                ],
                "weight_decay_mode":
                1,
            },
        ]

        optim_config = optim.AdamConfig(params=params,
                                        lr=2e-5,
                                        do_bias_correction=True)
        self.model = orttrainer.ORTTrainer(self.model,
                                           self.model_desc,
                                           optim_config,
                                           options=options)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps

                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())
        learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0])

    # Check output
    _test_helpers.assert_model_outputs(learning_rates,
                                       expected_learning_rates,
                                       rtol=rtol)
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)


@pytest.mark.parametrize("loss_scaler, expected_losses", [
    (None, [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934,\
        11.105679512023926, 10.981968879699707, 11.081787109375, 10.997162818908691, 11.107288360595703]),
    (amp.DynamicLossScaler(), [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172,\
        10.988829612731934, 11.105679512023926, 10.981969833374023, 11.081744194030762, 10.997139930725098, 11.107272148132324]),
    (CustomLossScaler(), [10.98803424835205, 10.99240493774414, 11.090554237365723, 11.042823791503906, 10.98877239227295,\
        11.105667114257812, 10.981982231140137, 11.081765174865723, 10.997125625610352, 11.107298851013184])
])
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses):
    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()
     None,
     [
         11.041126,
         10.986309,
         11.101673,
         11.013394,
         11.037781,
         11.041253,
         10.957072,
         11.069506,
         11.040807,
         11.164349,
     ],
 ),
 (
     amp.DynamicLossScaler(),
     [
         11.041126,
         10.986309,
         11.101673,
         11.013394,
         11.037781,
         11.041253,
         10.957072,
         11.069506,
         11.040807,
         11.164349,
     ],
 ),
 (
     CustomLossScaler(),
Пример #8
0
    "inputs": [("x", ["n", "d_in"]), ("target", ["n"])],
    "outputs": [("loss", [], True), ("output", ["n", d_out])],
}
# Actual schema used when running micro-batches.
pipeline_schema = {"x": [n_slice, d_in], "target": [n_slice], "output": [n_slice, d_out], "loss": []}
# Describe which axis to slice along for each sliced tensor.
sliced_axes = {"x": 0, "target": 0, "output": 0}

adam_config = optim.AdamConfig(lr=0.1)

# Specify configuration for pipeline parallel training.
trainer_config = ORTTrainerOptions(
    {
        "batch": {"gradient_accumulation_steps": num_pipeline_steps},
        "device": {"id": cuda_device},
        "mixed_precision": {"enabled": True, "loss_scaler": amp.DynamicLossScaler()},
        "distributed": {
            "world_size": total_ranks,
            "world_rank": rank,
            "data_parallel_size": int(total_ranks / num_pipeline_stages),
            "horizontal_parallel_size": 1,
            "pipeline_parallel": {
                "pipeline_parallel_size": int(num_pipeline_stages),
                "num_pipeline_micro_batches": num_pipeline_steps,
                "sliced_schema": pipeline_schema,
                "sliced_axes": sliced_axes,
                "sliced_tensor_names": ["x", "target", "output"],
                # Define pipeline stage partition by specifying cut points.
                # 2-stage cut. It's a cut on tensor "12".
                "pipeline_cut_info_string": "12",
            },
Пример #9
0
# Describe which axis to slice along for each sliced tensor.
sliced_axes = {'x': 0, 'target': 0, 'output': 0}

adam_config = optim.AdamConfig(lr=0.1)

# Specify configuration for pipeline parallel training.
trainer_config = ORTTrainerOptions({
    'batch': {
        'gradient_accumulation_steps': num_pipeline_steps
    },
    'device': {
        'id': cuda_device
    },
    'mixed_precision': {
        'enabled': True,
        'loss_scaler': amp.DynamicLossScaler()
    },
    'distributed': {
        'world_size': total_ranks,
        'world_rank': rank,
        'data_parallel_size': int(total_ranks / num_pipeline_stages),
        'horizontal_parallel_size': 1,
        'pipeline_parallel': {
            'pipeline_parallel_size': int(num_pipeline_stages),
            'num_pipeline_micro_batches': num_pipeline_steps,
            'sliced_schema': pipeline_schema,
            'sliced_axes': sliced_axes,
            'sliced_tensor_names': ['x', 'target', 'output'],
            # Define pipeline stage partition by specifying cut points.
            # 2-stage cut. It's a cut on tensor "12".
            'pipeline_cut_info_string': '12'
Пример #10
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        config = self.model.config
        model_desc = self.gpt2_model_description(
            config.n_head, config.vocab_size, config.n_embd, config.n_layer,
            config.n_ctx, self.args.per_gpu_train_batch_size)

        from onnxruntime.capi._pybind_state import set_arena_extend_strategy, ArenaExtendStrategy
        set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested)

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

        optim_config = optim.AdamConfig(params=[{
            'params':
            [n for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'lambda_coef':
            0.0
        }],
                                        lr=self.args.learning_rate,
                                        alpha=0.9,
                                        beta=0.999,
                                        lambda_coef=self.args.weight_decay,
                                        epsilon=self.args.adam_epsilon)

        warmup = self.args.warmup_steps / t_total
        lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(
            total_steps=t_total, warmup=warmup)
        loss_scaler = amp.DynamicLossScaler(
            automatic_update=True,
            loss_scale=float(1 << 20),
            up_scale_window=2000,
            min_loss_scale=1.0,
            max_loss_scale=float(1 << 24)) if self.args.fp16 else None

        opts = orttrainer.ORTTrainerOptions({
            'device': {
                'id': str(self.args.device)
            },
            'distributed': {
                'world_rank': self.args.world_rank,
                'world_size': self.args.world_size,
                'local_rank': self.args.local_rank,
                'allreduce_post_accumulation': True
            },
            'mixed_precision': {
                'enabled': self.args.fp16,
                'loss_scaler': loss_scaler
            },
            'batch': {
                'gradient_accumulation_steps':
                self.args.gradient_accumulation_steps
            },
            'lr_scheduler': lr_scheduler
        })

        self.ort_model = orttrainer.ORTTrainer(self.model,
                                               model_desc,
                                               optim_config,
                                               None,
                                               options=opts)

        logger.info("****************************Model converted to ORT")
        model = self.ort_model

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())

        # Train!
        if self.is_world_master():
            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_dataloader.dataset))
            logger.info("  Num Epochs = %d", num_train_epochs)
            logger.info("  Instantaneous batch size per GPU = %d",
                        self.args.per_gpu_train_batch_size)
            logger.info(
                "  Total train batch size (w. parallel, distributed & accumulation) = %d",
                self.args.train_batch_size *
                self.args.gradient_accumulation_steps *
                (self.args.world_size if self.args.local_rank != -1 else 1),
            )
            logger.info("  Gradient Accumulation steps = %d",
                        self.args.gradient_accumulation_steps)
            logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        global_batch_train_start = time.time()

        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if len(inputs['input_ids']
                       ) < self.args.per_gpu_train_batch_size:
                    # skip incomplete batch
                    logger.info('Skipping incomplete batch...')
                    continue

                tr_loss += self._training_step(model, inputs)
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    global_step += 1
                    global_batch_train_duration = time.time(
                    ) - global_batch_train_start
                    global_batch_train_start = time.time()

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            loss_avg = (tr_loss - logging_loss) / (
                                self.args.logging_steps *
                                self.args.gradient_accumulation_steps)
                            logs["learning_rate"] = lr_scheduler.get_last_lr(
                            )[0]
                            logs["loss"] = loss_avg.item()
                            logs["global_step"] = global_step
                            logs[
                                "global_step_time"] = global_batch_train_duration
                            logging_loss = tr_loss.clone()

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                                    run.log(k, v)
                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.ort_model
                            else:
                                assert model is self.ort_model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            self._rotate_checkpoints()

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()
        self.update_torch_model()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        if self.use_new_api:
            lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(
                t_total, self.args.warmup_steps / float(t_total))

            loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
            device = self.args.device.type

            device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0'
            options = orttrainer.ORTTrainerOptions({
                'batch': {
                    'gradient_accumulation_steps':
                    self.args.gradient_accumulation_steps
                },
                'device': {
                    'id': device
                },
                'mixed_precision': {
                    'enabled': self.args.fp16,
                    'loss_scaler': loss_scaler
                },
                'debug': {
                    'deterministic_compute': True,
                },
                'utils': {
                    'grad_norm_clip': False
                },
                'distributed': {
                    # we are running single node multi gpu test. thus world_rank = local_rank
                    # and world_size = self.args.n_gpu
                    'world_rank': max(0, self.args.local_rank),
                    'world_size': int(self.world_size),
                    'local_rank': max(0, self.args.local_rank),
                    'allreduce_post_accumulation': True
                },
                'lr_scheduler': lr_scheduler
            })

            param_optimizer = list(self.model.named_parameters())
            params = [{
                'params': [
                    n for n, p in param_optimizer
                    if "bias" in n or "LayerNorm.weight" in n
                ],
                "weight_decay_mode":
                1,
            }, {
                'params': [
                    n for n, p in param_optimizer
                    if not ("bias" in n or "LayerNorm.weight" in n)
                ],
                "weight_decay_mode":
                1,
            }]

            optim_config = optim.AdamConfig(params=params,
                                            lr=2e-5,
                                            do_bias_correction=True)
            self.model = orttrainer.ORTTrainer(self.model,
                                               self.new_model_desc,
                                               optim_config,
                                               options=options)
        else:

            def map_optimizer_attributes(name):
                no_decay = "bias" in name or "LayerNorm.weight" in name
                if no_decay:
                    return {"weight_decay_mode": 1}
                else:
                    return {"weight_decay_mode": 1}

            get_lr_this_step = get_linear_schedule_with_warmup(
                self.args.warmup_steps, t_total, self.args.learning_rate)
            loss_scaler = LossScaler(
                'loss_scale_input_name', True,
                up_scale_window=2000) if self.args.fp16 else None
            self.model = ORTTrainer(
                self.model,
                None,
                self.model_desc,
                "AdamOptimizer",
                map_optimizer_attributes=map_optimizer_attributes,
                learning_rate_description=IODescription(
                    'Learning_Rate', [
                        1,
                    ], torch.float32),
                device=self.args.device,
                gradient_accumulation_steps=self.args.
                gradient_accumulation_steps,
                world_rank=max(0, self.args.local_rank),
                world_size=int(self.world_size),
                use_mixed_precision=self.args.fp16,
                allreduce_post_accumulation=True,
                get_lr_this_step=get_lr_this_step,
                loss_scaler=loss_scaler,
                enable_grad_norm_clip=False,
                _opset_version=12,
                _use_deterministic_compute=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps
                            if not self.use_new_api:
                                learning_rate_scalar = get_lr_this_step(
                                    global_step)
                                logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)
    losses = []
    learning_rates = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())
        learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0])

    # Check output
    _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol)
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)


@pytest.mark.parametrize("loss_scaler, expected_losses", [
    (None, [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172, 10.988829612731934,\
        11.105679512023926, 10.981968879699707, 11.081787109375, 10.997162818908691, 11.107288360595703]),
    (amp.DynamicLossScaler(), [10.98803424835205, 10.99240493774414, 11.090575218200684, 11.042827606201172,\
        10.988829612731934, 11.105679512023926, 10.981969833374023, 11.081744194030762, 10.997139930725098, 11.107272148132324]),
    (CustomLossScaler(), [10.98803424835205, 10.99240493774414, 11.090554237365723, 11.042823791503906, 10.98877239227295,\
        11.105667114257812, 10.981982231140137, 11.081765174865723, 10.997125625610352, 11.107298851013184])
])
def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses):
    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Modeling
    model_desc = bert_model_description()