示例#1
0
def train(args, model, train_dataset, tokenizer, labels, pad_token_label_id):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    no_grad = ["embeddings"] + [
        "layer." + str(layer_i) + "."
        for layer_i in range(12) if layer_i < args.freeze_bottom_layer
    ]
    logger.info("  The frozen parameters are:")
    for n, p in model.named_parameters():
        p.requires_grad = False if any(nd in n for nd in no_grad) else True
        if not p.requires_grad:
            logger.info("    %s", n)
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=int(t_total *
                                                      args.warmup_ratio),
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  GPU IDs for training: %s",
                " ".join([str(id) for id in args.gpu_ids]))
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for epoch_i in range(args.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            model.train()
            # batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids":
                batch[0].to(args.device),
                "attention_mask":
                batch[1].to(args.device),
                "token_type_ids":
                batch[2].to(args.device)
                if args.model_type in ["bert", "xlnet"] else None,
                # XLM and RoBERTa don"t use segment_ids
                "labels":
                batch[3].to(args.device) if len(batch) <= 4 else batch[-1]
            }  # add hard-label scheme
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        logger.info("===== evaluate_during_training =====")
                        results, _ = evaluate(args,
                                              model,
                                              tokenizer,
                                              labels,
                                              pad_token_label_id,
                                              mode="dev")
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                            logger.info(
                                "Epoch: {}\t global_step: {}\t eval_{}: {}".
                                format(epoch_i, global_step, key, value))
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logger.info(
                        "Epoch: {}\t global_step: {}\t learning rate: {:.8}\t loss: {:.4f}"
                        .format(epoch_i, global_step,
                                scheduler.get_lr()[0],
                                (tr_loss - logging_loss) / args.logging_steps))
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, "module"
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer, teacher_model=None):
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    if args.model_type == 'roberta':
        args.warmup_steps = int(t_total * 0.06)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=1e-8)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    global_step = 0
    tr_loss = 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)

    current_best = 0
    output_eval_file = os.path.join(args.output_dir, 'eval_results.txt')

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3],
                'token_type_ids':
                batch[2] if args.model_type in ['bert'] else None
            }

            # prepare the hidden states and logits of the teacher model
            if args.training_phase == 'dynabertw' and teacher_model:
                with torch.no_grad():
                    _, teacher_logit, teacher_reps, _, _ = teacher_model(
                        **inputs)
            elif args.training_phase == 'dynabert' and teacher_model:
                hidden_max_all, logits_max_all = [], []
                for width_mult in sorted(args.width_mult_list, reverse=True):
                    with torch.no_grad():
                        _, teacher_logit, teacher_reps, _, _ = teacher_model(
                            **inputs)
                        hidden_max_all.append(teacher_reps)
                        logits_max_all.append(teacher_logit)

            # accumulate grads for all sub-networks
            for depth_mult in sorted(args.depth_mult_list, reverse=True):
                model.apply(lambda m: setattr(m, 'depth_mult', depth_mult))
                # select teacher model layers for matching
                if args.training_phase == 'dynabert' or 'final_finetuning':
                    model = model.module if hasattr(model, 'module') else model
                    base_model = getattr(model, model.base_model_prefix, model)
                    n_layers = base_model.config.num_hidden_layers
                    depth = round(depth_mult * n_layers)
                    kept_layers_index = []
                    for i in range(depth):
                        kept_layers_index.append(math.floor(i / depth_mult))
                    kept_layers_index.append(n_layers)

                # adjust width
                width_idx = 0
                for width_mult in sorted(args.width_mult_list, reverse=True):
                    model.apply(lambda m: setattr(m, 'width_mult', width_mult))
                    # stage 1: width-adaptive
                    if args.training_phase == 'dynabertw':
                        if getattr(args, 'data_aug'):
                            loss, student_logit, student_reps, _, _ = model(
                                **inputs)

                            # distillation loss of logits
                            if args.output_mode == "classification":
                                logit_loss = soft_cross_entropy(
                                    student_logit, teacher_logit.detach())
                            elif args.output_mode == "regression":
                                logit_loss = 0

                            # distillation loss of hidden states
                            rep_loss = 0
                            for student_rep, teacher_rep in zip(
                                    student_reps, teacher_reps):
                                tmp_loss = loss_mse(student_rep,
                                                    teacher_rep.detach())
                                rep_loss += tmp_loss

                            loss = args.width_lambda1 * logit_loss + args.width_lambda2 * rep_loss
                        else:
                            loss = model(**inputs)[0]

                    # stage 2: width- and depth- adaptive
                    elif args.training_phase == 'dynabert':
                        loss, student_logit, student_reps, _, _ = model(
                            **inputs)

                        # distillation loss of logits
                        if args.output_mode == "classification":
                            logit_loss = soft_cross_entropy(
                                student_logit,
                                logits_max_all[width_idx].detach())
                        elif args.output_mode == "regression":
                            logit_loss = 0

                        # distillation loss of hidden states
                        rep_loss = 0
                        for student_rep, teacher_rep in zip(
                                student_reps,
                                list(hidden_max_all[width_idx][i]
                                     for i in kept_layers_index)):
                            tmp_loss = loss_mse(student_rep,
                                                teacher_rep.detach())
                            rep_loss += tmp_loss

                        loss = args.depth_lambda1 * logit_loss + args.depth_lambda2 * rep_loss  # ground+truth and distillation
                        width_idx += 1  # move to the next width

                    # stage 3: final finetuning
                    else:
                        loss = model(**inputs)[0]

                    print(loss)
                    if args.n_gpu > 1:
                        loss = loss.mean()
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()

            # clip the accumulated grad from all widths
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # evaluate
                if global_step > 0 and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.evaluate_during_training:
                        acc = []
                        if args.task_name == "mnli":  # for both MNLI-m and MNLI-mm
                            acc_both = []

                        # collect performance of all sub-networks
                        for depth_mult in sorted(args.depth_mult_list,
                                                 reverse=True):
                            model.apply(
                                lambda m: setattr(m, 'depth_mult', depth_mult))
                            for width_mult in sorted(args.width_mult_list,
                                                     reverse=True):
                                model.apply(lambda m: setattr(
                                    m, 'width_mult', width_mult))
                                results = evaluate(args, model, tokenizer)

                                logger.info(
                                    "********** start evaluate results *********"
                                )
                                logger.info("depth_mult: %s ", depth_mult)
                                logger.info("width_mult: %s ", width_mult)
                                logger.info("results: %s ", results)
                                logger.info(
                                    "********** end evaluate results *********"
                                )

                                acc.append(list(results.values())[0])
                                if args.task_name == "mnli":
                                    acc_both.append(
                                        list(results.values())[0:2])

                        # save model
                        if sum(acc) > current_best:
                            current_best = sum(acc)
                            if args.task_name == "mnli":
                                print("***best***{}\n".format(acc_both))
                                with open(output_eval_file, "a") as writer:
                                    writer.write("{}\n".format(acc_both))
                            else:
                                print("***best***{}\n".format(acc))
                                with open(output_eval_file, "a") as writer:
                                    writer.write("{}\n".format(acc))

                            logger.info("Saving model checkpoint to %s",
                                        args.output_dir)
                            model_to_save = model.module if hasattr(
                                model, 'module') else model
                            model_to_save.save_pretrained(args.output_dir)
                            torch.save(
                                args,
                                os.path.join(args.output_dir,
                                             'training_args.bin'))
                            model_to_save.config.to_json_file(
                                os.path.join(args.output_dir, CONFIG_NAME))
                            tokenizer.save_vocabulary(args.output_dir)

            if 0 < t_total < global_step:
                epoch_iterator.close()
                break

        if 0 < t_total < global_step:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step
示例#3
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3]
            }
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in [
                    'bert', 'xlnet'
                ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1
                ) % args.gradient_accumulation_steps == 0 and not args.tpu:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.tpu:
                args.xla_model.optimizer_step(optimizer, barrier=True)
                model.zero_grad()
                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
示例#4
0
    def train(self, train_dataset, output_dir, show_running_loss=True):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        tokenizer = self.tokenizer
        device = self.device
        model = self.model
        args = self.args

        tb_writer = SummaryWriter()
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args["train_batch_size"])

        t_total = len(train_dataloader) // args[
            "gradient_accumulation_steps"] * args["num_train_epochs"]

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [{
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args["weight_decay"]
        }, {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        }]

        warmup_steps = math.ceil(t_total * args["warmup_ratio"])
        args["warmup_steps"] = warmup_steps if args[
            "warmup_steps"] == 0 else args["warmup_steps"]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args["learning_rate"],
                          eps=args["adam_epsilon"])
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args["warmup_steps"],
                                         t_total=t_total)

        if args["fp16"]:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args["fp16_opt_level"])

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args["num_train_epochs"]), desc="Epoch")

        for _ in train_iterator:
            # epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Current iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)

                inputs = self._get_inputs_dict(batch)
                outputs = model(**inputs)
                # model outputs are always tuple in pytorch-transformers (see doc)
                loss = outputs[0]
                if show_running_loss:
                    print("\rRunning loss: %f" % loss, end="")

                if args["gradient_accumulation_steps"] > 1:
                    loss = loss / args["gradient_accumulation_steps"]

                if args["fp16"]:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args["max_grad_norm"])
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args["max_grad_norm"])

                tr_loss += loss.item()
                if (step + 1) % args["gradient_accumulation_steps"] == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args["logging_steps"] > 0 and global_step % args[
                            "logging_steps"] == 0:
                        # Log metrics
                        # Only evaluate when single GPU otherwise metrics may not average well
                        tb_writer.add_scalar("lr",
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                             args["logging_steps"],
                                             global_step)
                        logging_loss = tr_loss

                    if args["save_steps"] > 0 and global_step % args[
                            "save_steps"] == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)

                        # Take care of distributed/parallel training
                        model_to_save = model.module if hasattr(
                            model, "module") else model
                        model_to_save.save_pretrained(output_dir)

        return global_step, tr_loss / global_step