Esempio n. 1
0
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict and the pytorch model
    """
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # PyTorch transformer model setup
    learning_rate = 0.1
    optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint_file_name = 'checkpoint*.ortcp'
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name))
    trainer.load_checkpoint(*checkpoint_files)

    # run an eval step to innitialize the graph
    torch.manual_seed(seed)
    set_seed(seed)
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return trainer.state_dict(), model
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(
        loss_scaler, legacy_loss_scaler):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig(lr=0.001)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=legacy_loss_scaler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
Esempio n. 3
0
def create_orttrainer_and_save_checkpoint(device,
                                          trainer_opts,
                                          checkpoint_dir,
                                          state_dict_key_name='state_dict',
                                          use_lamb=True):
    learning_rate = 0.1
    seed = 1

    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    if 'distributed' in trainer_opts:
        train_data = next(
            islice(
                _chunkify(train_data,
                          trainer_opts['distributed']['world_size']),
                trainer_opts['distributed']['world_rank'], None))

    # run train steps
    _train(trainer, train_data, batcher_fn)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        _save(trainer, checkpoint_dir, state_dict_key_name)
Esempio n. 4
0
def load_model_optim_state_and_eval(device,
                                    trainer_opts,
                                    use_lamb=True,
                                    seed=1,
                                    learning_rate=0.1):
    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load dummy state
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    optimizer_state_dict = trainer.state_dict()
    del optimizer_state_dict["model"]

    return dummy_init_state, optimizer_state_dict
def testToyBERTModelGradientAccumulationLegacyExperimental(
        gradient_accumulation_steps):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'batch': {
            'gradient_accumulation_steps': gradient_accumulation_steps
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        loss = trainer.train_step(*sample_input)
        experimental_losses.append(loss.cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "AdamOptimizer",
        None,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
        gradient_accumulation_steps=gradient_accumulation_steps)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
Esempio n. 6
0
def create_orttrainer_and_save_checkpoint_bart(
        device,
        trainer_opts,
        checkpoint_dir,
        state_dict_key_name="state_dict",
        use_lamb=True,
        seed=1,
        learning_rate=0.1):
    """Instantiate trainer and save checkpoint for BART.

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads a dummy optimizer state into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=ort_trainer_opts)

    # load dummy optimizer state as we are not going to run real training
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    init_state = copy.deepcopy(dummy_init_state)
    trainer.load_state_dict(dummy_init_state)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
            # save the initial complete model and optimizer states
            if ort_trainer_opts.distributed.world_rank == 0:
                init_state["model"] = {"full_precision": dict()}
                for initializer in model.graph.initializer:
                    init_state["model"]["full_precision"][
                        initializer.name] = numpy_helper.to_array(initializer)
                with open(
                        os.path.join(checkpoint_dir,
                                     "expected_state_dict.pkl"), "wb") as f:
                    pickle.dump(init_state, f)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL API
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.AdamConfig(
        params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False
    )
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr)

    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "AdamOptimizer",
        legacy_optim_map,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
    )
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        legacy_sample_input = [*sample_input, learning_rate]
        legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def prepare_model(args, device):
    config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)

    # config.num_hidden_layers = 12
    if args.force_num_hidden_layers:
        logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
        config.num_hidden_layers = args.force_num_hidden_layers

    model = BertForPreTraining(config)
    if args.init_state_dict is not None:
        model.load_state_dict(args.init_state_dict)
    model_desc = bert_model_description(config)

    lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)

    loss_scaler = amp.DynamicLossScaler() if args.fp16 else None

    options = orttrainer.ORTTrainerOptions({'batch': {
                                                'gradient_accumulation_steps': args.gradient_accumulation_steps},
                                            'device': {'id': str(device)},
                                            'mixed_precision': {
                                                'enabled': args.fp16,
                                                'loss_scaler': loss_scaler},
                                            'graph_transformer': {
                                                'attn_dropout_recompute': args.attn_dropout_recompute,
                                                'gelu_recompute': args.gelu_recompute,
                                                'transformer_layer_recompute': args.transformer_layer_recompute,
                                            },
                                            'debug': {'deterministic_compute': True, },
                                            'utils': {
                                                'grad_norm_clip': True},
                                            'distributed': {
                                                'world_rank': max(0, args.local_rank),
                                                'world_size': args.world_size,
                                                'local_rank': max(0, args.local_rank),
                                                'allreduce_post_accumulation': args.allreduce_post_accumulation,
                                                'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage},
                                                'enable_adasum': False},
                                            'lr_scheduler': lr_scheduler
                                            })

    param_optimizer = list(model.named_parameters())
    no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
    params = [{
        'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, {
        'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}]

    optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
    model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options)

    return model
def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses):
    return  # TODO: re-enable after nondeterminism on backend is fixed
    # Common setup
    device = "cuda"
    total_steps = 10
    seed = 1
    warmup = 0.05
    cycles = 0.5
    power = 1.0
    lr_end = 1e-7
    rtol = 1e-3
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # Setup LR Schedulers
    if (
        lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler
        or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler
    ):
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    # Modeling
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
            "lr_scheduler": lr_scheduler,
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    # Train
    losses = []
    learning_rates = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        losses.append(trainer.train_step(*sample_input).cpu().item())
        learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0])

    # Check output
    _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol)
    _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
Esempio n. 10
0
def create_initialized_orttrainer(device, trainer_opts, use_lamb=True):
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    learning_rate = 1e-10
    optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    _train(trainer, train_data, batcher_fn)

    return trainer
Esempio n. 11
0
def verify_optimizer_state_match(device,
                                 opts,
                                 checkpoint_dir,
                                 world_rank,
                                 use_lamb=False):
    expected_optim_state, trainer_state = load_model_optim_state_and_eval(
        device, opts, use_lamb)
    trainer_state = split_state_dict(trainer_state)
    # round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them.
    with open(
            os.path.join(checkpoint_dir,
                         'distributed_state_' + str(world_rank) + '.pkl'),
            "wb") as f:
        pickle.dump(trainer_state, f)
    dist.barrier()

    if world_rank == 0:
        num_states = len(glob.glob1(checkpoint_dir, "distributed_state*"))
        optimizer_states = dict()
        for rank in range(num_states):
            rank_state_dict = None
            with open(
                    os.path.join(checkpoint_dir,
                                 'distributed_state_' + str(rank) + '.pkl'),
                    'rb') as f:
                rank_state_dict = pickle.load(f)

            # collect optimizer states for later comparison since they are sharded
            aggregate_states(optimizer_states, rank_state_dict['optimizer'])

        # compare optimizer states
        optimizer_config = optim.LambConfig(
        ) if use_lamb else optim.AdamConfig()
        actual_optim_state = get_optim_state_from_state_dict(
            optimizer_states, optimizer_config)
        assert actual_optim_state.keys() == expected_optim_state.keys()
        for param_name, a_state in actual_optim_state.items():
            for k, v in a_state.items():
                assert_allclose(
                    v.reshape(expected_optim_state[param_name][k].shape),
                    expected_optim_state[param_name][k],
                    err_msg=
                    f"Optimizer state mismatch for param {param_name}, key {k}"
                )

    dist.barrier()
    os.remove(
        os.path.join(checkpoint_dir,
                     'distributed_state_' + str(world_rank) + '.pkl'))
Esempio n. 12
0
def create_orttrainer_and_load_checkpoint_bart(device,
                                               trainer_opts,
                                               checkpoint_dir,
                                               use_lamb=True,
                                               seed=1,
                                               learning_rate=0.1):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict, the expected state dict if present, and the onnx model
    """
    torch.manual_seed(seed)
    set_seed(seed)

    # model setup
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc = _load_bart_model()
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint_file_name = "checkpoint*.ortcp"
    checkpoint_files = glob.glob(
        os.path.join(checkpoint_dir, checkpoint_file_name))
    trainer.load_checkpoint(*checkpoint_files)

    # run an eval step to innitialize the graph
    src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(
        model_desc, seed=seed)
    trainer.eval_step(src_tokens, prev_output_tokens, target)

    expected_state_dict = None
    fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl")
    if os.path.isfile(fname):
        with open(fname, "rb") as f:
            expected_state_dict = pickle.load(f)

    return trainer.state_dict(), expected_state_dict, model
Esempio n. 13
0
def create_orttrainer_and_save_checkpoint(device,
                                          trainer_opts,
                                          checkpoint_dir,
                                          state_dict_key_name="state_dict",
                                          use_lamb=True,
                                          seed=1,
                                          learning_rate=0.1):
    torch.manual_seed(seed)
    set_seed(seed)

    ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts)
    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    loss_fn=loss_fn,
                                    options=ort_trainer_opts)

    if "distributed" in trainer_opts:
        train_data = next(
            islice(
                _chunkify(train_data,
                          trainer_opts["distributed"]["world_size"]),
                trainer_opts["distributed"]["world_rank"],
                None,
            ))

    # run train steps
    _train(trainer, train_data, batcher_fn)

    # save current model parameters as a checkpoint
    if checkpoint_dir:
        if _is_model_parallel_run(ort_trainer_opts):
            _save(trainer,
                  checkpoint_dir,
                  state_dict_key_name,
                  world_rank=ort_trainer_opts.distributed.world_rank)
        else:
            _save(trainer, checkpoint_dir, state_dict_key_name)
Esempio n. 14
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        config = self.model.config
        model_desc = self.gpt2_model_description(
            config.n_head, config.vocab_size, config.n_embd, config.n_layer,
            config.n_ctx, self.args.per_gpu_train_batch_size)

        from onnxruntime.capi._pybind_state import set_arena_extend_strategy, ArenaExtendStrategy
        set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested)

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

        optim_config = optim.AdamConfig(params=[{
            'params':
            [n for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'lambda_coef':
            0.0
        }],
                                        lr=self.args.learning_rate,
                                        alpha=0.9,
                                        beta=0.999,
                                        lambda_coef=self.args.weight_decay,
                                        epsilon=self.args.adam_epsilon)

        warmup = self.args.warmup_steps / t_total
        lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(
            total_steps=t_total, warmup=warmup)
        loss_scaler = amp.DynamicLossScaler(
            automatic_update=True,
            loss_scale=float(1 << 20),
            up_scale_window=2000,
            min_loss_scale=1.0,
            max_loss_scale=float(1 << 24)) if self.args.fp16 else None

        opts = orttrainer.ORTTrainerOptions({
            'device': {
                'id': str(self.args.device)
            },
            'distributed': {
                'world_rank': self.args.world_rank,
                'world_size': self.args.world_size,
                'local_rank': self.args.local_rank,
                'allreduce_post_accumulation': True
            },
            'mixed_precision': {
                'enabled': self.args.fp16,
                'loss_scaler': loss_scaler
            },
            'batch': {
                'gradient_accumulation_steps':
                self.args.gradient_accumulation_steps
            },
            'lr_scheduler': lr_scheduler
        })

        self.ort_model = orttrainer.ORTTrainer(self.model,
                                               model_desc,
                                               optim_config,
                                               None,
                                               options=opts)

        logger.info("****************************Model converted to ORT")
        model = self.ort_model

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())

        # Train!
        if self.is_world_master():
            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_dataloader.dataset))
            logger.info("  Num Epochs = %d", num_train_epochs)
            logger.info("  Instantaneous batch size per GPU = %d",
                        self.args.per_gpu_train_batch_size)
            logger.info(
                "  Total train batch size (w. parallel, distributed & accumulation) = %d",
                self.args.train_batch_size *
                self.args.gradient_accumulation_steps *
                (self.args.world_size if self.args.local_rank != -1 else 1),
            )
            logger.info("  Gradient Accumulation steps = %d",
                        self.args.gradient_accumulation_steps)
            logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        global_batch_train_start = time.time()

        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if len(inputs['input_ids']
                       ) < self.args.per_gpu_train_batch_size:
                    # skip incomplete batch
                    logger.info('Skipping incomplete batch...')
                    continue

                tr_loss += self._training_step(model, inputs)
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    global_step += 1
                    global_batch_train_duration = time.time(
                    ) - global_batch_train_start
                    global_batch_train_start = time.time()

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            loss_avg = (tr_loss - logging_loss) / (
                                self.args.logging_steps *
                                self.args.gradient_accumulation_steps)
                            logs["learning_rate"] = lr_scheduler.get_last_lr(
                            )[0]
                            logs["loss"] = loss_avg.item()
                            logs["global_step"] = global_step
                            logs[
                                "global_step_time"] = global_batch_train_duration
                            logging_loss = tr_loss.clone()

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                                    run.log(k, v)
                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.ort_model
                            else:
                                assert model is self.ort_model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            self._rotate_checkpoints()

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()
        self.update_torch_model()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler,
                                                  legacy_lr_scheduler):
    ############################################################################
    # These tests require hard-coded values for 'total_steps' and 'initial_lr' #
    ############################################################################

    # Common setup
    total_steps = 128
    device = 'cuda'
    seed = 1
    warmup = 0.05
    cycles = 0.5
    power = 1.
    lr_end = 1e-7

    # Setup both Experimental and Legacy LR Schedulers before the experimental loop
    if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup)
    elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup,
                                      cycles=cycles)
    elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup,
                                      power=power,
                                      lr_end=lr_end)
    else:
        raise RuntimeError("Invalid legacy_lr_scheduler")
    if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps,
                                    warmup=warmup,
                                    cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps,
                                    warmup=warmup,
                                    power=power,
                                    lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'lr_scheduler': lr_scheduler
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())
        assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0],
                        legacy_lr_scheduler(i))

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        initial_lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
Esempio n. 16
0
            total_loss += len(data) * loss.item()
    return total_loss / (len(data_source) - 1)


best_val_loss = float("inf")
epochs = 3  # The number of epochs
best_model = None

model_description = {
    'inputs': [('src', ['bptt', 'batch_size']),
               ('label', ['bptt_x_batch_size'])],
    'outputs': [('loss', [], True), ('output', ['bptt', 'batch_size',
                                                ntokens])]
}

optimizer_config = optim.AdamConfig(lr=learning_rate)

trainer = ORTTrainer(
    model,  # model
    model_description,  # model description
    optimizer_config,  # optimizer configuration
    loss_with_flat_output)  # loss function

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
                                     model_desc2,
                                     optim_config2,
                                     options=opts)
    trainer2.load_state_dict(state_dict)
    # Evaluate once to get a base loss
    ckpt_loss = trainer2.eval_step(*sample_input)

    # Must match as both trainers have the same dict state
    assert_allclose(loss.cpu(), ckpt_loss.cpu())
    loaded_state_dict = trainer2.state_dict()
    _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict)


@pytest.mark.parametrize("optimizer, mixedprecision_enabled", [
    (optim.LambConfig(), False),
    (optim.AdamConfig(), False),
    (optim.LambConfig(), True),
    (optim.AdamConfig(), True),
])
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled):
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optimizer
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
def prepare_model(args, device):
    config = BertConfig.from_pretrained('bert-base-uncased',
                                        cache_dir=args.cache_dir)
    if args.force_num_hidden_layers:
        logger.info("Modifying model config with num_hidden_layers to %d",
                    args.force_num_hidden_layers)
        config.num_hidden_layers = args.force_num_hidden_layers

    model = BertForPreTraining(config)
    model_desc = bert_model_description(config)

    lr_scheduler = PolyWarmupLRScheduler(total_steps=int(args.max_steps))

    loss_scaler = amp.DynamicLossScaler() if args.fp16 else None

    options = orttrainer.ORTTrainerOptions({
        'batch': {
            'gradient_accumulation_steps': args.gradient_accumulation_steps
        },
        'device': {
            'id': str(device)
        },
        'mixed_precision': {
            'enabled': args.fp16,
            'loss_scaler': loss_scaler
        },
        'debug': {
            'deterministic_compute': True,
        },
        'utils': {
            'grad_norm_clip': True
        },
        'distributed': {
            'allreduce_post_accumulation': True
        },
        'lr_scheduler': lr_scheduler
    })

    param_optimizer = list(model.named_parameters())
    no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
    params = [{
        'params': [
            n for n, p in param_optimizer
            if any(no_decay_key in n for no_decay_key in no_decay_keys)
        ],
        "alpha":
        0.9,
        "beta":
        0.999,
        "lambda":
        0.0,
        "epsilon":
        1e-6
    }, {
        'params': [
            n for n, p in param_optimizer
            if not any(no_decay_key in n for no_decay_key in no_decay_keys)
        ],
        "alpha":
        0.9,
        "beta":
        0.999,
        "lambda":
        0.0,
        "epsilon":
        1e-6
    }]

    optim_config = optim.AdamConfig(params=params,
                                    lr=2e-5,
                                    do_bias_correction=True)
    model = orttrainer.ORTTrainer(model,
                                  model_desc,
                                  optim_config,
                                  options=options)

    return model
Esempio n. 19
0
num_pipeline_stages = 2

# Compute batch size for micro-batches.
n_slice = int(n / num_pipeline_steps)

cuda_device = 'cuda:' + str(rank)
# Schema used when running the original batch.
schema = {'inputs': [('x', ['n', 'd_in']), ('target', ['n'])], 'outputs': [
    ('loss', [], True), ('output', ['n', d_out])]}
# Actual schema used when running micro-batches.
pipeline_schema = {'x': [n_slice, d_in], 'target': [
    n_slice], 'output': [n_slice, d_out], 'loss': []}
# Describe which axis to slice along for each sliced tensor.
sliced_axes = {'x': 0, 'target': 0, 'output': 0}

adam_config = optim.AdamConfig(lr=0.1)

# # Specify configuration for pipeline parallel training.
trainer_config = ORTTrainerOptions({
    'batch': {
        'gradient_accumulation_steps': num_pipeline_steps
    },
    'device': {
        'id': cuda_device
    },
    'distributed': {
        'world_size': total_ranks,
        'world_rank': rank,
        'data_parallel_size': int(total_ranks / num_pipeline_stages),
        'horizontal_parallel_size': 1,
        'pipeline_parallel': {
Esempio n. 20
0
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(
            t_total, self.args.warmup_steps / float(t_total))

        loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
        device = self.args.device.type

        device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0"
        options = orttrainer.ORTTrainerOptions({
            "batch": {
                "gradient_accumulation_steps":
                self.args.gradient_accumulation_steps
            },
            "device": {
                "id": device
            },
            "mixed_precision": {
                "enabled": self.args.fp16,
                "loss_scaler": loss_scaler
            },
            "debug": {
                "deterministic_compute": True,
            },
            "utils": {
                "grad_norm_clip": False
            },
            "distributed": {
                # we are running single node multi gpu test. thus world_rank = local_rank
                # and world_size = self.args.n_gpu
                "world_rank": max(0, self.args.local_rank),
                "world_size": int(self.world_size),
                "local_rank": max(0, self.args.local_rank),
                "allreduce_post_accumulation": True,
            },
            "lr_scheduler": lr_scheduler,
        })

        param_optimizer = list(self.model.named_parameters())
        params = [
            {
                "params": [
                    n for n, p in param_optimizer
                    if "bias" in n or "LayerNorm.weight" in n
                ],
                "weight_decay_mode":
                1,
            },
            {
                "params": [
                    n for n, p in param_optimizer
                    if not ("bias" in n or "LayerNorm.weight" in n)
                ],
                "weight_decay_mode":
                1,
            },
        ]

        optim_config = optim.AdamConfig(params=params,
                                        lr=2e-5,
                                        do_bias_correction=True)
        self.model = orttrainer.ORTTrainer(self.model,
                                           self.model_desc,
                                           optim_config,
                                           options=options)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps

                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        if self.use_new_api:
            lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(
                t_total, self.args.warmup_steps / float(t_total))

            loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
            device = self.args.device.type

            device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0'
            options = orttrainer.ORTTrainerOptions({
                'batch': {
                    'gradient_accumulation_steps':
                    self.args.gradient_accumulation_steps
                },
                'device': {
                    'id': device
                },
                'mixed_precision': {
                    'enabled': self.args.fp16,
                    'loss_scaler': loss_scaler
                },
                'debug': {
                    'deterministic_compute': True,
                },
                'utils': {
                    'grad_norm_clip': False
                },
                'distributed': {
                    # we are running single node multi gpu test. thus world_rank = local_rank
                    # and world_size = self.args.n_gpu
                    'world_rank': max(0, self.args.local_rank),
                    'world_size': int(self.world_size),
                    'local_rank': max(0, self.args.local_rank),
                    'allreduce_post_accumulation': True
                },
                'lr_scheduler': lr_scheduler
            })

            param_optimizer = list(self.model.named_parameters())
            params = [{
                'params': [
                    n for n, p in param_optimizer
                    if "bias" in n or "LayerNorm.weight" in n
                ],
                "weight_decay_mode":
                1,
            }, {
                'params': [
                    n for n, p in param_optimizer
                    if not ("bias" in n or "LayerNorm.weight" in n)
                ],
                "weight_decay_mode":
                1,
            }]

            optim_config = optim.AdamConfig(params=params,
                                            lr=2e-5,
                                            do_bias_correction=True)
            self.model = orttrainer.ORTTrainer(self.model,
                                               self.new_model_desc,
                                               optim_config,
                                               options=options)
        else:

            def map_optimizer_attributes(name):
                no_decay = "bias" in name or "LayerNorm.weight" in name
                if no_decay:
                    return {"weight_decay_mode": 1}
                else:
                    return {"weight_decay_mode": 1}

            get_lr_this_step = get_linear_schedule_with_warmup(
                self.args.warmup_steps, t_total, self.args.learning_rate)
            loss_scaler = LossScaler(
                'loss_scale_input_name', True,
                up_scale_window=2000) if self.args.fp16 else None
            self.model = ORTTrainer(
                self.model,
                None,
                self.model_desc,
                "AdamOptimizer",
                map_optimizer_attributes=map_optimizer_attributes,
                learning_rate_description=IODescription(
                    'Learning_Rate', [
                        1,
                    ], torch.float32),
                device=self.args.device,
                gradient_accumulation_steps=self.args.
                gradient_accumulation_steps,
                world_rank=max(0, self.args.local_rank),
                world_size=int(self.world_size),
                use_mixed_precision=self.args.fp16,
                allreduce_post_accumulation=True,
                get_lr_this_step=get_lr_this_step,
                loss_scaler=loss_scaler,
                enable_grad_norm_clip=False,
                _opset_version=12,
                _use_deterministic_compute=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps
                            if not self.use_new_api:
                                learning_rate_scalar = get_lr_this_step(
                                    global_step)
                                logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)